欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch中的weight-initilzation用法

 更新時間:2020年06月24日 09:08:45   作者:tsq292978891  
這篇文章主要介紹了pytorch中的weight-initilzation用法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧

pytorch中的權(quán)值初始化

官方論壇對weight-initilzation的討論

torch.nn.Module.apply(fn)

torch.nn.Module.apply(fn)
# 遞歸的調(diào)用weights_init函數(shù),遍歷nn.Module的submodule作為參數(shù)
# 常用來對模型的參數(shù)進行初始化
# fn是對參數(shù)進行初始化的函數(shù)的句柄,fn以nn.Module或者自己定義的nn.Module的子類作為參數(shù)
# fn (Module -> None) – function to be applied to each submodule
# Returns: self
# Return type: Module

例子:

def weights_init(m):
 classname = m.__class__.__name__
 if classname.find('Conv') != -1:
  m.weight.data.normal_(0.0, 0.02) 
  # m.weight.data是卷積核參數(shù), m.bias.data是偏置項參數(shù)
 elif classname.find('BatchNorm') != -1:
  m.weight.data.normal_(1.0, 0.02)
  m.bias.data.fill_(0)

netG = _netG(ngpu) # 生成模型實例
netG.apply(weights_init) # 遞歸的調(diào)用weights_init函數(shù),遍歷netG的submodule作為參數(shù)
#-*-coding:utf-8-*-
import torch
from torch.autograd import Variable

# 對模型參數(shù)進行初始化
# 官方論壇鏈接:https://discuss.pytorch.org/t/weight-initilzation/157/3

# 方法一
# 單獨定義一個weights_init函數(shù),輸入?yún)?shù)是m(torch.nn.module或者自己定義的繼承nn.module的子類)
# 然后使用net.apply()進行參數(shù)初始化
# m.__class__.__name__ 獲得nn.module的名字
# https://github.com/pytorch/examples/blob/master/dcgan/main.py#L90-L96
def weights_init(m):
 classname = m.__class__.__name__
 if classname.find('Conv') != -1:
  m.weight.data.normal_(0.0, 0.02)
 elif classname.find('BatchNorm') != -1:
  m.weight.data.normal_(1.0, 0.02)
  m.bias.data.fill_(0)

netG = _netG(ngpu) # 生成模型實例
netG.apply(weights_init) # 遞歸的調(diào)用weights_init函數(shù),遍歷netG的submodule作為參數(shù)

# function to be applied to each submodule

# 方法二
# 1. 使用net.modules()遍歷模型中的網(wǎng)絡層的類型 2. 對其中的m層的weigth.data(tensor)部分進行初始化操作
# Another initialization example from PyTorch Vision resnet implementation.
# https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L112-L118
class ResNet(nn.Module):
 def __init__(self, block, layers, num_classes=1000):
  self.inplanes = 64
  super(ResNet, self).__init__()
  self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
        bias=False)
  self.bn1 = nn.BatchNorm2d(64)
  self.relu = nn.ReLU(inplace=True)
  self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  self.layer1 = self._make_layer(block, 64, layers[0])
  self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
  self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
  self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
  self.avgpool = nn.AvgPool2d(7, stride=1)
  self.fc = nn.Linear(512 * block.expansion, num_classes)
  # 權(quán)值參數(shù)初始化
  for m in self.modules():
   if isinstance(m, nn.Conv2d):
    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    m.weight.data.normal_(0, math.sqrt(2. / n))
   elif isinstance(m, nn.BatchNorm2d):
    m.weight.data.fill_(1)
    m.bias.data.zero_()

# 方法三
# 自己知道網(wǎng)絡中參數(shù)的順序和類型, 然后將參數(shù)依次讀取出來,調(diào)用torch.nn.init中的方法進行初始化
net = AlexNet(2)
params = list(net.parameters()) # params依次為Conv2d參數(shù)和Bias參數(shù)
# 或者
conv1Params = list(net.conv1.parameters())
# 其中,conv1Params[0]表示卷積核參數(shù), conv1Params[1]表示bias項參數(shù)
# 然后使用torch.nn.init中函數(shù)進行初始化
torch.nn.init.normal(tensor, mean=0, std=1)
torch.nn.init.constant(tensor, 0)

# net.modules()迭代的返回: AlexNet,Sequential,Conv2d,ReLU,MaxPool2d,LRN,AvgPool3d....,Conv2d,...,Conv2d,...,Linear,
# 這里,只有Conv2d和Linear才有參數(shù)
# net.children()只返回實際存在的子模塊: Sequential,Sequential,Sequential,Sequential,Sequential,Sequential,Sequential,Linear

# 附AlexNet的定義
class AlexNet(nn.Module):
 def __init__(self, num_classes = 2): # 默認為兩類,貓和狗
#   super().__init__() # python3
  super(AlexNet, self).__init__()
  # 開始構(gòu)建AlexNet網(wǎng)絡模型,5層卷積,3層全連接層
  # 5層卷積層
  self.conv1 = nn.Sequential(
   nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4),
   nn.ReLU(inplace=True),
   nn.MaxPool2d(kernel_size=3, stride=2),
   LRN(local_size=5, bias=1, alpha=1e-4, beta=0.75, ACROSS_CHANNELS=True)
  )
  self.conv2 = nn.Sequential(
   nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, groups=2, padding=2),
   nn.ReLU(inplace=True),
   nn.MaxPool2d(kernel_size=3, stride=2),
   LRN(local_size=5, bias=1, alpha=1e-4, beta=0.75, ACROSS_CHANNELS=True)
  )
  self.conv3 = nn.Sequential(
   nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, padding=1),
   nn.ReLU(inplace=True)
  )
  self.conv4 = nn.Sequential(
   nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1),
   nn.ReLU(inplace=True)
  )
  self.conv5 = nn.Sequential(
   nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1),
   nn.ReLU(inplace=True),
   nn.MaxPool2d(kernel_size=3, stride=2)
  )
  # 3層全連接層
  # 前向計算的時候,最開始輸入需要進行view操作,將3D的tensor變?yōu)?D
  self.fc6 = nn.Sequential(
   nn.Linear(in_features=6*6*256, out_features=4096),
   nn.ReLU(inplace=True),
   nn.Dropout()
  )
  self.fc7 = nn.Sequential(
   nn.Linear(in_features=4096, out_features=4096),
   nn.ReLU(inplace=True),
   nn.Dropout()
  )
  self.fc8 = nn.Linear(in_features=4096, out_features=num_classes)

 def forward(self, x):
  x = self.conv5(self.conv4(self.conv3(self.conv2(self.conv1(x)))))
  x = x.view(-1, 6*6*256)
  x = self.fc8(self.fc7(self.fc6(x)))
  return x

補充知識:pytorch Load部分weights

我們從網(wǎng)上down下來的模型與我們的模型可能就存在一個層的差異,此時我們就需要重新訓練所有的參數(shù)是不合理的。

因此我們可以加載相同的參數(shù),而忽略不同的參數(shù),代碼如下:

  pretrained_dict = torch.load(“model.pth”)
  model_dict = et.state_dict()
  pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
  model_dict.update(pretrained_dict)
  net.load_state_dict(model_dict)

以上這篇pytorch中的weight-initilzation用法就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • python裝飾器decorator介紹

    python裝飾器decorator介紹

    這篇文章主要介紹了python裝飾器decorator介紹,decorator設計模式允許動態(tài)地對現(xiàn)有的對象或函數(shù)包裝以至于修改現(xiàn)有的職責和行為,簡單地講用來動態(tài)地擴展現(xiàn)有的功能,需要的朋友可以參考下
    2014-11-11
  • python通過文本在一個圖中畫多條線的實例

    python通過文本在一個圖中畫多條線的實例

    今天小編就為大家分享一篇python通過文本在一個圖中畫多條線的實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2020-02-02
  • Python操作配置文件ini的三種方法講解

    Python操作配置文件ini的三種方法講解

    今天小編就為大家分享一篇關(guān)于Python操作配置文件ini的三種方法講解,小編覺得內(nèi)容挺不錯的,現(xiàn)在分享給大家,具有很好的參考價值,需要的朋友一起跟隨小編來看看吧
    2019-02-02
  • Python蛇形方陣的實現(xiàn)

    Python蛇形方陣的實現(xiàn)

    本文主要介紹了Python蛇形方陣的實現(xiàn),文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧
    2023-05-05
  • Tensorflow 定義變量,函數(shù),數(shù)值計算等名字的更新方式

    Tensorflow 定義變量,函數(shù),數(shù)值計算等名字的更新方式

    今天小編就為大家分享一篇Tensorflow 定義變量,函數(shù),數(shù)值計算等名字的更新方式,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2020-02-02
  • Python操作MongoDB增刪改查代碼示例

    Python操作MongoDB增刪改查代碼示例

    這篇文章主要介紹了Python操作MongoDB增刪改查代碼示例,需要的朋友可以參考下
    2022-12-12
  • Python 用三行代碼提取PDF表格數(shù)據(jù)

    Python 用三行代碼提取PDF表格數(shù)據(jù)

    這篇文章主要介紹了Python 用三行代碼提取PDF表格數(shù)據(jù),文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧
    2019-10-10
  • 講解Python中if語句的嵌套用法

    講解Python中if語句的嵌套用法

    這篇文章主要介紹了講解Python中if語句的嵌套用法,是Python入門當中的基礎知識,需要的朋友可以參考下
    2015-05-05
  • Django框架靜態(tài)文件使用/中間件/禁用ip功能實例詳解

    Django框架靜態(tài)文件使用/中間件/禁用ip功能實例詳解

    這篇文章主要介紹了Django框架靜態(tài)文件使用/中間件/禁用ip功能,結(jié)合實例形式詳細分析了Django框架靜態(tài)文件的使用、中間件的原理、操作方法以及禁用ip功能相關(guān)實現(xiàn)技巧,需要的朋友可以參考下
    2019-07-07
  • Python?第三方庫?Pandas?數(shù)據(jù)分析教程

    Python?第三方庫?Pandas?數(shù)據(jù)分析教程

    這篇文章主要介紹了Python?第三方庫?Pandas?數(shù)據(jù)分析教程的相關(guān)資料,需要的朋友可以參考下
    2022-09-09

最新評論