python神經(jīng)網(wǎng)絡Batch?Normalization底層原理詳解
什么是Batch Normalization
Batch Normalization是神經(jīng)網(wǎng)絡中常用的層,解決了很多深度學習中遇到的問題,我們一起來學習一哈。
Batch Normalization是由google提出的一種訓練優(yōu)化方法。參考論文:Batch Normalization Accelerating Deep Network Training by Reducing Internal Covariate Shift。
Batch Normalization的名稱為批標準化,它的功能是使得輸入的X數(shù)據(jù)符合同一分布,從而使得訓練更加簡單、快速。
一般來講,Batch Normalization會放在卷積層后面,即卷積 + 標準化 + 激活函數(shù)。
其計算過程可以簡單歸納為以下3點:
1、求數(shù)據(jù)均值。
2、求數(shù)據(jù)方差。
3、數(shù)據(jù)進行標準化。
Batch Normalization的計算公式
Batch Normalization的計算公式主要看如下這幅圖:

這個公式一定要靜下心來看,整個公式可以分為四行:
1、對輸入進來的數(shù)據(jù)X進行均值求取。
2、利用輸入進來的數(shù)據(jù)X減去第一步得到的均值,然后求平方和,獲得輸入X的方差。
3、利用輸入X、第一步獲得的均值和第二步獲得的方差對數(shù)據(jù)進行歸一化,即利用X減去均值,然后除上方差開根號。方差開根號前需要添加上一個極小值。
4、引入γ和β變量,對輸入進來的數(shù)據(jù)進行縮放和平移。利用γ和β兩個參數(shù),讓我們的網(wǎng)絡可以學習恢復出原始網(wǎng)絡所要學習的特征分布。
前三步是標準化工序,最后一步是反標準化工序。
Bn層的好處
1、加速網(wǎng)絡的收斂速度。在神經(jīng)網(wǎng)絡中,存在內(nèi)部協(xié)變量偏移的現(xiàn)象,如果每層的數(shù)據(jù)分布不同的話,會導致非常難收斂,如果把每層的數(shù)據(jù)都在轉換在均值為零,方差為1的狀態(tài)下,這樣每層數(shù)據(jù)的分布都是一樣的,訓練會比較容易收斂。
2、防止梯度爆炸和梯度消失。對于梯度消失而言,以Sigmoid函數(shù)為例,它會使得輸出在[0,1]之間,實際上當x到了一定的大小,sigmoid激活函數(shù)的梯度值就變得非常小,不易訓練。歸一化數(shù)據(jù)的話,就能讓梯度維持在比較大的值和變化率;
對于梯度爆炸而言,在方向傳播的過程中,每一層的梯度都是由上一層的梯度乘以本層的數(shù)據(jù)得到。如果歸一化的話,數(shù)據(jù)均值都在0附近,很顯然,每一層的梯度不會產(chǎn)生爆炸的情況。
3、防止過擬合。在網(wǎng)絡的訓練中,Bn使得一個minibatch中所有樣本都被關聯(lián)在了一起,因此網(wǎng)絡不會從某一個訓練樣本中生成確定的結果,這樣就會使得整個網(wǎng)絡不會朝這一個方向使勁學習。一定程度上避免了過擬合。
為什么要引入γ和β變量
Bn層在進行前三步后,會引入γ和β變量,對輸入進來的數(shù)據(jù)進行縮放和平移。
γ和β變量是網(wǎng)絡參數(shù),是可學習的。
引入γ和β變量進行縮放平移可以使得神經(jīng)網(wǎng)絡有自適應的能力,在標準化效果好時,盡量不抵消標準化的作用,而在標準化效果不好時,盡量去抵消一部分標準化的效果,相當于讓神經(jīng)網(wǎng)絡學會要不要標準化,如何折中選擇。
Bn層的代碼實現(xiàn)
Pytorch代碼看起來比較簡單,而且和上面的公式非常符合,可以學習一下,參考自
http://www.dbjr.com.cn/article/247197.htm
def batch_norm(is_training, x, gamma, beta, moving_mean, moving_var, eps=1e-5, momentum=0.9):
if not is_training:
x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)
else:
mean = x.mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
var = ((x - mean) ** 2).mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
x_hat = (x - mean) / torch.sqrt(var + eps)
moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
moving_var = momentum * moving_var + (1.0 - momentum) * var
Y = gamma * x_hat + beta
return Y, moving_mean, moving_var
class BatchNorm2d(nn.Module):
def __init__(self, num_features):
super(BatchNorm2d, self).__init__()
shape = (1, num_features, 1, 1)
self.gamma = nn.Parameter(torch.ones(shape))
self.beta = nn.Parameter(torch.zeros(shape))
self.register_buffer('moving_mean', torch.zeros(shape))
self.register_buffer('moving_var', torch.ones(shape))
def forward(self, x):
if self.moving_mean.device != x.device:
self.moving_mean = self.moving_mean.to(x.device)
self.moving_var = self.moving_var.to(x.device)
y, self.moving_mean, self.moving_var = batch_norm(self.training,
x, self.gamma, self.beta, self.moving_mean,
self.moving_var, eps=1e-5, momentum=0.9)
return y
以上就是python神經(jīng)網(wǎng)絡Batch Normalization底層原理詳解的詳細內(nèi)容,更多關于Batch Normalization底層原理的資料請關注腳本之家其它相關文章!
相關文章
Python爬蟲使用bs4方法實現(xiàn)數(shù)據(jù)解析
這篇文章主要介紹了Python爬蟲使用bs4方法實現(xiàn)數(shù)據(jù)解析,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友可以參考下2020-08-08
PyTorch中的神經(jīng)網(wǎng)絡 Mnist 分類任務
這篇文章主要介紹了PyTorch中的神經(jīng)網(wǎng)絡 Mnist 分類任務,在本次的分類任務當中,我們使用的數(shù)據(jù)集是 Mnist 數(shù)據(jù)集,這個數(shù)據(jù)集大家都比較熟悉,需要的朋友可以參考下2023-03-03
一文了解Python中NotImplementedError的作用
NotImplementedError是一個內(nèi)置異常類,本文主要介紹了一文了解Python中NotImplementedError的作用,具有一定的參考價值,感興趣的可以了解一下2024-03-03
Python虛擬環(huán)境virtualenv創(chuàng)建及使用過程圖解
這篇文章主要介紹了Python虛擬環(huán)境virtualenv創(chuàng)建及使用過程圖解,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友可以參考下2020-12-12

