算法描述
-
随机梯度下降法(SGD)[1]对训练数据做随机采样,其更新公式如下:
+1=−∇()wt+1=wt−ηt∇fit(wt)
其中,it是第t轮随机采样的数据标号。具体算法如下列的伪代码所示:
我们知道,机器学习问题中的经验损失函数定义为所有样本数据对应的损失函数的平均值。而我们这里用有放回随机采用获得的数据来计算梯度,是对用全部数据来计算梯度的一个无偏估计,即∇()=∇()Eit∇itf(wt)=∇f(wt),注意此处()=1∑=1∇()f(wt)=1n∑i=1n∇fi(wt))。而由于每次更新只随机采一个样本,随机梯度中的计算量大大减小。因此,随机梯度可以作为梯度的自然替代,从而大大提高学习效率。不过正如我们前面所说,优化算法的复杂度不仅包括单位计算复杂度,还包括迭代次数复杂度(与收敛率有关)。天下没有免费的午餐,随机梯度下降单位计算复杂度降低,那么也相应地会付出迭代次数复杂度增大的代价。
考虑实际每次只采一个样本比较极端,常用的形式是随机梯度下降法的一个推广:小批量(mini-batch)随机梯度下降法。该方法可以看做是在随机优化算法和确定性优化算法之间寻找某种折中,每次采一个较小的样本集合∈{1,2,...}It∈{1,2,...n}(多于单样本,少于全样本),然后执行更新公式:
+1=−∇t()=−||∑∈∇()wt+1=wt−ηt∇fIt(wt)=wt−ηt|It|∑i∈It∇fi(wt)
西南地区IT社群(QQ)
- 云南
- 【昆明网页设计交流吧】243627302
- 【昆明nodejs交流吧】 243626749
- 【VUE】838405306
- 【云南程序员总群】343606807
- 【昆明UI设计】104031254
- 【云南软件外包】15547313
- 贵州
- 【PHP/java源码/站长交流群】55692114
- 四川
- 【成都Java/JavaWeb交流】86669225
- 【vaScript+PHP+MySql】116270060
- 【UI设计/设计交流学习群】135794928
- 重庆
- 【诺基亚 JAVA游戏博物馆】 559479780
- 【PHP,Java,Python,C++接单】 442103442
- 西藏