欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

PyTorch使用torch.nn.Module模塊自定義模型結(jié)構(gòu)方式

 更新時(shí)間:2024年02月26日 10:20:58   作者:精致的螺旋線  
這篇文章主要介紹了PyTorch使用torch.nn.Module模塊自定義模型結(jié)構(gòu)方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教

以實(shí)現(xiàn)LeNet網(wǎng)絡(luò)為例,來(lái)學(xué)習(xí)使用pytorch如何搭建一個(gè)神經(jīng)網(wǎng)絡(luò)。

LeNet網(wǎng)絡(luò)的結(jié)構(gòu)如下圖所示。

一、使用torch.nn.Module類(lèi)構(gòu)建網(wǎng)絡(luò)模型

搭建自己的網(wǎng)絡(luò)模型,我們需要新建一個(gè)類(lèi),讓它繼承torch.nn.Module類(lèi),并必須重寫(xiě)Module類(lèi)中的__init__()和forward()函數(shù)。

init()函數(shù)用來(lái)申明模型中各層的定義,forward()函數(shù)用來(lái)描述各層之間的連接關(guān)系,定義前向傳播計(jì)算的過(guò)程。

也就是說(shuō)__init__()函數(shù)只是用來(lái)定義層,但并沒(méi)有將它們連接起來(lái),forward()函數(shù)的作用就是將這些定義好的層連接成網(wǎng)絡(luò)。

使用上述方法實(shí)現(xiàn)LeNet網(wǎng)絡(luò)的代碼如下。

import torch.nn as nn
class LeNet(nn.Module):
	def __init__(self):
		super().__init__()
		self.C1 = nn.Conv2d(1, 6, 5)
		self.sig = nn.Sigmoid()
		self.S2 = nn.MaxPool2d(2, 2)
		
		self.C3 = nn.Conv2d(6, 16, 5)
		self.S4 = nn.MaxPool2d(2, 2)

		self.C5 = nn.Conv2d(16, 120, 5)
		self.C6 = nn.Linear(120, 84)
		self.C7 = nn.Linear(84, 10)
	
	def forward(self, x):
		x1 = self.C1(x)
		x2 = self.sig(x1)
		x3 = self.S2(x2)
	
		x4 = self.C3(x3)
		x5 = self.sig(x4)
		x6 = self.S4(x5)
	
		x7 = self.C5(x6)
		x8 = self.C6(x7)
		y = self.C7(x8)
		return y

net = LeNet()
print(net)

結(jié)果為

在__init__()函數(shù)中,實(shí)例化了nn.Linear()、nn.Conv2d()這種pytorch封裝好的類(lèi),用來(lái)定義全連接層、卷積層等網(wǎng)絡(luò)層,并規(guī)定好它們的參數(shù)。

例如,self.C1 = nn.Conv2d(1, 6, 5)表示定義一個(gè)卷積層,它的卷積核輸入通道為1、輸出通道為6,大小為5×5。

真正向這個(gè)卷積層輸入數(shù)據(jù)是在forward()函數(shù)中,x1 = self.C1(x)表示將輸入x喂給卷積層,并得到輸出x1。

二、引入torch.nn.functional實(shí)現(xiàn)層的運(yùn)算

引入torch.nn.functional模塊中的函數(shù),可以簡(jiǎn)化__init__()函數(shù)中的內(nèi)容。

在__init__()函數(shù)中,我們可以只定義具有需要學(xué)習(xí)的參數(shù)的層,如卷積層、線性層,它們的權(quán)重都需要學(xué)習(xí)。

對(duì)于不需要學(xué)習(xí)參數(shù)的層,我們不需要在__init__()函數(shù)中定義,只需要在forward()函數(shù)中引入torch.nn.functional類(lèi)中相關(guān)函數(shù)的調(diào)用。

例如LeNet中,我們?cè)赺_init__()中只定義了卷積層和全連接層。池化層和激活函數(shù)只需要在forward()函數(shù)中,調(diào)用torch.nn.functional中的函數(shù)進(jìn)行實(shí)現(xiàn)即可。

import torch.nn as nn
import torch.nn.functional as F
class LeNet(nn.Module):
	def __init__(self):
		super().__init__()
		self.C1 = nn.Conv2d(1, 6, 5)
		self.C3 = nn.Conv2d(6, 16, 5)
		self.C5 = nn.Conv2d(16, 120, 5)
		self.C6 = nn.Linear(120, 84)
		self.C7 = nn.Linear(84, 10)
	
	def forward(self, x):
		x1 = self.C1(x)
		x2 = F.sigmoid(x1)
		x3 = F.max_pool2d(x2)
	
		x4 = self.C3(x3)
		x5 = F.sigmoid(x4)
		x6 = F.max_pool2d(x5)
	
		x7 = self.C5(x6)
		x8 = self.C6(x7)
		y = self.C7(x8)
		return y

net = LeNet()
print(net)

運(yùn)行結(jié)果為

當(dāng)然,torch.nn.functional中也對(duì)需要學(xué)習(xí)參數(shù)的層進(jìn)行了實(shí)現(xiàn),包括卷積層conv2d()和線性層linear(),但pytorch官方推薦我們只對(duì)不需要學(xué)習(xí)參數(shù)的層使用nn.functional中的函數(shù)。

對(duì)于一個(gè)層,使用nn.Xxx實(shí)現(xiàn)和使用nn.functional.xxx()實(shí)現(xiàn)的區(qū)別為:

1.nn.Xxx是一個(gè)類(lèi),繼承自nn.Modules,因此內(nèi)部會(huì)有很多屬性和方法,如train(), eval(),load_state_dict, state_dict 等。

2.nn.functional.xxx()僅僅是一個(gè)函數(shù)。作為一個(gè)類(lèi),nn.Xxx需要先實(shí)例化并傳入?yún)?shù),然后以函數(shù)調(diào)用的方式向?qū)嵗瘜?duì)象中喂入輸入數(shù)據(jù)。

conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding)
output = conv(input)

nn.functional.xxx()是在調(diào)用時(shí)同時(shí)傳入輸入數(shù)據(jù)和設(shè)置參數(shù)。

output = nn.functional.conv2d(input, weight, bias, padding)

3.nn.Xxx不需要自己定義和管理權(quán)重,但nn.functional.xxx()需要自己定義權(quán)重,每次調(diào)用時(shí)要手動(dòng)傳入。

三、Sequential類(lèi)

1. 基礎(chǔ)使用

Sequential類(lèi)繼承自Module類(lèi)。對(duì)于一個(gè)簡(jiǎn)單的序貫?zāi)P停梢圆槐刈约涸俣鄬?xiě)一個(gè)類(lèi)繼承Module類(lèi),而是直接使用pytorch提供的Sequential類(lèi),來(lái)將若干層或若干子模塊直接包裝成一個(gè)大的模塊。

例如在LeNet中,我們直接將各個(gè)層按順序排列好,然后用Sequential類(lèi)包裝一下,就可以方便地構(gòu)建好一個(gè)神經(jīng)網(wǎng)路了。

import torch.nn as nn
net = nn.Sequential(
		nn.Conv2d(1, 6, 5),
		nn.Sigmoid(),
		nn.MaxPool2d(2, 2),
		
		nn.Conv2d(6, 16, 5),
		nn.Sigmoid(),
		nn.MaxPool2d(2, 2),

		nn.Conv2d(16, 120, 5),
		nn.Linear(120, 84),
		nn.Linear(84, 10)
	)

print(net)
print(net[2]) #通過(guò)索引可以獲取到層

運(yùn)行結(jié)果為

上面這種方法沒(méi)有給每一個(gè)層指定名稱(chēng),默認(rèn)使用層的索引數(shù)0、1、2來(lái)命名。我們可以通過(guò)索引值來(lái)直接獲對(duì)應(yīng)的層的信息。

當(dāng)然,我們也可以給層指定名稱(chēng),但我們并不能通過(guò)名稱(chēng)獲取層,想獲取層依舊要使用索引數(shù)字。

import torch.nn as nn
from collections import OrderedDict
net = nn.Sequential(OrderedDict([
		('C1', nn.Conv2d(1, 6, 5)),
		('Sig1', nn.Sigmoid()),
		('S2', nn.MaxPool2d(2, 2)),
		
		('C3', nn.Conv2d(6, 16, 5)),
		('Sig2', nn.Sigmoid()),
		('S4', nn.MaxPool2d(2, 2)),

		('C5', nn.Conv2d(16, 120, 5)),
		('C6', nn.Linear(120, 84)),
		('C7', nn.Linear(84, 10))
	]))

print(net)
print(net[2]) #通過(guò)索引可以獲取到層

運(yùn)行結(jié)果為

也可以使用add_module函數(shù)向Sequential()中添加層。

import torch.nn as nn
net = nn.Sequential()
net.add_module('C1', nn.Conv2d(1, 6, 5))
net.add_module('Sig1', nn.Sigmoid())
net.add_module('S2', nn.MaxPool2d(2, 2))

net.add_module('C3', nn.Conv2d(6, 16, 5))
net.add_module('Sig2', nn.Sigmoid())
net.add_module('S4', nn.MaxPool2d(2, 2))

net.add_module('C5', nn.Conv2d(16, 120, 5))
net.add_module('C6', nn.Linear(120, 84))
net.add_module('C7', nn.Linear(84, 10))

print(net)
print(net[2])

輸出為

2. 使用Sequential類(lèi)將層包裝成子模塊

Sequential類(lèi)也可以應(yīng)用到自定義Module類(lèi)的方法中,用來(lái)將幾個(gè)層包裝成一個(gè)大層(塊)。

當(dāng)然Sequential依舊有三種使用方法,我們這里只使用第一種作為舉例。

import torch.nn as nn
class LeNet(nn.Module):
	def __init__(self):
		super().__init__()
		self.conv = nn.Sequential(
			nn.Conv2d(1, 6, 5),
			nn.Sigmoid(),
			nn.MaxPool2d(2, 2),
			nn.Conv2d(6, 16, 5),
			nn.Sigmoid(),
			nn.MaxPool2d(2, 2)
		)

		self.fc = nn.Sequential(
			nn.Conv2d(16, 120, 5),
			nn.Linear(120, 84),
			nn.Linear(84, 10)
		)
	
	def forward(self, x):
		x1 = self.conv(x)
		y = self.fc(x1)
		return y

net = LeNet()
print(net)

輸出為

四、ModuleList類(lèi)和ModuleDict類(lèi)

ModuleList類(lèi)和ModuleDict類(lèi)都是Modules類(lèi)的子類(lèi),和Sequential類(lèi)似,它也可以對(duì)若干層或子模塊進(jìn)行打包,列表化的構(gòu)造網(wǎng)絡(luò)。

但與Sequential類(lèi)不同的是,這兩個(gè)類(lèi)只是將這些層定義并排列成列表(List)或字典(Dict),但并沒(méi)有將它們連接起來(lái),也就是說(shuō)并沒(méi)有實(shí)現(xiàn)forward()函數(shù)。

因此,這兩個(gè)類(lèi)并不要求相鄰層的輸入輸出維度匹配,也不能直接向ModuleList和ModuleDict中直接喂入輸入數(shù)據(jù)。

ModuleList的訪問(wèn)方法和普通的List類(lèi)似。

net = nn.ModuleList([
      nn.Linear(784, 256),
      nn.ReLU()
    ])
net.append(nn.Linear(256, 20)) # ModuleList可以像普通的List以下進(jìn)行append操作
print(net[-1]) # ModuleList的訪問(wèn)方法與List也相似
print(net)
# X = torch.zeros(1, 784)
# net(X) # 出錯(cuò)。向ModuleList中輸入數(shù)據(jù)會(huì)出錯(cuò),因?yàn)镸oduleList的作用僅僅是存儲(chǔ)
		 # 網(wǎng)絡(luò)的各個(gè)模塊,但并不連接它們,即沒(méi)有實(shí)現(xiàn)forward()

輸出為

ModuleDict的使用方法也和普通的字典類(lèi)似。

net = nn.ModuleDict({
    'linear': nn.Linear(784, 256),
    'act': nn.ReLU(),
})
net['output'] = nn.Linear(256, 10) # 添加
print(net['linear']) # 訪問(wèn)
print(net.output)
print(net)
# net(torch.zeros(1, 784)) # 會(huì)報(bào)NotImplementedError

輸出為

ModuleList和ModuleDict的使用是為了在定義前向傳播時(shí)能更加靈活。下面是官網(wǎng)上的一個(gè)關(guān)于ModuleList使用的例子。

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

    def forward(self, x):
        # ModuleList can act as an iterable, or be indexed using ints
        for i, l in enumerate(self.linears):
            x = self.linears[i // 2](x) + l(x)
        return x

此外,ModuleList和ModuleDict里,所有子模塊的參數(shù)都會(huì)被自動(dòng)添加到神經(jīng)網(wǎng)絡(luò)中,這一點(diǎn)是與普通的List和Dict不同的。

舉個(gè)例子。

class Module_ModuleList(nn.Module):
    def __init__(self):
        super(Module_ModuleList, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10)])

class Module_List(nn.Module):
    def __init__(self):
        super(Module_List, self).__init__()
        self.linears = [nn.Linear(10, 10)]

net1 = Module_ModuleList()
net2 = Module_List()

print("net1:")
for p in net1.parameters():
    print(p.size())

print("net2:")
for p in net2.parameters():
    print(p)

輸出為

五、向模型中輸入數(shù)據(jù)

假設(shè)我們向模型中輸入的數(shù)據(jù)為input,從模型中得到的前向傳播結(jié)果為output,則輸入數(shù)據(jù)的方法為

output = net(input)

net是對(duì)象名,我們直接將輸入作為參數(shù)傳入到對(duì)象名中,而并沒(méi)有顯示的調(diào)用forward()函數(shù),就完成了前向傳播的計(jì)算。

上面的寫(xiě)法其實(shí)等價(jià)于

output = net.forward(input)

這是因?yàn)樵趖orch.nn.Module類(lèi)中,定義了__call__()函數(shù),其中就包括了對(duì)forward()方法的調(diào)用。

在python語(yǔ)法中__call__()方法使得類(lèi)實(shí)例對(duì)象可以像調(diào)用普通函數(shù)那樣,以“對(duì)象名()”的形式使用,并執(zhí)行__call__()函數(shù)體中的內(nèi)容。

總結(jié)

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

最新評(píng)論