python神經網絡pytorch中BN運算操作自實現
更新時間:2022年05月07日 15:46:12 作者:皮特潘
這篇文章主要為大家介紹了python神經網絡pytorch中BN運算操作自實現示例,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪
BN 想必大家都很熟悉,來自論文:
《Batch Normalization Accelerating Deep Network Training by Reducing Internal Covariate Shift》
也是面試??疾斓膬热荩m然一行代碼就能搞定,但是還是很有必要用代碼自己實現一下,也可以加深一下對其內部機制的理解。
通用公式:
直奔代碼:
首先是定義一個函數,實現BN的運算操作:
def batch_norm(is_training, x, gamma, beta, moving_mean, moving_var, eps=1e-5, momentum=0.9): # 判斷當前模式是訓練模式還是預測模式 if not is_training: # 如果是在預測模式下,直接使用傳入的移動平均所得的均值和方差 x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps) else: if len(x.shape) == 2: # 使用全連接層的情況,計算特征維上的均值和方差 mean = x.mean(dim=0) var = ((x - mean) ** 2).mean(dim=0) else: # 使用二維卷積層的情況,計算通道維上(axis=1)的均值和方差。這里我們需要保持 # x的形狀以便后面可以做廣播運算 mean = x.mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True) var = ((x - mean) ** 2).mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True) # 訓練模式下用當前的均值和方差做標準化 x_hat = (x - mean) / torch.sqrt(var + eps) # 更新移動平均的均值和方差 moving_mean = momentum * moving_mean + (1.0 - momentum) * mean moving_var = momentum * moving_var + (1.0 - momentum) * var x = gamma * x_hat + beta # 拉伸和偏移 return Y, moving_mean, moving_var
然后再定義一個類,就是常用的集成nn.Module的類了。
這里說明三點:
- 在卷積上的BN實現,是在 Batch,W,H上進行歸一化操作的,也就是BWH拉成一個維度求均值和方差,均值和方差以及beta和gamma的尺寸為channel。當然其他各種N,包括IN,LN,GN都是包含WH維度的。
- 不需要計算梯度和參與梯度更新的參數,可以用self.register_buffer直接注冊就可以了;注冊的變量同樣使用;
- 被包成nn.Parameter的參數,需要求梯度,但是不能加cuda(),否則會報錯。 如果想在gpu上運算,可以將整個類的實例加.cuda()。例如 bn = BatchNorm(**param),bn=bn.cuda().
class BatchNorm(nn.Module): def __init__(self, num_features, num_dims): super(BatchNorm, self).__init__() if num_dims == 2: # 同樣是判斷是全連層還是卷積層 shape = (1, num_features) else: shape = (1, num_features, 1, 1) # 參與求梯度和迭代的拉伸和偏移參數,分別初始化成0和1 self.gamma = nn.Parameter(torch.ones(shape)) self.beta = nn.Parameter(torch.zeros(shape)) # 不參與求梯度和迭代的變量,全初始化成0 self.register_buffer('moving_mean', torch.zeros(shape)) self.register_buffer('moving_var', torch.ones(shape)) def forward(self, x): # 如果X不在內存上,將moving_mean和moving_var復制到X所在顯存上 if self.moving_mean.device != x.device: self.moving_mean = self.moving_mean.to(X.device) self.moving_var = self.moving_var.to(X.device) # 保存更新過的moving_mean和moving_var, Module實例的traning屬性默認為true, 調用.eval()后設成false y, self.moving_mean, self.moving_var = batch_norm(self.training, x, self.gamma, self.beta, self.moving_mean, self.moving_var, eps=1e-5, momentum=0.9) return x
以上就是python神經網絡pytorch中BN運算操作自實現的詳細內容,更多關于pytorch BN運算的資料請關注腳本之家其它相關文章!
相關文章
利用Pandas 創(chuàng)建空的DataFrame方法
下面小編就為大家分享一篇利用Pandas 創(chuàng)建空的DataFrame方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-04-04Python獲取Linux系統(tǒng)下的本機IP地址代碼分享
這篇文章主要介紹了Python獲取Linux系統(tǒng)下的本機IP地址代碼分享,本文直接給出實現代碼,可以獲取到eth0等網卡的IP地址,需要的朋友可以參考下2014-11-11Python3.10.4激活venv環(huán)境失敗解決方法
這篇文章主要介紹了Python3.10.4激活venv環(huán)境失敗解決方法的相關資料,需要的朋友可以參考下2023-01-01