PyTorch模型創(chuàng)建與nn.Module構建
模型創(chuàng)建與nn.Module
文章和代碼已經(jīng)歸檔至【Github倉庫:https://github.com/timerring/dive-into-AI 】
創(chuàng)建網(wǎng)絡模型通常有2個要素:
- 構建子模塊
- 拼接子模塊
class LeNet(nn.Module): # 子模塊創(chuàng)建 ? ?def __init__(self, classes): ? ? ? ?super(LeNet, self).__init__() ? ? ? ?self.conv1 = nn.Conv2d(3, 6, 5) ? ? ? ?self.conv2 = nn.Conv2d(6, 16, 5) ? ? ? ?self.fc1 = nn.Linear(16*5*5, 120) ? ? ? ?self.fc2 = nn.Linear(120, 84) ? ? ? ?self.fc3 = nn.Linear(84, classes) # 子模塊拼接 ? ?def forward(self, x): ? ? ? ?out = F.relu(self.conv1(x)) ? ? ? ?out = F.max_pool2d(out, 2) ? ? ? ?out = F.relu(self.conv2(out)) ? ? ? ?out = F.max_pool2d(out, 2) ? ? ? ?out = out.view(out.size(0), -1) ? ? ? ?out = F.relu(self.fc1(out)) ? ? ? ?out = F.relu(self.fc2(out)) ? ? ? ?out = self.fc3(out) ? ? ? ?return out
調(diào)用net = LeNet(classes=2)
創(chuàng)建模型時,會調(diào)用__init__()
方法創(chuàng)建模型的子模塊。
訓練調(diào)用outputs = net(inputs)
時,會進入module.py
的call()
函數(shù)中:
def __call__(self, *input, **kwargs): ? ? ? ?for hook in self._forward_pre_hooks.values(): ? ? ? ? ? ?result = hook(self, input) ? ? ? ? ? ?if result is not None: ? ? ? ? ? ? ? ?if not isinstance(result, tuple): ? ? ? ? ? ? ? ? ? ?result = (result,) ? ? ? ? ? ? ? ?input = result ? ? ? ?if torch._C._get_tracing_state(): ? ? ? ? ? ?result = self._slow_forward(*input, **kwargs) ? ? ? ?else: ? ? ? ? ? ?result = self.forward(*input, **kwargs) ? ? ? ... ? ? ? ... ? ? ? ...
最終會調(diào)用result = self.forward(*input, **kwargs)
函數(shù),該函數(shù)會進入模型的forward()
函數(shù)中,進行前向傳播。
在 torch.nn
中包含 4 個模塊,如下圖所示。
本次重點就在于nn.Model的解析:
nn.Module
nn.Module
有 8 個屬性,都是OrderDict
(有序字典)的結構。在 LeNet 的__init__()
方法中會調(diào)用父類nn.Module
的__init__()
方法,創(chuàng)建這 8 個屬性。
def __init__(self): ? ? ? ?""" ? ? ? Initializes internal Module state, shared by both nn.Module and ScriptModule. ? ? ? """ ? ? ? ?torch._C._log_api_usage_once("python.nn_module") ? ? ? ? ?self.training = True ? ? ? ?self._parameters = OrderedDict() ? ? ? ?self._buffers = OrderedDict() ? ? ? ?self._backward_hooks = OrderedDict() ? ? ? ?self._forward_hooks = OrderedDict() ? ? ? ?self._forward_pre_hooks = OrderedDict() ? ? ? ?self._state_dict_hooks = OrderedDict() ? ? ? ?self._load_state_dict_pre_hooks = OrderedDict() ? ? ? ?self._modules = OrderedDict()
- _parameters 屬性:存儲管理 nn.Parameter 類型的參數(shù)
- _modules 屬性:存儲管理 nn.Module 類型的參數(shù)
- _buffers 屬性:存儲管理緩沖屬性,如 BN 層中的 running_mean
- 5 個 *_hooks 屬性:存儲管理鉤子函數(shù)
LeNet 的__init__()
中創(chuàng)建了 5 個子模塊,nn.Conv2d()
和nn.Linear()
都繼承于nn.module
,即一個 module 都是包含多個子 module 的。
class LeNet(nn.Module): # 子模塊創(chuàng)建 ? ?def __init__(self, classes): ? ? ? ?super(LeNet, self).__init__() ? ? ? ?self.conv1 = nn.Conv2d(3, 6, 5) ? ? ? ?self.conv2 = nn.Conv2d(6, 16, 5) ? ? ? ?self.fc1 = nn.Linear(16*5*5, 120) ? ? ? ?self.fc2 = nn.Linear(120, 84) ? ? ? ?self.fc3 = nn.Linear(84, classes) ? ? ? ?... ? ? ? ?... ? ? ? ?...
當調(diào)用net = LeNet(classes=2)
創(chuàng)建模型后,net
對象的 modules 屬性就包含了這 5 個子網(wǎng)絡模塊。
下面看下每個子模塊是如何添加到 LeNet 的_modules
屬性中的。以self.conv1 = nn.Conv2d(3, 6, 5)
為例,當我們運行到這一行時,首先 Step Into 進入 Conv2d
的構造,然后 Step Out。右鍵Evaluate Expression
查看nn.Conv2d(3, 6, 5)
的屬性。
上面說了Conv2d
也是一個 module,里面的_modules
屬性為空,_parameters
屬性里包含了該卷積層的可學習參數(shù),這些參數(shù)的類型是 Parameter,繼承自 Tensor。
此時只是完成了nn.Conv2d(3, 6, 5)
module 的創(chuàng)建。還沒有賦值給self.conv1
。在nn.Module
里有一個機制,會攔截所有的類屬性賦值操作(self.conv1
是類屬性) ,進入到__setattr__()
函數(shù)中。我們再次 Step Into 就可以進入__setattr__()
。
def __setattr__(self, name, value): ? ? ? ?def remove_from(*dicts): ? ? ? ? ? ?for d in dicts: ? ? ? ? ? ? ? ?if name in d: ? ? ? ? ? ? ? ? ? ?del d[name] ? ? ? ? ?params = self.__dict__.get('_parameters') ? ? ? ?if isinstance(value, Parameter): ? ? ? ? ? ?if params is None: ? ? ? ? ? ? ? ?raise AttributeError( ? ? ? ? ? ? ? ? ? ?"cannot assign parameters before Module.__init__() call") ? ? ? ? ? ?remove_from(self.__dict__, self._buffers, self._modules) ? ? ? ? ? ?self.register_parameter(name, value) ? ? ? ?elif params is not None and name in params: ? ? ? ? ? ?if value is not None: ? ? ? ? ? ? ? ?raise TypeError("cannot assign '{}' as parameter '{}' " ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?"(torch.nn.Parameter or None expected)" ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? .format(torch.typename(value), name)) ? ? ? ? ? ?self.register_parameter(name, value) ? ? ? ?else: ? ? ? ? ? ?modules = self.__dict__.get('_modules') ? ? ? ? ? ?if isinstance(value, Module): ? ? ? ? ? ? ? ?if modules is None: ? ? ? ? ? ? ? ? ? ?raise AttributeError( ? ? ? ? ? ? ? ? ? ? ? ?"cannot assign module before Module.__init__() call") ? ? ? ? ? ? ? ?remove_from(self.__dict__, self._parameters, self._buffers) ? ? ? ? ? ? ? ?modules[name] = value ? ? ? ? ? ?elif modules is not None and name in modules: ? ? ? ? ? ? ? ?if value is not None: ? ? ? ? ? ? ? ? ? ?raise TypeError("cannot assign '{}' as child module '{}' " ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?"(torch.nn.Module or None expected)" ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? .format(torch.typename(value), name)) ? ? ? ? ? ? ? ?modules[name] = value ? ? ? ? ? ... ? ? ? ? ? ... ? ? ? ? ? ...
在這里判斷 value 的類型是Parameter
還是Module
,存儲到對應的有序字典中。
這里nn.Conv2d(3, 6, 5)
的類型是Module
,因此會執(zhí)行modules[name] = value
,key 是類屬性的名字conv1
,value 就是nn.Conv2d(3, 6, 5)
。
總結
- 一個 module 里可包含多個子 module。比如 LeNet 是一個 Module,里面包括多個卷積層、池化層、全連接層等子 module
- 一個 module 相當于一個運算,必須實現(xiàn) forward() 函數(shù)
- 每個 module 都有 8 個字典管理自己的屬性
以上就是PyTorch模型創(chuàng)建與nn.Module構建的詳細內(nèi)容,更多關于PyTorch模型創(chuàng)建nn.Module的資料請關注腳本之家其它相關文章!
相關文章
Python BeautifulSoup基本用法詳解(通過標簽及class定位元素)
這篇文章主要介紹了Python BeautifulSoup基本用法(通過標簽及class定位元素),本文給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下2021-08-08Python中dictionary items()系列函數(shù)的用法實例
這篇文章主要介紹了Python中dictionary items()系列函數(shù)的用法,很實用的函數(shù),需要的朋友可以參考下2014-08-08python中itertools模塊zip_longest函數(shù)詳解
itertools模塊包含創(chuàng)建高效迭代器的函數(shù),這些函數(shù)的返回值不是list,而是iterator(可迭代對象),可以用各種方式對數(shù)據(jù)執(zhí)行循環(huán)操作,今天我們來詳細探討下zip_longest函數(shù)2018-06-06