pytorch中的weight-initilzation用法
pytorch中的權(quán)值初始化
官方論壇對(duì)weight-initilzation的討論
torch.nn.Module.apply(fn)
torch.nn.Module.apply(fn)
# 遞歸的調(diào)用weights_init函數(shù),遍歷nn.Module的submodule作為參數(shù)
# 常用來對(duì)模型的參數(shù)進(jìn)行初始化
# fn是對(duì)參數(shù)進(jìn)行初始化的函數(shù)的句柄,fn以nn.Module或者自己定義的nn.Module的子類作為參數(shù)
# fn (Module -> None) – function to be applied to each submodule
# Returns: self
# Return type: Module
例子:
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
# m.weight.data是卷積核參數(shù), m.bias.data是偏置項(xiàng)參數(shù)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
netG = _netG(ngpu) # 生成模型實(shí)例
netG.apply(weights_init) # 遞歸的調(diào)用weights_init函數(shù),遍歷netG的submodule作為參數(shù)
#-*-coding:utf-8-*-
import torch
from torch.autograd import Variable
# 對(duì)模型參數(shù)進(jìn)行初始化
# 官方論壇鏈接:https://discuss.pytorch.org/t/weight-initilzation/157/3
# 方法一
# 單獨(dú)定義一個(gè)weights_init函數(shù),輸入?yún)?shù)是m(torch.nn.module或者自己定義的繼承nn.module的子類)
# 然后使用net.apply()進(jìn)行參數(shù)初始化
# m.__class__.__name__ 獲得nn.module的名字
# https://github.com/pytorch/examples/blob/master/dcgan/main.py#L90-L96
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
netG = _netG(ngpu) # 生成模型實(shí)例
netG.apply(weights_init) # 遞歸的調(diào)用weights_init函數(shù),遍歷netG的submodule作為參數(shù)
# function to be applied to each submodule
# 方法二
# 1. 使用net.modules()遍歷模型中的網(wǎng)絡(luò)層的類型 2. 對(duì)其中的m層的weigth.data(tensor)部分進(jìn)行初始化操作
# Another initialization example from PyTorch Vision resnet implementation.
# https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L112-L118
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000):
self.inplanes = 64
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool2d(7, stride=1)
self.fc = nn.Linear(512 * block.expansion, num_classes)
# 權(quán)值參數(shù)初始化
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
# 方法三
# 自己知道網(wǎng)絡(luò)中參數(shù)的順序和類型, 然后將參數(shù)依次讀取出來,調(diào)用torch.nn.init中的方法進(jìn)行初始化
net = AlexNet(2)
params = list(net.parameters()) # params依次為Conv2d參數(shù)和Bias參數(shù)
# 或者
conv1Params = list(net.conv1.parameters())
# 其中,conv1Params[0]表示卷積核參數(shù), conv1Params[1]表示bias項(xiàng)參數(shù)
# 然后使用torch.nn.init中函數(shù)進(jìn)行初始化
torch.nn.init.normal(tensor, mean=0, std=1)
torch.nn.init.constant(tensor, 0)
# net.modules()迭代的返回: AlexNet,Sequential,Conv2d,ReLU,MaxPool2d,LRN,AvgPool3d....,Conv2d,...,Conv2d,...,Linear,
# 這里,只有Conv2d和Linear才有參數(shù)
# net.children()只返回實(shí)際存在的子模塊: Sequential,Sequential,Sequential,Sequential,Sequential,Sequential,Sequential,Linear
# 附AlexNet的定義
class AlexNet(nn.Module):
def __init__(self, num_classes = 2): # 默認(rèn)為兩類,貓和狗
# super().__init__() # python3
super(AlexNet, self).__init__()
# 開始構(gòu)建AlexNet網(wǎng)絡(luò)模型,5層卷積,3層全連接層
# 5層卷積層
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
LRN(local_size=5, bias=1, alpha=1e-4, beta=0.75, ACROSS_CHANNELS=True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, groups=2, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
LRN(local_size=5, bias=1, alpha=1e-4, beta=0.75, ACROSS_CHANNELS=True)
)
self.conv3 = nn.Sequential(
nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.conv4 = nn.Sequential(
nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.conv5 = nn.Sequential(
nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2)
)
# 3層全連接層
# 前向計(jì)算的時(shí)候,最開始輸入需要進(jìn)行view操作,將3D的tensor變?yōu)?D
self.fc6 = nn.Sequential(
nn.Linear(in_features=6*6*256, out_features=4096),
nn.ReLU(inplace=True),
nn.Dropout()
)
self.fc7 = nn.Sequential(
nn.Linear(in_features=4096, out_features=4096),
nn.ReLU(inplace=True),
nn.Dropout()
)
self.fc8 = nn.Linear(in_features=4096, out_features=num_classes)
def forward(self, x):
x = self.conv5(self.conv4(self.conv3(self.conv2(self.conv1(x)))))
x = x.view(-1, 6*6*256)
x = self.fc8(self.fc7(self.fc6(x)))
return x
補(bǔ)充知識(shí):pytorch Load部分weights
我們從網(wǎng)上down下來的模型與我們的模型可能就存在一個(gè)層的差異,此時(shí)我們就需要重新訓(xùn)練所有的參數(shù)是不合理的。
因此我們可以加載相同的參數(shù),而忽略不同的參數(shù),代碼如下:
pretrained_dict = torch.load(“model.pth”)
model_dict = et.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
net.load_state_dict(model_dict)
以上這篇pytorch中的weight-initilzation用法就是小編分享給大家的全部內(nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
python通過文本在一個(gè)圖中畫多條線的實(shí)例
今天小編就為大家分享一篇python通過文本在一個(gè)圖中畫多條線的實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-02-02
Tensorflow 定義變量,函數(shù),數(shù)值計(jì)算等名字的更新方式
今天小編就為大家分享一篇Tensorflow 定義變量,函數(shù),數(shù)值計(jì)算等名字的更新方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-02-02
Python 用三行代碼提取PDF表格數(shù)據(jù)
這篇文章主要介紹了Python 用三行代碼提取PDF表格數(shù)據(jù),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-10-10
Django框架靜態(tài)文件使用/中間件/禁用ip功能實(shí)例詳解
這篇文章主要介紹了Django框架靜態(tài)文件使用/中間件/禁用ip功能,結(jié)合實(shí)例形式詳細(xì)分析了Django框架靜態(tài)文件的使用、中間件的原理、操作方法以及禁用ip功能相關(guān)實(shí)現(xiàn)技巧,需要的朋友可以參考下2019-07-07
Python?第三方庫?Pandas?數(shù)據(jù)分析教程
這篇文章主要介紹了Python?第三方庫?Pandas?數(shù)據(jù)分析教程的相關(guān)資料,需要的朋友可以參考下2022-09-09

