博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
机器学习之梯度下降法
阅读量:5897 次
发布时间:2019-06-19

本文共 1974 字,大约阅读时间需要 6 分钟。

方向导数

如图,对于函数f(x,y),函数的增量与pp’两点距离之比在p’沿l趋于p时,则为函数在点p沿l方向的方向导数。记为fl=limρ0f(x+Δx,y+Δy)f(x,y)ρ,其中ρ=(Δx)2+(Δy)2。方向导数为函数f沿某方向的变化速率。

这里写图片描述

而且有如下定理:

fl=fxcosΘ+fysinΘ

梯度

梯度是一个向量,它的方向与取得最大方向导数的方向一致,梯度的模为方向导数的最大值。某点的梯度记为

gradf(x,y)=fxi+fyj

梯度的方向就是函数f在此点增长最快的方向,梯度的模为方向导数的最大值。

梯度下降

同样还是在线性回归中,假设函数为

h(x)=θ0+θ1x
那么损失函数为
J(θ)=12ni=1(h(xi)yi)2
要求最小损失,分别对θ0θ1求偏导,
J(θ)θj=θj12ni=1(h(xi)yi)2
=ni=1(h(xi)yi)θj(nj=0(θjxj)iyi)
=ni=1(h(xi)yi)xij
那么不断通过下面方式更新θ即可以逼近最低点。
θj:=θjαni=1(h(xi)yi)xij

其中α为learning rate,表现为下降的步伐。它不能太大也不能太小,太大会overshoot,太小则下降慢。通常可以尝试0.001、0.003、0.01、0.03、0.1、0.3。

这就好比站在一座山的某个位置上,往周围各个方向跨出相同步幅的一步,能够最快下降的方向就是梯度。这个方向是梯度的反方向。

这里写图片描述

这里写图片描述

这里写图片描述

另外,初始点的不同可能会出现局部最优解的情况,如下图:

这里写图片描述

伪代码

repeat until convergence{

θj:=θjαni=1(h(xi)yi)xij

for every j

}

随机梯度下降

样本太大时,每次更新都需要遍历整个样本,效率较低,这是就引入了随机梯度下降。

它可以每次只用一个样本来更新,免去了遍历整个样本。

伪代码如下

repeat until convergence{

i=random(1,n)

θj:=θjα(h(xi)yi)xij

for every j

}

另外与随机梯度下降类似的还有小批量梯度下降,它是折中的方式,取了所有样本中的一小部分。

代码实现

import numpy as npimport matplotlib.pyplot as pltlearning_rate = 0.0005theta = [0.7, 0.8, 0.9]loss = 100times = 100ite = 0expectation = 0.0001x_train = [[1, 2], [2, 1], [2, 3], [3, 5], [1, 3], [4, 2], [7, 3], [4, 5], [11, 3], [8, 7]]y_train = [7, 8, 10, 14, 8, 13, 20, 16, 28, 26]loss_array = np.zeros(times)def h(x):    return theta[0]*x[0]+theta[1]*x[1]+theta[2]while loss > expectation and ite < times:    loss = 0    sum_theta0 = 0    sum_theta1 = 0    sum_theta2 = 0    for x, y in zip(x_train, y_train):        sum_theta0 += (h(x) - y) * x[0]        sum_theta1 += (h(x) - y) * x[1]        sum_theta2 += (h(x) - y)    theta[0] -= learning_rate * sum_theta0    theta[1] -= learning_rate * sum_theta1    theta[2] -= learning_rate * sum_theta2    loss = 0    for x, y in zip(x_train, y_train):        loss += pow((h(x) - y), 2)    loss_array[ite] = loss    ite += 1plt.plot(loss_array)plt.show()

这里写图片描述

========广告时间========

鄙人的新书《Tomcat内核设计剖析》已经在京东销售了,有需要的朋友可以到 进行预定。感谢各位朋友。

=========================

欢迎关注:

这里写图片描述

你可能感兴趣的文章
[LeetCode] Find Anagram Mappings 寻找异构映射
查看>>
--Too small initial heap for new size specified
查看>>
黄聪:3分钟学会sessionStorage用法
查看>>
17monipdb根据IP获得区域
查看>>
Entity Framework 全面教程详解(转)
查看>>
模拟源码深入理解Vue数据驱动原理(2)
查看>>
Hibernate的配置中,c3p0连接池相关配置
查看>>
024-Spring Boot 应用的打包和部署
查看>>
linux的fork()函数具体解释 子进程复制父进程什么
查看>>
js 温故而知新 用typeof 来判断一个未定义的变量
查看>>
【Windows】免费图片提取文字的方法
查看>>
C# HttpWebResponse下载限速
查看>>
springboot redis多数据源设置
查看>>
AjaxToolKit之Rating控件的使用(http://www.soaspx.com/dotnet/ajax/ajaxtech/ajaxtech_20091021_1219.html)...
查看>>
Android 程式开发:(十四)显示图像 —— 14.1 Gallery和ImageView
查看>>
T-SQL性能调整——信息收集
查看>>
我眼中的领域驱动设计(转)
查看>>
[sh]. 点的含义
查看>>
【转】marquee标签简介
查看>>
未来十大最热门职业,可能消失的职业
查看>>