PyTorch之前向傳播函數(shù)forward用法解讀
之前向傳播函數(shù)forward用法
神經(jīng)網(wǎng)絡(luò)的典型處理如下所示:
- 1.定義可學(xué)習(xí)參數(shù)的網(wǎng)絡(luò)結(jié)構(gòu)(堆疊各層和層的設(shè)計(jì));
- 2.數(shù)據(jù)集輸入;
- 3.對(duì)輸入進(jìn)行處理(由定義的網(wǎng)絡(luò)層進(jìn)行處理),主要體現(xiàn)在網(wǎng)絡(luò)的前向傳播;
- 4.計(jì)算loss ,由Loss層計(jì)算;
- 5.反向傳播求梯度;
- 6.根據(jù)梯度改變參數(shù)值,最簡(jiǎn)單的實(shí)現(xiàn)方式(SGD)為:
weight = weight - learning_rate * gradient
利用PyTorch定義深度網(wǎng)絡(luò)層(Op)示例
class FeatureL2Norm(torch.nn.Module): def __init__(self): super(FeatureL2Norm, self).__init__() def forward(self, feature): epsilon = 1e-6 # print(feature.size()) # print(torch.pow(torch.sum(torch.pow(feature,2),1)+epsilon,0.5).size()) norm = torch.pow(torch.sum(torch.pow(feature,2),1)+epsilon,0.5).unsqueeze(1).expand_as(feature) return torch.div(feature,norm)
class FeatureRegression(nn.Module): def __init__(self, output_dim=6, use_cuda=True): super(FeatureRegression, self).__init__() self.conv = nn.Sequential( nn.Conv2d(225, 128, kernel_size=7, padding=0), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 64, kernel_size=5, padding=0), nn.BatchNorm2d(64), nn.ReLU(inplace=True), ) self.linear = nn.Linear(64 * 5 * 5, output_dim) if use_cuda: self.conv.cuda() self.linear.cuda() def forward(self, x): x = self.conv(x) x = x.view(x.size(0), -1) x = self.linear(x) return x
由上例代碼可以看到,不論是在定義網(wǎng)絡(luò)結(jié)構(gòu)還是定義網(wǎng)絡(luò)層的操作(Op),均需要定義forward函數(shù),下面看一下PyTorch官網(wǎng)對(duì)PyTorch的forward方法的描述:
那么調(diào)用forward方法的具體流程是什么樣的呢?
以一個(gè)Module為例:
- 1.調(diào)用module的call方法
- 2.module的call里面調(diào)用module的forward方法
- 3.forward里面如果碰到Module的子類,回到第1步,如果碰到的是Function的子類,繼續(xù)往下
- 4.調(diào)用Function的call方法
- 5.Function的call方法調(diào)用了Function的forward方法。
- 6.Function的forward返回值
- 7.module的forward返回值
- 8.在module的call進(jìn)行forward_hook操作,然后返回值。
上述中“調(diào)用module的call方法”是指nn.Module 的__call__方法。
定義__call__方法的類可以當(dāng)作函數(shù)調(diào)用,具體參考Python的面向?qū)ο缶幊獭?/p>
也就是說(shuō),當(dāng)把定義的網(wǎng)絡(luò)模型model當(dāng)作函數(shù)調(diào)用的時(shí)候就自動(dòng)調(diào)用定義的網(wǎng)絡(luò)模型的forward方法。
nn.Module 的__call__方法部分源碼
如下所示:
def __call__(self, *input, **kwargs): result = self.forward(*input, **kwargs) for hook in self._forward_hooks.values(): #將注冊(cè)的hook拿出來(lái)用 hook_result = hook(self, input, result) ... return result
可以看到,當(dāng)執(zhí)行model(x)的時(shí)候,底層自動(dòng)調(diào)用forward方法計(jì)算結(jié)果。
具體示例如下:
class LeNet(nn.Module): def __init__(self): super(LeNet, self).__init__() layer1 = nn.Sequential() layer1.add_module('conv1', nn.Conv(1, 6, 3, padding=1)) layer1.add_moudle('pool1', nn.MaxPool2d(2, 2)) self.layer1 = layer1 layer2 = nn.Sequential() layer2.add_module('conv2', nn.Conv(6, 16, 5)) layer2.add_moudle('pool2', nn.MaxPool2d(2, 2)) self.layer2 = layer2 layer3 = nn.Sequential() layer3.add_module('fc1', nn.Linear(400, 120)) layer3.add_moudle('fc2', nn.Linear(120, 84)) layer3.add_moudle('fc3', nn.Linear(84, 10)) self.layer3 = layer3 def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = x.view(x.size(0), -1) x = self.layer3(x) return x
- model = LeNet()
- y = model(x)
如上則調(diào)用網(wǎng)絡(luò)模型定義的forward方法。
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
matplotlib繪制多個(gè)子圖(subplot)的方法
這篇文章主要介紹了matplotlib繪制多個(gè)子圖(subplot)的方法,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2019-12-12python實(shí)現(xiàn)簡(jiǎn)單的socket server實(shí)例
這篇文章主要介紹了python實(shí)現(xiàn)簡(jiǎn)單的socket server的方法,實(shí)例分析了Python中socket的操作技巧,非常具有實(shí)用價(jià)值,需要的朋友可以參考下2015-04-04python數(shù)據(jù)預(yù)處理之將類別數(shù)據(jù)轉(zhuǎn)換為數(shù)值的方法
下面小編就為大家?guī)?lái)一篇python數(shù)據(jù)預(yù)處理之將類別數(shù)據(jù)轉(zhuǎn)換為數(shù)值的方法。小編覺得挺不錯(cuò)的,現(xiàn)在就分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧2017-07-07Python?Pandas如何獲取和修改任意位置的值(at,iat,loc,iloc)
在我們對(duì)數(shù)據(jù)進(jìn)行選擇之后,需要對(duì)特定的數(shù)據(jù)進(jìn)行設(shè)置更改,設(shè)置,下面這篇文章主要給大家介紹了關(guān)于Python?Pandas如何獲取和修改任意位置的值(at,iat,loc,iloc)的相關(guān)資料,文中通過(guò)實(shí)例代碼介紹的非常詳細(xì),需要的朋友可以參考下2022-01-01Python人臉檢測(cè)實(shí)戰(zhàn)之疲勞檢測(cè)
本文主要介紹了實(shí)現(xiàn)疲勞檢測(cè):如果眼睛已經(jīng)閉上了一段時(shí)間,我們會(huì)認(rèn)為他們開始打瞌睡并發(fā)出警報(bào)來(lái)喚醒他們并引起他們的注意。感興趣的朋友可以了解一下2021-12-12搭建python django虛擬環(huán)境完整步驟詳解
這篇文章主要介紹了搭建python django虛擬環(huán)境完整步驟詳解,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2019-07-07對(duì)python中的乘法dot和對(duì)應(yīng)分量相乘multiply詳解
今天小編就為大家分享一篇對(duì)python中的乘法dot和對(duì)應(yīng)分量相乘multiply詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-11-11淺談在django中使用redirect重定向數(shù)據(jù)傳輸?shù)膯?wèn)題
這篇文章主要介紹了淺談在django中使用redirect重定向數(shù)據(jù)傳輸?shù)膯?wèn)題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-03-03