PyTorch如何利用parameters()獲取模型參數(shù)
利用parameters()獲取模型參數(shù)
在PyTorch中,可以使用parameters函數(shù)來獲取模型中的所有可學(xué)習(xí)參數(shù)。
以下是一個(gè)示例:
import torch.nn as nn class MyModel(nn.Module): ? ? def __init__(self): ? ? ? ? super(MyModel, self).__init__() ? ? ? ? self.fc1 = nn.Linear(10, 5) ? ? ? ? self.fc2 = nn.Linear(5, 1) ? ? def forward(self, x): ? ? ? ? x = self.fc1(x) ? ? ? ? x = self.fc2(x) ? ? ? ? return x model = MyModel() params = list(model.parameters())
在這個(gè)示例中,我們首先定義了一個(gè)包含兩個(gè)線性層的神經(jīng)網(wǎng)絡(luò),然后通過list(model.parameters())獲取了模型中的所有可學(xué)習(xí)參數(shù)。
這些參數(shù)存儲(chǔ)在一個(gè)Python列表中,可以用于進(jìn)行優(yōu)化器的初始化和模型的保存和加載。
PyTorch中模型的parameters()方法
首先先定義一個(gè)模型:
import torch as t import torch.nn as nn class A(nn.Module): ? ? def __init__(self): ? ? ? ? super().__init__() ? ? ? ? self.conv1 = nn.Conv2d(2, 2, 3) ? ? ? ? self.conv2 = nn.Conv2d(2, 2, 3) ? ? ? ? self.conv3 = nn.Conv2d(2, 2, 3) ? ? def forward(self, x): ? ? ? ? x = self.conv1(x) ? ? ? ? x = self.conv2(x) ? ? ? ? x = self.conv3(x) ? ? ? ? return x
然后打印出該模型的參數(shù):
pythona = A() print(a.parameters()) #<generator object Module.parameters at 0x7f7b740d2360>
以上代碼說明parameters()會(huì)返回一個(gè)生成器(迭代器)
然后將其迭代打印出來:
print(list(a.parameters())):#將迭代器轉(zhuǎn)換成列表 Parameter containing: tensor([[[[-0.0299, ?0.0891, ?0.0303], ? ? ? ? ? [ 0.0869, -0.0230, -0.1760], ? ? ? ? ? [ 0.1408, ?0.0348, ?0.1795]], ? ? ? ? ?[[ 0.2001, ?0.0023, -0.1775], ? ? ? ? ? [ 0.0947, -0.0231, -0.1756], ? ? ? ? ? [ 0.1201, -0.0997, -0.0303]]], ? ? ? ? [[[-0.0425, ?0.0748, -0.1754], ? ? ? ? ? [-0.1191, -0.1203, -0.1219], ? ? ? ? ? [-0.0794, ?0.0895, -0.1719]], ? ? ? ? ?[[ 0.1968, -0.0463, ?0.0550], ? ? ? ? ? [-0.0386, ?0.1594, ?0.1282], ? ? ? ? ? [-0.0009, ?0.2167, -0.1783]]]], requires_grad=True) Parameter containing: tensor([ 0.0147, -0.0406], requires_grad=True) Parameter containing: tensor([[[[-0.0578, -0.1114, -0.1194], ? ? ? ? ? [-0.1469, -0.1175, -0.1616], ? ? ? ? ? [-0.2289, -0.0975, -0.1700]], ? ? ? ? ?[[-0.0894, ?0.0074, ?0.1222], ? ? ? ? ? [-0.0176, -0.0509, ?0.1622], ? ? ? ? ? [-0.0405, -0.1349, ?0.1782]]], ? ? ? ? [[[-0.0739, ?0.2167, ?0.1864], ? ? ? ? ? [ 0.0956, -0.1761, ?0.0464], ? ? ? ? ? [ 0.0062, -0.0685, ?0.0748]], ? ? ? ? ?[[ 0.1085, ?0.1481, ?0.1334], ? ? ? ? ? [ 0.2236, -0.0706, -0.0224], ? ? ? ? ? [ 0.0079, -0.1835, -0.0407]]]], requires_grad=True) Parameter containing: tensor([-8.0720e-05, ?1.6026e-01], requires_grad=True) Parameter containing: tensor([[[[-0.0702, ?0.1846, ?0.0419], ? ? ? ? ? [-0.1891, -0.0893, -0.0024], ? ? ? ? ? [-0.0349, -0.0213, ?0.0936]], ? ? ? ? ?[[-0.1062, ?0.1242, ?0.0391], ? ? ? ? ? [-0.1924, ?0.0535, -0.1480], ? ? ? ? ? [ 0.0400, -0.0487, -0.2317]]], ? ? ? ? [[[ 0.1202, ?0.0961, ?0.2336], ? ? ? ? ? [ 0.2225, -0.2294, -0.2283], ? ? ? ? ? [-0.0963, -0.0311, -0.2354]], ? ? ? ? ?[[ 0.0676, -0.0439, -0.0962], ? ? ? ? ? [-0.2316, -0.0639, -0.0671], ? ? ? ? ? [ 0.1737, -0.1169, -0.1751]]]], requires_grad=True) Parameter containing: tensor([-0.1939, -0.0959], requires_grad=True)
從以上結(jié)果可以看出列表中有6個(gè)元素,由于nn.Conv2d()的參數(shù)包括self.weight和self.bias兩部分,所以每個(gè)2D卷積層包括兩部分的參數(shù).注意self.bias是加在每個(gè)通道上的,所以self.bias的長(zhǎng)度與output_channl相同
心得:
parameters()會(huì)返回一個(gè)生成器(迭代器),生成器每次生成的是Tensor類型的數(shù)據(jù).
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python機(jī)器學(xué)習(xí)庫(kù)之Scikit-learn基本用法詳解
Scikit-learn?是?Python?中最著名的機(jī)器學(xué)習(xí)庫(kù)之一,它提供了大量實(shí)用的機(jī)器學(xué)習(xí)算法以及相關(guān)的工具,可以方便我們進(jìn)行數(shù)據(jù)挖掘和數(shù)據(jù)分析,在這篇文章中,我們將介紹?Scikit-learn?的基本使用,包括如何導(dǎo)入數(shù)據(jù)、預(yù)處理數(shù)據(jù)、選擇和訓(xùn)練模型,以及評(píng)估模型的性能2023-07-07Python?Matplotlib繪制箱線圖boxplot()函數(shù)詳解
箱線圖一般用來展現(xiàn)數(shù)據(jù)的分布(如上下四分位值、中位數(shù)等),同時(shí)也可以用箱線圖來反映數(shù)據(jù)的異常情況,下面這篇文章主要給大家介紹了關(guān)于Python?Matplotlib繪制箱線圖boxplot()函數(shù)的相關(guān)資料,需要的朋友可以參考下2022-07-07Python實(shí)現(xiàn)向好友發(fā)送微信消息優(yōu)化篇
利用python可以實(shí)現(xiàn)微信消息發(fā)送功能,怎么實(shí)現(xiàn)呢?你肯定會(huì)想著很復(fù)雜,但是python的好處就是很多人已經(jīng)把接口打包做好了,只需要調(diào)用即可,今天通過本文給大家分享使用?Python?實(shí)現(xiàn)微信消息發(fā)送的思路代碼,一起看看吧2022-06-06Python集成開發(fā)環(huán)境pycharm配置git的實(shí)現(xiàn)步驟
本文主要介紹了Python集成開發(fā)環(huán)境pycharm配置git的實(shí)現(xiàn)步驟,文中通過圖文的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2024-05-05python中超簡(jiǎn)單的字符分割算法記錄(車牌識(shí)別、儀表識(shí)別等)
這篇文章主要給大家介紹了關(guān)于python中超簡(jiǎn)單的字符分割算法記錄,如車牌識(shí)別、儀表識(shí)別等,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2021-09-09Anaconda修改默認(rèn)虛擬環(huán)境安裝位置的方案分享
新安裝Anaconda后,在創(chuàng)建環(huán)境時(shí)環(huán)境自動(dòng)安裝在C盤,但是C盤空間有限,下面這篇文章主要給大家介紹了關(guān)于Anaconda修改默認(rèn)虛擬環(huán)境安裝位置的相關(guān)資料,需要的朋友可以參考下2023-01-01