PyTorch中nn.Module示例詳解
直接print(dir(nn.Module)),得到如下內(nèi)容:

一、模型結(jié)構(gòu)與參數(shù)
parameters()- 用途:返回模塊的所有可訓(xùn)練參數(shù)(如權(quán)重、偏置)。
- 示例:
for param in model.parameters(): print(param.shape)
named_parameters()- 用途:返回帶名稱的參數(shù)迭代器,便于調(diào)試和訪問特定參數(shù)。
- 示例:
for name, param in model.named_parameters(): if 'weight' in name: print(name, param.shape)
children()- 用途:返回直接子模塊的迭代器。
- 示例:
for child in model.children(): print(type(child))
modules()- 用途:遞歸返回所有子模塊(包括自身)。
- 示例:
for module in model.modules(): if isinstance(module, nn.Conv2d): print(module.kernel_size)
二、模型狀態(tài)與模式
train()和eval()- 用途:切換訓(xùn)練/推理模式(影響Dropout、BatchNorm等層)。
- 示例:
model.train() # 訓(xùn)練模式 model.eval() # 推理模式
training- 用途:布爾屬性,指示當(dāng)前模式(
True為訓(xùn)練,False為推理)。 - 示例:
print(model.training) # 輸出:True/False
- 用途:布爾屬性,指示當(dāng)前模式(
三、模型保存與加載
state_dict()- 用途:返回包含模型所有參數(shù)的字典(
OrderedDict)。 - 示例:
torch.save(model.state_dict(), 'model.pth')
- 用途:返回包含模型所有參數(shù)的字典(
load_state_dict()- 用途:從字典加載模型參數(shù)。
- 示例:
model.load_state_dict(torch.load('model.pth'))
四、設(shè)備與數(shù)據(jù)類型
to()- 用途:將模型移動到指定設(shè)備(如GPU)或轉(zhuǎn)換數(shù)據(jù)類型。
- 示例:
model.to('cuda') # 移動到GPU model.to(torch.float16) # 轉(zhuǎn)換為半精度
cpu()和cuda()- 用途:快捷方法,分別將模型移動到CPU或GPU。
- 示例:
model.cuda() # 等價于 model.to('cuda')
五、前向傳播與計算
forward()- 用途:定義模型的前向傳播邏輯(需在自定義模塊中重寫)。
- 示例:
class MyModel(nn.Module): def forward(self, x): return self.layer(x)
__call__()- 用途:調(diào)用模型實例時觸發(fā)(內(nèi)部調(diào)用
forward(),支持鉤子函數(shù))。 - 示例:
output = model(x) # 等價于 output = model.forward(x)
- 用途:調(diào)用模型實例時觸發(fā)(內(nèi)部調(diào)用
六、參數(shù)初始化與優(yōu)化
zero_grad()- 用途:清空所有參數(shù)的梯度(通常在每個訓(xùn)練步驟前調(diào)用)。
- 示例:
optimizer.zero_grad() # 等價于 model.zero_grad()
requires_grad_()- 用途:設(shè)置參數(shù)是否需要梯度(用于凍結(jié)部分模型)。
- 示例:
for param in model.parameters(): param.requires_grad = False # 凍結(jié)所有參數(shù)
七、調(diào)試與信息
extra_repr()- 用途:自定義模塊打印信息(需在子類中重寫)。
- 示例:
class MyModel(nn.Module): def extra_repr(self): return f"hidden_size={self.hidden_size}"
dump_patches()- 用途:打印模型的補丁信息(用于調(diào)試版本差異)。
八、其他實用方法
apply()- 用途:遞歸應(yīng)用函數(shù)到所有子模塊(如初始化權(quán)重)。
- 示例:
def init_weights(m): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) model.apply(init_weights)
register_forward_hook()- 用途:注冊前向傳播鉤子(用于捕獲中間輸出,調(diào)試或特征提?。?/li>
總結(jié)
日常使用中,最頻繁的方法包括:
- 模型構(gòu)建:
parameters(),children(),modules() - 訓(xùn)練與推理:
train(),eval(),zero_grad(),forward() - 保存與加載:
state_dict(),load_state_dict() - 設(shè)備管理:
to(),cuda(),cpu()
其他方法根據(jù)具體需求選擇使用,例如鉤子函數(shù)用于高級調(diào)試,apply() 用于統(tǒng)一初始化。
與nn.Sequential對比:
1. 繼承關(guān)系與基礎(chǔ)屬性
nn.Module- 是所有神經(jīng)網(wǎng)絡(luò)模塊的基類,提供最基礎(chǔ)的功能(如參數(shù)管理、鉤子機制)。
- 包含核心屬性:
_parameters,_modules,_buffers等。
nn.Sequential- 是
nn.Module的子類,繼承了所有基礎(chǔ)功能。 - 額外添加了與順序執(zhí)行相關(guān)的屬性(如
__getitem__、append)。
- 是
2. 核心差異對比
| 功能類別 | nn.Module | nn.Sequential |
|---|---|---|
| 模塊構(gòu)建 | 需要手動實現(xiàn) forward 方法 | 自動按順序執(zhí)行子模塊,無需定義 forward |
| 子模塊訪問 | 通過屬性名(如 self.conv1) | 通過索引或命名(如 model[0]) |
| 動態(tài)修改 | 需手動管理子模塊 | 支持 append、extend、insert 等操作 |
| 適用場景 | 復(fù)雜網(wǎng)絡(luò)結(jié)構(gòu)(如ResNet、U-Net) | 簡單順序結(jié)構(gòu)(如LeNet卷積部分) |
3. 具體方法對比
3.1 公共方法(兩者都有)
# 模型參數(shù)與結(jié)構(gòu) ['parameters', 'named_parameters', 'children', 'modules', 'named_children', 'named_modules'] # 模型狀態(tài) ['train', 'eval', 'training', 'zero_grad', 'requires_grad_'] # 設(shè)備與數(shù)據(jù)類型 ['to', 'cpu', 'cuda', 'float', 'double', 'half', 'bfloat16'] # 保存與加載 ['state_dict', 'load_state_dict'] # 鉤子機制 ['register_forward_hook', 'register_backward_hook']
3.2nn.Sequential特有的方法
# 列表操作(動態(tài)修改模塊順序) ['__getitem__', '__setitem__', '__delitem__', '__len__', 'append', 'extend', 'insert', 'pop'] # 索引相關(guān) ['_get_item_by_idx']
3.3nn.Module特有的方法
# 自定義實現(xiàn) ['forward', 'extra_repr'] # 高級管理 ['add_module', 'register_module', 'register_parameter', 'register_buffer']
4. 示例對比
4.1 創(chuàng)建模型
# nn.Module(需自定義 forward)
class CustomModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 64, 3)
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(self.conv(x))
# nn.Sequential(自動按順序執(zhí)行)
seq_model = nn.Sequential(
nn.Conv2d(3, 64, 3),
nn.ReLU()
)4.2 訪問子模塊
# nn.Module custom_model.conv # 通過屬性名訪問 # nn.Sequential seq_model[0] # 通過索引訪問 seq_model.append(nn.MaxPool2d(2)) # 動態(tài)添加模塊
5. 總結(jié)
| 特性 | nn.Module | nn.Sequential |
|---|---|---|
| 靈活性 | 高(自定義任意邏輯) | 低(僅支持順序執(zhí)行) |
| 代碼復(fù)雜度 | 較高(需手動實現(xiàn) forward) | 低(自動處理前向傳播) |
| 動態(tài)修改 | 不支持直接操作(需手動管理) | 支持 append、insert 等操作 |
| 適用場景 | 復(fù)雜網(wǎng)絡(luò)、分支結(jié)構(gòu)、自定義操作 | 簡單堆疊模塊(如CNN的卷積部分) |
建議:
- 對于簡單的順序網(wǎng)絡(luò),優(yōu)先使用
nn.Sequential以減少代碼量。 - 對于包含復(fù)雜邏輯(如殘差連接、多輸入輸出)的網(wǎng)絡(luò),使用
nn.Module自定義實現(xiàn)。
到此這篇關(guān)于PyTorch中nn.Module詳解的文章就介紹到這了,更多相關(guān)PyTorch nn.Module內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
django將網(wǎng)絡(luò)中的圖片,保存成model中的ImageField的實例
今天小編就為大家分享一篇django將網(wǎng)絡(luò)中的圖片,保存成model中的ImageField的實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-08-08
python實現(xiàn)通過flask和前端進(jìn)行數(shù)據(jù)收發(fā)
今天小編就為大家分享一篇python實現(xiàn)通過flask和前端進(jìn)行數(shù)據(jù)收發(fā),具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-08-08
使用python telnetlib批量備份交換機配置的方法
今天小編就為大家分享一篇使用python telnetlib批量備份交換機配置的方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-07-07
DjangoUeditor圖片不顯示img的src沒有域名問題
在使用DjangoUeditor過程中,可能遇到圖片上傳后不顯示問題,解決辦法是修改源碼view.py,加入代碼使得保存的圖片URL帶有協(xié)議和域名,具體做法是在保存圖片代碼中添加request.scheme獲取協(xié)議,request.META['HTTP_HOST']獲取域名2024-09-09
python 裝飾器帶參數(shù)和不帶參數(shù)步驟詳解
裝飾器是Python語言中一種特殊的語法,用于在不修改原函數(shù)代碼的情況下,為函數(shù)添加額外的功能或修改函數(shù)的行為,這篇文章主要介紹了python裝飾器帶參數(shù)和不帶參數(shù)的相關(guān)知識,需要的朋友可以參考下2024-05-05
使用Python3 poplib模塊刪除服務(wù)器多天前的郵件實現(xiàn)代碼
這篇文章主要介紹了使用Python3 poplib模塊刪除多天前的郵件的實現(xiàn)代碼,代碼簡單易懂,非常不錯,對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友可以參考下2020-04-04

