• 上一节中,对于权重参数w的更新策略,我们采用的是穷举法。
  • 本节学习使用梯度下降算法自动更新权重,目标是求出使cost最小的权重组合

优化问题:求解使得目标函数(cost)最小的权重组合的任务(即第二行式子所示) 

梯度下降(GD)、随机梯度下降(SGD),小批量随机梯度下降(mini-batch-SGD)

  •  梯度下降:用所有样本的平均损失来作为求偏导的依据
  • 随机梯度下降:用单个样本的loss作为求偏导的依据
  • 小批量随机梯度下降:梯度下降运算性能好(可以利用并行性),随机梯度下降计算性能好(有可能脱离鞍点),为了中和两种效果,若干个样本为一组,对每组进行随机梯度下降

 注:batch指的是所有样本,mini-batch指的才是上图中蓝框所示的小批量样本

 

 梯度下降实现:

import numpy as np
import matplotlib.pyplot as plt

# define dataset
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

# initial guess of weight
w = 1.0

# define the model (前馈函数)
def forward(x):
    return x * w

# define cost function
def cost(xs, ys):
    cost = 0
    for x, y in zip(xs, ys):
        y_pred = forward(x)
        loss = (y_pred - y) ** 2
        cost += loss
    return cost / len(xs)

def gradient(xs, ys):
    grad = 0
    for x, y in zip(xs, ys):
        grad += 2 * x * (x * w - y)
    return grad / len(xs)

cost_list = []
epoch_list = []
print("Predict (before training)", 4, forward(4))
for epoch in range(100):
    epoch_list.append(epoch)
    cost_val = cost(x_data, y_data)
    cost_list.append(cost_val)
    gradient_val = gradient(x_data, y_data)
    w -= 0.01 * gradient_val
    print("Epoch:", epoch, "w=", w, "loss=", cost_val)

print("Prediction (after training)", 4, forward(4))

plt.plot(epoch_list, cost_list)
plt.xlabel("epoch")
plt.ylabel("cost")
plt.show()

随机梯度下降实现:

import numpy as np
import matplotlib.pyplot as plt

x_data = [1.0, 2.0, 3.0]
y_data = [4.0, 5.0, 6.0]

w = 1.0

def forward(x):
    return x * w

def loss(x, y):
    y_pred = forward(x)
    return (y_pred - y) ** 2

def gradient(x, y):
    y_pred = forward(x)
    return 2 * x * (y_pred - y)

loss_list = []
epoch_list = []
print("Predict (before training)", 4, forward(4))
for epoch in range(100):
    epoch_list.append(epoch)
    for x, y in zip(x_data, y_data):
        loss_val = loss(x, y)
        grad = gradient(x, y)
        w -= 0.01 * grad
        print("\tgrad:",x, y, grad)

    print("Progress:",epoch, "w=", w, "Loss = ", loss_val)
    loss_list.append(loss_val)
print("Predict (after training)", 4, forward(4))

plt.plot(epoch_list, loss_list)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

Logo

更多推荐