美文网首页
快速理解梯度下降法

快速理解梯度下降法

作者: 陈同庆 | 来源:发表于2018-09-19 15:01 被阅读0次

1. 梯度下降法的抽象理解

梯度下降法抽象图示

        梯度下降法(Gradient descent)是深度学习中的最基础的工具之一,它是一种寻找最低点的方法。可以抽象的理解为一个人要去寻找一个大水坑的最低点,最常见的思路就是,先沿着某个方向走一步,如果发现水变深了,说明方向是对的继续往这个方向走一步;反之,如果发现水变浅了,就反方向走一步,这样当发现走一步水位变化很小(梯度很小)时,就认为找到了最低点。这就是梯度下降法的抽象理解。

        梯度可以理解成斜率或者导数,这里面有一个重要的事情是:如果走的步子太大就容易扯到蛋,就是有可能一下子就调到了岸上再也找不到水坑的最低点了;而步子太小的话找到时间就会成倍增加,所以步长是很讲究的,但是可悲的是,没有什么公式可以告诉我们到底多少步长合适,需要程序猿根据经验一点点试出来。


2. 梯度下降法和深度学习的联系

        看吴恩达教授的课程时,会有小伙伴疑问,怎么上来就讲梯度下降法?其实深度学习无非就是利用一套自学习的机制,让计算机通过大量数据学习出一套公式来,使得新的数据来了可以套用这个公式得到预测结果。深度学习设计一层层的结构就是设计一个公式,比如最简单的y=wx+b,然后通过这些给定的X,Y数据集,一点点凑出来最接近的w和b。梯度下降法就是其中一种凑的方法。

        凑出这个公式的过程其实就是找到一组解(w, b)使得所有通过w, b计算的Y_hat(预测值用Y帽表示),与真实值Y之间最接近也就是误差最小。所以需要建立误差函数L,一般常用均方差来做。这里用了2m是为了计算方便,也可以用m。然后这里需要转换思路,要把w和b作为变量来找J的最小值,找的方法就是梯度下降法。

误差函数

3. 梯度下降法的实现

3.1 简单示例

        上面的J公式是关于w和b的二元函数,我们先选一个初始位置,比如w=0,b=0,计算该点的梯度,然后按这个梯度往前走一步。得到新的位置,然后再继续走,直到梯度很小。

        为了便于理解我们简化一下,假设J = 0.5w^2, 梯度函数是J'=w起步从w0=1的位置按步长0.4走。

step #1. 梯度为J'(1)=1, 新的w1 = w0 - 0.4*J'(1)=0.6 

step #2. 梯度为J'(0.6)=0.6, 新的w2 = w1- 0.4*J'(0.6)=0.36    

step #3. 梯度为J'(0.36)=0.36, 新的w3 = w2 - 0.4*J'(0.36)=0.216    

step #4. 梯度为J'(0.216)=0.216, 新的w4 = w3 - 0.4*J'(0.216)=0.1296 

        这样一步步走下去就能找到最低点0,当然实际上是这个例子中不可能找到精确的0位置,比较接近就可以了。

        二维的做法差不多,比如y=w^2+b^2,梯度函数为J'w= 2w, J'b = 2b, 从w0=1, b=2起步,步长为0.4。

step #1. 梯度为J'w(1)=2, J'b(2)=4, 新的w1 = w0 - 0.4*2=0.2, b1=b0-0.4*4=0.4

step #2. 梯度为J'w(0.2)=0.4, J'b(0.4)=0.8 新的w2 = w1- 0.4*0.4=0.16, b2=b1-0.4*0.8=0.08

        当两个方向的梯度都很小时就可以接近最低点了。

3.2 代码实现

        这里简单拟合一个线性的问题即Y=wX+b,X的维数可以任意,比如y=w0x0+w1x1+b。提前说明一下,我们这里写法按照吴恩达建议的方式,把维数按行排列,数据集的个数按照列排列和threano相同,和tesnorflow与此相反。其实无所谓,只要好好关注行列的个数就行了。

假设数据集大小m=20

X=[1,2,3,4,5,6,7...20]

Y=[3,4,6,7,9,12,16,16,19,21,22,25,27,30,30,32,36,37,40,41]

        python中的Numpy对矩阵计算很6,所以我们把数据集改成矩阵的形式

矩阵运算

即,Y=[[3,4......41]] (shape为1*m),X_=[[1,2,3..20],[1,1,1,...1]] (shape为2*m)。具体是实现如下图所示:

y=wx+b类型的梯度下降法设计  然后我们做两个测试,一个是X的维度是1,另一个X维度是2 测试结果,两个例子分别学习了5444和3793次

4. 梯度下降法的限制

        梯度下降法也存在着一些限制,尽管伟大的科学家们已经找了很多策略来减弱这些限制。

        首先,现实生活中的模型很少是这种线性的模型,也很少通过一层的结构能计算出理想的结果。比如,预测房价时,房价不可能为负;预测明天下雨的概率时结果只能为0-1之间的数等等,如果按线性模型求解,会得到很奇怪的结果。因此我们会通过不同的层来把问题“复杂化”,或者通过一定的函数限定中间的值(激活函数),比如Relu或sigmoid等。

        其次,实际模型中需要大量的数据才能得到相对稳定的结果,这会对梯度下降法中的每一次迭代都会造成庞大的计算,如果只选一个样本,又很可能得出的解不够准确。因此常用一批样本来更新梯度这种折中方案。

        再次,梯度下降法很有可能走入一个局部最优解,这个问题需要更高级的方法才可能解决。在后面的学习中,我们逐步展开,建立一些非线性的模型来求解实际问题。

相关文章

网友评论

      本文标题:快速理解梯度下降法

      本文链接:https://www.haomeiwen.com/subject/cdcinftx.html