Pytorch實(shí)現(xiàn)簡單自定義網(wǎng)絡(luò)層的方法
前言
Pytorch、Tensoflow等許多深度學(xué)習(xí)框架集成了大量常見的網(wǎng)絡(luò)層,為我們搭建神經(jīng)網(wǎng)絡(luò)提供了諸多便利。但在實(shí)際工作中,因?yàn)轫?xiàng)目要求、研究需要或者發(fā)論文需要等等,大家一般都會(huì)需要自己發(fā)明一個(gè)現(xiàn)在在深度學(xué)習(xí)框架中還不存在的層。 在這些情況下,就必須構(gòu)建自定義層。
博主在學(xué)習(xí)了沐神的動(dòng)手學(xué)深度學(xué)習(xí)這本書之后,學(xué)到了許多東西。這里記錄一下書中基于Pytorch實(shí)現(xiàn)簡單自定義網(wǎng)絡(luò)層的方法,僅供參考。
一、不帶參數(shù)的層
首先,我們構(gòu)造一個(gè)沒有任何參數(shù)的自定義層,要構(gòu)建它,只需繼承基礎(chǔ)層類并實(shí)現(xiàn)前向傳播功能。
import torch import torch.nn.functional as F from torch import nn class CenteredLayer(nn.Module): def __init__(self): super().__init__() def forward(self, X): return X - X.mean()
輸入一些數(shù)據(jù),驗(yàn)證一下網(wǎng)絡(luò)是否能正常工作:
layer = CenteredLayer() print(layer(torch.FloatTensor([1, 2, 3, 4, 5])))
輸出結(jié)果如下:
tensor([-2., -1., 0., 1., 2.])
運(yùn)行正常,表明網(wǎng)絡(luò)沒有問題。
現(xiàn)在將我們自建的網(wǎng)絡(luò)層作為組件合并到更復(fù)雜的模型中,并輸入數(shù)據(jù)進(jìn)行驗(yàn)證:
net = nn.Sequential(nn.Linear(8, 128), CenteredLayer()) Y = net(torch.rand(4, 8)) print(Y.mean()) # 因?yàn)槟P蛥?shù)較多,輸出也較多,所以這里輸出Y的均值,驗(yàn)證模型可運(yùn)行即可
結(jié)果如下:
tensor(-5.5879e-09, grad_fn=<MeanBackward0>)
二、帶參數(shù)的層
這里使用內(nèi)置函數(shù)來創(chuàng)建參數(shù),這些函數(shù)可以提供一些基本的管理功能,使用更加方便。
這里實(shí)現(xiàn)了一個(gè)簡單的自定義的全連接層,大家可根據(jù)需要自行修改即可。
class MyLinear(nn.Module): def __init__(self, in_units, units): super().__init__() self.weight = nn.Parameter(torch.randn(in_units, units)) self.bias = nn.Parameter(torch.randn(units,)) def forward(self, X): linear = torch.matmul(X, self.weight.data) + self.bias.data return F.relu(linear)
接下來實(shí)例化類并訪問其模型參數(shù):
linear = MyLinear(5, 3) print(linear.weight)
結(jié)果如下:
Parameter containing:
tensor([[-0.3708, 1.2196, 1.3658],
[ 0.4914, -0.2487, -0.9602],
[ 1.8458, 0.3016, -0.3956],
[ 0.0616, -0.3942, 1.6172],
[ 0.7839, 0.6693, -0.8890]], requires_grad=True)
而后輸入一些數(shù)據(jù),查看模型輸出結(jié)果:
print(linear(torch.rand(2, 5))) # 結(jié)果如下 tensor([[1.2394, 0.0000, 0.0000], [1.3514, 0.0968, 0.6667]])
我們還可以使用自定義層構(gòu)建模型,使用方法與使用內(nèi)置的全連接層相同。
net = nn.Sequential(MyLinear(64, 8), MyLinear(8, 1)) print(net(torch.rand(2, 64))) # 結(jié)果如下 tensor([[4.1416], [0.2567]])
三、總結(jié)
我們可以通過基本層類設(shè)計(jì)自定義層。這允許我們定義靈活的新層,其行為與深度學(xué)習(xí)框架中的任何現(xiàn)有層不同。
在自定義層定義完成后,我們就可以在任意環(huán)境和網(wǎng)絡(luò)架構(gòu)中調(diào)用該自定義層。
層可以有局部參數(shù),這些參數(shù)可以通過內(nèi)置函數(shù)創(chuàng)建。
四、參考
《動(dòng)手學(xué)深度學(xué)習(xí)》 — 動(dòng)手學(xué)深度學(xué)習(xí) 2.0.0-beta0 documentation
附:pytorch獲取網(wǎng)絡(luò)的層數(shù)和每層的名字
#創(chuàng)建自己的網(wǎng)絡(luò) import models model = models.__dict__["resnet50"](pretrained=True) for index ,(name, param) in enumerate(model.named_parameters()): ? ? print( str(index) + " " +name)
結(jié)果如下:
0 conv1.weight
1 bn1.weight
2 bn1.bias
3 layer1.0.conv1.weight
4 layer1.0.bn1.weight
5 layer1.0.bn1.bias
6 layer1.0.conv2.weight
7 layer1.0.bn2.weight
8 layer1.0.bn2.bias
9 layer1.0.conv3.weight
到此這篇關(guān)于Pytorch實(shí)現(xiàn)簡單自定義網(wǎng)絡(luò)層的文章就介紹到這了,更多相關(guān)Pytorch自定義網(wǎng)絡(luò)層內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python Shiny庫創(chuàng)建交互式Web應(yīng)用及高級(jí)功能案例
Shiny是一個(gè)基于Python的交互式Web應(yīng)用框架,專注于簡化Web應(yīng)用的開發(fā)流程,本文將深入探討Shiny庫的基本用法、高級(jí)功能以及實(shí)際應(yīng)用案例,以幫助開發(fā)者充分發(fā)揮Shiny在Web應(yīng)用開發(fā)中的優(yōu)勢(shì)2023-12-12python基礎(chǔ)教程之匿名函數(shù)lambda
這篇文章主要介紹了 python基礎(chǔ)教程之匿名函數(shù)lambda的相關(guān)資料,需要的朋友可以參考下2017-01-01Python利用multiprocessing實(shí)現(xiàn)最簡單的分布式作業(yè)調(diào)度系統(tǒng)實(shí)例
這篇文章主要給大家介紹了關(guān)于Python利用multiprocessing如何實(shí)現(xiàn)最簡單的分布式作業(yè)調(diào)度系統(tǒng)的相關(guān)資料,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面來一起看看吧。2017-11-11python 對(duì)類的成員函數(shù)開啟線程的方法
今天小編就為大家分享一篇python 對(duì)類的成員函數(shù)開啟線程的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2019-01-01使用Pytorch實(shí)現(xiàn)Swish激活函數(shù)的示例詳解
激活函數(shù)是人工神經(jīng)網(wǎng)絡(luò)的基本組成部分,他們將非線性引入模型,使其能夠?qū)W習(xí)數(shù)據(jù)中的復(fù)雜關(guān)系,Swish 激活函數(shù)就是此類激活函數(shù)之一,在本文中,我們將深入研究 Swish 激活函數(shù),提供數(shù)學(xué)公式,探索其相對(duì)于 ReLU 的優(yōu)勢(shì),并使用 PyTorch 演示其實(shí)現(xiàn)2023-11-11python numpy中multiply與*及matul 的區(qū)別說明
這篇文章主要介紹了python numpy中multiply與*及matul 的區(qū)別說明,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2021-05-05