欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

關(guān)于PyTorch中nn.Module類的簡介

 更新時(shí)間:2023年02月20日 08:41:18   作者:fengbingchun  
這篇文章主要介紹了關(guān)于PyTorch中nn.Module類的簡介,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教

PyTorch nn.Module類的簡介

torch.nn.Module類是所有神經(jīng)網(wǎng)絡(luò)模塊(modules)的基類,它的實(shí)現(xiàn)在torch/nn/modules/module.py中。你的模型也應(yīng)該繼承這個(gè)類,主要重載__init__、forward和extra_repr函數(shù)。Modules還可以包含其它Modules,從而可以將它們嵌套在樹結(jié)構(gòu)中。

只要在自己的類中定義了forward函數(shù),backward函數(shù)就會(huì)利用Autograd被自動(dòng)實(shí)現(xiàn)。只要實(shí)例化一個(gè)對(duì)象并傳入對(duì)應(yīng)的參數(shù)就可以自動(dòng)調(diào)用forward函數(shù)。因?yàn)榇藭r(shí)會(huì)調(diào)用對(duì)象的__call__方法,而nn.Module類中的__call__方法會(huì)調(diào)用forward函數(shù)。

nn.Module類中函數(shù)介紹:

  • __init__:初始化內(nèi)部module狀態(tài)。
  • register_buffer:向module添加buffer,不作為模型參數(shù),可作為module狀態(tài)的一部分。默認(rèn)情況下,buffer是持久(persistent)的,將與參數(shù)一起保存。buffer是否persistent的區(qū)別在于這個(gè)buffer是否被放入self.state_dict()中被保存下來。
  • register_parameter:向module添加參數(shù)。
  • add_module:添加一個(gè)submodule(children)到當(dāng)前module中。
  • apply:將fn遞歸應(yīng)用于每個(gè)submodule(children),典型用途為初始化模型參數(shù)。
  • cuda:將所有模型參數(shù)和buffers轉(zhuǎn)移到GPU上。
  • xpu:將所有模型參數(shù)和buffers轉(zhuǎn)移到XPU上。
  • cpu:將所有模型參數(shù)和buffers轉(zhuǎn)移到CPU上。
  • type:將所有參數(shù)和buffers轉(zhuǎn)換為所需的類型。
  • float:將所有浮點(diǎn)參數(shù)和buffers轉(zhuǎn)換為float32數(shù)據(jù)類型。
  • double:將所有浮點(diǎn)參數(shù)和buffers轉(zhuǎn)換為double數(shù)據(jù)類型。
  • half:將所有浮點(diǎn)參數(shù)和buffers轉(zhuǎn)換為float16數(shù)據(jù)類型。
  • bfloat16:將所有浮點(diǎn)參數(shù)和buffers轉(zhuǎn)換為bfloat16數(shù)據(jù)類型。
  • to:將參數(shù)和buffers轉(zhuǎn)換為指定的數(shù)據(jù)類型或轉(zhuǎn)換到指定的設(shè)備上。
  • register_backward_hook:在module中注冊一個(gè)反向鉤子。不推薦使用。
  • register_full_backward_hook:在module中注冊一個(gè)反向鉤子。每次計(jì)算梯度時(shí)都會(huì)調(diào)用此鉤子。使用此鉤子時(shí)不允許就地(in place)修改輸入或輸出,否則會(huì)觸發(fā)error。
  • register_forward_pre_hook:在module中注冊前向pre-hook。每次調(diào)用forward之前都會(huì)調(diào)用此鉤子。
  • register_forward_hook:在module中注冊一個(gè)前向鉤子。每次forward計(jì)算輸出后都會(huì)調(diào)用此鉤子。
  • state_dict:返回包含了module的整個(gè)狀態(tài)的字典。其中keys是對(duì)應(yīng)的參數(shù)和buffer名稱。
  • load_state_dict:將參數(shù)和buffers從state_dict復(fù)制到module及其后代(descendants)中。
  • parameters:返回module的參數(shù)的迭代器。
  • named_parameters:返回module的參數(shù)的迭代器,產(chǎn)生(yield)參數(shù)的名稱以及參數(shù)本身。不會(huì)返回重復(fù)的parameter。
  • buffers:返回module的buffers的迭代器。
  • named_buffers:返回module的buffers的迭代器,產(chǎn)生(yield)buffer的名稱以及buffer本身。不會(huì)返回重復(fù)的buffer。
  • children:返回直接子module的迭代器。
  • named_children:返回直接子module的迭代器,產(chǎn)生(yield)子module的名稱以及子module本身。不會(huì)返回重復(fù)的children。
  • modules:返回網(wǎng)絡(luò)中所有modules的迭代器。
  • named_modules:返回網(wǎng)絡(luò)中所有modules的迭代器,產(chǎn)生(yield)module的名稱以及module本身。不會(huì)返回重復(fù)的module。
  • train:將module設(shè)置為訓(xùn)練模式。這僅對(duì)某些module起作用。module.py實(shí)現(xiàn)中會(huì)修改self.training并通過self.children()來調(diào)整所有submodule的狀態(tài)。
  • eval:將module設(shè)置為評(píng)估模式。這僅對(duì)某些module起作用。module.py實(shí)現(xiàn)中直接調(diào)用train(False)。
  • requires_grad_:更改autograd是否應(yīng)記錄對(duì)此module中參數(shù)的操作。此方法就地(in place)設(shè)置參數(shù)的requires_grad屬性。
  • zero_grad:將所有模型參數(shù)的梯度設(shè)置為零。
  • extra_repr:設(shè)置module的額外表示。你應(yīng)該在自己的modules中重新實(shí)現(xiàn)此方法。

測試代碼如下:

import torch
import torch.nn as nn
import torch.nn.functional as F # nn.functional.py中存放激活函數(shù)等的實(shí)現(xiàn)
?
@torch.no_grad()
def init_weights(m):
? ? print("xxxx:", m)
? ? if type(m) == nn.Linear:
? ? ? ? ?m.weight.fill_(1.0)
? ? ? ? ?print("yyyy:", m.weight)
?
class Model(nn.Module):
? ? def __init__(self):
? ? ? ? # 在實(shí)現(xiàn)自己的__init__函數(shù)時(shí),為了正確初始化自定義的神經(jīng)網(wǎng)絡(luò)模塊,一定要先調(diào)用super().__init__
? ? ? ? super(Model, self).__init__()
? ? ? ? self.conv1 = nn.Conv2d(1, 20, 5) # submodule(child module)
? ? ? ? self.conv2 = nn.Conv2d(20, 20, 5)
? ? ? ? self.add_module("conv3", nn.Conv2d(10, 40, 5)) # 添加一個(gè)submodule到當(dāng)前module,等價(jià)于self.conv3 = nn.Conv2d(10, 40, 5)
? ? ? ? self.register_buffer("buffer", torch.randn([2,3])) # 給module添加一個(gè)presistent(持久的) buffer
? ? ? ? self.param1 = nn.Parameter(torch.rand([1])) # module參數(shù)的tensor
? ? ? ? self.register_parameter("param2", nn.Parameter(torch.rand([1]))) # 向module添加參數(shù)
?
? ? ? ? # nn.Sequential: 順序容器,module將按照它們在構(gòu)造函數(shù)中傳遞的順序添加,它允許將整個(gè)容器視為單個(gè)module
? ? ? ? self.feature = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
? ? ? ? self.feature.apply(init_weights) # 將fn遞歸應(yīng)用于每個(gè)submodule,典型用途為初始化模型參數(shù)
? ? ? ? self.feature.to(torch.double) # 將參數(shù)數(shù)據(jù)類型轉(zhuǎn)換為double
? ? ? ? cpu = torch.device("cpu")
? ? ? ? self.feature.to(cpu) # 將參數(shù)數(shù)據(jù)轉(zhuǎn)換到cpu設(shè)備上
?
? ? def forward(self, x):
? ? ? ?x = F.relu(self.conv1(x))
? ? ? ?return F.relu(self.conv2(x))
?
model = Model()
print("## Model:", model)
?
model.cpu() # 將所有模型參數(shù)和buffers移動(dòng)到CPU上
model.float() # 將所有浮點(diǎn)參數(shù)和buffers轉(zhuǎn)換為float數(shù)據(jù)類型
model.zero_grad() # 將所有模型參數(shù)的梯度設(shè)置為零
?
# state_dict:返回一個(gè)字典,保存著module的所有狀態(tài),參數(shù)和persistent buffers都會(huì)包含在字典中,字典的key就是參數(shù)和buffer的names
print("## state_dict:", model.state_dict().keys())
?
for name, parameters in model.named_parameters(): # 返回module的參數(shù)(weight and bias)的迭代器,產(chǎn)生(yield)參數(shù)的名稱以及參數(shù)本身
? ? print(f"## named_parameters: name: {name}; parameters size: {parameters.size()}")
?
for name, buffers in model.named_buffers(): # 返回module的buffers的迭代器,產(chǎn)生(yield)buffer的名稱以及buffer本身
? ? print(f"## named_buffers: name: {name}; buffers size: {buffers.size()}")
?
# 注:children和modules中重復(fù)的module只被返回一次
for children in model.children(): # 返回當(dāng)前module的child module(submodule)的迭代器
? ? print("## children:", children)
?
for name, children in model.named_children(): # 返回直接submodule的迭代器,產(chǎn)生(yield) submodule的名稱以及submodule本身
? ? print(f"## named_children: name: {name}; children: {children}")
?
for modules in model.modules(): # 返回當(dāng)前模型所有module的迭代器,注意與children的區(qū)別
? ? print("## modules:", modules)
?
for name, modules in model.named_modules(): # 返回網(wǎng)絡(luò)中所有modules的迭代器,產(chǎn)生(yield)module的名稱以及module本身,注意與named_children的區(qū)別
? ? print(f"## named_modules: name: {name}; module: {modules}")
?
model.train() # 將module設(shè)置為訓(xùn)練模式
model.eval() # 將module設(shè)置為評(píng)估模式
?
print("test finish")

GitHub:https://github.com/fengbingchun/PyTorch_Test

PyTorch中nn.Module理解

nn.Module是Pytorch封裝的一個(gè)類,是搭建神經(jīng)網(wǎng)絡(luò)時(shí)需要繼承的父類:

import torch
import torch.nn as nn

# 括號(hào)中加入nn.Module(父類)。Test2變成子類,繼承父類(nn.Module)的所有特性。
class Test2(nn.Module):  
    def __init__(self):  # Test2類定義初始化方法
       super(Test2, self).__init__()  # 父類初始化
       self.M = nn.Parameter(torch.ones(10))
        
    def weightInit(self):
        print('Testing')

    def forward(self, n):
        # print(2 * n)
        print(self.M * n)
        self.weightInit()

# 調(diào)用方法
network = Test2()
network(2)  # 2賦值給forward(self, n)中的n。
……省略一部分代碼……
# 因?yàn)門est2是nn.Module的子類,所以也可以執(zhí)行父類中的方法。如:
model_dict = network.state_dict()  # 調(diào)用父類中的方法state_dict(),將Test2中訓(xùn)練參數(shù)賦值model_dict。
for k, v in model_dict.items():  # 查看自己網(wǎng)絡(luò)參數(shù)各層名稱、數(shù)值
	print(k)  # 輸出網(wǎng)絡(luò)參數(shù)名字
    # print(v)  # 輸出網(wǎng)絡(luò)參數(shù)數(shù)值

繼承nn.Module的子類程序是從forward()方法開始執(zhí)行的,如果要想執(zhí)行其他方法,必須把它放在forward()方法中。這一點(diǎn)與python中繼承有稍許的不同。

總結(jié)

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • Python3+Selenium+Chrome實(shí)現(xiàn)自動(dòng)填寫WPS表單

    Python3+Selenium+Chrome實(shí)現(xiàn)自動(dòng)填寫WPS表單

    本文通過python3、第三方python庫Selenium和谷歌瀏覽器Chrome,完成WPS表單的自動(dòng)填寫,通過實(shí)例代碼給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2020-02-02
  • python實(shí)現(xiàn)修改xml文件內(nèi)容

    python實(shí)現(xiàn)修改xml文件內(nèi)容

    這篇文章主要介紹了python實(shí)現(xiàn)修改xml文件內(nèi)容,XML 指可擴(kuò)展標(biāo)記語言,是一種標(biāo)記語言,是從標(biāo)準(zhǔn)通用標(biāo)記語言(SGML)中簡化修改出來的
    2022-07-07
  • 簡單介紹Python中的decode()方法的使用

    簡單介紹Python中的decode()方法的使用

    這篇文章主要介紹了簡單介紹Python中的decode()方法的使用,是Python入門學(xué)習(xí)當(dāng)中必須掌握的基礎(chǔ)知識(shí),需要的朋友可以參考下
    2015-05-05
  • 詳解如何在Python中使用Jinja2進(jìn)行模板渲染

    詳解如何在Python中使用Jinja2進(jìn)行模板渲染

    Jinja2 是一個(gè)現(xiàn)代的、設(shè)計(jì)精美的 Python 模板引擎,它使用類似于 Django 的模板語言來渲染文本文件,下面我將通過幾個(gè)例子展示如何在 Python 中使用 Jinja2 進(jìn)行模板渲染,文中有詳細(xì)的代碼供大家參考,需要的朋友可以參考下
    2024-08-08
  • python的turtle庫使用詳解

    python的turtle庫使用詳解

    在本篇文章里小編給大家分享了關(guān)于python的turtle庫相關(guān)知識(shí)點(diǎn)以及使用方法,需要的朋友們跟著學(xué)習(xí)下。
    2019-05-05
  • pyspark操作hive分區(qū)表及.gz.parquet和part-00000文件壓縮問題

    pyspark操作hive分區(qū)表及.gz.parquet和part-00000文件壓縮問題

    這篇文章主要介紹了pyspark操作hive分區(qū)表及.gz.parquet和part-00000文件壓縮問題,針對(duì)問題整理了spark操作hive表的幾種方式,需要的朋友可以參考下
    2021-08-08
  • Python字典推導(dǎo)式將cookie字符串轉(zhuǎn)化為字典解析

    Python字典推導(dǎo)式將cookie字符串轉(zhuǎn)化為字典解析

    這篇文章主要介紹了Python字典推導(dǎo)式將cookie字符串轉(zhuǎn)化為字典解析,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下
    2019-08-08
  • 教你Pycharm安裝使用requests第三方庫的詳細(xì)教程

    教你Pycharm安裝使用requests第三方庫的詳細(xì)教程

    PyCharm安裝第三方庫是十分方便的,無需pip或其他工具,平臺(tái)就自帶了這個(gè)功能而且操作十分簡便,今天通過本文帶領(lǐng)大家學(xué)習(xí)Pycharm安裝使用requests第三方庫的詳細(xì)教程,感興趣的朋友一起看看吧
    2021-07-07
  • python 類中函數(shù)名前后加下劃線的具體使用

    python 類中函數(shù)名前后加下劃線的具體使用

    在Python編程語言中,函數(shù)名前后有下劃線是一種常見的命名約定,,被廣泛應(yīng)用于類中的函數(shù),本文將介紹下劃線命名風(fēng)格的由來、使用場景以及如何正確應(yīng)用它,感興趣的可以了解一下
    2024-01-01
  • Python讀寫配置文件的方法

    Python讀寫配置文件的方法

    這篇文章主要介紹了Python讀寫配置文件的方法,涉及ConfigParser模塊的操作技巧,需要的朋友可以參考下
    2015-06-06

最新評(píng)論