PyTorch使用torch.nn.Module模塊自定義模型結(jié)構(gòu)方式
以實(shí)現(xiàn)LeNet網(wǎng)絡(luò)為例,來(lái)學(xué)習(xí)使用pytorch如何搭建一個(gè)神經(jīng)網(wǎng)絡(luò)。
LeNet網(wǎng)絡(luò)的結(jié)構(gòu)如下圖所示。
一、使用torch.nn.Module類(lèi)構(gòu)建網(wǎng)絡(luò)模型
搭建自己的網(wǎng)絡(luò)模型,我們需要新建一個(gè)類(lèi),讓它繼承torch.nn.Module類(lèi),并必須重寫(xiě)Module類(lèi)中的__init__()和forward()函數(shù)。
init()函數(shù)用來(lái)申明模型中各層的定義,forward()函數(shù)用來(lái)描述各層之間的連接關(guān)系,定義前向傳播計(jì)算的過(guò)程。
也就是說(shuō)__init__()函數(shù)只是用來(lái)定義層,但并沒(méi)有將它們連接起來(lái),forward()函數(shù)的作用就是將這些定義好的層連接成網(wǎng)絡(luò)。
使用上述方法實(shí)現(xiàn)LeNet網(wǎng)絡(luò)的代碼如下。
import torch.nn as nn class LeNet(nn.Module): def __init__(self): super().__init__() self.C1 = nn.Conv2d(1, 6, 5) self.sig = nn.Sigmoid() self.S2 = nn.MaxPool2d(2, 2) self.C3 = nn.Conv2d(6, 16, 5) self.S4 = nn.MaxPool2d(2, 2) self.C5 = nn.Conv2d(16, 120, 5) self.C6 = nn.Linear(120, 84) self.C7 = nn.Linear(84, 10) def forward(self, x): x1 = self.C1(x) x2 = self.sig(x1) x3 = self.S2(x2) x4 = self.C3(x3) x5 = self.sig(x4) x6 = self.S4(x5) x7 = self.C5(x6) x8 = self.C6(x7) y = self.C7(x8) return y net = LeNet() print(net)
結(jié)果為
在__init__()函數(shù)中,實(shí)例化了nn.Linear()、nn.Conv2d()這種pytorch封裝好的類(lèi),用來(lái)定義全連接層、卷積層等網(wǎng)絡(luò)層,并規(guī)定好它們的參數(shù)。
例如,self.C1 = nn.Conv2d(1, 6, 5)表示定義一個(gè)卷積層,它的卷積核輸入通道為1、輸出通道為6,大小為5×5。
真正向這個(gè)卷積層輸入數(shù)據(jù)是在forward()函數(shù)中,x1 = self.C1(x)表示將輸入x喂給卷積層,并得到輸出x1。
二、引入torch.nn.functional實(shí)現(xiàn)層的運(yùn)算
引入torch.nn.functional模塊中的函數(shù),可以簡(jiǎn)化__init__()函數(shù)中的內(nèi)容。
在__init__()函數(shù)中,我們可以只定義具有需要學(xué)習(xí)的參數(shù)的層,如卷積層、線性層,它們的權(quán)重都需要學(xué)習(xí)。
對(duì)于不需要學(xué)習(xí)參數(shù)的層,我們不需要在__init__()函數(shù)中定義,只需要在forward()函數(shù)中引入torch.nn.functional類(lèi)中相關(guān)函數(shù)的調(diào)用。
例如LeNet中,我們?cè)赺_init__()中只定義了卷積層和全連接層。池化層和激活函數(shù)只需要在forward()函數(shù)中,調(diào)用torch.nn.functional中的函數(shù)進(jìn)行實(shí)現(xiàn)即可。
import torch.nn as nn import torch.nn.functional as F class LeNet(nn.Module): def __init__(self): super().__init__() self.C1 = nn.Conv2d(1, 6, 5) self.C3 = nn.Conv2d(6, 16, 5) self.C5 = nn.Conv2d(16, 120, 5) self.C6 = nn.Linear(120, 84) self.C7 = nn.Linear(84, 10) def forward(self, x): x1 = self.C1(x) x2 = F.sigmoid(x1) x3 = F.max_pool2d(x2) x4 = self.C3(x3) x5 = F.sigmoid(x4) x6 = F.max_pool2d(x5) x7 = self.C5(x6) x8 = self.C6(x7) y = self.C7(x8) return y net = LeNet() print(net)
運(yùn)行結(jié)果為
當(dāng)然,torch.nn.functional中也對(duì)需要學(xué)習(xí)參數(shù)的層進(jìn)行了實(shí)現(xiàn),包括卷積層conv2d()和線性層linear(),但pytorch官方推薦我們只對(duì)不需要學(xué)習(xí)參數(shù)的層使用nn.functional中的函數(shù)。
對(duì)于一個(gè)層,使用nn.Xxx實(shí)現(xiàn)和使用nn.functional.xxx()實(shí)現(xiàn)的區(qū)別為:
1.nn.Xxx是一個(gè)類(lèi),繼承自nn.Modules,因此內(nèi)部會(huì)有很多屬性和方法,如train(), eval(),load_state_dict, state_dict 等。
2.nn.functional.xxx()僅僅是一個(gè)函數(shù)。作為一個(gè)類(lèi),nn.Xxx需要先實(shí)例化并傳入?yún)?shù),然后以函數(shù)調(diào)用的方式向?qū)嵗瘜?duì)象中喂入輸入數(shù)據(jù)。
conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding) output = conv(input)
nn.functional.xxx()是在調(diào)用時(shí)同時(shí)傳入輸入數(shù)據(jù)和設(shè)置參數(shù)。
output = nn.functional.conv2d(input, weight, bias, padding)
3.nn.Xxx不需要自己定義和管理權(quán)重,但nn.functional.xxx()需要自己定義權(quán)重,每次調(diào)用時(shí)要手動(dòng)傳入。
三、Sequential類(lèi)
1. 基礎(chǔ)使用
Sequential類(lèi)繼承自Module類(lèi)。對(duì)于一個(gè)簡(jiǎn)單的序貫?zāi)P停梢圆槐刈约涸俣鄬?xiě)一個(gè)類(lèi)繼承Module類(lèi),而是直接使用pytorch提供的Sequential類(lèi),來(lái)將若干層或若干子模塊直接包裝成一個(gè)大的模塊。
例如在LeNet中,我們直接將各個(gè)層按順序排列好,然后用Sequential類(lèi)包裝一下,就可以方便地構(gòu)建好一個(gè)神經(jīng)網(wǎng)路了。
import torch.nn as nn net = nn.Sequential( nn.Conv2d(1, 6, 5), nn.Sigmoid(), nn.MaxPool2d(2, 2), nn.Conv2d(6, 16, 5), nn.Sigmoid(), nn.MaxPool2d(2, 2), nn.Conv2d(16, 120, 5), nn.Linear(120, 84), nn.Linear(84, 10) ) print(net) print(net[2]) #通過(guò)索引可以獲取到層
運(yùn)行結(jié)果為
上面這種方法沒(méi)有給每一個(gè)層指定名稱(chēng),默認(rèn)使用層的索引數(shù)0、1、2來(lái)命名。我們可以通過(guò)索引值來(lái)直接獲對(duì)應(yīng)的層的信息。
當(dāng)然,我們也可以給層指定名稱(chēng),但我們并不能通過(guò)名稱(chēng)獲取層,想獲取層依舊要使用索引數(shù)字。
import torch.nn as nn from collections import OrderedDict net = nn.Sequential(OrderedDict([ ('C1', nn.Conv2d(1, 6, 5)), ('Sig1', nn.Sigmoid()), ('S2', nn.MaxPool2d(2, 2)), ('C3', nn.Conv2d(6, 16, 5)), ('Sig2', nn.Sigmoid()), ('S4', nn.MaxPool2d(2, 2)), ('C5', nn.Conv2d(16, 120, 5)), ('C6', nn.Linear(120, 84)), ('C7', nn.Linear(84, 10)) ])) print(net) print(net[2]) #通過(guò)索引可以獲取到層
運(yùn)行結(jié)果為
也可以使用add_module函數(shù)向Sequential()中添加層。
import torch.nn as nn net = nn.Sequential() net.add_module('C1', nn.Conv2d(1, 6, 5)) net.add_module('Sig1', nn.Sigmoid()) net.add_module('S2', nn.MaxPool2d(2, 2)) net.add_module('C3', nn.Conv2d(6, 16, 5)) net.add_module('Sig2', nn.Sigmoid()) net.add_module('S4', nn.MaxPool2d(2, 2)) net.add_module('C5', nn.Conv2d(16, 120, 5)) net.add_module('C6', nn.Linear(120, 84)) net.add_module('C7', nn.Linear(84, 10)) print(net) print(net[2])
輸出為
2. 使用Sequential類(lèi)將層包裝成子模塊
Sequential類(lèi)也可以應(yīng)用到自定義Module類(lèi)的方法中,用來(lái)將幾個(gè)層包裝成一個(gè)大層(塊)。
當(dāng)然Sequential依舊有三種使用方法,我們這里只使用第一種作為舉例。
import torch.nn as nn class LeNet(nn.Module): def __init__(self): super().__init__() self.conv = nn.Sequential( nn.Conv2d(1, 6, 5), nn.Sigmoid(), nn.MaxPool2d(2, 2), nn.Conv2d(6, 16, 5), nn.Sigmoid(), nn.MaxPool2d(2, 2) ) self.fc = nn.Sequential( nn.Conv2d(16, 120, 5), nn.Linear(120, 84), nn.Linear(84, 10) ) def forward(self, x): x1 = self.conv(x) y = self.fc(x1) return y net = LeNet() print(net)
輸出為
四、ModuleList類(lèi)和ModuleDict類(lèi)
ModuleList類(lèi)和ModuleDict類(lèi)都是Modules類(lèi)的子類(lèi),和Sequential類(lèi)似,它也可以對(duì)若干層或子模塊進(jìn)行打包,列表化的構(gòu)造網(wǎng)絡(luò)。
但與Sequential類(lèi)不同的是,這兩個(gè)類(lèi)只是將這些層定義并排列成列表(List)或字典(Dict),但并沒(méi)有將它們連接起來(lái),也就是說(shuō)并沒(méi)有實(shí)現(xiàn)forward()函數(shù)。
因此,這兩個(gè)類(lèi)并不要求相鄰層的輸入輸出維度匹配,也不能直接向ModuleList和ModuleDict中直接喂入輸入數(shù)據(jù)。
ModuleList的訪問(wèn)方法和普通的List類(lèi)似。
net = nn.ModuleList([ nn.Linear(784, 256), nn.ReLU() ]) net.append(nn.Linear(256, 20)) # ModuleList可以像普通的List以下進(jìn)行append操作 print(net[-1]) # ModuleList的訪問(wèn)方法與List也相似 print(net) # X = torch.zeros(1, 784) # net(X) # 出錯(cuò)。向ModuleList中輸入數(shù)據(jù)會(huì)出錯(cuò),因?yàn)镸oduleList的作用僅僅是存儲(chǔ) # 網(wǎng)絡(luò)的各個(gè)模塊,但并不連接它們,即沒(méi)有實(shí)現(xiàn)forward()
輸出為
ModuleDict的使用方法也和普通的字典類(lèi)似。
net = nn.ModuleDict({ 'linear': nn.Linear(784, 256), 'act': nn.ReLU(), }) net['output'] = nn.Linear(256, 10) # 添加 print(net['linear']) # 訪問(wèn) print(net.output) print(net) # net(torch.zeros(1, 784)) # 會(huì)報(bào)NotImplementedError
輸出為
ModuleList和ModuleDict的使用是為了在定義前向傳播時(shí)能更加靈活。下面是官網(wǎng)上的一個(gè)關(guān)于ModuleList使用的例子。
class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)]) def forward(self, x): # ModuleList can act as an iterable, or be indexed using ints for i, l in enumerate(self.linears): x = self.linears[i // 2](x) + l(x) return x
此外,ModuleList和ModuleDict里,所有子模塊的參數(shù)都會(huì)被自動(dòng)添加到神經(jīng)網(wǎng)絡(luò)中,這一點(diǎn)是與普通的List和Dict不同的。
舉個(gè)例子。
class Module_ModuleList(nn.Module): def __init__(self): super(Module_ModuleList, self).__init__() self.linears = nn.ModuleList([nn.Linear(10, 10)]) class Module_List(nn.Module): def __init__(self): super(Module_List, self).__init__() self.linears = [nn.Linear(10, 10)] net1 = Module_ModuleList() net2 = Module_List() print("net1:") for p in net1.parameters(): print(p.size()) print("net2:") for p in net2.parameters(): print(p)
輸出為
五、向模型中輸入數(shù)據(jù)
假設(shè)我們向模型中輸入的數(shù)據(jù)為input,從模型中得到的前向傳播結(jié)果為output,則輸入數(shù)據(jù)的方法為
output = net(input)
net是對(duì)象名,我們直接將輸入作為參數(shù)傳入到對(duì)象名中,而并沒(méi)有顯示的調(diào)用forward()函數(shù),就完成了前向傳播的計(jì)算。
上面的寫(xiě)法其實(shí)等價(jià)于
output = net.forward(input)
這是因?yàn)樵趖orch.nn.Module類(lèi)中,定義了__call__()函數(shù),其中就包括了對(duì)forward()方法的調(diào)用。
在python語(yǔ)法中__call__()方法使得類(lèi)實(shí)例對(duì)象可以像調(diào)用普通函數(shù)那樣,以“對(duì)象名()”的形式使用,并執(zhí)行__call__()函數(shù)體中的內(nèi)容。
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
基于Python編寫(xiě)一個(gè)串口調(diào)試工具
這篇文章主要為大家詳細(xì)介紹了如何基于 Python編寫(xiě)一個(gè)tkinter 和 pyserial 的串口調(diào)試工具,可以方便地進(jìn)行串口通信的設(shè)置等操作,感興趣的小伙伴可以了解下2025-02-02Python關(guān)于版本升級(jí)與包的維護(hù)方式
這篇文章主要介紹了Python關(guān)于版本升級(jí)與包的維護(hù)方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2021-06-06python編程冒泡排序法實(shí)現(xiàn)動(dòng)圖排序示例解析
這篇文章主要介紹了python編程中如何使用冒泡排序法實(shí)現(xiàn)動(dòng)圖排序的示例解析,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步2021-10-10Python用requests模塊實(shí)現(xiàn)動(dòng)態(tài)網(wǎng)頁(yè)爬蟲(chóng)
大家好,本篇文章主要講的是Python用requests模塊實(shí)現(xiàn)動(dòng)態(tài)網(wǎng)頁(yè)爬蟲(chóng),感興趣的同學(xué)趕快來(lái)看一看吧,對(duì)你有幫助的話記得收藏一下2022-02-02python中可以發(fā)生異常自動(dòng)重試庫(kù)retrying
這篇文章主要介紹了python中可以發(fā)生異常自動(dòng)重試庫(kù)retrying,retrying是一個(gè)極簡(jiǎn)的使用Python編寫(xiě)的庫(kù),主題更多相關(guān)內(nèi)容需要的朋友可以參考一下2022-06-06初次部署django+gunicorn+nginx的方法步驟
這篇文章主要介紹了初次部署django+gunicorn+nginx的方法步驟,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2019-09-09python使用minimax算法實(shí)現(xiàn)五子棋
這篇文章主要為大家詳細(xì)介紹了python使用minimax算法實(shí)現(xiàn)五子棋,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2019-07-07Python中的數(shù)據(jù)標(biāo)準(zhǔn)化與反標(biāo)準(zhǔn)化全面指南
在數(shù)據(jù)處理和機(jī)器學(xué)習(xí)中,數(shù)據(jù)標(biāo)準(zhǔn)化是一項(xiàng)至關(guān)重要的預(yù)處理步驟,標(biāo)準(zhǔn)化能夠?qū)⒉煌叨群头秶臄?shù)據(jù)轉(zhuǎn)換為相同的標(biāo)準(zhǔn),有助于提高模型的性能和穩(wěn)定性,Python提供了多種庫(kù)和函數(shù)來(lái)執(zhí)行數(shù)據(jù)標(biāo)準(zhǔn)化和反標(biāo)準(zhǔn)化,如Scikit-learn和TensorFlow2024-01-01Python實(shí)現(xiàn)讀取txt文件并畫(huà)三維圖簡(jiǎn)單代碼示例
這篇文章主要介紹了Python實(shí)現(xiàn)讀取txt文件并畫(huà)三維圖簡(jiǎn)單代碼示例,具有一定借鑒價(jià)值,需要的朋友可以參考下。2017-12-12Win7下搭建python開(kāi)發(fā)環(huán)境圖文教程(安裝Python、pip、解釋器)
這篇文章主要為大家分享了Win7下搭建python開(kāi)發(fā)環(huán)境圖文教程,本文主要介紹了安裝Python、pip、解釋器的詳細(xì)步驟,感興趣的小伙伴們可以參考一下2016-05-05