欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch固定BN層參數(shù)的操作

 更新時間:2021年05月27日 08:58:15   作者:grllery  
這篇文章主要介紹了pytorch固定BN層參數(shù)的操作,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教

背景:

基于PyTorch的模型,想固定主分支參數(shù),只訓練子分支,結(jié)果發(fā)現(xiàn)在不同epoch相同的測試數(shù)據(jù)經(jīng)過主分支輸出的結(jié)果不同。

原因:

未固定主分支BN層中的running_mean和running_var。

解決方法:

將需要固定的BN層狀態(tài)設(shè)置為eval。

問題示例:

環(huán)境:torch:1.7.0

# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.bn1 = nn.BatchNorm2d(6)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.bn2 = nn.BatchNorm2d(16)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 6 * 6, 120)  # 6*6 from image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 5)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

def print_parameter_grad_info(net):
    print('-------parameters requires grad info--------')
    for name, p in net.named_parameters():
        print(f'{name}:\t{p.requires_grad}')

def print_net_state_dict(net):
    for key, v in net.state_dict().items():
        print(f'{key}')

if __name__ == "__main__":
    net = Net()

    print_parameter_grad_info(net)
    net.requires_grad_(False)
    print_parameter_grad_info(net)

    torch.random.manual_seed(5)
    test_data = torch.rand(1, 1, 32, 32)
    train_data = torch.rand(5, 1, 32, 32)

    # print(test_data)
    # print(train_data[0, ...])
    for epoch in range(2):
        # training phase, 假設(shè)每個epoch只迭代一次
        net.train()
        pre = net(train_data)
        # 計算損失和參數(shù)更新等
        # ....

        # test phase
        net.eval()
        x = net(test_data)
        print(f'epoch:{epoch}', x)

運行結(jié)果:

-------parameters requires grad info--------
conv1.weight: True
conv1.bias: True
bn1.weight: True
bn1.bias: True
conv2.weight: True
conv2.bias: True
bn2.weight: True
bn2.bias: True
fc1.weight: True
fc1.bias: True
fc2.weight: True
fc2.bias: True
fc3.weight: True
fc3.bias: True
-------parameters requires grad info--------
conv1.weight: False
conv1.bias: False
bn1.weight: False
bn1.bias: False
conv2.weight: False
conv2.bias: False
bn2.weight: False
bn2.bias: False
fc1.weight: False
fc1.bias: False
fc2.weight: False
fc2.bias: False
fc3.weight: False
fc3.bias: False
epoch:0 tensor([[-0.0755, 0.1138, 0.0966, 0.0564, -0.0224]])
epoch:1 tensor([[-0.0763, 0.1113, 0.0970, 0.0574, -0.0235]])

可以看到:

net.requires_grad_(False)已經(jīng)將網(wǎng)絡(luò)中的各參數(shù)設(shè)置成了不需要梯度更新的狀態(tài),但是同樣的測試數(shù)據(jù)test_data在不同epoch中前向之后出現(xiàn)了不同的結(jié)果。

調(diào)用print_net_state_dict可以看到BN層中的參數(shù)running_mean和running_var并沒在可優(yōu)化參數(shù)net.parameters中

bn1.weight
bn1.bias
bn1.running_mean
bn1.running_var
bn1.num_batches_tracked

但在training pahse的前向過程中,這兩個參數(shù)被更新了。導致整個網(wǎng)絡(luò)在freeze的情況下,同樣的測試數(shù)據(jù)出現(xiàn)了不同的結(jié)果

Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a defaultmomentumof 0.1. source

因此在training phase時對BN層顯式設(shè)置eval狀態(tài):

if __name__ == "__main__":
    net = Net()
    net.requires_grad_(False)

    torch.random.manual_seed(5)
    test_data = torch.rand(1, 1, 32, 32)
    train_data = torch.rand(5, 1, 32, 32)

    # print(test_data)
    # print(train_data[0, ...])
    for epoch in range(2):
        # training phase, 假設(shè)每個epoch只迭代一次
        net.train()
        net.bn1.eval()
        net.bn2.eval()
        pre = net(train_data)
        # 計算損失和參數(shù)更新等
        # ....

        # test phase
        net.eval()
        x = net(test_data)
        print(f'epoch:{epoch}', x)

可以看到結(jié)果正常了:

epoch:0 tensor([[ 0.0944, -0.0372, 0.0059, -0.0625, -0.0048]])
epoch:1 tensor([[ 0.0944, -0.0372, 0.0059, -0.0625, -0.0048]])

補充:pytorch---之BN層參數(shù)詳解及應用(1,2,3)(1,2)?

BN層參數(shù)詳解(1,2)

一般來說pytorch中的模型都是繼承nn.Module類的,都有一個屬性trainning指定是否是訓練狀態(tài),訓練狀態(tài)與否將會影響到某些層的參數(shù)是否是固定的,比如BN層(對于BN層測試的均值和方差是通過統(tǒng)計訓練的時候所有的batch的均值和方差的平均值)或者Dropout層(對于Dropout層在測試的時候所有神經(jīng)元都是激活的)。通常用model.train()指定當前模型model為訓練狀態(tài),model.eval()指定當前模型為測試狀態(tài)。

同時,BN的API中有幾個參數(shù)需要比較關(guān)心的,一個是affine指定是否需要仿射,還有個是track_running_stats指定是否跟蹤當前batch的統(tǒng)計特性。容易出現(xiàn)問題也正好是這三個參數(shù):trainning,affine,track_running_stats。

其中的affine指定是否需要仿射,也就是是否需要上面算式的第四個,如果affine=False則γ=1,β=0 \gamma=1,\beta=0γ=1,β=0,并且不能學習被更新。一般都會設(shè)置成affine=True。(這里是一個可學習參數(shù))

trainning和track_running_stats,track_running_stats=True表示跟蹤整個訓練過程中的batch的統(tǒng)計特性,得到方差和均值,而不只是僅僅依賴與當前輸入的batch的統(tǒng)計特性(意思就是說新的batch依賴于之前的batch的均值和方差這里使用momentum參數(shù),參考了指數(shù)移動平均的算法EMA)。相反的,如果track_running_stats=False那么就只是計算當前輸入的batch的統(tǒng)計特性中的均值和方差了。當在推理階段的時候,如果track_running_stats=False,此時如果batch_size比較小,那么其統(tǒng)計特性就會和全局統(tǒng)計特性有著較大偏差,可能導致糟糕的效果。

應用技巧:(1,2)

通常pytorch都會用到optimizer.zero_grad() 來清空以前的batch所累加的梯度,因為pytorch中Variable計算的梯度會進行累計,所以每一個batch都要重新清空一次梯度,原始的做法是下面這樣的:

問題:參數(shù)non_blocking,以及pytorch的整體框架??

代碼(1)

for index,data,target in enumerate(dataloader):
    data = data.cuda(non_blocking=True)
    target = torch.from_numpy(np.array(target)).float().cuda(non_blocking = Trye)
    output = model(data)
    loss = criterion(output,target)
    
    #清空梯度
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

而這里為了模仿minibacth,我們每次batch不清0,累積到一定次數(shù)再清0,再更新權(quán)重:

for index, data, target in enumerate(dataloader):
    #如果不是Tensor,一般要用到torch.from_numpy()
    data = data.cuda(non_blocking = True)
    target = torch.from_numpy(np.array(target)).float().cuda(non_blocking = True)
    output = model(data)
    loss = criterion(data, target)
    loss.backward()
    if index%accumulation == 0:
        #用累積的梯度更新權(quán)重
        optimizer.step()
        #清空梯度
        optimizer.zero_grad()

雖然這里的梯度是相當于原來的accumulation倍,但是實際在前向傳播的過程中,對于BN幾乎沒有影響,因為前向的BN還是只是一個batch的均值和方差,這個時候可以用pytorch中BN的momentum參數(shù),默認是0.1,BN參數(shù)如下,就是指數(shù)移動平均

x_new_running = (1 - momentum) * x_running + momentum * x_new_observed. momentum

以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • python微信好友數(shù)據(jù)分析詳解

    python微信好友數(shù)據(jù)分析詳解

    這篇文章主要為大家詳細介紹了python微信好友數(shù)據(jù)分析,實現(xiàn)對微信好友的獲取,并對省份、性別等數(shù)據(jù)分析,具有一定的參考價值,感興趣的小伙伴們可以參考一下
    2018-11-11
  • 通過字符串導入 Python 模塊的方法詳解

    通過字符串導入 Python 模塊的方法詳解

    這篇文章主要介紹了通過字符串導入 Python 模塊的方法詳解,本文通過實例結(jié)合,給大家介紹的非常詳細,具有一定的參考借鑒價值,需要的朋友可以參考下
    2019-10-10
  • python實現(xiàn)向微信用戶發(fā)送每日一句 python實現(xiàn)微信聊天機器人

    python實現(xiàn)向微信用戶發(fā)送每日一句 python實現(xiàn)微信聊天機器人

    這篇文章主要為大家詳細介紹了python實現(xiàn)向微信用戶發(fā)送每日一句,python調(diào)實現(xiàn)微信聊天機器人,具有一定的參考價值,感興趣的小伙伴們可以參考一下
    2019-03-03
  • 基于python的docx模塊處理word和WPS的docx格式文件方式

    基于python的docx模塊處理word和WPS的docx格式文件方式

    今天小編就為大家分享一篇基于python的docx模塊處理word和WPS的docx格式文件方式,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2020-02-02
  • 使用Python生成XML的方法實例

    使用Python生成XML的方法實例

    這篇文章主要介紹了使用Python生成XML的方法,結(jié)合具體實例形式詳細分析了Python生成xml文件的具體流暢與相關(guān)注意事項,需要的朋友可以參考下
    2017-03-03
  • 如何在vscode中安裝python庫的方法步驟

    如何在vscode中安裝python庫的方法步驟

    這篇文章主要介紹了如何在vscode中安裝python庫的方法步驟,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧
    2021-01-01
  • python中BackgroundScheduler和BlockingScheduler的區(qū)別

    python中BackgroundScheduler和BlockingScheduler的區(qū)別

    這篇文章主要介紹了python中BackgroundScheduler和BlockingScheduler的區(qū)別,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧
    2021-07-07
  • Python掃描IP段查看指定端口是否開放的方法

    Python掃描IP段查看指定端口是否開放的方法

    這篇文章主要介紹了Python掃描IP段查看指定端口是否開放的方法,涉及Python使用socket模塊實現(xiàn)端口掃描功能的相關(guān)技巧,需要的朋友可以參考下
    2015-06-06
  • Python判斷文件和字符串編碼類型的實例

    Python判斷文件和字符串編碼類型的實例

    下面小編就為大家分享一篇Python判斷文件和字符串編碼類型的實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2017-12-12
  • PyTorch深度學習LSTM從input輸入到Linear輸出

    PyTorch深度學習LSTM從input輸入到Linear輸出

    這篇文章主要為大家介紹了PyTorch深度學習LSTM從input輸入到Linear輸出深入理解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪
    2022-05-05

最新評論