欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch 在網(wǎng)絡(luò)中添加可訓(xùn)練參數(shù),修改預(yù)訓(xùn)練權(quán)重文件的方法

 更新時間:2019年08月17日 14:58:25   作者:馬管子  
今天小編就為大家分享一篇pytorch 在網(wǎng)絡(luò)中添加可訓(xùn)練參數(shù),修改預(yù)訓(xùn)練權(quán)重文件的方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧

實踐中,針對不同的任務(wù)需求,我們經(jīng)常會在現(xiàn)成的網(wǎng)絡(luò)結(jié)構(gòu)上做一定的修改來實現(xiàn)特定的目的。

假如我們現(xiàn)在有一個簡單的兩層感知機網(wǎng)絡(luò):

# -*- coding: utf-8 -*-
import torch
from torch.autograd import Variable
import torch.optim as optim
 
x = Variable(torch.FloatTensor([1, 2, 3])).cuda()
y = Variable(torch.FloatTensor([4, 5])).cuda()
 
class MLP(torch.nn.Module):
  def __init__(self):
    super(MLP, self).__init__()
    self.linear1 = torch.nn.Linear(3, 5)
    self.relu = torch.nn.ReLU()
    self.linear2 = torch.nn.Linear(5, 2)
 
  def forward(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.linear2(x)
 
    return x
 
model = MLP().cuda()
 
loss_fn = torch.nn.MSELoss(size_average=False)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
 
for t in range(500):
  y_pred = model(x)
  loss = loss_fn(y_pred, y)
  print(t, loss.data[0])
  model.zero_grad()
  loss.backward()
  optimizer.step()
 
print(model(x))

現(xiàn)在想在前向傳播時,在relu之后給x乘以一個可訓(xùn)練的系數(shù),只需要在__init__函數(shù)中添加一個nn.Parameter類型變量,并在forward函數(shù)中乘以該變量即可:

class MLP(torch.nn.Module):
  def __init__(self):
    super(MLP, self).__init__()
    self.linear1 = torch.nn.Linear(3, 5)
    self.relu = torch.nn.ReLU()
    self.linear2 = torch.nn.Linear(5, 2)
    # the para to be added and updated in train phase, note that NO cuda() at last
    self.coefficient = torch.nn.Parameter(torch.Tensor([1.55]))
 
  def forward(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.coefficient * x
    x = self.linear2(x)
 
    return x

注意,Parameter變量和Variable變量的操作大致相同,但是不能手動調(diào)用.cuda()方法將其加載在GPU上,事實上它會自動在GPU上加載,可以通過model.state_dict()或者model.named_parameters()函數(shù)查看現(xiàn)在的全部可訓(xùn)練參數(shù)(包括通過繼承得到的父類中的參數(shù)):

print(model.state_dict().keys())
for i, j in model.named_parameters():
  print(i)
  print(j)

輸出如下:

odict_keys(['linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias'])
linear1.weight
Parameter containing:
-0.3582 -0.0283 0.2607
 0.5190 -0.2221 0.0665
-0.2586 -0.3311 0.1927
-0.2765 0.5590 -0.2598
 0.4679 -0.2923 -0.3379
[torch.cuda.FloatTensor of size 5x3 (GPU 0)]
 
linear1.bias
Parameter containing:
-0.2549
-0.5246
-0.1109
 0.5237
-0.1362
[torch.cuda.FloatTensor of size 5 (GPU 0)]
 
linear2.weight
Parameter containing:
-0.0286 -0.3045 0.1928 -0.2323 0.2966
 0.2601 0.1441 -0.2159 0.2484 0.0544
[torch.cuda.FloatTensor of size 2x5 (GPU 0)]
 
linear2.bias
Parameter containing:
-0.4038
 0.3129
[torch.cuda.FloatTensor of size 2 (GPU 0)]

這個參數(shù)會在反向傳播時與原有變量同時參與更新,這就達(dá)到了添加可訓(xùn)練參數(shù)的目的。

如果我們有原先網(wǎng)絡(luò)的預(yù)訓(xùn)練權(quán)重,現(xiàn)在添加了一個新的參數(shù),原有的權(quán)重文件自然就不能加載了,我們需要修改原權(quán)重文件,在其中添加我們的新變量的初始值。

調(diào)用model.state_dict查看我們添加的參數(shù)在參數(shù)字典中的完整名稱,然后打開原先的權(quán)重文件:

a = torch.load("OldWeights.pth") a是一個collecitons.OrderedDict類型變量,也就是一個有序字典,直接將新參數(shù)名稱和初始值作為鍵值對插入,然后保存即可。

a = torch.load("OldWeights.pth")
 
a["layer1.0.coefficient"] = torch.FloatTensor([1.2])
a["layer1.1.coefficient"] = torch.FloatTensor([1.5])
 
torch.save(a, "Weights.pth")

現(xiàn)在權(quán)重就可以加載在修改后的模型上了。

以上這篇pytorch 在網(wǎng)絡(luò)中添加可訓(xùn)練參數(shù),修改預(yù)訓(xùn)練權(quán)重文件的方法就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • python serial串口通信示例詳解

    python serial串口通信示例詳解

    Python的serial庫是一個用于串口通信的強大工具,它提供了一個簡單而靈活的接口,可以方便地與串口設(shè)備進(jìn)行通信,包括與驅(qū)動電機進(jìn)行通信,這篇文章主要介紹了python serial串口通信,需要的朋友可以參考下
    2023-12-12
  • 淺析python繼承與多重繼承

    淺析python繼承與多重繼承

    在本篇文章中我們給大家分析了python繼承與多重繼承的相關(guān)知識點內(nèi)容,有興趣的讀者們參考下。
    2018-09-09
  • pyautogui自動化控制鼠標(biāo)和鍵盤操作的步驟

    pyautogui自動化控制鼠標(biāo)和鍵盤操作的步驟

    這篇文章主要介紹了pyautogui自動化控制鼠標(biāo)和鍵盤操作的步驟,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2020-04-04
  • Python之plt.bar繪制柱狀圖參數(shù)解讀

    Python之plt.bar繪制柱狀圖參數(shù)解讀

    這篇文章主要介紹了Python之plt.bar繪制柱狀圖參數(shù),具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教
    2023-09-09
  • Python?時間操作datetime詳情

    Python?時間操作datetime詳情

    這篇文章主要介紹了?Python?時間操作datetime,datetime?模塊提供處理時間和日期的多種類,簡單方便,下面文章將詳細(xì)介紹其內(nèi)容,需要的朋友可以參考一下
    2021-11-11
  • python爬蟲請求頭的使用

    python爬蟲請求頭的使用

    這篇文章主要介紹了python爬蟲請求頭的使用,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2020-12-12
  • Python3 列表list合并的4種方法

    Python3 列表list合并的4種方法

    這篇文章主要介紹了Python3 列表list合并的4種方法,需要的朋友可以參考下
    2021-04-04
  • Python中的os.path路徑模塊中的操作方法總結(jié)

    Python中的os.path路徑模塊中的操作方法總結(jié)

    os.path模塊主要集成了針對路徑文件夾的操作功能,這里我們就來看一下Python中的os.path路徑模塊中的操作方法總結(jié),需要的朋友可以參考下
    2016-07-07
  • Python命名空間的本質(zhì)和加載順序

    Python命名空間的本質(zhì)和加載順序

    這篇文章主要介紹了Python命名空間的本質(zhì)和加載順序,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2018-12-12
  • Python中拆分具有多個分隔符的字符串方法實例

    Python中拆分具有多個分隔符的字符串方法實例

    str.split()是Python中字符串類型的一個方法,可以用來將字符串按照指定的分隔符分割成多個子字符串,這篇文章主要給大家介紹了關(guān)于Python中拆分具有多個分隔符的字符串的相關(guān)資料,需要的朋友可以參考下
    2023-04-04

最新評論