pytorch __init__、forward與__call__的用法小結
1.介紹
當我們使用pytorch來構建網(wǎng)絡框架的時候,也會遇到和tensorflow(tensorflow __init__、build 和call小結)類似的情況,即經(jīng)常會遇到__init__、forward和call這三個互相搭配著使用,那么它們的主要區(qū)別又在哪里呢?
1)__init__主要用來做參數(shù)初始化用,比如我們要初始化卷積的一些參數(shù),就可以放到這里面,這點和tf里面的用法是一樣的
2)forward是表示一個前向傳播,構建網(wǎng)絡層的先后運算步驟
3)__call__的功能其實和forward類似,所以很多時候,我們構建網(wǎng)絡的時候,可以用__call__替代forward函數(shù),但它們兩個的區(qū)別又在哪里呢?
當網(wǎng)絡構建完之后,調(diào)__call__的時候,會去先調(diào)forward,即__call__其實是包了一層forward,所以會導致兩者的功能類似。
在pytorch在nn.Module中,實現(xiàn)了__call__方法,而在__call__方法中調(diào)用了forward函數(shù):
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py
2.代碼
import torch import torch.nn as nn import torch.nn.functional as F class Net(nn.Module): def __init__(self, in_channels, mid_channels, out_channels): super(Net, self).__init__() self.conv0 = torch.nn.Sequential( torch.nn.Conv2d(in_channels, mid_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), torch.nn.LeakyReLU()) self.conv1 = torch.nn.Sequential( torch.nn.Conv2d(mid_channels, out_channels * 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))) def forward(self, x): x = self.conv0(x) x = self.conv1(x) return x class Net(nn.Module): def __init__(self, in_channels, mid_channels, out_channels): super(Net, self).__init__() self.conv0 = torch.nn.Sequential( torch.nn.Conv2d(in_channels, mid_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), torch.nn.LeakyReLU()) self.conv1 = torch.nn.Sequential( torch.nn.Conv2d(mid_channels, out_channels * 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))) def __call__(self, x): x = self.conv0(x) x = self.conv1(x) return x
補充:torch/nn目錄結構以及__init__.py
torch/nn目錄結構以及init.py
torch/nn目錄結構
__init__.py:
from .modules import * #nn.modules 導入modules目錄下內(nèi)容 定義容器modules from .parameter import Parameter #nn.Parameter 導入parameter.py 定義parameter from .parallel import DataParallel #導入parallel目錄下data_parallel.py中的DataParallel類 from . import init #nn.init 導入init.py 參數(shù)初始化 from . import utils #nn.utils 導入utils目錄下內(nèi)容 官網(wǎng)api下nn.utils下api
對于backends, functional.py, _functions 需要在代碼前重新Import
例如我們常用的
import torch.nn.functional as F 就是導入了functional.py
backends和_functions是functional.py實現(xiàn)各種函數(shù)時所用到的。
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。如有錯誤或未考慮完全的地方,望不吝賜教。
相關文章
python中BackgroundScheduler和BlockingScheduler的區(qū)別
這篇文章主要介紹了python中BackgroundScheduler和BlockingScheduler的區(qū)別,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2021-07-0710分鐘教你用python動畫演示深度優(yōu)先算法搜尋逃出迷宮的路徑
這篇文章主要介紹了10分鐘教你用python動畫演示深度優(yōu)先算法搜尋逃出迷宮的路徑,非常不錯,具有一定的參考借鑒價值,需要的朋友可以參考下2019-08-08解決Pyinstaller 打包exe文件 取消dos窗口(黑框框)的問題
今天小編就為大家分享一篇解決Pyinstaller 打包exe文件 取消dos窗口(黑框框)的問題,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-06-06