欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch模型的保存加載與續(xù)訓(xùn)練詳解

 更新時間:2022年11月10日 09:54:05   作者:禿頭小蘇  
這篇文章主要為大家介紹了pytorch模型的保存加載與續(xù)訓(xùn)練詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪

前面

最近,看到不少小伙伴問pytorch如何保存和加載模型,其實這部分pytorch官網(wǎng)介紹的也是很清楚的,感興趣的點擊了解詳情??????

但是肯定有很多人是不愿意看官網(wǎng)的,所以我還是花一篇文章來為大家介紹介紹。當(dāng)然了,在介紹中我會加入自己的一些理解,讓大家有一個更深的認(rèn)識。如果準(zhǔn)備好了的話,就讓我們開始吧。???

模型保存與加載

pytorch中介紹了幾種不同的模型保存和加載方式,我會在下文一一為大家介紹。首先先讓我們來隨便定義一個模型,如下:【用的是pytorch官網(wǎng)的例子】

# 模型定義
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

定義好模型結(jié)構(gòu)后,我們可以實例化這個模型:

#模型初始化
model = TheModelClass()

模型初始化過后,我們就一起來看看模型保存和加載的方式吧。??????

方式1

方式1是官方推薦的一種方式,我們直接來看代碼好了,如下:

# 保存模型
torch.save(model.state_dict(), './model/model_state_dict.pth')

該方法后面的參數(shù)'./model/model_state_dict.pth'為模型的保存路徑,模型后綴名官方推薦使用.pth.pt,當(dāng)然了,你取別的后綴名也是完全可行的。???

介紹了模型的保存,下面就來看看方式1是如何加載模型的。【這里我說明一點,模型保存往往是在訓(xùn)練中進(jìn)行的,而模型加載多數(shù)用在模型推理中,它們存在兩個文件中,故我們在推理過程中要先實列化模型】

# 加載模型
model_test1 = TheModelClass()   # 加載模型時應(yīng)先實例化模型
# load_state_dict()函數(shù)接收一個字典,所以不能直接將'./model/model_state_dict.pth'傳入,而是先使用load函數(shù)將保存的模型參數(shù)反序列化
model_test1.load_state_dict(torch.load('./model/model_state_dict.pth'))
model_test1.eval()    # 模型推理時設(shè)置

在上述的代碼注釋中我有寫到,我們使用load_state_dict()加載模型時先需要使用load方法將保存的模型參數(shù)==反序列化==,load后的結(jié)果是一個字典,這時就可以通過load_state_dict()方法來加載了。

這里我來簡單說一下我理解的反序列化,其和序列化是相對應(yīng)的一個概念。序列化就是把內(nèi)存中的數(shù)據(jù)保存到磁盤中,像我們使用torch.save()方法保存模型就是序列化;而反序列化則是將硬盤中的數(shù)據(jù)加載到內(nèi)存當(dāng)中,顯然我們加載模型的過程就是反序列化過程。【大致的意思如下圖所示,偶然在水群的時候看到一個畫圖軟件,是不是還挺好看的??????】

方式2

方式2非常簡單,直接上代碼:

# 保存模型
torch.save(model, './model/model.pt')    #這里我們保存模型的后綴名取.pt
# 加載模型
model_test2 = torch.load('./model/model.pt')     
model_test2.eval()   # 模型推理時設(shè)置

但是這種方式是不推薦使用的,因為你使用這種方式保存模型,然后再加載時會遇到各種各樣的錯誤。為了加深大家理解,我們來看這樣的一個例子。文件的結(jié)構(gòu)如下圖所示:

models.py文件中存儲的是模型的定義,其位于文件夾models下。save_model.py文件中寫的是保存模型的代碼,如下:

from models.models import TheModelClass
from torch import optim
import torch
#模型初始化
model = TheModelClass()
# 初始化優(yōu)化器
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# ## 保存加載方式2——save/load
# # 保存模型
# torch.save(models, './models/models.pt')

執(zhí)行此文件后,會生成models.pt文件,我們在執(zhí)行load_mode.py文件即可實現(xiàn)加載,load_mode.py內(nèi)容如下:

from models.models import TheModelClass
import torch
## 加載方式2
# 加載模型
model_test2 = TheModelClass()
model_test2 = torch.load('./models/models.pt')     
model_test2.eval()   # 模型推理時設(shè)置
print(model_test2)

此時我們可以正常加載。但如果我們將models文件夾修改為model,如下:

此時我們在使用如下代碼加載模型的話就會出現(xiàn)錯誤:

from models.models import TheModelClass
import torch
## 加載方式2
# 加載模型
model_test2 = TheModelClass()
model_test2 = torch.load('./model/models.pt')     #這里需要修改一下文件路徑  
model_test2.eval()   # 模型推理時設(shè)置
print(model_test2)

出現(xiàn)這種錯誤的原因是使用方式2進(jìn)行模型保存的時候會把模型結(jié)構(gòu)定義文件路徑記錄下來,加載的時候就會根據(jù)路徑解析它然后裝載參數(shù);當(dāng)把模型定義文件路徑修改以后,使用torch.load(path)就會報錯。

其實使用方式2進(jìn)行模型的保存和加載還會存在各種問題,感興趣的可以看看這篇博文??傊?,在我們今后的使用中,盡量不要用方式2來加載模型。??????

方式3

pytorch還為我們提供了一種模型保存與加載的方式——checkpoint。這種方式保存的是一個字典,如果我們程序在運行中由于某種原因異常中止,那么這種方式可以很方便的讓我們接著上次訓(xùn)練,正因為這樣,我非常推薦大家使用這種方式進(jìn)行模型的保存與加載。下面就讓我們一起來看看方式3是如何使用的吧?。?!??????

首先,我們同樣使用torch.save來保存模型,但是這里保存的是一個字典,里面可以填入你需要保存的參數(shù),如下:

# 保存checkpoint
torch.save({
            'epoch':epoch,
            'model_state_dict':model.state_dict(),
            'optimizer_state_dict':optimizer.state_dict(),
            'loss':loss
            }, './model/model_checkpoint.tar'    #這里的后綴名官方推薦使用.tar
            )

接著我們來看看如何加載checkpoint,代碼如下:

# 加載checkpoint
model_checkpoint = TheModelClass()
optimizer =  optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
checkpoint = torch.load('./model/model_checkpoint.tar')    # 先反序列化模型
model_checkpoint.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

看了我上文的介紹,大家是否知道如何使用checkpoint了呢,我想大家都會覺得這個不是很難,但要自己寫可能還是不好把握,那么第一次就讓我來帶領(lǐng)大家看看如何在代碼中使用checkpoint吧?。?!??????

這節(jié)我采用cifar10數(shù)據(jù)集實現(xiàn)物體分類的例子,我的這篇博文對其進(jìn)行了詳細(xì)介紹,那么這里介紹checkpoint我將利用這個demo來為大家講解。首先我們直接來看模型保存的完整代碼,如下:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#1、準(zhǔn)備數(shù)據(jù)集
train_dataset = torchvision.datasets.CIFAR10("./data", train=True, transform=torchvision.transforms.ToTensor(), download= True)
test_dataset = torchvision.datasets.CIFAR10("./data", train=False, transform=torchvision.transforms.ToTensor(), download= True)
#2、加載數(shù)據(jù)集
train_dataset_loader = DataLoader(dataset=train_dataset, batch_size=100)
test_dataset_loader = DataLoader(dataset=test_dataset, batch_size=100)
#3、搭建神經(jīng)網(wǎng)絡(luò)
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.model1 = nn.Sequential(
            nn.Conv2d(3, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(1024, 64),
            nn.Linear(64, 10)
        )
    def forward(self, input):
        input = self.model1(input)
        return input
#4、創(chuàng)建網(wǎng)絡(luò)模型
net = Net()
#5、設(shè)置損失函數(shù)、優(yōu)化器
#損失函數(shù)
loss_fun = nn.CrossEntropyLoss()   #交叉熵
loss_fun = loss_fun.to(device)
#優(yōu)化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(net.parameters(), learning_rate)   #SGD:梯度下降算法
#6、設(shè)置網(wǎng)絡(luò)訓(xùn)練中的一些參數(shù)
total_train_step = 0   #記錄總計訓(xùn)練次數(shù)
total_test_step = 0    #記錄總計測試次數(shù)
Max_epoch = 10    #設(shè)計訓(xùn)練輪數(shù)
#7、開始進(jìn)行訓(xùn)練
for epoch in range(Max_epoch):
    print("---第{}輪訓(xùn)練開始---".format(epoch))
    net.train()     #開始訓(xùn)練,不是必須的,在網(wǎng)絡(luò)中有BN,dropout時需要
    #由于訓(xùn)練集數(shù)據(jù)較多,這里我沒用訓(xùn)練集訓(xùn)練,而是采用測試集(test_dataset_loader)當(dāng)訓(xùn)練集,但思想是一致的
    for data in test_dataset_loader:      
        imgs, targets = data
        targets = targets.to(device)
        outputs = net(imgs)
        #比較輸出與真實值,計算Loss
        loss = loss_fun(outputs, targets)
        #反向傳播,調(diào)整參數(shù)
        optimizer.zero_grad()    #每次讓梯度重置
        loss.backward()
        optimizer.step()
        total_train_step += 1
        if total_train_step % 50 == 0:
            print("---第{}次訓(xùn)練結(jié)束, Loss:{})".format(total_train_step, loss.item()))
    if (epoch+1) % 2 == 0:
        # 保存checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
        }, './model/model_checkpoint_epoch_{}.tar'.format(epoch)  # 這里的后綴名官方推薦使用.tar
        )
    if epoch > 5:
        print("---意外中斷---")
        break

整個流程和這篇文章基本一致,不清楚的建議先花幾分鐘閱讀一下哈。??????主要區(qū)別就是在最后保存模型的時候我使用了checkpoint進(jìn)行保存,且兩個epoch保存一次。當(dāng)epoch=6時,我設(shè)置了一個break模擬程序意外中斷,中斷后可以來看一下終端的輸出信息,如下圖所示:

我們可以看到在進(jìn)行第6輪循環(huán)時,程序中斷了,此時最新的保存的模型是第五次訓(xùn)練結(jié)果,如下:

同時注意到第5次訓(xùn)練結(jié)束的loss在2.0左右,如果我們下次接著訓(xùn)練,損失應(yīng)該是在2.0附近。??????

好了,上面由于一些糟糕的原因?qū)е鲁绦蛑袛嗔?,現(xiàn)在我想接著上次訓(xùn)練的結(jié)果繼續(xù)訓(xùn)練,我該怎么辦呢?代碼如下:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#1、準(zhǔn)備數(shù)據(jù)集
train_dataset = torchvision.datasets.CIFAR10("./data", train=True, transform=torchvision.transforms.ToTensor(), download= True)
test_dataset = torchvision.datasets.CIFAR10("./data", train=False, transform=torchvision.transforms.ToTensor(), download= True)
#2、加載數(shù)據(jù)集
train_dataset_loader = DataLoader(dataset=train_dataset, batch_size=100)
test_dataset_loader = DataLoader(dataset=test_dataset, batch_size=100)
#3、搭建神經(jīng)網(wǎng)絡(luò)
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.model1 = nn.Sequential(
            nn.Conv2d(3, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(1024, 64),
            nn.Linear(64, 10)
        )
    def forward(self, input):
        input = self.model1(input)
        return input
#4、創(chuàng)建網(wǎng)絡(luò)模型
net = Net()
#5、設(shè)置損失函數(shù)、優(yōu)化器
#損失函數(shù)
loss_fun = nn.CrossEntropyLoss()   #交叉熵
loss_fun = loss_fun.to(device)
#優(yōu)化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(net.parameters(), learning_rate)   #SGD:梯度下降算法
#6、設(shè)置網(wǎng)絡(luò)訓(xùn)練中的一些參數(shù)
total_train_step = 0   #記錄總計訓(xùn)練次數(shù)
total_test_step = 0    #記錄總計測試次數(shù)
Max_epoch = 10    #設(shè)計訓(xùn)練輪數(shù)
##########################################################################################
# 加載checkpoint
checkpoint = torch.load('./model/model_checkpoint_epoch_5.tar')    # 先反序列化模型
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
loss = checkpoint['loss']
##########################################################################################
#7、開始進(jìn)行訓(xùn)練
for epoch in range(start_epoch+1, Max_epoch):
    print("---第{}輪訓(xùn)練開始---".format(epoch))
    net.train()     #開始訓(xùn)練,不是必須的,在網(wǎng)絡(luò)中有BN,dropout時需要
    for data in test_dataset_loader:
        imgs, targets = data
        targets = targets.to(device)
        outputs = net(imgs)
        #比較輸出與真實值,計算Loss
        loss = loss_fun(outputs, targets)
        #反向傳播,調(diào)整參數(shù)
        optimizer.zero_grad()    #每次讓梯度重置
        loss.backward()
        optimizer.step()
        total_train_step += 1
        if total_train_step % 50 == 0:
            print("---第{}次訓(xùn)練結(jié)束, Loss:{})".format(total_train_step, loss.item()))
    if (epoch+1) % 2 == 0:
        # 保存checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
        }, './model/model_checkpoint_epoch_{}.tar'.format(epoch)  # 這里的后綴名官方推薦使用.tar
        )

這里的代碼相較之前的多了一個加載checkpoint的過程,我將其截取出來,如下圖所示:

通過加載checkpoint我們就保存了之前訓(xùn)練的參數(shù),進(jìn)而實現(xiàn)斷點續(xù)訓(xùn)練,我們直接來看執(zhí)行此代碼的結(jié)果,如下圖所示:

從上圖可以看出我們的訓(xùn)練是從第6輪開始的,并且初始的loss為1.99,和2.0接近。這就說明了我們已經(jīng)實現(xiàn)了中斷后恢復(fù)訓(xùn)練的操作。

????????????????????????????????????????

這里我簡單的說兩句,上文介紹checkpoint的用法時,訓(xùn)練中斷和訓(xùn)練恢復(fù)我是放在兩個文件中的進(jìn)行的,但是在實際中我們肯定是在一個文件中運行,那這該怎么辦呢?其實方法很簡單啦,我們只需要設(shè)置一個if條件將加載checkpoint的部分放在訓(xùn)練文件中,然后設(shè)置一個參數(shù)來控制if條件的執(zhí)行即可。具體細(xì)節(jié)我就不給大家介紹了,如果有不明白的評論區(qū)見吧?。?!????????

????????????????????????????????????????

總結(jié)

這部分還是蠻簡單的,但一些細(xì)節(jié)還是需要大家自行考量,我就為大家介紹到這里啦,希望大家都能夠有所收獲吧

以上就是pytorch模型的保存加載與續(xù)訓(xùn)練詳解的詳細(xì)內(nèi)容,更多關(guān)于pytorch模型保存加載續(xù)訓(xùn)練的資料請關(guān)注腳本之家其它相關(guān)文章!

相關(guān)文章

  • wxPython電子表格功能wx.grid實例教程

    wxPython電子表格功能wx.grid實例教程

    這篇文章主要介紹了wxPython電子表格功能wx.grid實例教程,文中示例代碼介紹的非常詳細(xì),具有一定的參考價值,感興趣的小伙伴們可以參考一下
    2019-11-11
  • Python-VTK隱式函數(shù)屬性選擇和剪切數(shù)據(jù)

    Python-VTK隱式函數(shù)屬性選擇和剪切數(shù)據(jù)

    這篇文章主要介紹了Python-VTK隱式函數(shù)屬性選擇和剪切數(shù)據(jù),VTK,是一個開放資源的免費軟件系統(tǒng),主要用于三維計算機(jī)圖形學(xué)、圖像處理和可視化,下面文章主題相關(guān)詳細(xì)內(nèi)容需要的小伙伴可以參考一下
    2022-04-04
  • Python struct.unpack

    Python struct.unpack

    Python中按一定的格式取出某字符串中的子字符串,使用struck.unpack是非常高效的。
    2008-09-09
  • 對python 生成拼接xml報文的示例詳解

    對python 生成拼接xml報文的示例詳解

    今天小編就為大家分享一篇對python 生成拼接xml報文的示例詳解,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2018-12-12
  • Python日志:自定義輸出字段 json格式輸出方式

    Python日志:自定義輸出字段 json格式輸出方式

    這篇文章主要介紹了Python日志:自定義輸出字段 json格式輸出方式,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2020-04-04
  • pandas DataFrame的修改方法(值、列、索引)

    pandas DataFrame的修改方法(值、列、索引)

    這篇文章主要介紹了pandas DataFrame的修改方法(值、列、索引),文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2019-08-08
  • Python 使用 prettytable 庫打印表格美化輸出功能

    Python 使用 prettytable 庫打印表格美化輸出功能

    這篇文章主要介紹了Python 使用 prettytable 庫打印表格美化輸出功能,本文通過實例代碼給大家介紹的非常詳細(xì),具有一定的參考借鑒價值,需要的朋友可以參考下
    2019-12-12
  • 如何將python代碼打包成pip包(可以pip?install)

    如何將python代碼打包成pip包(可以pip?install)

    這篇文章主要介紹了如何將python代碼打包成pip包(可以pip?install),文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2023-02-02
  • Django重置migrations文件的方法步驟

    Django重置migrations文件的方法步驟

    這篇文章主要介紹了Django重置migrations文件的方法步驟,小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧
    2019-05-05
  • python 3.8.3 安裝配置圖文教程

    python 3.8.3 安裝配置圖文教程

    這篇文章主要為大家詳細(xì)介紹了python 3.8.3 安裝配置圖文教程,文中安裝步驟介紹的非常詳細(xì),具有一定的參考價值,感興趣的小伙伴們可以參考一下
    2020-05-05

最新評論