PyTorch的nn.Module類的定義和使用介紹
在PyTorch中,nn.Module
類是構(gòu)建神經(jīng)網(wǎng)絡(luò)模型的基礎(chǔ)類,所有自定義的層、模塊或整個(gè)神經(jīng)網(wǎng)絡(luò)架構(gòu)都需要繼承自這個(gè)類。nn.Module
類提供了一系列屬性和方法用于管理網(wǎng)絡(luò)的結(jié)構(gòu)和訓(xùn)練過(guò)程中的計(jì)算。
1. PyTorch中nn.Module基類的定義
在PyTorch中,nn.Module
是所有神經(jīng)網(wǎng)絡(luò)模塊的基礎(chǔ)類。盡管這里不能提供完整的源代碼(因?yàn)樗婕按罅績(jī)?nèi)部邏輯和API細(xì)節(jié)),但我可以給出一個(gè)簡(jiǎn)化的 nn.Module
類的基本結(jié)構(gòu),并描述其關(guān)鍵方法:
Python
# 此處簡(jiǎn)化了 nn.Module 的定義,實(shí)際 PyTorch 源碼更為復(fù)雜 import torch class nn.Module: def __init__(self): super().__init__() # 存儲(chǔ)子模塊的字典 self._modules = dict() # 參數(shù)和緩沖區(qū)的集合 self._parameters = OrderedDict() self._buffers = OrderedDict() def __setattr__(self, name, value): # 特殊處理參數(shù)和子模塊的設(shè)置 if isinstance(value, nn.Parameter): # 注冊(cè)參數(shù)到 _parameters 字典中 self.register_parameter(name, value) elif isinstance(value, Module) and not isinstance(value, Container): # 注冊(cè)子模塊到 _modules 字典中 self.add_module(name, value) else: # 對(duì)于普通屬性,執(zhí)行標(biāo)準(zhǔn)的 setattr 操作 object.__setattr__(self, name, value) def add_module(self, name: str, module: 'Module') -> None: r"""添加子模塊到當(dāng)前模塊""" # 內(nèi)部實(shí)現(xiàn)細(xì)節(jié)省略... self._modules[name] = module def register_parameter(self, name: str, param: nn.Parameter) -> None: r"""注冊(cè)一個(gè)新的參數(shù)""" # 內(nèi)部實(shí)現(xiàn)細(xì)節(jié)省略... self._parameters[name] = param def parameters(self, recurse: bool = True) -> Iterator[nn.Parameter]: r"""返回一個(gè)包含所有可學(xué)習(xí)參數(shù)的迭代器""" # 內(nèi)部實(shí)現(xiàn)細(xì)節(jié)省略... return iter(getattr(self, '_parameters', {}).values()) def forward(self, *input: Tensor) -> Tensor: r"""定義前向傳播操作""" raise NotImplementedError # 還有許多其他的方法如:zero_grad、to、state_dict、load_state_dict 等等... # 在自定義模型時(shí),繼承 nn.Module 并重寫 forward 方法 class MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() self.linear = nn.Linear(20, 30) def forward(self, x): return self.linear(x)
這段代碼定義了 PyTorch 中 nn.Module
類的基礎(chǔ)結(jié)構(gòu)。在實(shí)際的 PyTorch 源碼中,nn.Module
的實(shí)現(xiàn)更為復(fù)雜,但這里簡(jiǎn)化后的代碼片段展示了其核心部分。
class nn.Module:
:定義了一個(gè)名為nn.Module
的類,它是所有神經(jīng)網(wǎng)絡(luò)模塊(如卷積層、全連接層、激活函數(shù)等)的基類。def __init__(self):
:這是類的初始化方法,在創(chuàng)建一個(gè)nn.Module
或其子類實(shí)例時(shí)會(huì)被自動(dòng)調(diào)用。這里的self
參數(shù)代表將來(lái)創(chuàng)建出的實(shí)例自身。super().__init__()
:調(diào)用父類的構(gòu)造函數(shù),確保基類的初始化邏輯得到執(zhí)行。在這里,雖然沒有顯示指定父類,但因?yàn)?nbsp;nn.Module
是其他所有模塊的基類,所以實(shí)際上它是在調(diào)用自身的構(gòu)造函數(shù)來(lái)初始化內(nèi)部狀態(tài)。self._modules = dict()
:聲明并初始化一個(gè)字典_modules
,用于存儲(chǔ)模型中的所有子模塊。每個(gè)子模塊是一個(gè)同樣繼承自nn.Module
的對(duì)象,并通過(guò)名稱進(jìn)行索引。這樣可以方便地管理和組織復(fù)雜的層次化網(wǎng)絡(luò)結(jié)構(gòu)。self._parameters = OrderedDict()
:使用有序字典(OrderedDict)類型聲明和初始化一個(gè)變量_parameters
,用來(lái)保存模型的所有可學(xué)習(xí)參數(shù)(權(quán)重和偏置等)。有序字典保證參數(shù)按添加順序存儲(chǔ),這對(duì)于一些依賴參數(shù)順序的操作(如加載預(yù)訓(xùn)練模型的權(quán)重)是必要的。self._buffers = OrderedDict()
:類似地,聲明并初始化另一個(gè)有序字典_buffers
,用于存儲(chǔ)模型中的緩沖區(qū)(Buffer)。緩沖區(qū)通常是不參與梯度計(jì)算的變量,比如在 BatchNorm 層中存儲(chǔ)的均值和方差統(tǒng)計(jì)量。
總結(jié)來(lái)說(shuō),這段代碼為構(gòu)建神經(jīng)網(wǎng)絡(luò)模型提供了一個(gè)基礎(chǔ)框架,其中包含了對(duì)子模塊、參數(shù)和緩沖區(qū)的管理機(jī)制,這些基礎(chǔ)設(shè)施對(duì)于構(gòu)建、運(yùn)行和優(yōu)化深度學(xué)習(xí)模型至關(guān)重要。在自定義模塊時(shí),開發(fā)者通常會(huì)在此基礎(chǔ)上添加更多的層和功能,并重寫 forward
方法以定義前向傳播邏輯。
以上代碼僅展示了 nn.Module
類的部分核心功能,實(shí)際上 PyTorch 官方的實(shí)現(xiàn)會(huì)更加詳盡和復(fù)雜,包括更多的內(nèi)部機(jī)制來(lái)支持模塊化構(gòu)建深度學(xué)習(xí)模型。開發(fā)者通常需要繼承 nn.Module
類并重寫 forward
方法來(lái)實(shí)現(xiàn)自定義的神經(jīng)網(wǎng)絡(luò)層或整個(gè)網(wǎng)絡(luò)架構(gòu)。
2. nn.Module類中的關(guān)鍵屬性和方法
在PyTorch的nn.Module
類中,有以下幾個(gè)關(guān)鍵屬性和方法:
__init__(self, ...)
: 這是每個(gè)派生自nn.Module
的類都必須重載的方法,在該方法中定義并初始化模型的所有層和參數(shù)。.parameters()
:這是一個(gè)動(dòng)態(tài)生成器,用于獲取模型的所有可學(xué)習(xí)參數(shù)(權(quán)重和偏置等)。這些參數(shù)都是nn.Parameter
類型的張量,在訓(xùn)練過(guò)程中可以自動(dòng)計(jì)算梯度。
示例:
Python
for param in model.parameters(): print(param)
.buffers()
:類似于.parameters()
,但返回的是模塊內(nèi)定義的非可學(xué)習(xí)緩沖區(qū)變量,例如一些統(tǒng)計(jì)量或臨時(shí)存儲(chǔ)數(shù)據(jù)。
.named_parameters()
和 .named_buffers()
:與上面類似,但返回元組形式的迭代器,每個(gè)元素是一個(gè)包含名稱和對(duì)應(yīng)參數(shù)/緩沖區(qū)的元組,便于按名稱訪問(wèn)特定參數(shù)。
.children()
和 .modules()
:這兩個(gè)方法分別返回一個(gè)包含當(dāng)前模塊所有直接子模塊的迭代器和包含所有層級(jí)子模塊(包括自身)的迭代器。
.state_dict()
:該方法返回一個(gè)字典,包含了模型的所有狀態(tài)信息(即參數(shù)和緩沖區(qū)),方便保存和恢復(fù)模型。
.train()
和 .eval()
:方法用于切換模型的運(yùn)行模式。在訓(xùn)練模式下,某些層如批次歸一化層會(huì)有不同的行為;而在評(píng)估模式下,通常會(huì)禁用dropout層并使用移動(dòng)平均統(tǒng)計(jì)量(對(duì)于批歸一化層)。
._parameters
和 ._buffers
:這是內(nèi)部字典屬性,分別儲(chǔ)存了模型的所有參數(shù)和緩沖區(qū),雖然不推薦直接操作,但在自定義模塊時(shí)可能需要用到。
.to(device)
:將整個(gè)模型及其參數(shù)轉(zhuǎn)移到指定設(shè)備上,比如從CPU到GPU。
其他內(nèi)部維護(hù)的屬性,如 _forward_pre_hooks
和 _forward_hooks
用于實(shí)現(xiàn)向前傳播過(guò)程中的預(yù)處理和后處理鉤子,以及 _backward_hooks
用于反向傳播過(guò)程中的鉤子,這些通常在高級(jí)功能開發(fā)時(shí)使用。
forward(self, input)
:定義模型如何處理輸入數(shù)據(jù)并生成輸出,這是構(gòu)建神經(jīng)網(wǎng)絡(luò)的核心部分,每次調(diào)用模型實(shí)例都會(huì)執(zhí)行 forward
函數(shù)。
add_module(name, module)
:將一個(gè)子模塊添加到當(dāng)前模塊,并通過(guò)給定的名字引用它。
register_parameter(name, param)
:注冊(cè)一個(gè)新的參數(shù)到模塊中。
zero_grad()
:將模塊及其所有子模塊的參數(shù)梯度設(shè)置為零,通常在優(yōu)化器更新前調(diào)用。
train(mode=True)
和 eval()
:切換模型的工作模式,在訓(xùn)練模式下會(huì)啟用批次歸一化層和丟棄層等依賴于訓(xùn)練/預(yù)測(cè)階段的行為,在評(píng)估模式下則關(guān)閉這些行為。
state_dict()
和 load_state_dict(state_dict)
:用于保存和加載模型的狀態(tài)字典,其中包括模型的權(quán)重和配置信息,便于模型持久化和遷移。
其他與模型保存和恢復(fù)相關(guān)的方法,例如 save(filename)
、load(filename)
等。
請(qǐng)注意,具體的屬性和方法可能會(huì)隨著PyTorch版本的更新而有所增減或改進(jìn)。
3. nn.Module子類的定義和使用
在PyTorch中,nn.Module
類扮演著核心角色,它是構(gòu)建任何自定義神經(jīng)網(wǎng)絡(luò)層、復(fù)雜模塊或完整神經(jīng)網(wǎng)絡(luò)架構(gòu)的基礎(chǔ)構(gòu)建塊。通過(guò)繼承 nn.Module
并在其子類中定義模型結(jié)構(gòu)和前向傳播邏輯(forward()
方法),開發(fā)者能夠方便地搭建并訓(xùn)練深度學(xué)習(xí)模型。
具體來(lái)說(shuō),在自定義一個(gè) nn.Module
子類時(shí),通常會(huì)執(zhí)行以下操作:
初始化 (__init__
):在類的初始化方法中定義并實(shí)例化所有需要的層、參數(shù)和其他組件。
Python
class MyModel(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(MyModel, self).__init__() self.layer1 = nn.Linear(input_size, hidden_size) self.layer2 = nn.Linear(hidden_size, output_size)
前向傳播 (forward
):實(shí)現(xiàn)前向傳播函數(shù)來(lái)描述輸入數(shù)據(jù)如何通過(guò)網(wǎng)絡(luò)產(chǎn)生輸出結(jié)果。
Python
class MyModel(nn.Module): # ... def forward(self, x): x = torch.relu(self.layer1(x)) x = self.layer2(x) return x
管理參數(shù)和模塊:
- 使用
.parameters()
或.named_parameters()
訪問(wèn)模型的所有可學(xué)習(xí)參數(shù)。 - 使用
add_module()
添加子模塊,并給它們命名以便于訪問(wèn)。 - 使用
register_buffer()
為模型注冊(cè)非可學(xué)習(xí)的緩沖區(qū)變量。
訓(xùn)練與評(píng)估模式切換:
- 使用
model.train()
將模型設(shè)置為訓(xùn)練模式,這會(huì)影響某些層的行為,如批量歸一化層和丟棄層。 - 使用
model.eval()
將模型設(shè)置為評(píng)估模式,此時(shí)會(huì)禁用這些依賴于訓(xùn)練階段的行為。
保存和加載模型狀態(tài):
- 調(diào)用
model.state_dict()
獲取模型權(quán)重和優(yōu)化器狀態(tài)的字典形式。 - 使用
torch.save()
和torch.load()
來(lái)保存和恢復(fù)整個(gè)模型或者僅其狀態(tài)字典。 - 通過(guò)
model.load_state_dict(state_dict)
加載先前保存的狀態(tài)字典到模型中。
此外,nn.Module
還提供了諸如移動(dòng)模型至不同設(shè)備(CPU或GPU)、零化梯度等實(shí)用功能,這些功能在整個(gè)模型訓(xùn)練過(guò)程中起到重要作用。
到此這篇關(guān)于PyTorch的nn.Module類的詳細(xì)介紹的文章就介紹到這了,更多相關(guān)PyTorch nn.Module類內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
- PyTorch模型創(chuàng)建與nn.Module構(gòu)建
- 關(guān)于PyTorch中nn.Module類的簡(jiǎn)介
- Pytorch參數(shù)注冊(cè)和nn.ModuleList nn.ModuleDict的問(wèn)題
- 人工智能學(xué)習(xí)PyTorch實(shí)現(xiàn)CNN卷積層及nn.Module類示例分析
- pytorch 中的重要模塊化接口nn.Module的使用
- 用pytorch的nn.Module構(gòu)造簡(jiǎn)單全鏈接層實(shí)例
- 淺析PyTorch中nn.Module的使用
- 對(duì)Pytorch中nn.ModuleList 和 nn.Sequential詳解
相關(guān)文章
Python實(shí)現(xiàn)TCP/IP協(xié)議下的端口轉(zhuǎn)發(fā)及重定向示例
這篇文章主要介紹了Python實(shí)現(xiàn)TCP/IP協(xié)議下的端口轉(zhuǎn)發(fā)及重定向示例,以一個(gè)webpy站點(diǎn)在本機(jī)的兩個(gè)端口雙向通信下演示,需要的朋友可以參考下2016-06-06Python利用pdfplumber庫(kù)提取pdf中的文字
pdfplumber是一個(gè)用于從PDF文檔中提取文本和表格數(shù)據(jù)的Python庫(kù),它可以幫助用戶輕松地從PDF文件中提取有用的信息,例如表格、文本、元數(shù)據(jù)等,本文將給大家介紹如何通過(guò)Python的pdfplumber庫(kù)提取pdf中的文字,需要的朋友可以參考下2023-05-05python在linux環(huán)境下安裝skimage的示例代碼
這篇文章主要介紹了python在linux環(huán)境下安裝skimage,本文通過(guò)實(shí)例代碼給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-10-10pygame編寫音樂播放器的實(shí)現(xiàn)代碼示例
這篇文章主要介紹了pygame編寫音樂播放器的實(shí)現(xiàn)代碼示例,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2019-11-11淺談keras 的抽象后端(from keras import backend as K)
這篇文章主要介紹了淺談keras 的抽象后端(from keras import backend as K),具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-06-06pyinstaller打包可執(zhí)行文件出現(xiàn)KeyError的問(wèn)題
這篇文章主要介紹了pyinstaller打包可執(zhí)行文件出現(xiàn)KeyError的問(wèn)題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-11-11python獲取局域網(wǎng)占帶寬最大3個(gè)ip的方法
這篇文章主要介紹了python獲取局域網(wǎng)占帶寬最大3個(gè)ip的方法,涉及Python解析URL參數(shù)的相關(guān)技巧,具有一定參考借鑒價(jià)值,需要的朋友可以參考下2015-07-07python3 小數(shù)位的四舍五入(用兩種方法解決round 遇5不進(jìn))
這篇文章主要介紹了python3 小數(shù)位的四舍五入(用兩種方法解決round 遇5不進(jìn)),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2019-04-04