Pytorch參數(shù)注冊和nn.ModuleList nn.ModuleDict的問題
參考自官方文檔
參數(shù)注冊
嘗試自己寫GoogLeNet時碰到的問題,放在字典中的參數(shù)無法自動注冊,所謂的注冊,就是當(dāng)參數(shù)注冊到這個網(wǎng)絡(luò)上時,它會隨著你在外部調(diào)用net.cuda()后自動遷移到GPU上,而沒有注冊的參數(shù)則不會隨著網(wǎng)絡(luò)遷到GPU上,這就可能導(dǎo)致輸入在GPU上而參數(shù)不在GPU上,從而出現(xiàn)錯誤,為了說明這個現(xiàn)象。
舉一個有點鐵憨憨的例子:
import torch import torch.nn as nn import torch.nn.functional as F class Net(nn.Module): ?? ?def __init__(self): ?? ??? ?super(Net,self).__init__() ?? ??? ?self.weight = torch.rand((3,4)) # 這里其實可以直接用nn.Linear,但為了舉例這里先憨憨一下 ?? ? ?? ?def forward(self,x): ?? ??? ?return F.linear(x,self.weight) if __name__ == "__main__": ?? ?batch_size = 10 ?? ?dummy = torch.rand((batch_size,4)) ?? ?net = Net() ?? ?print(net(dummy))
上面的代碼可以成功運(yùn)行,因為所有的數(shù)值都是放在CPU上的,但是,一旦我們要把模型移到GPU上時
import torch import torch.nn as nn import torch.nn.functional as F class Net(nn.Module): ?? ?def __init__(self): ?? ??? ?super(Net,self).__init__() ?? ??? ?self.weight = torch.rand((3,4)) ?? ? ?? ?def forward(self,x): ?? ??? ?return F.linear(x,self.weight) if __name__ == "__main__": ?? ?batch_size = 10 ?? ?dummy = torch.rand((batch_size,4)).cuda() ?? ?net = Net().cuda() ?? ?print(net(dummy))
運(yùn)行后就會出現(xiàn)
...
RuntimeError: Expected object of backend CUDA but got backend CPU for argument #2 'mat2'
這就是因為self.weight沒有隨著模型一起移到GPU上的原因,此時我們查看模型的參數(shù),會發(fā)現(xiàn)并沒有self.weight
import torch import torch.nn as nn import torch.nn.functional as F class Net(nn.Module): ?? ?def __init__(self): ?? ??? ?super(Net,self).__init__() ?? ??? ?self.weight = torch.rand((3,4)) ?? ? ?? ?def forward(self,x): ?? ??? ?return F.linear(x,self.weight) if __name__ == "__main__": ?? ?net = Net() ?? ?for parameter in net.parameters(): ?? ??? ?print(parameter)
上面的代碼沒有輸出,因為net根本沒有參數(shù)
那么為了讓net有參數(shù),我們需要手動地將self.weight注冊到網(wǎng)絡(luò)上
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
?? ?def __init__(self):
?? ??? ?super(Net,self).__init__()
?? ??? ?self.weight = nn.Parameter(torch.rand((3,4))) # 被注冊的參數(shù)必須是nn.Parameter類型
?? ??? ?self.register_parameter('weight',self.weight) # 手動注冊參數(shù)
?? ??? ?
?? ?
?? ?def forward(self,x):
?? ??? ?return F.linear(x,self.weight)
if __name__ == "__main__":
?? ?net = Net()
?? ?for parameter in net.parameters():
?? ??? ?print(parameter)
?? ?batch_size = 10
?? ?net = net.cuda()
?? ?dummy = torch.rand((batch_size,4)).cuda()
?? ?print(net(dummy))此時網(wǎng)絡(luò)的參數(shù)就有了輸出,同時會隨著一起遷到GPU上,輸出就類似這樣
Parameter containing:
tensor([...])
tensor([...])
不過后來我實驗了以下,好像只寫nn.Parameter不寫register也可以被默認(rèn)注冊
nn.ModuleList和nn.ModuleDict
有時候我們?yōu)榱藞D省事,可能會這樣寫網(wǎng)絡(luò)
import torch import torch.nn as nn import torch.nn.functional as F class Net(nn.Module): ?? ?def __init__(self): ?? ??? ?super(Net,self).__init__() ?? ??? ?self.linears = [nn.Linear(4,4),nn.Linear(4,4),nn.Linear(4,2)] ?? ? ?? ?def forward(self,x): ?? ??? ?for linear in self.linears: ?? ??? ??? ?x = linear(x) ?? ??? ??? ?x = F.relu(x) ?? ??? ?return x if __name__ == '__main__': ?? ?net = Net() ?? ?for parameter in net.parameters(): ?? ??? ?print(parameter)??
同樣,輸出網(wǎng)絡(luò)的參數(shù)啥也沒有,這意味著當(dāng)調(diào)用net.cuda時,self.linears里面的參數(shù)不會一起走到GPU上去
此時我們可以在__init__方法中手動對self.parameters()迭代然后把每個參數(shù)注冊,但更好的方法是,pytorch已經(jīng)為我們提供了nn.ModuleList,用來代替python內(nèi)置的list,放在nn.ModuleList中的參數(shù)將會自動被正確注冊
import torch import torch.nn as nn import torch.nn.functional as F class Net(nn.Module): ?? ?def __init__(self): ?? ??? ?super(Net,self).__init__() ?? ??? ?self.linears = nn.ModuleList([nn.Linear(4,4),nn.Linear(4,4),nn.Linear(4,2)]) ?? ? ?? ?def forward(self,x): ?? ??? ?for linear in self.linears: ?? ??? ??? ?x = linear(x) ?? ??? ??? ?x = F.relu(x) ?? ??? ?return x if __name__ == '__main__': ?? ?net = Net() ?? ?for parameter in net.parameters(): ?? ??? ?print(parameter)?? ??? ?
此時就有輸出了
Parameter containing:
tensor(...)
Parameter containing:
tensor(...)
...
nn.ModuleDict也是類似,當(dāng)我們需要把參數(shù)放在一個字典里的時候,能夠用的上,這里直接給一個官方的例子看一看就OK
class MyModule(nn.Module):
? ? def __init__(self):
? ? ? ? super(MyModule, self).__init__()
? ? ? ? self.choices = nn.ModuleDict({
? ? ? ? ? ? ? ? 'conv': nn.Conv2d(10, 10, 3),
? ? ? ? ? ? ? ? 'pool': nn.MaxPool2d(3)
? ? ? ? })
? ? ? ? self.activations = nn.ModuleDict([
? ? ? ? ? ? ? ? ['lrelu', nn.LeakyReLU()],
? ? ? ? ? ? ? ? ['prelu', nn.PReLU()]
? ? ? ? ])
? ? def forward(self, x, choice, act):
? ? ? ? x = self.choices[choice](x)
? ? ? ? x = self.activations[act](x)
? ? ? ? return x需要注意的是,雖然直接放在python list中的參數(shù)不會自動注冊,但如果只是暫時放在list里,隨后又調(diào)用了nn.Sequential把整個list整合起來,參數(shù)仍然是會自動注冊的
另外一點要注意的是ModuleList和ModuleDict里面只能放Module的子類,也就是nn.Conv,nn.Linear這樣的,但不能放nn.Parameter,如果要放nn.Parameter,用nn.ParameterList即可,用法和nn.ModuleList一樣
總結(jié)
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python利用多進(jìn)程將大量數(shù)據(jù)放入有限內(nèi)存的教程
這篇文章主要介紹了Python利用多進(jìn)程將大量數(shù)據(jù)放入有限內(nèi)存的教程,使用了multiprocessing和pandas來加速內(nèi)存中的操作,需要的朋友可以參考下2015-04-04
PyCharm安裝庫numpy失敗問題的詳細(xì)解決方法
今天使用pycharm編譯python程序時,由于要調(diào)用numpy包,但又未曾安裝numpy,于是就根據(jù)pycharm的提示進(jìn)行安裝,最后竟然提示出錯,下面這篇文章主要給大家介紹了關(guān)于PyCharm安裝庫numpy失敗問題的詳細(xì)解決方法,需要的朋友可以參考下2022-06-06
Python提取特定時間段內(nèi)數(shù)據(jù)的方法實例
今天小編就為大家分享一篇關(guān)于Python提取特定時間段內(nèi)數(shù)據(jù)的方法實例,小編覺得內(nèi)容挺不錯的,現(xiàn)在分享給大家,具有很好的參考價值,需要的朋友一起跟隨小編來看看吧2019-04-04
Python通過WHL文件實現(xiàn)離線安裝的操作詳解
在Python開發(fā)中,我們經(jīng)常需要安裝第三方庫來擴(kuò)展Python的功能,通常情況下,我們可以通過pip命令在線安裝這些庫,此時,WHL(Wheel)文件成為了非常實用的解決方案,本教程將結(jié)合實際案例,詳細(xì)介紹如何通過WHL文件在Python中進(jìn)行離線安裝,需要的朋友可以參考下2024-08-08
Django使用Celery實現(xiàn)異步發(fā)送郵件
這篇文章主要為大家詳細(xì)介紹了Django如何使用Celery實現(xiàn)異步發(fā)送郵件的功能,文中的示例代碼講解詳細(xì),感興趣的小伙伴可以了解一下2023-04-04
解決python和pycharm安裝gmpy2 出現(xiàn)ERROR的問題
這篇文章主要介紹了python和pycharm安裝gmpy2 出現(xiàn)ERROR的解決方法,本文給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友可以參考下2020-08-08
基于循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)實現(xiàn)影評情感分類
這篇文章主要為大家詳細(xì)介紹了基于循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)實現(xiàn)影評情感分類,具有一定的參考價值,感興趣的小伙伴們可以參考一下2018-03-03

