PyTorch的nn.Module類的定義和使用介紹
在PyTorch中,nn.Module 類是構(gòu)建神經(jīng)網(wǎng)絡(luò)模型的基礎(chǔ)類,所有自定義的層、模塊或整個(gè)神經(jīng)網(wǎng)絡(luò)架構(gòu)都需要繼承自這個(gè)類。nn.Module 類提供了一系列屬性和方法用于管理網(wǎng)絡(luò)的結(jié)構(gòu)和訓(xùn)練過程中的計(jì)算。
1. PyTorch中nn.Module基類的定義
在PyTorch中,nn.Module 是所有神經(jīng)網(wǎng)絡(luò)模塊的基礎(chǔ)類。盡管這里不能提供完整的源代碼(因?yàn)樗婕按罅績(jī)?nèi)部邏輯和API細(xì)節(jié)),但我可以給出一個(gè)簡(jiǎn)化的 nn.Module 類的基本結(jié)構(gòu),并描述其關(guān)鍵方法:
Python
# 此處簡(jiǎn)化了 nn.Module 的定義,實(shí)際 PyTorch 源碼更為復(fù)雜
import torch
class nn.Module:
def __init__(self):
super().__init__()
# 存儲(chǔ)子模塊的字典
self._modules = dict()
# 參數(shù)和緩沖區(qū)的集合
self._parameters = OrderedDict()
self._buffers = OrderedDict()
def __setattr__(self, name, value):
# 特殊處理參數(shù)和子模塊的設(shè)置
if isinstance(value, nn.Parameter):
# 注冊(cè)參數(shù)到 _parameters 字典中
self.register_parameter(name, value)
elif isinstance(value, Module) and not isinstance(value, Container):
# 注冊(cè)子模塊到 _modules 字典中
self.add_module(name, value)
else:
# 對(duì)于普通屬性,執(zhí)行標(biāo)準(zhǔn)的 setattr 操作
object.__setattr__(self, name, value)
def add_module(self, name: str, module: 'Module') -> None:
r"""添加子模塊到當(dāng)前模塊"""
# 內(nèi)部實(shí)現(xiàn)細(xì)節(jié)省略...
self._modules[name] = module
def register_parameter(self, name: str, param: nn.Parameter) -> None:
r"""注冊(cè)一個(gè)新的參數(shù)"""
# 內(nèi)部實(shí)現(xiàn)細(xì)節(jié)省略...
self._parameters[name] = param
def parameters(self, recurse: bool = True) -> Iterator[nn.Parameter]:
r"""返回一個(gè)包含所有可學(xué)習(xí)參數(shù)的迭代器"""
# 內(nèi)部實(shí)現(xiàn)細(xì)節(jié)省略...
return iter(getattr(self, '_parameters', {}).values())
def forward(self, *input: Tensor) -> Tensor:
r"""定義前向傳播操作"""
raise NotImplementedError
# 還有許多其他的方法如:zero_grad、to、state_dict、load_state_dict 等等...
# 在自定義模型時(shí),繼承 nn.Module 并重寫 forward 方法
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.linear = nn.Linear(20, 30)
def forward(self, x):
return self.linear(x) 這段代碼定義了 PyTorch 中 nn.Module 類的基礎(chǔ)結(jié)構(gòu)。在實(shí)際的 PyTorch 源碼中,nn.Module 的實(shí)現(xiàn)更為復(fù)雜,但這里簡(jiǎn)化后的代碼片段展示了其核心部分。
class nn.Module::定義了一個(gè)名為nn.Module的類,它是所有神經(jīng)網(wǎng)絡(luò)模塊(如卷積層、全連接層、激活函數(shù)等)的基類。def __init__(self)::這是類的初始化方法,在創(chuàng)建一個(gè)nn.Module或其子類實(shí)例時(shí)會(huì)被自動(dòng)調(diào)用。這里的self參數(shù)代表將來創(chuàng)建出的實(shí)例自身。super().__init__():調(diào)用父類的構(gòu)造函數(shù),確?;惖某跏蓟壿嫷玫綀?zhí)行。在這里,雖然沒有顯示指定父類,但因?yàn)?nbsp;nn.Module是其他所有模塊的基類,所以實(shí)際上它是在調(diào)用自身的構(gòu)造函數(shù)來初始化內(nèi)部狀態(tài)。self._modules = dict():聲明并初始化一個(gè)字典_modules,用于存儲(chǔ)模型中的所有子模塊。每個(gè)子模塊是一個(gè)同樣繼承自nn.Module的對(duì)象,并通過名稱進(jìn)行索引。這樣可以方便地管理和組織復(fù)雜的層次化網(wǎng)絡(luò)結(jié)構(gòu)。self._parameters = OrderedDict():使用有序字典(OrderedDict)類型聲明和初始化一個(gè)變量_parameters,用來保存模型的所有可學(xué)習(xí)參數(shù)(權(quán)重和偏置等)。有序字典保證參數(shù)按添加順序存儲(chǔ),這對(duì)于一些依賴參數(shù)順序的操作(如加載預(yù)訓(xùn)練模型的權(quán)重)是必要的。self._buffers = OrderedDict():類似地,聲明并初始化另一個(gè)有序字典_buffers,用于存儲(chǔ)模型中的緩沖區(qū)(Buffer)。緩沖區(qū)通常是不參與梯度計(jì)算的變量,比如在 BatchNorm 層中存儲(chǔ)的均值和方差統(tǒng)計(jì)量。
總結(jié)來說,這段代碼為構(gòu)建神經(jīng)網(wǎng)絡(luò)模型提供了一個(gè)基礎(chǔ)框架,其中包含了對(duì)子模塊、參數(shù)和緩沖區(qū)的管理機(jī)制,這些基礎(chǔ)設(shè)施對(duì)于構(gòu)建、運(yùn)行和優(yōu)化深度學(xué)習(xí)模型至關(guān)重要。在自定義模塊時(shí),開發(fā)者通常會(huì)在此基礎(chǔ)上添加更多的層和功能,并重寫 forward 方法以定義前向傳播邏輯。
以上代碼僅展示了 nn.Module 類的部分核心功能,實(shí)際上 PyTorch 官方的實(shí)現(xiàn)會(huì)更加詳盡和復(fù)雜,包括更多的內(nèi)部機(jī)制來支持模塊化構(gòu)建深度學(xué)習(xí)模型。開發(fā)者通常需要繼承 nn.Module 類并重寫 forward 方法來實(shí)現(xiàn)自定義的神經(jīng)網(wǎng)絡(luò)層或整個(gè)網(wǎng)絡(luò)架構(gòu)。
2. nn.Module類中的關(guān)鍵屬性和方法
在PyTorch的nn.Module類中,有以下幾個(gè)關(guān)鍵屬性和方法:
__init__(self, ...): 這是每個(gè)派生自nn.Module的類都必須重載的方法,在該方法中定義并初始化模型的所有層和參數(shù)。.parameters():這是一個(gè)動(dòng)態(tài)生成器,用于獲取模型的所有可學(xué)習(xí)參數(shù)(權(quán)重和偏置等)。這些參數(shù)都是nn.Parameter類型的張量,在訓(xùn)練過程中可以自動(dòng)計(jì)算梯度。
示例:
Python
for param in model.parameters(): print(param)
.buffers():類似于.parameters(),但返回的是模塊內(nèi)定義的非可學(xué)習(xí)緩沖區(qū)變量,例如一些統(tǒng)計(jì)量或臨時(shí)存儲(chǔ)數(shù)據(jù)。
.named_parameters() 和 .named_buffers():與上面類似,但返回元組形式的迭代器,每個(gè)元素是一個(gè)包含名稱和對(duì)應(yīng)參數(shù)/緩沖區(qū)的元組,便于按名稱訪問特定參數(shù)。
.children() 和 .modules():這兩個(gè)方法分別返回一個(gè)包含當(dāng)前模塊所有直接子模塊的迭代器和包含所有層級(jí)子模塊(包括自身)的迭代器。
.state_dict():該方法返回一個(gè)字典,包含了模型的所有狀態(tài)信息(即參數(shù)和緩沖區(qū)),方便保存和恢復(fù)模型。
.train() 和 .eval():方法用于切換模型的運(yùn)行模式。在訓(xùn)練模式下,某些層如批次歸一化層會(huì)有不同的行為;而在評(píng)估模式下,通常會(huì)禁用dropout層并使用移動(dòng)平均統(tǒng)計(jì)量(對(duì)于批歸一化層)。
._parameters 和 ._buffers:這是內(nèi)部字典屬性,分別儲(chǔ)存了模型的所有參數(shù)和緩沖區(qū),雖然不推薦直接操作,但在自定義模塊時(shí)可能需要用到。
.to(device):將整個(gè)模型及其參數(shù)轉(zhuǎn)移到指定設(shè)備上,比如從CPU到GPU。
其他內(nèi)部維護(hù)的屬性,如 _forward_pre_hooks 和 _forward_hooks 用于實(shí)現(xiàn)向前傳播過程中的預(yù)處理和后處理鉤子,以及 _backward_hooks 用于反向傳播過程中的鉤子,這些通常在高級(jí)功能開發(fā)時(shí)使用。
forward(self, input):定義模型如何處理輸入數(shù)據(jù)并生成輸出,這是構(gòu)建神經(jīng)網(wǎng)絡(luò)的核心部分,每次調(diào)用模型實(shí)例都會(huì)執(zhí)行 forward 函數(shù)。
add_module(name, module):將一個(gè)子模塊添加到當(dāng)前模塊,并通過給定的名字引用它。
register_parameter(name, param):注冊(cè)一個(gè)新的參數(shù)到模塊中。
zero_grad():將模塊及其所有子模塊的參數(shù)梯度設(shè)置為零,通常在優(yōu)化器更新前調(diào)用。
train(mode=True) 和 eval():切換模型的工作模式,在訓(xùn)練模式下會(huì)啟用批次歸一化層和丟棄層等依賴于訓(xùn)練/預(yù)測(cè)階段的行為,在評(píng)估模式下則關(guān)閉這些行為。
state_dict() 和 load_state_dict(state_dict):用于保存和加載模型的狀態(tài)字典,其中包括模型的權(quán)重和配置信息,便于模型持久化和遷移。
其他與模型保存和恢復(fù)相關(guān)的方法,例如 save(filename)、load(filename) 等。
請(qǐng)注意,具體的屬性和方法可能會(huì)隨著PyTorch版本的更新而有所增減或改進(jìn)。
3. nn.Module子類的定義和使用
在PyTorch中,nn.Module 類扮演著核心角色,它是構(gòu)建任何自定義神經(jīng)網(wǎng)絡(luò)層、復(fù)雜模塊或完整神經(jīng)網(wǎng)絡(luò)架構(gòu)的基礎(chǔ)構(gòu)建塊。通過繼承 nn.Module 并在其子類中定義模型結(jié)構(gòu)和前向傳播邏輯(forward() 方法),開發(fā)者能夠方便地搭建并訓(xùn)練深度學(xué)習(xí)模型。
具體來說,在自定義一個(gè) nn.Module 子類時(shí),通常會(huì)執(zhí)行以下操作:
初始化 (__init__):在類的初始化方法中定義并實(shí)例化所有需要的層、參數(shù)和其他組件。
Python
class MyModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MyModel, self).__init__()
self.layer1 = nn.Linear(input_size, hidden_size)
self.layer2 = nn.Linear(hidden_size, output_size)前向傳播 (forward):實(shí)現(xiàn)前向傳播函數(shù)來描述輸入數(shù)據(jù)如何通過網(wǎng)絡(luò)產(chǎn)生輸出結(jié)果。
Python
class MyModel(nn.Module):
# ...
def forward(self, x):
x = torch.relu(self.layer1(x))
x = self.layer2(x)
return x管理參數(shù)和模塊:
- 使用
.parameters()或.named_parameters()訪問模型的所有可學(xué)習(xí)參數(shù)。 - 使用
add_module()添加子模塊,并給它們命名以便于訪問。 - 使用
register_buffer()為模型注冊(cè)非可學(xué)習(xí)的緩沖區(qū)變量。
訓(xùn)練與評(píng)估模式切換:
- 使用
model.train()將模型設(shè)置為訓(xùn)練模式,這會(huì)影響某些層的行為,如批量歸一化層和丟棄層。 - 使用
model.eval()將模型設(shè)置為評(píng)估模式,此時(shí)會(huì)禁用這些依賴于訓(xùn)練階段的行為。
保存和加載模型狀態(tài):
- 調(diào)用
model.state_dict()獲取模型權(quán)重和優(yōu)化器狀態(tài)的字典形式。 - 使用
torch.save()和torch.load()來保存和恢復(fù)整個(gè)模型或者僅其狀態(tài)字典。 - 通過
model.load_state_dict(state_dict)加載先前保存的狀態(tài)字典到模型中。
此外,nn.Module 還提供了諸如移動(dòng)模型至不同設(shè)備(CPU或GPU)、零化梯度等實(shí)用功能,這些功能在整個(gè)模型訓(xùn)練過程中起到重要作用。
到此這篇關(guān)于PyTorch的nn.Module類的詳細(xì)介紹的文章就介紹到這了,更多相關(guān)PyTorch nn.Module類內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
- PyTorch模型創(chuàng)建與nn.Module構(gòu)建
- 關(guān)于PyTorch中nn.Module類的簡(jiǎn)介
- Pytorch參數(shù)注冊(cè)和nn.ModuleList nn.ModuleDict的問題
- 人工智能學(xué)習(xí)PyTorch實(shí)現(xiàn)CNN卷積層及nn.Module類示例分析
- pytorch 中的重要模塊化接口nn.Module的使用
- 用pytorch的nn.Module構(gòu)造簡(jiǎn)單全鏈接層實(shí)例
- 淺析PyTorch中nn.Module的使用
- 對(duì)Pytorch中nn.ModuleList 和 nn.Sequential詳解
相關(guān)文章
Python實(shí)現(xiàn)TCP/IP協(xié)議下的端口轉(zhuǎn)發(fā)及重定向示例
這篇文章主要介紹了Python實(shí)現(xiàn)TCP/IP協(xié)議下的端口轉(zhuǎn)發(fā)及重定向示例,以一個(gè)webpy站點(diǎn)在本機(jī)的兩個(gè)端口雙向通信下演示,需要的朋友可以參考下2016-06-06
python在linux環(huán)境下安裝skimage的示例代碼
這篇文章主要介紹了python在linux環(huán)境下安裝skimage,本文通過實(shí)例代碼給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-10-10
pygame編寫音樂播放器的實(shí)現(xiàn)代碼示例
這篇文章主要介紹了pygame編寫音樂播放器的實(shí)現(xiàn)代碼示例,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-11-11
淺談keras 的抽象后端(from keras import backend as K)
這篇文章主要介紹了淺談keras 的抽象后端(from keras import backend as K),具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-06-06
pyinstaller打包可執(zhí)行文件出現(xiàn)KeyError的問題
這篇文章主要介紹了pyinstaller打包可執(zhí)行文件出現(xiàn)KeyError的問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-11-11
python獲取局域網(wǎng)占帶寬最大3個(gè)ip的方法
這篇文章主要介紹了python獲取局域網(wǎng)占帶寬最大3個(gè)ip的方法,涉及Python解析URL參數(shù)的相關(guān)技巧,具有一定參考借鑒價(jià)值,需要的朋友可以參考下2015-07-07
python3 小數(shù)位的四舍五入(用兩種方法解決round 遇5不進(jìn))
這篇文章主要介紹了python3 小數(shù)位的四舍五入(用兩種方法解決round 遇5不進(jìn)),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-04-04

