欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

python神經(jīng)網(wǎng)絡(luò)pytorch中BN運(yùn)算操作自實(shí)現(xiàn)

 更新時(shí)間:2022年05月07日 15:46:12   作者:皮特潘  
這篇文章主要為大家介紹了python神經(jīng)網(wǎng)絡(luò)pytorch中BN運(yùn)算操作自實(shí)現(xiàn)示例,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪

BN 想必大家都很熟悉,來(lái)自論文:

《Batch Normalization Accelerating Deep Network Training by Reducing Internal Covariate Shift》

也是面試??疾斓膬?nèi)容,雖然一行代碼就能搞定,但是還是很有必要用代碼自己實(shí)現(xiàn)一下,也可以加深一下對(duì)其內(nèi)部機(jī)制的理解。

通用公式:

直奔代碼:

首先是定義一個(gè)函數(shù),實(shí)現(xiàn)BN的運(yùn)算操作:

def batch_norm(is_training, x, gamma, beta, moving_mean, moving_var, eps=1e-5, momentum=0.9):
    # 判斷當(dāng)前模式是訓(xùn)練模式還是預(yù)測(cè)模式
    if not is_training:
        # 如果是在預(yù)測(cè)模式下,直接使用傳入的移動(dòng)平均所得的均值和方差
        x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        if len(x.shape) == 2:
            # 使用全連接層的情況,計(jì)算特征維上的均值和方差
            mean = x.mean(dim=0)
            var = ((x - mean) ** 2).mean(dim=0)
        else:
            # 使用二維卷積層的情況,計(jì)算通道維上(axis=1)的均值和方差。這里我們需要保持
            # x的形狀以便后面可以做廣播運(yùn)算
            mean = x.mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
            var = ((x - mean) ** 2).mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
        # 訓(xùn)練模式下用當(dāng)前的均值和方差做標(biāo)準(zhǔn)化
        x_hat = (x - mean) / torch.sqrt(var + eps)
        # 更新移動(dòng)平均的均值和方差
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    x = gamma * x_hat + beta  # 拉伸和偏移
    return Y, moving_mean, moving_var

然后再定義一個(gè)類(lèi),就是常用的集成nn.Module的類(lèi)了。

這里說(shuō)明三點(diǎn):

  • 在卷積上的BN實(shí)現(xiàn),是在 Batch,W,H上進(jìn)行歸一化操作的,也就是BWH拉成一個(gè)維度求均值和方差,均值和方差以及beta和gamma的尺寸為channel。當(dāng)然其他各種N,包括IN,LN,GN都是包含WH維度的。
  • 不需要計(jì)算梯度和參與梯度更新的參數(shù),可以用self.register_buffer直接注冊(cè)就可以了;注冊(cè)的變量同樣使用;
  • 被包成nn.Parameter的參數(shù),需要求梯度,但是不能加cuda(),否則會(huì)報(bào)錯(cuò)。 如果想在gpu上運(yùn)算,可以將整個(gè)類(lèi)的實(shí)例加.cuda()。例如 bn = BatchNorm(**param),bn=bn.cuda().
class BatchNorm(nn.Module):
    def __init__(self, num_features, num_dims):
        super(BatchNorm, self).__init__()
        if num_dims == 2: # 同樣是判斷是全連層還是卷積層
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        # 參與求梯度和迭代的拉伸和偏移參數(shù),分別初始化成0和1
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        # 不參與求梯度和迭代的變量,全初始化成0
        self.register_buffer('moving_mean', torch.zeros(shape))
        self.register_buffer('moving_var', torch.ones(shape))

    def forward(self, x):
        # 如果X不在內(nèi)存上,將moving_mean和moving_var復(fù)制到X所在顯存上
        if self.moving_mean.device != x.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        # 保存更新過(guò)的moving_mean和moving_var, Module實(shí)例的traning屬性默認(rèn)為true, 調(diào)用.eval()后設(shè)成false
        y, self.moving_mean, self.moving_var = batch_norm(self.training,
            x, self.gamma, self.beta, self.moving_mean,
            self.moving_var, eps=1e-5, momentum=0.9)
        return x

以上就是python神經(jīng)網(wǎng)絡(luò)pytorch中BN運(yùn)算操作自實(shí)現(xiàn)的詳細(xì)內(nèi)容,更多關(guān)于pytorch BN運(yùn)算的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!

相關(guān)文章

  • python的scipy實(shí)現(xiàn)插值的示例代碼

    python的scipy實(shí)現(xiàn)插值的示例代碼

    這篇文章主要介紹了python的scipy實(shí)現(xiàn)插值的示例代碼,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧
    2019-11-11
  • 利用Pandas 創(chuàng)建空的DataFrame方法

    利用Pandas 創(chuàng)建空的DataFrame方法

    下面小編就為大家分享一篇利用Pandas 創(chuàng)建空的DataFrame方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧
    2018-04-04
  • Python獲取Linux系統(tǒng)下的本機(jī)IP地址代碼分享

    Python獲取Linux系統(tǒng)下的本機(jī)IP地址代碼分享

    這篇文章主要介紹了Python獲取Linux系統(tǒng)下的本機(jī)IP地址代碼分享,本文直接給出實(shí)現(xiàn)代碼,可以獲取到eth0等網(wǎng)卡的IP地址,需要的朋友可以參考下
    2014-11-11
  • Pyqt QImage 與 np array 轉(zhuǎn)換方法

    Pyqt QImage 與 np array 轉(zhuǎn)換方法

    今天小編就為大家分享一篇Pyqt QImage 與 np array 轉(zhuǎn)換方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧
    2019-06-06
  • Python面向?qū)ο笾?lèi)的定義與繼承用法示例

    Python面向?qū)ο笾?lèi)的定義與繼承用法示例

    這篇文章主要介紹了Python面向?qū)ο笾?lèi)的定義與繼承用法,結(jié)合實(shí)例形式分析了Python類(lèi)的定義、實(shí)例化、繼承等基本操作技巧,需要的朋友可以參考下
    2019-01-01
  • Java中的各種單例模式優(yōu)缺點(diǎn)解析

    Java中的各種單例模式優(yōu)缺點(diǎn)解析

    這篇文章主要介紹了Java中的各種單例模式解析,單例模式是Java中最簡(jiǎn)單的設(shè)計(jì)模式之一,這種類(lèi)型的設(shè)計(jì)模式屬于創(chuàng)建者模式,它提供了一種訪問(wèn)對(duì)象的最佳方式,需要的朋友可以參考下
    2023-07-07
  • python基礎(chǔ)中的文件對(duì)象詳解

    python基礎(chǔ)中的文件對(duì)象詳解

    這篇文章主要為大家介紹了python基礎(chǔ)中的文件對(duì)象,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下,希望能夠給你帶來(lái)幫助
    2022-01-01
  • Python3.10.4激活venv環(huán)境失敗解決方法

    Python3.10.4激活venv環(huán)境失敗解決方法

    這篇文章主要介紹了Python3.10.4激活venv環(huán)境失敗解決方法的相關(guān)資料,需要的朋友可以參考下
    2023-01-01
  • python讀取文本繪制動(dòng)態(tài)速度曲線(xiàn)

    python讀取文本繪制動(dòng)態(tài)速度曲線(xiàn)

    這篇文章主要為大家詳細(xì)介紹了python讀取文本繪制動(dòng)態(tài)速度曲線(xiàn),多圖同步顯示,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下
    2018-06-06
  • 詳解Python 中的容器 collections

    詳解Python 中的容器 collections

    這篇文章主要介紹了Python 中的容器 collections的相關(guān)資料,幫助大家更好的理解和學(xué)習(xí)python,感興趣的朋友可以了解下
    2020-08-08

最新評(píng)論