欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

淺析pytorch中對nn.BatchNorm2d()函數(shù)的理解

 更新時間:2023年11月15日 16:48:30   作者:Code_LiShi  
Batch Normalization強行將數(shù)據(jù)拉回到均值為0,方差為1的正太分布上,一方面使得數(shù)據(jù)分布一致,另一方面避免梯度消失,這篇文章主要介紹了pytorch中對nn.BatchNorm2d()函數(shù)的理解,需要的朋友可以參考下

簡介

機器學習中,進行模型訓練之前,需對數(shù)據(jù)做歸一化處理,使其分布一致。在深度神經(jīng)網(wǎng)絡訓練過程中,通常一次訓練是一個batch,而非全體數(shù)據(jù)。每個batch具有不同的分布產(chǎn)生了internal covarivate shift問題——在訓練過程中,數(shù)據(jù)分布會發(fā)生變化,對下一層網(wǎng)絡的學習帶來困難。Batch Normalization強行將數(shù)據(jù)拉回到均值為0,方差為1的正太分布上,一方面使得數(shù)據(jù)分布一致,另一方面避免梯度消失。

計算

如圖所示:

3. Pytorch的nn.BatchNorm2d()函數(shù)

其主要需要輸入4個參數(shù):
(1)num_features:輸入數(shù)據(jù)的shape一般為[batch_size, channel, height, width], num_features為其中的channel;
(2)eps: 分母中添加的一個值,目的是為了計算的穩(wěn)定性,默認:1e-5;
(3)momentum: 一個用于運行過程中均值和方差的一個估計參數(shù),默認值為0.1.

(4)affine:當設為true時,給定可以學習的系數(shù)矩陣 γ \gamma γ和 β \beta β

4 代碼示例

import torch
data = torch.ones(size=(2, 2, 3, 4))
data[0][0][0][0] = 25
print("data = ", data)
print("\n")
print("=========================使用封裝的BatchNorm2d()計算================================")
BN = torch.nn.BatchNorm2d(num_features=2, eps=0, momentum=0)
BN_data = BN(data)
print("BN_data = ", BN_data)
print("\n")
print("=========================自行計算================================")
x = torch.cat((data[0][0], data[1][0]), dim=1)      # 1.將同一通道進行拼接(即把同一通道當作一個整體)
x_mean = torch.Tensor.mean(x)                       # 2.計算同一通道所有制的均值(即拼接后的均值)
x_var = torch.Tensor.var(x, False)                  # 3.計算同一通道所有制的方差(即拼接后的方差)
# 4.使用第一個數(shù)按照公式來求BatchNorm后的值
bn_first = ((data[0][0][0][0] - x_mean) / ( torch.pow(x_var, 0.5))) * BN.weight[0] + BN.bias[0]
print("bn_first = ", bn_first)

到此這篇關(guān)于pytorch中對nn.BatchNorm2d()函數(shù)的理解的文章就介紹到這了,更多相關(guān)pytorch nn.BatchNorm2d()函數(shù)內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

最新評論