欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch加載預(yù)訓(xùn)練模型與自己模型不匹配的解決方案

 更新時(shí)間:2021年05月13日 16:33:18   作者:找不到服務(wù)器1703  
這篇文章主要介紹了pytorch加載預(yù)訓(xùn)練模型與自己模型不匹配的解決方案,具有很好的參考價(jià)值,希望對大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教

pytorch中如果自己搭建網(wǎng)絡(luò)并且加載別人的與訓(xùn)練模型的話,如果模型和參數(shù)不嚴(yán)格匹配,就可能會(huì)出問題,接下來記錄一下我的解決方法。

兩個(gè)有序字典找不同

模型的參數(shù)和pth文件的參數(shù)都是有序字典(OrderedDict),把字典中的鍵轉(zhuǎn)為列表就可以在for循環(huán)里迭代找不同了。

model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        err = 1

自己搭建模型的注意事項(xiàng)

搭網(wǎng)絡(luò)時(shí)要對照pth文件的字典順序搭,字典順序、權(quán)重尺寸(shape)和變量命名必須與pth文件完全一致。如果僅僅是變量命名不同,可采用類似的方法對模型的權(quán)重重新賦值。

model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        continue
    model_dict1[model_list1[n]] = model_dict2[model_list2[n]]
model.load_state_dict(model_dict2)

完整的代碼見自己搭建resnet18網(wǎng)絡(luò)并加載torchvision自帶權(quán)重

新增的改進(jìn)代碼

model_dict1 = torch.load('yolov5.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
m, n = 0, 0
while True:
    if m >= len1 or n >= len2:
        break
    layername1, layername2 = model_list1[m], model_list2[n]
    w1, w2 = model_dict1[layername1], model_dict2[layername2]
    if w1.shape != w2.shape:
        continue
    model_dict2[layername2] = model_dict1[layername1]
    m += 1
    n += 1
model.load_state_dict(model_dict2)

如果因?yàn)槟P筒黄ヅ?,運(yùn)行第14行語句后,可看自己情況手動(dòng)對m或n加上1。

補(bǔ)充:pytorch的一些坑:用預(yù)訓(xùn)練的vgg模型的部分層的特征報(bào)錯(cuò),如張量不匹配

看代碼吧~

#打算取VGG19的第二個(gè)全連接層的輸出,那么就需要構(gòu)建一個(gè)類,這個(gè)類要包含VGG的全部卷積層,
#以及到第二個(gè)全連接層的全部網(wǎng)絡(luò)還有他們對應(yīng)的參數(shù)
class Classification_att(nn.Module):
    def __init__(self, rgb_range):
        super(Classification_att, self).__init__()
        self.vgg19 =models.vgg19(pretrained=True)
        vgg = models.vgg19(pretrained=True).features
        conv_modules = [m for m in vgg]
        self.vgg_conv = nn.Sequential(*conv_modules[:37])
        classfi = models.vgg19(pretrained=True).classifier
        classif_modules = [n for n in classfi]
        self.vgg_class = nn.Sequential(*classif_modules[:4])
        vgg_mean = (0.485, 0.456, 0.406)
        vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
        self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
        for p in self.vgg_conv.parameters():
            p.requires_grad = False
        for p in self.vgg_class.parameters():
            p.requires_grad = False
        self.classifi = nn.Sequential(
            nn.Linear(4096, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 256),
            nn.ReLU(True),
            nn.Linear(256, 64),
        )
 
    def forward(self, x):
        x = F.interpolate(x, size=[224, 224], scale_factor=None, mode='bilinear', 
        align_corners=False)
        x = self.sub_mean(x)
        x = self.vgg_conv(x)  
        x = self.vgg_class(x)  #執(zhí)行這部報(bào)錯(cuò),說張量不匹配

原因是因?yàn)榫矸e層的輸出不能直接連接全連接層,即使輸出的張量的總的大小是一致的

查看vgg的pytorch源碼發(fā)現(xiàn)是

x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
#自己的代碼沒有torch.flatten(x, 1)這步

所以自己的少了一步

x = torch.flatten(x, 1)

補(bǔ)上就好了!

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • Python調(diào)用PC攝像頭實(shí)現(xiàn)掃描二維碼

    Python調(diào)用PC攝像頭實(shí)現(xiàn)掃描二維碼

    PC攝像機(jī)掃描二維碼的應(yīng)用場景很廣泛,可以應(yīng)用于各種需要快速掃描、識(shí)別和管理的場景,本文就來具體講講如何用Python實(shí)現(xiàn)這一功能吧
    2023-05-05
  • PyTorch中 tensor.detach() 和 tensor.data 的區(qū)別詳解

    PyTorch中 tensor.detach() 和 tensor.data 的區(qū)別詳解

    今天小編就為大家分享一篇PyTorch中 tensor.detach() 和 tensor.data 的區(qū)別詳解,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2020-01-01
  • Python中使用?zipfile創(chuàng)建文件壓縮工具

    Python中使用?zipfile創(chuàng)建文件壓縮工具

    這篇文章主要介紹了Python中使用zipfile創(chuàng)建文件壓縮工具,通過使用 wxPython 模塊,我們創(chuàng)建了一個(gè)簡單而實(shí)用的文件壓縮工具,本文結(jié)合實(shí)例代碼給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的ca參考借鑒價(jià)值,需要的朋友可以參考下
    2023-09-09
  • Pygame實(shí)戰(zhàn)練習(xí)之炸彈人學(xué)院游戲

    Pygame實(shí)戰(zhàn)練習(xí)之炸彈人學(xué)院游戲

    炸彈人學(xué)院想必是很多人童年時(shí)期的經(jīng)典游戲,我們依舊能記得抱個(gè)老人機(jī)娛樂的場景,下面這篇文章主要給大家介紹了關(guān)于如何利用python寫一個(gè)簡單的炸彈人學(xué)院小游戲的相關(guān)資料,需要的朋友可以參考下
    2021-09-09
  • 使用python tkinter實(shí)現(xiàn)各種個(gè)樣的撩妹鼠標(biāo)拖尾效果

    使用python tkinter實(shí)現(xiàn)各種個(gè)樣的撩妹鼠標(biāo)拖尾效果

    這篇文章主要介紹了使用python tkinter實(shí)現(xiàn)各種個(gè)樣的撩妹鼠標(biāo)拖尾效果,本文通過實(shí)例代碼,給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2021-09-09
  • Python中Tkinter組件Button的具體使用

    Python中Tkinter組件Button的具體使用

    Button=組件用于實(shí)現(xiàn)各種各樣的按鈕,本文主要介紹了Python中Tkinter組件Button的具體使用,具有一定的參考價(jià)值,感興趣的可以了解一下
    2022-01-01
  • 把django中admin后臺(tái)界面的英文修改為中文顯示的方法

    把django中admin后臺(tái)界面的英文修改為中文顯示的方法

    今天小編就為大家分享一篇把django中admin后臺(tái)界面的英文修改為中文顯示的方法,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2019-07-07
  • python?print無法打印\r的問題及解決

    python?print無法打印\r的問題及解決

    這篇文章主要介紹了python?print無法打印\r的問題及解決方案,具有很好的參考價(jià)值,希望對大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教
    2023-08-08
  • python?字符串常用方法超詳細(xì)梳理總結(jié)

    python?字符串常用方法超詳細(xì)梳理總結(jié)

    字符串是Python中基本的數(shù)據(jù)類型,幾乎在每個(gè)Python程序中都會(huì)使用到它。本文為大家總結(jié)了Python中必備的31個(gè)字符串方法,需要的可以參考一下
    2022-03-03
  • Python如何急速下載第三方庫詳解

    Python如何急速下載第三方庫詳解

    這篇文章主要給大家介紹了關(guān)于Python如何急速下載第三方庫的相關(guān)資料,文中通過圖文介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2020-11-11

最新評論