Python如何加載模型并查看網(wǎng)絡(luò)
加載模型并查看網(wǎng)絡(luò)
加載模型,以vgg19為例。
打開終端
> python Python 3.7.2 (tags/v3.7.2:9a3ffc0492, Dec 23 2018, 23:09:28) [MSC v.1916 64 bit (AMD64)] on win32 Type "help", "copyright", "credits" or "license" for more information. >>> from torchvision import models >>> model = models.vgg19(pretrained=True) #此時(shí)如果是第一次加載會(huì)開始下載模型的pth文件 >>> print(model.model)
結(jié)果:
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace)
(16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(17): ReLU(inplace)
(18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace)
(23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(24): ReLU(inplace)
(25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(26): ReLU(inplace)
(27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace)
(30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(31): ReLU(inplace)
(32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(33): ReLU(inplace)
(34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(35): ReLU(inplace)
(36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace)
(2): Dropout(p=0.5)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace)
(5): Dropout(p=0.5)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
注意,直接打印模型是沒有辦法看到模型結(jié)構(gòu)的,只能看到帶模型參數(shù)的pth文件內(nèi)容;需要打印model.model才可以看到模型本身。
神經(jīng)網(wǎng)絡(luò)_模型的保存,模型的加載
模型的保存(torch.save)
方式1(模型結(jié)構(gòu)+模型參數(shù))
參數(shù):保存位置
# 創(chuàng)建模型 vgg16 = torchvision.models.vgg16(pretrained=False) # 保存方式1——模型結(jié)構(gòu)+模型參數(shù) torch.save(vgg16, "vgg16_method1.pth")
方式2(模型參數(shù))
# 保存方式2 ?模型參數(shù)(官方推薦)。保存成字典,只保存網(wǎng)絡(luò)模型中的一些參數(shù) torch.save(vgg16.state_dict(), "vgg16_method2.pth")
模型的加載(torch.load)
對(duì)應(yīng)保存方式1
參數(shù):模型路徑
# 方式1 --》 保存方式1 model1 = torch.load("vgg16_method1.pth")
對(duì)應(yīng)保存方式2
vgg16.load_state_dict("vgg16_method2.pth")
輸出為字典形式。若要回復(fù)網(wǎng)絡(luò),采用以下形式:
model2 = torch.load("vgg16_method2.pth") ?#輸出是字典形式 # 恢復(fù)網(wǎng)絡(luò)結(jié)構(gòu) vgg16 = torchvision.models.vgg16(pretrained=False) vgg16.load_state_dict(model2)
方式1存儲(chǔ),加載時(shí)需注意事項(xiàng)
新建自己的網(wǎng)絡(luò):
class test(nn.Module): ? ? def __init__(self): ? ? ? ? super(lh, self).__init__() ? ? ? ? self.conv1 = nn.Conv2d(3, 64, kernel_size=3) ? ? def forward(self, x): ? ? ? ? x = self.conv1(x) ? ? ? ? return x
保存自己的網(wǎng)絡(luò):
Test = test() # 保存自己定義的網(wǎng)絡(luò) torch.save(Test, "Test_method1.pth")
加載自己的網(wǎng)絡(luò):
model3 = torch.load("Test_method1.pth")
會(huì)報(bào)錯(cuò)?。。。。?!
解決辦法(需要注意):
將定義的網(wǎng)絡(luò)復(fù)制到加載的python文件中:
class test(nn.Module): def __init__(self): super(test, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=3) def forward(self, x): x = self.conv1(x) return x model3 = torch.load("Test_method1.pth")
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
python 字典(dict)遍歷的四種方法性能測(cè)試報(bào)告
本文主要是針對(duì)Python的字典dict遍歷的4種方法進(jìn)行了性能測(cè)試,以便分析得出效率最高的一種方法2014-06-06基于MATLAB和Python實(shí)現(xiàn)MFCC特征參數(shù)提取
這篇文章主要介紹了基于MATLAB和Python實(shí)現(xiàn)MFCC特征參數(shù)提取,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-08-08