欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Pytorch模型中的parameter與buffer用法

 更新時間:2021年06月01日 11:24:27   作者:CV/NLP大蝦  
這篇文章主要介紹了Pytorch模型中的parameter與buffer用法,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教

Parameter 和 buffer

If you have parameters in your model, which should be saved and restored in the state_dict, but not trained by the optimizer, you should register them as buffers.Buffers won't be returned in model.parameters(), so that the optimizer won't have a change to update them.

模型中需要保存下來的參數(shù)包括兩種

一種是反向傳播需要被optimizer更新的,稱之為 parameter

一種是反向傳播不需要被optimizer更新,稱之為 buffer

第一種參數(shù)我們可以通過 model.parameters() 返回;第二種參數(shù)我們可以通過 model.buffers() 返回。因為我們的模型保存的是 state_dict 返回的 OrderDict,所以這兩種參數(shù)不僅要滿足是否需要被更新的要求,還需要被保存到OrderDict。

那么現(xiàn)在的問題是這兩種參數(shù)如何創(chuàng)建呢,創(chuàng)建好了如何保存到OrderDict呢?

第一種參數(shù)有兩種方式

我們可以直接將模型的成員變量(http://self.xxx) 通過nn.Parameter() 創(chuàng)建,會自動注冊到parameters中,可以通過model.parameters() 返回,并且這樣創(chuàng)建的參數(shù)會自動保存到OrderDict中去;

通過nn.Parameter() 創(chuàng)建普通Parameter對象,不作為模型的成員變量,然后將Parameter對象通過register_parameter()進(jìn)行注冊,可以通model.parameters() 返回,注冊后的參數(shù)也會自動保存到OrderDict中去;

第二種參數(shù)我們需要創(chuàng)建tensor

然后將tensor通過register_buffer()進(jìn)行注冊,可以通model.buffers() 返回,注冊完后參數(shù)也會自動保存到OrderDict中去。

Pytorch中Module,Parameter和Buffer區(qū)別

下文都將torch.nn簡寫成nn

Module: 就是我們常用的torch.nn.Module類,你定義的所有網(wǎng)絡(luò)結(jié)構(gòu)都必須繼承這個類。

Buffer: buffer和parameter相對,就是指那些不需要參與反向傳播的參數(shù)

示例如下:

class MyModel(nn.Module):
 def __init__(self):
  super(MyModel, self).__init__()
  self.my_tensor = torch.randn(1) # 參數(shù)直接作為模型類成員變量
  self.register_buffer('my_buffer', torch.randn(1)) # 參數(shù)注冊為 buffer
  self.my_param = nn.Parameter(torch.randn(1))
 def forward(self, x):
  return x 

model = MyModel()
print(model.state_dict())
>>>OrderedDict([('my_param', tensor([1.2357])), ('my_buffer', tensor([-0.9982]))])
Parameter: 是nn.parameter.Paramter,也就是組成Module的參數(shù)。例如一個nn.Linear通常由weight和bias參數(shù)組成。它的特點是默認(rèn)requires_grad=True,也就是說訓(xùn)練過程中需要反向傳播的,就需要使用這個
import torch.nn as nn
fc = nn.Linear(2,2)

# 讀取參數(shù)的方式一
fc._parameters
>>> OrderedDict([('weight', Parameter containing:
              tensor([[0.4142, 0.0424],
                      [0.3940, 0.0796]], requires_grad=True)),
             ('bias', Parameter containing:
              tensor([-0.2885,  0.5825], requires_grad=True))])
     
# 讀取參數(shù)的方式二(推薦這種)
for n, p in fc.named_parameters():
 print(n,p)
>>>weight Parameter containing:
tensor([[0.4142, 0.0424],
        [0.3940, 0.0796]], requires_grad=True)
bias Parameter containing:
tensor([-0.2885,  0.5825], requires_grad=True)

# 讀取參數(shù)的方式三
for p in fc.parameters():
 print(p)
>>>Parameter containing:
tensor([[0.4142, 0.0424],
        [0.3940, 0.0796]], requires_grad=True)
Parameter containing:
tensor([-0.2885,  0.5825], requires_grad=True)

通過上面的例子可以看到,nn.parameter.Paramter的requires_grad屬性值默認(rèn)為True。另外上面例子給出了三種讀取parameter的方法,推薦使用后面兩種,因為是以迭代生成器的方式來讀取,第一種方式是一股腦的把參數(shù)全丟給你,要是模型很大,估計你的電腦會吃不消。

另外需要介紹的是_parameters是nn.Module在__init__()函數(shù)中就定義了的一個OrderDict類,這個可以通過看下面給出的部分源碼看到,可以看到還初始化了很多其他東西,其實原理都大同小異,你理解了這個之后,其他的也是同樣的道理。

class Module(object):
 ...
    def __init__(self):
        self._backend = thnn_backend
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._state_dict_hooks = OrderedDict()
        self._load_state_dict_pre_hooks = OrderedDict()
        self._modules = OrderedDict()
        self.training = True

每當(dāng)我們給一個成員變量定義一個nn.parameter.Paramter的時候,都會自動注冊到_parameters,具體的步驟如下:

import torch.nn as nn
class MyModel(nn.Module):
 def __init__(self):
  super(MyModel, self).__init__()
  # 下面兩種定義方式均可
  self.p1 = nn.paramter.Paramter(torch.tensor(1.0))
  print(self._parameters)
  self.p2 = nn.Paramter(torch.tensor(2.0))
  print(self._parameters)

首先運行super(MyModel, self).__init__(),這樣MyModel就初始化了_paramters等一系列的OrderDict,此時所有變量還都是空的。

self.p1 = nn.paramter.Paramter(torch.tensor(1.0)): 這行代碼會觸發(fā)nn.Module預(yù)定義好的__setattr__函數(shù),該函數(shù)部分源碼如下:

def __setattr__(self, name, value):
 ...
 params = self.__dict__.get('_parameters')
 if isinstance(value, Parameter):
  if params is None:
   raise AttributeError(
    "cannot assign parameters before Module.__init__() call")
  remove_from(self.__dict__, self._buffers, self._modules)
  self.register_parameter(name, value)
 ...

__setattr__函數(shù)作用簡單理解就是判斷你定義的參數(shù)是否正確,如果正確就繼續(xù)調(diào)用register_parameter函數(shù)進(jìn)行注冊,這個函數(shù)簡單概括就是做了下面這件事

def register_parameter(self,name,param):
 ...
 self._parameters[name]=param

下面我們實例化這個模型看結(jié)果怎樣

model = MyModel()
>>>OrderedDict([('p1', Parameter containing:
tensor(1., requires_grad=True))])
OrderedDict([('p1', Parameter containing:
tensor(1., requires_grad=True)), ('p2', Parameter containing:
tensor(2., requires_grad=True))])

結(jié)果和上面分析的一致。

以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • 簡單介紹Ruby中的CGI編程

    簡單介紹Ruby中的CGI編程

    這篇文章主要介紹了簡單介紹Ruby中的CGI編程,包括創(chuàng)建Form表單等基本內(nèi)容,需要的朋友可以參考下
    2015-04-04
  • Python多項式回歸的實現(xiàn)方法

    Python多項式回歸的實現(xiàn)方法

    這篇文章主要介紹了Python多項式回歸的實現(xiàn)方法,小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧
    2019-03-03
  • Python 流媒體播放器的實現(xiàn)(基于VLC)

    Python 流媒體播放器的實現(xiàn)(基于VLC)

    這篇文章主要介紹了Python 流媒體播放器的實現(xiàn)(基于VLC),文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2021-04-04
  • Python關(guān)于迭代器的使用

    Python關(guān)于迭代器的使用

    這篇文章主要介紹了Python關(guān)于迭代器的使用,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教
    2024-06-06
  • 一文搞懂Python讀取text,CSV,JSON文件的方法

    一文搞懂Python讀取text,CSV,JSON文件的方法

    文件處理是一種用于創(chuàng)建文件、寫入數(shù)據(jù)和從中讀取數(shù)據(jù)的過程,Python 擁有豐富的用于處理不同文件類型的包,從而使得我們可以更加輕松方便的完成文件處理的工作,本文將來為大家詳細(xì)講講
    2022-06-06
  • Python疊加矩形框圖層2種方法及效果

    Python疊加矩形框圖層2種方法及效果

    這篇文章主要介紹了Python疊加矩形框圖層2種方法及效果,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下
    2020-06-06
  • python抓取最新博客內(nèi)容并生成Rss

    python抓取最新博客內(nèi)容并生成Rss

    本文給大家分享的是使用python抓取最新博客內(nèi)容并生成Rss的代碼,主要用到了PyRSS2Gen方法,非常的簡單實用,有需要的小伙伴可以參考下。
    2015-05-05
  • 初探利用Python進(jìn)行圖文識別(OCR)

    初探利用Python進(jìn)行圖文識別(OCR)

    這篇文章主要介紹了初探利用Python進(jìn)行圖文識別(OCR),小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧
    2019-02-02
  • Numpy數(shù)組的廣播機制的實現(xiàn)

    Numpy數(shù)組的廣播機制的實現(xiàn)

    這篇文章主要介紹了Numpy數(shù)組的廣播機制的實現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2020-11-11
  • tensorboard 可以顯示graph,卻不能顯示scalar的解決方式

    tensorboard 可以顯示graph,卻不能顯示scalar的解決方式

    今天小編就為大家分享一篇tensorboard 可以顯示graph,卻不能顯示scalar的解決方式,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2020-02-02

最新評論