欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

PyTorch中的參數(shù)類torch.nn.Parameter()詳解

 更新時(shí)間:2022年02月24日 09:40:18   作者:Adenialzz  
這篇文章主要給大家介紹了關(guān)于PyTorch中torch.nn.Parameter()的相關(guān)資料,要內(nèi)容包括基礎(chǔ)應(yīng)用、實(shí)用技巧、原理機(jī)制等方面,文章通過(guò)實(shí)例介紹的非常詳細(xì),需要的朋友可以參考下

前言

今天來(lái)聊一下PyTorch中的torch.nn.Parameter()這個(gè)函數(shù),筆者第一次見(jiàn)的時(shí)候也是大概能理解函數(shù)的用途,但是具體實(shí)現(xiàn)原理細(xì)節(jié)也是云里霧里,在參考了幾篇博文,做過(guò)幾個(gè)實(shí)驗(yàn)之后算是清晰了,本文在記錄的同時(shí)希望給后來(lái)人一個(gè)參考,歡迎留言討論。

分析

先看其名,parameter,中文意為參數(shù)。我們知道,使用PyTorch訓(xùn)練神經(jīng)網(wǎng)絡(luò)時(shí),本質(zhì)上就是訓(xùn)練一個(gè)函數(shù),這個(gè)函數(shù)輸入一個(gè)數(shù)據(jù)(如CV中輸入一張圖像),輸出一個(gè)預(yù)測(cè)(如輸出這張圖像中的物體是屬于什么類別)。而在我們給定這個(gè)函數(shù)的結(jié)構(gòu)(如卷積、全連接等)之后,能學(xué)習(xí)的就是這個(gè)函數(shù)的參數(shù)了,我們?cè)O(shè)計(jì)一個(gè)損失函數(shù),配合梯度下降法,使得我們學(xué)習(xí)到的函數(shù)(神經(jīng)網(wǎng)絡(luò))能夠盡量準(zhǔn)確地完成預(yù)測(cè)任務(wù)。

通常,我們的參數(shù)都是一些常見(jiàn)的結(jié)構(gòu)(卷積、全連接等)里面的計(jì)算參數(shù)。而當(dāng)我們的網(wǎng)絡(luò)有一些其他的設(shè)計(jì)時(shí),會(huì)需要一些額外的參數(shù)同樣很著整個(gè)網(wǎng)絡(luò)的訓(xùn)練進(jìn)行學(xué)習(xí)更新,最后得到最優(yōu)的值,經(jīng)典的例子有注意力機(jī)制中的權(quán)重參數(shù)、Vision Transformer中的class token和positional embedding等。

而這里的torch.nn.Parameter()就可以很好地適應(yīng)這種應(yīng)用場(chǎng)景。

下面是這篇博客的一個(gè)總結(jié),筆者認(rèn)為講的比較明白,在這里引用一下:

首先可以把這個(gè)函數(shù)理解為類型轉(zhuǎn)換函數(shù),將一個(gè)不可訓(xùn)練的類型Tensor轉(zhuǎn)換成可以訓(xùn)練的類型parameter并將這個(gè)parameter綁定到這個(gè)module里面(net.parameter()中就有這個(gè)綁定的parameter,所以在參數(shù)優(yōu)化的時(shí)候可以進(jìn)行優(yōu)化的),所以經(jīng)過(guò)類型轉(zhuǎn)換這個(gè)self.v變成了模型的一部分,成為了模型中根據(jù)訓(xùn)練可以改動(dòng)的參數(shù)了。使用這個(gè)函數(shù)的目的也是想讓某些變量在學(xué)習(xí)的過(guò)程中不斷的修改其值以達(dá)到最優(yōu)化。

ViT中nn.Parameter()的實(shí)驗(yàn)

看過(guò)這個(gè)分析后,我們?cè)倏匆幌耉ision Transformer中的用法:

...

self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
...

我們知道在ViT中,positonal embedding和class token是兩個(gè)需要隨著網(wǎng)絡(luò)訓(xùn)練學(xué)習(xí)的參數(shù),但是它們又不屬于FC、MLP、MSA等運(yùn)算的參數(shù),在這時(shí),就可以用nn.Parameter()來(lái)將這個(gè)隨機(jī)初始化的Tensor注冊(cè)為可學(xué)習(xí)的參數(shù)Parameter。

為了確定這兩個(gè)參數(shù)確實(shí)是被添加到了net.Parameters()內(nèi),筆者稍微改動(dòng)源碼,顯式地指定這兩個(gè)參數(shù)的初始數(shù)值為0.98,并打印迭代器net.Parameters()。

...

self.pos_embedding = nn.Parameter(torch.ones(1, num_patches+1, dim) * 0.98)
self.cls_token = nn.Parameter(torch.ones(1, 1, dim) * 0.98)
...

實(shí)例化一個(gè)ViT模型并打印net.Parameters():

net_vit = ViT(
        image_size = 256,
        patch_size = 32,
        num_classes = 1000,
        dim = 1024,
        depth = 6,
        heads = 16,
        mlp_dim = 2048,
        dropout = 0.1,
        emb_dropout = 0.1
    )

for para in net_vit.parameters():
        print(para.data)

輸出結(jié)果中可以看到,最前兩行就是我們顯式指定為0.98的兩個(gè)參數(shù)pos_embedding和cls_token:

tensor([[[0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],
         [0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],
         [0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],
         ...,
         [0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],
         [0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],
         [0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800]]])
tensor([[[0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800]]])
tensor([[-0.0026, -0.0064,  0.0111,  ...,  0.0091, -0.0041, -0.0060],
        [ 0.0003,  0.0115,  0.0059,  ..., -0.0052, -0.0056,  0.0010],
        [ 0.0079,  0.0016, -0.0094,  ...,  0.0174,  0.0065,  0.0001],
        ...,
        [-0.0110, -0.0137,  0.0102,  ...,  0.0145, -0.0105, -0.0167],
        [-0.0116, -0.0147,  0.0030,  ...,  0.0087,  0.0022,  0.0108],
        [-0.0079,  0.0033, -0.0087,  ..., -0.0174,  0.0103,  0.0021]])
...
...

這就可以確定nn.Parameter()添加的參數(shù)確實(shí)是被添加到了Parameters列表中,會(huì)被送入優(yōu)化器中隨訓(xùn)練一起學(xué)習(xí)更新。

from torch.optim import Adam
opt = Adam(net_vit.parameters(), learning_rate=0.001)

其他解釋

以下是國(guó)外StackOverflow的一個(gè)大佬的解讀,筆者自行翻譯并放在這里供大家參考,想查看原文的同學(xué)請(qǐng)戳這里。

我們知道Tensor相當(dāng)于是一個(gè)高維度的矩陣,它是Variable類的子類。Variable和Parameter之間的差異體現(xiàn)在與Module關(guān)聯(lián)時(shí)。當(dāng)Parameter作為model的屬性與module相關(guān)聯(lián)時(shí),它會(huì)被自動(dòng)添加到Parameters列表中,并且可以使用net.Parameters()迭代器進(jìn)行訪問(wèn)。

最初在Torch中,一個(gè)Variable(例如可以是某個(gè)中間state)也會(huì)在賦值時(shí)被添加為模型的Parameter。在某些實(shí)例中,需要緩存變量,而不是將它們添加到Parameters列表中。

文檔中提到的一種情況是RNN,在這種情況下,您需要保存最后一個(gè)hidden state,這樣就不必一次又一次地傳遞它。需要緩存一個(gè)Variable,而不是讓它自動(dòng)注冊(cè)為模型的Parameter,這就是為什么我們有一個(gè)顯式的方法將參數(shù)注冊(cè)到我們的模型,即nn.Parameter類。

舉個(gè)例子:

import torch
import torch.nn as nn
from torch.optim import Adam

class NN_Network(nn.Module):
    def __init__(self,in_dim,hid,out_dim):
        super(NN_Network, self).__init__()
        self.linear1 = nn.Linear(in_dim,hid)
        self.linear2 = nn.Linear(hid,out_dim)
        self.linear1.weight = torch.nn.Parameter(torch.zeros(in_dim,hid))
        self.linear1.bias = torch.nn.Parameter(torch.ones(hid))
        self.linear2.weight = torch.nn.Parameter(torch.zeros(in_dim,hid))
        self.linear2.bias = torch.nn.Parameter(torch.ones(hid))

    def forward(self, input_array):
        h = self.linear1(input_array)
        y_pred = self.linear2(h)
        return y_pred

in_d = 5
hidn = 2
out_d = 3
net = NN_Network(in_d, hidn, out_d)

然后檢查一下這個(gè)模型的Parameters列表:

for param in net.parameters():
    print(type(param.data), param.size())

""" Output
<class 'torch.FloatTensor'> torch.Size([5, 2])
<class 'torch.FloatTensor'> torch.Size([2])
<class 'torch.FloatTensor'> torch.Size([5, 2])
<class 'torch.FloatTensor'> torch.Size([2])
"""

可以輕易地送入到優(yōu)化器中:

opt = Adam(net.parameters(), learning_rate=0.001)

另外,請(qǐng)注意Parameter的require_grad會(huì)自動(dòng)設(shè)定。

各位讀者有疑惑或異議的地方,歡迎留言討論。

參考:

http://www.dbjr.com.cn/article/238632.htm

https://stackoverflow.com/questions/50935345/understanding-torch-nn-parameter

總結(jié)

到此這篇關(guān)于PyTorch中torch.nn.Parameter()的文章就介紹到這了,更多相關(guān)PyTorch中torch.nn.Parameter()內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

  • Python OpenCV實(shí)現(xiàn)鼠標(biāo)畫(huà)框效果

    Python OpenCV實(shí)現(xiàn)鼠標(biāo)畫(huà)框效果

    這篇文章主要為大家詳細(xì)介紹了Python OpenCV實(shí)現(xiàn)鼠標(biāo)畫(huà)框效果,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下
    2019-08-08
  • python使用正則表達(dá)式匹配字符串開(kāi)頭并打印示例

    python使用正則表達(dá)式匹配字符串開(kāi)頭并打印示例

    這篇文章主要介紹了python使用正則表達(dá)式匹配字符串開(kāi)頭并打印的方法,結(jié)合實(shí)例形式分析了Python基于正則表達(dá)式操作字符串的相關(guān)技巧,需要的朋友可以參考下
    2017-01-01
  • Python+AI實(shí)現(xiàn)給老照片上色

    Python+AI實(shí)現(xiàn)給老照片上色

    今天給大家分享一個(gè)有趣的AI項(xiàng)目——利用NoGAN的圖像增強(qiáng)技術(shù)給老照片著色,文中的示例代碼講解詳細(xì),感興趣的小伙伴可以了解一下
    2022-06-06
  • python爬取網(wǎng)頁(yè)數(shù)據(jù)到保存到csv

    python爬取網(wǎng)頁(yè)數(shù)據(jù)到保存到csv

    大家好,本篇文章主要講的是python爬取網(wǎng)頁(yè)數(shù)據(jù)到保存到csv,感興趣的同學(xué)趕快來(lái)看一看吧,對(duì)你有幫助的話記得收藏一下,方便下次瀏覽
    2022-01-01
  • python 生成器協(xié)程運(yùn)算實(shí)例

    python 生成器協(xié)程運(yùn)算實(shí)例

    下面小編就為大家?guī)?lái)一篇python 生成器協(xié)程運(yùn)算實(shí)例。小編覺(jué)得挺不錯(cuò)的,現(xiàn)在就分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧
    2017-09-09
  • python遞歸計(jì)算N!的方法

    python遞歸計(jì)算N!的方法

    這篇文章主要介紹了python遞歸計(jì)算N!的方法,涉及Python遞歸計(jì)算階乘的技巧,非常簡(jiǎn)單實(shí)用,需要的朋友可以參考下
    2015-05-05
  • Python如何獲取文件路徑/目錄

    Python如何獲取文件路徑/目錄

    這篇文章主要介紹了Python如何獲取文件路徑/目錄,幫助大家更好的利用python處理文件,感興趣的朋友可以了解下
    2020-09-09
  • Python3.5 + sklearn利用SVM自動(dòng)識(shí)別字母驗(yàn)證碼方法示例

    Python3.5 + sklearn利用SVM自動(dòng)識(shí)別字母驗(yàn)證碼方法示例

    這篇文章主要給大家介紹了關(guān)于Python3.5 + sklearn利用SVM自動(dòng)識(shí)別字母驗(yàn)證碼的相關(guān)資料,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家學(xué)習(xí)或者使用Python具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面來(lái)一起學(xué)習(xí)學(xué)習(xí)吧
    2019-05-05
  • wxPython實(shí)現(xiàn)窗口用圖片做背景

    wxPython實(shí)現(xiàn)窗口用圖片做背景

    這篇文章主要為大家詳細(xì)介紹了wxPython實(shí)現(xiàn)窗口用圖片做背景,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下
    2018-04-04
  • Django為窗體加上防機(jī)器人的驗(yàn)證碼功能過(guò)程解析

    Django為窗體加上防機(jī)器人的驗(yàn)證碼功能過(guò)程解析

    這篇文章主要介紹了Django為窗體加上防機(jī)器人的驗(yàn)證碼功能過(guò)程解析,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下
    2019-08-08

最新評(píng)論