梯度下降是机器学习中的常用算法,通过不断迭代计算函数的梯度,判断该点的某一方向和目标之间的距离,最终求得最小的损失函数和相关参数,为建立线性模型提供支持。
梯度下降是一种广泛用于求解线性和非线性模型最优解的迭代算法,它的中心思想在于通过迭代次数的递增,调整使得损失函数最小化的权重。
它的作用是用于优化一个目标函数,如果要最小化一个损失函数,使用的就是梯度下降法,梯度下降所要求解的x是函数极小值点。
如下图所示,找到J(θ)最小值
其实,J(θ)的真正图形是类似下面这样的,因为其是一个凸函数,只有一个全局最优解,所以不必担心像上图一样找到局部最优解
不妨举一个简单函数的例子:
$$
f(x)=x^2*sin(x)
$$
$$
f’(x)=2xsin(x)+x^2*cos(x)
$$
在逼近极值时,下一个自变量x值是通过
$$
x - lr(learningrate)*f’(x)
$$
来确定的。
事实上,不同的算法有不同的lr,一个适合的lr可以使抖动次数变少,更快得到极值点的位置,一般取0.001或0.005,对于新奇的算法,需要调试探索其合适的lr。
已知一批数据符合
$$
y=w*x+b+c
$$
高斯噪声c(eps)数据符合 **N(0.01,1)**的分布,
1.567=w*1+b+eps
3.043=w*2+b+eps
4.519=w*3+b+eps
……
在实际问题给出的数据中,我们得到近似解即可
事实上很多情况下我们并不知道函数的基本模型是什么,而是通过观察数据的分布去假设并验证是否拟合
为了使用梯度算法来求w和b的近似值,我们建立变量loss,使loss最小的w和b的值即是近似解
$$
loss=\sum_{i=1}^k(w*x_i+b-y)^2
$$
import numpy as np
# y = wx + b
def compute_error_for_line_given_points(b, w, points):#统计总误差
totalError = 0
for i in range(0, len(points)):
x = points[i, 0]
y = points[i, 1]
totalError += (y - (w * x + b)) ** 2
return totalError / float(len(points))
def step_gradient(b_current, w_current, points, learningRate):#求loss对b和w的梯度,分别将b和w看作自变量求导,new_b=b - lr(learningrate)*f'(b) 更新b与w的值
b_gradient = 0
w_gradient = 0
N = float(len(points))
for i in range(0, len(points)):
x = points[i, 0]
y = points[i, 1]
b_gradient += -(2/N) * (y - ((w_current * x) + b_current))
w_gradient += -(2/N) * x * (y - ((w_current * x) + b_current))
new_b = b_current - (learningRate * b_gradient)
new_w = w_current - (learningRate * w_gradient)
return [new_b, new_w]
def gradient_descent_runner(points, starting_b, starting_w, learning_rate, num_iterations):#设定迭代次数
b = starting_b
w = starting_w
for i in range(num_iterations):
b, w = step_gradient(b, w, np.array(points), learning_rate)
return [b, w]
def run():#输出总误差
points = np.genfromtxt("data.csv", delimiter=",")
learning_rate = 0.0001
initial_b = 0 # initial y-intercept guess
initial_w = 0 # initial slope guess
num_iterations = 1000
print("Starting gradient descent at b = {0}, m = {1}, error = {2}"
.format(initial_b, initial_w,
compute_error_for_line_given_points(initial_b, initial_w, points))
)
print("Running...")
[b, w] = gradient_descent_runner(points, initial_b, initial_w, learning_rate, num_iterations)
print("After {0} iterations b = {1}, m = {2}, error = {3}".
format(num_iterations, b, w,
compute_error_for_line_given_points(b, w, points))
)
if __name__ == '__main__':
run()
“data.csv”文件的内容如下(逗号已省略)
32.50234527 31.70700585
53.42680403 68.77759598
61.53035803 62.5623823
47.47563963 71.54663223
59.81320787 87.23092513
55.14218841 78.21151827
52.21179669 79.64197305
39.29956669 59.17148932
48.10504169 75.3312423
52.55001444 71.30087989
45.41973014 55.16567715
54.35163488 82.47884676
44.1640495 62.00892325
58.16847072 75.39287043
56.72720806 81.43619216
48.95588857 60.72360244
44.68719623 82.89250373
60.29732685 97.37989686
45.61864377 48.84715332
38.81681754 56.87721319
66.18981661 83.87856466
65.41605175 118.5912173
47.48120861 57.25181946
41.57564262 51.39174408
51.84518691 75.38065167
59.37082201 74.76556403
57.31000344 95.45505292
63.61556125 95.22936602
46.73761941 79.05240617
50.55676015 83.43207142
52.22399609 63.35879032
35.56783005 41.4128853
42.43647694 76.61734128
58.16454011 96.76956643
57.50444762 74.08413012
45.44053073 66.58814441
61.89622268 77.76848242
33.09383174 50.71958891
36.43600951 62.12457082
37.67565486 60.81024665
44.55560838 52.68298337
43.31828263 58.56982472
50.07314563 82.90598149
43.87061265 61.4247098
62.99748075 115.2441528
32.66904376 45.57058882
40.16689901 54.0840548
53.57507753 87.99445276
33.86421497 52.72549438
64.70713867 93.57611869
38.11982403 80.16627545
44.50253806 65.10171157
40.59953838 65.56230126
41.72067636 65.28088692
51.08863468 73.43464155
55.0780959 71.13972786
41.37772653 79.10282968
62.49469743 86.52053844
49.20388754 84.74269781
41.10268519 59.35885025
41.18201611 61.68403752
50.18638949 69.84760416
52.37844622 86.09829121
50.13548549 59.10883927
33.64470601 69.89968164
39.55790122 44.86249071
56.13038882 85.49806778
57.36205213 95.53668685
60.26921439 70.25193442
35.67809389 52.72173496
31.588117 50.39267014
53.66093226 63.64239878
46.68222865 72.24725107
43.10782022 57.81251298
70.34607562 104.2571016
44.49285588 86.64202032
57.5045333 91.486778
36.93007661 55.23166089
55.80573336 79.55043668
38.95476907 44.84712424
56.9012147 80.20752314
56.86890066 83.14274979
34.3331247 55.72348926
59.04974121 77.63418251
57.78822399 99.05141484
54.28232871 79.12064627
51.0887199 69.58889785
50.28283635 69.51050331
44.21174175 73.68756432
38.00548801 61.36690454
32.94047994 67.17065577
53.69163957 85.66820315
68.76573427 114.8538712
46.2309665 90.12357207
68.31936082 97.91982104
50.03017434 81.53699078
49.23976534 72.11183247
50.03957594 85.23200734
48.14985889 66.22495789
25.12848465 53.45439421