博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
TensorFlow+实战Google深度学习框架学习笔记(10)-----神经网络几种优化方法
阅读量:6332 次
发布时间:2019-06-22

本文共 3647 字,大约阅读时间需要 12 分钟。

神经网络的优化方法:

1、

2、

3、

4、

 

一、学习率的设置----指数衰减方法

通过指数衰减的方法设置GD的学习率。该方法可让模型在训练的前期快速接近较优解,又可保证模型在训练后期不会有太大的波动,从而更加接近局部最优。

学习率不能过大,可能让参数在极值两侧波动,不能过小,训练时间会过长。

TensorFlow提供的方法:tf.train.exponential_decay函数实现了指数衰减学习率。通过这个函数,可以先使用较大的学习率来快速得到一个比较优的解,然后随着迭代的继续逐步减小学习率,使得模型在训练后期更加稳定。exponential_decay函数会指数级地减少学习率:

 

 应用:

 

 二、过拟合问题:

过拟合:一个模型过于复杂时,只记得去学习训练数据中随机噪声而忘了去学习训练数据中通用的趋势。

避免过拟合的方法:正则化、数据增强(增加训练数据)、提早终止训练、参数共享、批标准化、集成方法、辅助分类节点

正则化:在损失函数中加入刻画模型复杂程度的指标。通过限制权重的大小,得到模型不能拟合训练数据中的随机噪音。

TensorFlow中的L2正则化的损失函数定义:

提高代码可读性:可以运用集合collection的思想,即将均方误差损失函数和正则化损失函数分开计算,然后放入loss【自己取名】的集合中,最终再从集合中取出求和。

三、滑动平均模型【参数的更新】

作用:

使用滑动平均模型在很多应用中都可以在一定程度提高模型在测试数据上的鲁棒性。

其实滑动平均模型,主要是通过控制衰减率来控制参数更新前后之间的差距,从而达到减缓参数的变化值(如,参数更新前是5,更新后的值是4,通过滑动平均模型之后,参数的值会在4到5之间)

如:本次结果=(1-a)本次采样值+a上次结果

目的:平滑、滤波,即使数据平滑变化,通过调整参数来调整变化的稳定性

为何在测试阶段使用:

对神经网络边的权重 weights 使用滑动平均,得到对应的影子变量 shadow_weights。在训练过程仍然使用原来不带滑动平均的权重 weights,不然无法得到 weights 下一步更新的值,又怎么求下一步 weights 的影子变量 shadow_weights。之后在测试过程中使用 shadow_weights 来代替 weights 作为神经网络边的权重,这样在测试数据上效果更好。因为 shadow_weights 的更新更加平滑,对于随机梯度下降而言,更平滑的更新说明不会偏离最优点很远;对于梯度下降 batch gradient decent,我感觉影子变量作用不大,因为梯度下降的方向已经是最优的了,loss 一定减小;对于 mini-batch gradient decent,可以尝试滑动平均,毕竟 mini-batch gradient decent 对参数的更新也存在抖动。

比如:在最后的1000次训练过程中,模型早已经训练完成,正处于抖动阶段,而滑动平均相当于将最后的1000次抖动进行了平均,这样得到的权重会更加robust。

应用:

tensorflow提供了tf.train.ExponentialMovingAverage(decay, num_updates=None, name='ExponentialMovingAverage'这个接口。

  • decay:一般设置为非常接近1的数,比如0.9999,
  • num_updates:为了在初期快速的更新,可以设置num_updates,如果num_updates = None ,那么decay将为一个固定的值。设置num_updates=global_step,那么dacay将会根据如下公式选择decay值:

                    min(decay, (1 + num_updates) / (10 + num_updates))

使用MovingAverage的三个要素。

  1. 指定decay参数创建实例: 
    • ema = tf.train.ExponentialMovingAverage(decay=0.9999)
  2. 对模型变量使用apply方法:
    • maintain_averages_op = ema.apply([var0, var1])
  3. 在优化方法使用梯度更新模型参数后执行MovingAverage:
    • with tf.control_dependencies([opt_op]):
          training_op = tf.group(maintain_averages_op)
      其中,tf.group将传入的操作捆绑成一个操作。

以下的代码有以下几点要注意:

1) 定义好ema之后,分两步,一步ema.apply,一步ema.average
2) 先apply,后average
3) apply里放的是一个list
4) Variable通过tf.assign改动

原理:

apply方法会为每个变量(也可以指定特定变量)创建各自的shadow variable, 即影子变量。之所以叫影子变量,是因为它会全程跟随训练中的模型变量。影子变量会被初始化为模型变量的值,然后,每训练一个step,就更新一次。更新的方式为:

应用例子:

import tensorflow as tfv1 = tf.Variable(0, dtype=tf.float32) #初始化v1变量step = tf.Variable(0, trainable=False) #初始化step为0ema = tf.train.ExponentialMovingAverage(0.99, step) #定义平滑类,设置参数以及stepmaintain_averages_op = ema.apply([v1]) #定义更新变量平均操作with tf.Session() as sess:    # 初始化    init_op = tf.global_variables_initializer()    sess.run(init_op)    print sess.run([v1, ema.average(v1)])    # 更新变量v1的取值    sess.run(tf.assign(v1, 5))    sess.run(maintain_averages_op)    print sess.run([v1, ema.average(v1)])     # 更新step和v1的取值    sess.run(tf.assign(step, 10000))      sess.run(tf.assign(v1, 10))    sess.run(maintain_averages_op)    print sess.run([v1, ema.average(v1)])           # 更新一次v1的滑动平均值    sess.run(maintain_averages_op)    print sess.run([v1, ema.average(v1)])

书中详细解释

四、批标准化:(batch normalization,BN)

https://www.cnblogs.com/zyly/p/8996070.html

(1)BN的来源:

批标准化(BN)是为了克服神经网络层数加深导致难以训练而产生的。

神经网络层数加深,收敛速度会很慢,常常导致梯度弥散或者梯度爆炸问题。

统计学中有一个ICS(Internal Covariate Shift)理论,这是一个经典假设:源域和目标域的数据分布是一致的。即训练数据和测试数据是满足相同分布的,这是通过训练数据获得的模型能够在测试集获得好效果的一个基本保障。

【Covariate Shift:指训练集的样本数据和目标样本集分布不一致时,训练得到的模型无法很好地泛化,它是分布不一致假设之下的一个分支问题,即源域和目标域的条件概率是一致的,但其边缘概率不同。的确,神经网络的各层输出,输出分布和对应的输入分布不同,且差异随着网络深度增大而加大,但每一层所指向的样本标记(label)是不变的。】

解决思路:根据训练样本和目标样本的比例对训练样本做一个矫正。通过BN来规范化某些层或所有层的输入,来固定每层输入的均值与方差。

(2)BN方法:

BN一般用在激活函数之前,对x=Wu+b做规范化,使结果的均值为0,方差为1.让每一层的输入有一个稳定的分布会有利于网络的训练。

(3)BN优点:

BN通过规范化让激活函数分布在线性区间,结果就是加大了梯度,优点如下:

  1. 加大探索的步长,加快收敛速度
  2. 容易跳出局部最小值
  3. 破坏原来的数据分布,一定程度上缓解过拟合

(4)代码示例:

 

转载于:https://www.cnblogs.com/Lee-yl/p/10029246.html

你可能感兴趣的文章
smb服务器
查看>>
1,支付宝开发 - 使用OpenSSL 将RSA私钥 转码为pkcs8格式
查看>>
初识PHP
查看>>
我的友情链接
查看>>
Web浏览器中的JavaScript(一)
查看>>
openstack I版的搭建十--基于NFS的云硬盘
查看>>
控件自定义属性介绍 http://www.jb51.net/article/32172.htm
查看>>
我的Git忽略文件
查看>>
大型网站技术架构(五)网站高可用架构
查看>>
HTML5 Session Local Storage
查看>>
青春路上,岁月如烟
查看>>
Linux实用工具
查看>>
将centos7打造成桌面系统
查看>>
puppet基本配置
查看>>
Spring常用注解
查看>>
linux登录日志
查看>>
yum不能升级
查看>>
又见那风鸣
查看>>
网络公开课《八一建军节引发的Oracle数据库思考:虚拟私有数据库》
查看>>
Ubuntu下完全卸载Nginx
查看>>