對比分析BN和dropout在預(yù)測和訓(xùn)練時區(qū)別
Batch Normalization和Dropout是深度學(xué)習(xí)模型中常用的結(jié)構(gòu)。
但BN和dropout在訓(xùn)練和測試時使用卻不相同。
Batch Normalization
BN在訓(xùn)練時是在每個batch上計算均值和方差來進(jìn)行歸一化,每個batch的樣本量都不大,所以每次計算出來的均值和方差就存在差異。預(yù)測時一般傳入一個樣本,所以不存在歸一化,其次哪怕是預(yù)測一個batch,但batch計算出來的均值和方差是偏離總體樣本的,所以通常是通過滑動平均結(jié)合訓(xùn)練時所有batch的均值和方差來得到一個總體均值和方差。
以tensorflow代碼實現(xiàn)為例:
def bn_layer(self, inputs, training, name='bn', moving_decay=0.9, eps=1e-5):
# 獲取輸入維度并判斷是否匹配卷積層(4)或者全連接層(2)
shape = inputs.shape
param_shape = shape[-1]
with tf.variable_scope(name):
# 聲明BN中唯一需要學(xué)習(xí)的兩個參數(shù),y=gamma*x+beta
gamma = tf.get_variable('gamma', param_shape, initializer=tf.constant_initializer(1))
beta = tf.get_variable('beat', param_shape, initializer=tf.constant_initializer(0))
# 計算當(dāng)前整個batch的均值與方差
axes = list(range(len(shape)-1))
batch_mean, batch_var = tf.nn.moments(inputs , axes, name='moments')
# 采用滑動平均更新均值與方差
ema = tf.train.ExponentialMovingAverage(moving_decay, name="ema")
def mean_var_with_update():
ema_apply_op = ema.apply([batch_mean, batch_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean), tf.identity(batch_var)
# 訓(xùn)練時,更新均值與方差,測試時使用之前最后一次保存的均值與方差
mean, var = tf.cond(tf.equal(training,True), mean_var_with_update,
lambda:(ema.average(batch_mean), ema.average(batch_var)))
# 最后執(zhí)行batch normalization
return tf.nn.batch_normalization(inputs ,mean, var, beta, gamma, eps)training參數(shù)可以通過tf.placeholder傳入,這樣就可以控制訓(xùn)練和預(yù)測時training的值。
self.training = tf.placeholder(tf.bool, name="training")
Dropout
Dropout在訓(xùn)練時會隨機丟棄一些神經(jīng)元,這樣會導(dǎo)致輸出的結(jié)果變小。而預(yù)測時往往關(guān)閉dropout,保證預(yù)測結(jié)果的一致性(不關(guān)閉dropout可能同一個輸入會得到不同的輸出,不過輸出會服從某一分布。另外有些情況下可以不關(guān)閉dropout,比如文本生成下,不關(guān)閉會增大輸出的多樣性)。
為了對齊Dropout訓(xùn)練和預(yù)測的結(jié)果,通常有兩種做法,假設(shè)dropout rate = 0.2。一種是訓(xùn)練時不做處理,預(yù)測時輸出乘以(1 - dropout rate)。另一種是訓(xùn)練時留下的神經(jīng)元除以(1 - dropout rate),預(yù)測時不做處理。以tensorflow為例。
x = tf.nn.dropout(x, self.keep_prob)
self.keep_prob = tf.placeholder(tf.float32, name="keep_prob")
tf.nn.dropout就是采用了第二種做法,訓(xùn)練時除以(1 - dropout rate),源碼如下:
binary_tensor = math_ops.floor(random_tensor) ret = math_ops.div(x, keep_prob) * binary_tensor if not context.executing_eagerly(): ret.set_shape(x.get_shape()) return ret
binary_tensor就是一個mask tensor,即里面的值由0或1組成。keep_prob = 1 - dropout rate。
以上就是對比分析BN和dropout在預(yù)測和訓(xùn)練時區(qū)別的詳細(xì)內(nèi)容,更多關(guān)于BN與dropout預(yù)測訓(xùn)練對比的資料請關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
用Python編寫生成樹狀結(jié)構(gòu)的文件目錄的腳本的教程
這篇文章主要介紹了用Python編寫生成樹狀結(jié)構(gòu)的文件目錄的腳本的教程,是一個利用os模塊下各函數(shù)的簡單實現(xiàn),需要的朋友可以參考下2015-05-05
python中requests使用代理proxies方法介紹
這篇文章主要介紹了python中requests使用代理proxies方法介紹,具有一定參考價值,需要的朋友可以了解下。2017-10-10
python如何處理matlab的mat數(shù)據(jù)
這篇文章主要介紹了python如何處理matlab的mat數(shù)據(jù),具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2022-05-05
Matplotlib實戰(zhàn)之平行坐標(biāo)系繪制詳解
平行坐標(biāo)系是一種統(tǒng)計圖表,它包含多個垂直平行的坐標(biāo)軸,每個軸表示一個字段,并用刻度標(biāo)明范圍,下面我們就來看看如何繪制平行坐標(biāo)系吧2023-08-08
詳解pyqt5的UI中嵌入matplotlib圖形并實時刷新(挖坑和填坑)
這篇文章主要介紹了詳解pyqt5的UI中嵌入matplotlib圖形并實時刷新(挖坑和填坑),文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-08-08
python 找出list中最大或者最小幾個數(shù)的索引方法
今天小編就為大家分享一篇python 找出list中最大或者最小幾個數(shù)的索引方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-10-10

