對(duì)比分析BN和dropout在預(yù)測(cè)和訓(xùn)練時(shí)區(qū)別
Batch Normalization和Dropout是深度學(xué)習(xí)模型中常用的結(jié)構(gòu)。
但BN和dropout在訓(xùn)練和測(cè)試時(shí)使用卻不相同。
Batch Normalization
BN在訓(xùn)練時(shí)是在每個(gè)batch上計(jì)算均值和方差來(lái)進(jìn)行歸一化,每個(gè)batch的樣本量都不大,所以每次計(jì)算出來(lái)的均值和方差就存在差異。預(yù)測(cè)時(shí)一般傳入一個(gè)樣本,所以不存在歸一化,其次哪怕是預(yù)測(cè)一個(gè)batch,但batch計(jì)算出來(lái)的均值和方差是偏離總體樣本的,所以通常是通過(guò)滑動(dòng)平均結(jié)合訓(xùn)練時(shí)所有batch的均值和方差來(lái)得到一個(gè)總體均值和方差。
以tensorflow代碼實(shí)現(xiàn)為例:
def bn_layer(self, inputs, training, name='bn', moving_decay=0.9, eps=1e-5): # 獲取輸入維度并判斷是否匹配卷積層(4)或者全連接層(2) shape = inputs.shape param_shape = shape[-1] with tf.variable_scope(name): # 聲明BN中唯一需要學(xué)習(xí)的兩個(gè)參數(shù),y=gamma*x+beta gamma = tf.get_variable('gamma', param_shape, initializer=tf.constant_initializer(1)) beta = tf.get_variable('beat', param_shape, initializer=tf.constant_initializer(0)) # 計(jì)算當(dāng)前整個(gè)batch的均值與方差 axes = list(range(len(shape)-1)) batch_mean, batch_var = tf.nn.moments(inputs , axes, name='moments') # 采用滑動(dòng)平均更新均值與方差 ema = tf.train.ExponentialMovingAverage(moving_decay, name="ema") def mean_var_with_update(): ema_apply_op = ema.apply([batch_mean, batch_var]) with tf.control_dependencies([ema_apply_op]): return tf.identity(batch_mean), tf.identity(batch_var) # 訓(xùn)練時(shí),更新均值與方差,測(cè)試時(shí)使用之前最后一次保存的均值與方差 mean, var = tf.cond(tf.equal(training,True), mean_var_with_update, lambda:(ema.average(batch_mean), ema.average(batch_var))) # 最后執(zhí)行batch normalization return tf.nn.batch_normalization(inputs ,mean, var, beta, gamma, eps)
training參數(shù)可以通過(guò)tf.placeholder傳入,這樣就可以控制訓(xùn)練和預(yù)測(cè)時(shí)training的值。
self.training = tf.placeholder(tf.bool, name="training")
Dropout
Dropout在訓(xùn)練時(shí)會(huì)隨機(jī)丟棄一些神經(jīng)元,這樣會(huì)導(dǎo)致輸出的結(jié)果變小。而預(yù)測(cè)時(shí)往往關(guān)閉dropout,保證預(yù)測(cè)結(jié)果的一致性(不關(guān)閉dropout可能同一個(gè)輸入會(huì)得到不同的輸出,不過(guò)輸出會(huì)服從某一分布。另外有些情況下可以不關(guān)閉dropout,比如文本生成下,不關(guān)閉會(huì)增大輸出的多樣性)。
為了對(duì)齊Dropout訓(xùn)練和預(yù)測(cè)的結(jié)果,通常有兩種做法,假設(shè)dropout rate = 0.2。一種是訓(xùn)練時(shí)不做處理,預(yù)測(cè)時(shí)輸出乘以(1 - dropout rate)。另一種是訓(xùn)練時(shí)留下的神經(jīng)元除以(1 - dropout rate),預(yù)測(cè)時(shí)不做處理。以tensorflow為例。
x = tf.nn.dropout(x, self.keep_prob)
self.keep_prob = tf.placeholder(tf.float32, name="keep_prob")
tf.nn.dropout就是采用了第二種做法,訓(xùn)練時(shí)除以(1 - dropout rate),源碼如下:
binary_tensor = math_ops.floor(random_tensor) ret = math_ops.div(x, keep_prob) * binary_tensor if not context.executing_eagerly(): ret.set_shape(x.get_shape()) return ret
binary_tensor就是一個(gè)mask tensor,即里面的值由0或1組成。keep_prob = 1 - dropout rate。
以上就是對(duì)比分析BN和dropout在預(yù)測(cè)和訓(xùn)練時(shí)區(qū)別的詳細(xì)內(nèi)容,更多關(guān)于BN與dropout預(yù)測(cè)訓(xùn)練對(duì)比的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
用Python編寫生成樹狀結(jié)構(gòu)的文件目錄的腳本的教程
這篇文章主要介紹了用Python編寫生成樹狀結(jié)構(gòu)的文件目錄的腳本的教程,是一個(gè)利用os模塊下各函數(shù)的簡(jiǎn)單實(shí)現(xiàn),需要的朋友可以參考下2015-05-05python中requests使用代理proxies方法介紹
這篇文章主要介紹了python中requests使用代理proxies方法介紹,具有一定參考價(jià)值,需要的朋友可以了解下。2017-10-10python如何處理matlab的mat數(shù)據(jù)
這篇文章主要介紹了python如何處理matlab的mat數(shù)據(jù),具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2022-05-05Python?Fire實(shí)現(xiàn)自動(dòng)生成命令行接口
命令行程序是平時(shí)寫一些小工具時(shí)最常用的方式,隨著命令行程序功能的豐富,也就是參數(shù)多了以后,解析和管理參數(shù)之間的關(guān)系會(huì)變得越來(lái)越繁重,而本次介紹的?Fire?庫(kù)正好可以解決這個(gè)問(wèn)題,下面我們就來(lái)看看具體實(shí)現(xiàn)方法吧2023-09-09Matplotlib實(shí)戰(zhàn)之平行坐標(biāo)系繪制詳解
平行坐標(biāo)系是一種統(tǒng)計(jì)圖表,它包含多個(gè)垂直平行的坐標(biāo)軸,每個(gè)軸表示一個(gè)字段,并用刻度標(biāo)明范圍,下面我們就來(lái)看看如何繪制平行坐標(biāo)系吧2023-08-08詳解pyqt5的UI中嵌入matplotlib圖形并實(shí)時(shí)刷新(挖坑和填坑)
這篇文章主要介紹了詳解pyqt5的UI中嵌入matplotlib圖形并實(shí)時(shí)刷新(挖坑和填坑),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-08-08python 找出list中最大或者最小幾個(gè)數(shù)的索引方法
今天小編就為大家分享一篇python 找出list中最大或者最小幾個(gè)數(shù)的索引方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-10-10