pytorch 網(wǎng)絡(luò)參數(shù) weight bias 初始化詳解
權(quán)重初始化對于訓練神經(jīng)網(wǎng)絡(luò)至關(guān)重要,好的初始化權(quán)重可以有效的避免梯度消失等問題的發(fā)生。
在pytorch的使用過程中有幾種權(quán)重初始化的方法供大家參考。
注意:第一種方法不推薦。盡量使用后兩種方法。
# not recommend
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
# recommend def initialize_weights(m): if isinstance(m, nn.Conv2d): m.weight.data.normal_(0, 0.02) m.bias.data.zero_() elif isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.02) m.bias.data.zero_()
# recommend def weights_init(m): if isinstance(m, nn.Conv2d): nn.init.xavier_normal_(m.weight.data) nn.init.xavier_normal_(m.bias.data) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight,1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm1d): nn.init.constant_(m.weight,1) nn.init.constant_(m.bias, 0)
編寫好weights_init函數(shù)后,可以使用模型的apply方法對模型進行權(quán)重初始化。
net = Residual() # generate an instance network from the Net class
net.apply(weights_init) # apply weight init
補充知識:Pytorch權(quán)值初始化及參數(shù)分組
1. 模型參數(shù)初始化
# ————————————————— 利用model.apply(weights_init)實現(xiàn)初始化
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('BatchNorm') != -1:
m.weight.data.fill_(1)
m.bias.data.zero_()
elif classname.find('Linear') != -1:
n = m.weight.size(1)
m.weight.data.normal_(0, 0.01)
m.bias.data = torch.ones(m.bias.data.size())
# ————————————————— 直接放在__init__構(gòu)造函數(shù)中實現(xiàn)初始化
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm1d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
# —————————————————
self.weight = Parameter(torch.Tensor(out_features, in_features))
self.bias = Parameter(torch.FloatTensor(out_features))
nn.init.xavier_uniform_(self.weight)
nn.init.zero_(self.bias)
nn.init.constant_(m, initm)
# nn.init.kaiming_uniform_()
# self.weight.data.normal_(std=0.001)
2. 模型參數(shù)分組weight_decay
def separate_bn_prelu_params(model, ignored_params=[]):
bn_prelu_params = []
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
ignored_params += list(map(id, m.parameters()))
bn_prelu_params += m.parameters()
if isinstance(m, nn.BatchNorm1d):
ignored_params += list(map(id, m.parameters()))
bn_prelu_params += m.parameters()
elif isinstance(m, nn.PReLU):
ignored_params += list(map(id, m.parameters()))
bn_prelu_params += m.parameters()
base_params = list(filter(lambda p: id(p) not in ignored_params, model.parameters()))
return base_params, bn_prelu_params, ignored_params
OPTIMIZER = optim.SGD([
{'params': base_params, 'weight_decay': WEIGHT_DECAY},
{'params': fc_head_param, 'weight_decay': WEIGHT_DECAY * 10},
{'params': bn_prelu_params, 'weight_decay': 0.0}
], lr=LR, momentum=MOMENTUM ) # , nesterov=True
Note 1:PReLU(x) = max(0,x) + a * min(0,x). Here a is a learnable parameter. When called without arguments, nn.PReLU() uses a single parameter a across all input channels. If called with nn.PReLU(nChannels), a separate a is used for each input channel.
Note 2: weight decay should not be used when learning a for good performance.
Note 3: The default number of a to learn is 1, the default initial value of a is 0.25.
3. 參數(shù)分組weight_decay–其他
第2節(jié)中的內(nèi)容可以滿足一般的參數(shù)分組需求,此部分可以滿足更個性化的分組需求。參考:face_evoLVe_Pytorch-master
自定義schedule
def schedule_lr(optimizer):
for params in optimizer.param_groups:
params['lr'] /= 10.
print(optimizer)
方法一:利用model.modules()和obj.__class__ (更普適)
# model.modules()和model.children()的區(qū)別:model.modules()會迭代地遍歷模型的所有子層,而model.children()只會遍歷模型下的一層
# 下面的關(guān)鍵詞if 'model',源于模型定義文件。如model_resnet.py中自定義的所有nn.Module子類,都會前綴'model_resnet',所以可通過這種方式一次性篩選出自定義的模塊
def separate_irse_bn_paras(model):
paras_only_bn = []
paras_no_bn = []
for layer in model.modules():
if 'model' in str(layer.__class__): # eg. a=[1,2] type(a): <class 'list'> a.__class__: <class 'list'>
continue
if 'container' in str(layer.__class__): # 去掉Sequential型的模塊
continue
else:
if 'batchnorm' in str(layer.__class__):
paras_only_bn.extend([*layer.parameters()])
else:
paras_no_bn.extend([*layer.parameters()]) # extend()用于在列表末尾一次性追加另一個序列中的多個值(用新列表擴展原來的列表)
return paras_only_bn, paras_no_bn
方法二:調(diào)用modules.parameters和named_parameters()
但是本質(zhì)上,parameters()是根據(jù)named_parameters()獲取,named_parameters()是根據(jù)modules()獲取。使用此方法的前提是,須按下文1,2中的方式定義模型,或者利用Sequential+OrderedDict定義模型。
def separate_resnet_bn_paras(model):
all_parameters = model.parameters()
paras_only_bn = []
for pname, p in model.named_parameters():
if pname.find('bn') >= 0:
paras_only_bn.append(p)
paras_only_bn_id = list(map(id, paras_only_bn))
paras_no_bn = list(filter(lambda p: id(p) not in paras_only_bn_id, all_parameters))
return paras_only_bn, paras_no_bn
兩種方法的區(qū)別
參數(shù)分組的區(qū)別,其實對應(yīng)了模型構(gòu)造時的區(qū)別。舉例:
1、構(gòu)造ResNet的basic block,在__init__()函數(shù)中定義了
self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = BatchNorm2d(planes) self.relu = ReLU(inplace = True) …
2、在forward()中定義
out = self.conv1(x) out = self.bn1(out) out = self.relu(out) …
3、對ResNet取model.name_parameters()返回的pname形如:
‘layer1.0.conv1.weight' ‘layer1.0.bn1.weight' ‘layer1.0.bn1.bias' # layer對應(yīng)conv2_x, …, conv5_x; '0'對應(yīng)各layer中的block索引,比如conv2_x有3個block,對應(yīng)索引為layer1.0, …, layer1.2; 'conv1'就是__init__()中定義的self.conv1
4、若構(gòu)造model時采用了Sequential(),則model.name_parameters()返回的pname形如:
‘body.3.res_layer.1.weight',此處的1.weight實際對應(yīng)了BN的weight,無法通過pname.find(‘bn')找到該模塊。
self.res_layer = Sequential( Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), BatchNorm2d(depth), ReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) )
5、針對4中的情況,兩種解決辦法:利用OrderedDict修飾Sequential,或利用方法一
downsample = Sequential( OrderedDict([ (‘conv_ds', conv1x1(self.inplanes, planes * block.expansion, stride)), (‘bn_ds', BatchNorm2d(planes * block.expansion)), ])) # 如此,相應(yīng)模塊的pname將會帶有'conv_ds',‘bn_ds'字樣
以上這篇pytorch 網(wǎng)絡(luò)參數(shù) weight bias 初始化詳解就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python命令行參數(shù)解析包argparse的使用詳解
argparse?是?python?自帶的命令行參數(shù)解析包,可以用來方便的服務(wù)命令行參數(shù)。本文將通過示例和大家詳細講講argparse的使用,需要的可以參考一下2022-09-09
Python最大連續(xù)區(qū)間和動態(tài)規(guī)劃
這篇文章主要介紹了Python最大連續(xù)區(qū)間和動態(tài)規(guī)劃,文章圍繞Python最大連續(xù)區(qū)間和動態(tài)規(guī)劃的相關(guān)資料展開內(nèi)容,需要的小伙伴可以參考一下2022-01-01
python機器學習Sklearn實戰(zhàn)adaboost算法示例詳解
這篇文章主要為大家介紹了python機器學習Sklearn實戰(zhàn)adaboost算法的示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步早日升職加薪2021-11-11

