欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Pytorch模型的保存/復(fù)用/遷移實(shí)現(xiàn)代碼

 更新時(shí)間:2023年05月05日 10:32:07   作者:信海  
本文整理了Pytorch框架下模型的保存、復(fù)用、推理、再訓(xùn)練和遷移等實(shí)現(xiàn),本文通過(guò)實(shí)例代碼給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下

本文整理了Pytorch框架下模型的保存、復(fù)用、推理、再訓(xùn)練和遷移等實(shí)現(xiàn)。

模型的保存與復(fù)用

模型定義和參數(shù)打印

# 定義模型結(jié)構(gòu)
class LenNet(nn.Module):
    def __init__(self):
        super(LenNet, self).__init__()
        self.conv = nn.Sequential(  # [batch, 1, 28, 28]
            nn.Conv2d(1, 8, 5, 2),  # [batch, 1, 28, 28]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [batch, 8, 14, 14]
            nn.Conv2d(8, 16, 5),  # [batch, 16, 10, 10]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [batch, 16, 5, 5]
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(16*5*5, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 10)
        )
    def forward(self, X):
        return self.fc(self.conv(X))
# 查看模型參數(shù)
# 網(wǎng)絡(luò)模型中的參數(shù)model.state_dict()是以字典形式保存(實(shí)質(zhì)上是collections模塊中的OrderedDict)
model = LenNet()
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())
# 參數(shù)名中的fc和conv前綴是根據(jù)定義nn.Sequential()時(shí)的名字所確定。
# 參數(shù)名中的數(shù)字表示每個(gè)Sequential()中網(wǎng)絡(luò)層所在的位置。
print(model.state_dict().keys())  # 打印鍵
print(model.state_dict().values())  # 打印值
# 優(yōu)化器optimizer的參數(shù)打印類似
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
   print(var_name, "\t", optimizer.state_dict()[var_name])

模型保存

import os
# 指定保存的模型名稱時(shí)Pytorch官方建議的后綴為.pt或者.pth
model_save_dir = './model_logs/'
model_save_path = os.path.join(model_save_dir, 'LeNet.pt')
torch.save(model.state_dict(), model_save_path)
# 在訓(xùn)練過(guò)程中保存某個(gè)條件下的最優(yōu)模型,可以如下操作
best_model_state = deepcopy(model.state_dict()) 
torch.save(best_model_state, model_save_path)
# 下面這種方法是錯(cuò)誤的,因?yàn)閎est_model_state只是model.state_dict()的引用,會(huì)隨著訓(xùn)練的改變而改變
best_model_state = model.state_dict() 
torch.save(best_model_state, model_save_path)

模型推理

def inference(data_iter, device, model_save_dir):
	model = LeNet()  # 初始化現(xiàn)有模型的權(quán)重參數(shù)
    model.to(device)
    model_save_path = os.path.join(model_save_dir, 'LeNet.pt')
    # 如果本地存在模型,則加載本地模型參數(shù)覆蓋原有模型
    if os.path.exists(model_save_path): 
        loaded_paras = torch.load(model_save_path)
        model.load_state_dict(loaded_paras)
        model.eval()
    with torch.no_grad():  # 開(kāi)始推理
        acc_sum, n = 0., 0
        for x, y in data_iter:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            acc_sum += (logits.argmax(1) == y).float().sum().item()
            n += len(y)
        print("Accuracy in test data is : ", acc_sum / n)

模型再訓(xùn)練

class MyModel:
    def __init__(self,
                 batch_size=64,
                 epochs=5,
                 learning_rate=0.001,
                 model_save_dir='./MODEL'):
        self.batch_size = batch_size
        self.epochs = epochs
        self.learning_rate = learning_rate
        self.model_save_dir = model_save_dir
        self.model = LeNet()
    def train(self):
        train_iter, test_iter = load_dataset(self.batch_size)
        # 在訓(xùn)練過(guò)程中只保存網(wǎng)絡(luò)權(quán)重,在再訓(xùn)練時(shí)只載入網(wǎng)絡(luò)權(quán)重參數(shù)初始化網(wǎng)絡(luò)訓(xùn)練。這里是核心部分,開(kāi)始。
        if not os.path.exists(self.model_save_dir):
            os.makedirs(self.model_save_dir)
        model_save_path = os.path.join(self.model_save_dir, 'model.pt')
        if os.path.exists(model_save_path):
            loaded_paras = torch.load(model_save_path)
            self.model.load_state_dict(loaded_paras)
            print("#### 成功載入已有模型,進(jìn)行再訓(xùn)練...")
        # 結(jié)束  
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)  
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(device)
        for epoch in range(self.epochs):
            for i, (x, y) in enumerate(train_iter):
                x, y = x.to(device), y.to(device)
                loss, logits = self.model(x)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()  
                if i % 100 == 0:
                    acc = (logits.argmax(1) == y).float().mean()
                    print("Epochs[{}/{}]---batch[{}/{}]---acc {:.4}---loss {:.4}".format(
                        epoch, self.epochs, len(train_iter), i, acc, loss.item()))
            print("Epochs[{}/{}]--acc on test {:.4}".format(epoch, self.epochs,
                                                            self.evaluate(test_iter, self.model, device)))
            torch.save(self.model.state_dict(), model_save_path)
    @staticmethod
    def evaluate(data_iter, model, device):
        with torch.no_grad():
            acc_sum, n = 0.0, 0
            for x, y in data_iter:
                x, y = x.to(device), y.to(device)
                logits = model(x)
                acc_sum += (logits.argmax(1) == y).float().sum().item()
                n += len(y)
            return acc_sum / n
# 在保存參數(shù)的時(shí)候,將優(yōu)化器參數(shù)、損失值等可一同保存,然后在恢復(fù)模型時(shí)連同其它參數(shù)一起恢復(fù)
model_save_path = os.path.join(model_save_dir, 'LeNet.pt')
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, model_save_path)
# 加載方式如下
checkpoint = torch.load(model_save_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

模型遷移

# 定義新模型NewLeNet 和LeNet區(qū)別在于新增了一個(gè)全連接層
class NewLenNet(nn.Module):
    def __init__(self):
        super(NewLenNet, self).__init__()
        self.conv = nn.Sequential(  # [batch, 1, 28, 28]
            nn.Conv2d(1, 8, 5, 2),  # [batch, 1, 28, 28]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [batch, 8, 14, 14]
            nn.Conv2d(8, 16, 5),  # [batch, 16, 10, 10]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [batch, 16, 5, 5]
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(16*5*5, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 64), # 這層以前和LeNet結(jié)構(gòu)一致 可以用LeNet的參數(shù)來(lái)進(jìn)行替換
            nn.ReLU(inplace=True),
            nn.Linear(64, 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 10)
        )
    def forward(self, X):
        return self.fc(self.conv(X))
# 定義替換函數(shù) 匹配兩個(gè)網(wǎng)絡(luò) size相同處地方進(jìn)行參數(shù)替換
def para_state_dict(model, model_save_dir):
    state_dict = deepcopy(model.state_dict())
    model_save_path = os.path.join(model_save_dir, 'model.pt')
    if os.path.exists(model_save_path):
        loaded_paras = torch.load(model_save_path)
        for key in state_dict:  # 在新的網(wǎng)絡(luò)模型中遍歷對(duì)應(yīng)參數(shù)
            if key in loaded_paras and state_dict[key].size() == loaded_paras[key].size():
                print("成功初始化參數(shù):", key)
                state_dict[key] = loaded_paras[key]
    return state_dict
# 更新一下模型遷移后的訓(xùn)練代碼
def train(self):
        train_iter, test_iter = load_dataset(self.batch_size)
        if not os.path.exists(self.model_save_dir):
            os.makedirs(self.model_save_dir)
        model_save_path = os.path.join(self.model_save_dir, 'model_new.pt')
        old_model = os.path.join(self.model_save_dir, 'LeNet.pt')
        if os.path.exists(old_model):
            state_dict = para_state_dict(self.model, self.model_save_dir)  # 調(diào)用遷移代碼 將LeNet的前幾層參數(shù)遷移到NewLeNet
            self.model.load_state_dict(state_dict)
            print("#### 成功載入已有模型,進(jìn)行再訓(xùn)練...")
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)  
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(device)
        for epoch in range(self.epochs):
            for i, (x, y) in enumerate(train_iter):
                x, y = x.to(device), y.to(device)
                loss, logits = self.model(x)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()  
                if i % 100 == 0:
                    acc = (logits.argmax(1) == y).float().mean()
                    print("Epochs[{}/{}]---batch[{}/{}]---acc {:.4}---loss {:.4}".format(
                        epoch, self.epochs, len(train_iter), i, acc, loss.item()))
            print("Epochs[{}/{}]--acc on test {:.4}".format(epoch, self.epochs,
                                                            self.evaluate(test_iter, self.model, device)))
            torch.save(self.model.state_dict(), model_save_path)
# 這里更新未進(jìn)行訓(xùn)練的推理
def inference(data_iter, device, model_save_dir='./MODEL'):
    model = NewLeNet()  # 初始化現(xiàn)有模型的權(quán)重參數(shù)
    print("初始化參數(shù) conv.0.bias 為:", model.state_dict()['conv.0.bias'])
    model.to(device)
    state_dict = para_state_dict(model, model_save_dir) # 遷移模型參數(shù)
    model.load_state_dict(state_dict)
    model.eval()
    print("載入本地模型重新初始化 conv.0.bias 為:", model.state_dict()['conv.0.bias'])
    with torch.no_grad():
        acc_sum, n = 0.0, 0
        for x, y in data_iter:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            acc_sum += (logits.argmax(1) == y).float().sum().item()
            n += len(y)
        print("Accuracy in test data is :", acc_sum / n)

參考文獻(xiàn)

[1] https://github.com/moon-hotel/DeepLearningWithMe

到此這篇關(guān)于Pytorch模型的保存/復(fù)用/遷移的文章就介紹到這了,更多相關(guān)Pytorch模型保存遷移內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

  • 如何在Python?中獲取單成員集合中的唯一元素

    如何在Python?中獲取單成員集合中的唯一元素

    這篇文章主要介紹了如何在Python?中獲取單成員集合中的唯一元素,本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2023-03-03
  • python圖像填充與裁剪/resize的實(shí)現(xiàn)代碼

    python圖像填充與裁剪/resize的實(shí)現(xiàn)代碼

    這篇文章主要介紹了python圖像填充與裁剪/resize,本文通過(guò)示例代碼給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2022-08-08
  • pandas中DataFrame新增行及global變量的使用方式

    pandas中DataFrame新增行及global變量的使用方式

    這篇文章主要介紹了pandas中DataFrame新增行及global變量的使用方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教
    2024-02-02
  • python合并兩個(gè)字典的方法總結(jié)

    python合并兩個(gè)字典的方法總結(jié)

    在Python中,有多種方法可以通過(guò)使用各種函數(shù)和構(gòu)造函數(shù)來(lái)合并字典,在本文中,我們將討論一些合并字典的方法,有需要的小伙伴可以參考一下·
    2023-09-09
  • python 實(shí)現(xiàn)dict轉(zhuǎn)json并保存文件

    python 實(shí)現(xiàn)dict轉(zhuǎn)json并保存文件

    今天小編就為大家分享一篇python 實(shí)現(xiàn)dict轉(zhuǎn)json并保存文件,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧
    2019-12-12
  • python 多線程死鎖問(wèn)題的解決方案

    python 多線程死鎖問(wèn)題的解決方案

    這篇文章主要介紹了python 多線程死鎖問(wèn)題的解決方案,幫助大家更好的理解和學(xué)習(xí)python 鎖,感興趣的朋友可以了解下
    2020-08-08
  • Python多進(jìn)程方式抓取基金網(wǎng)站內(nèi)容的方法分析

    Python多進(jìn)程方式抓取基金網(wǎng)站內(nèi)容的方法分析

    這篇文章主要介紹了Python多進(jìn)程方式抓取基金網(wǎng)站內(nèi)容的方法,結(jié)合實(shí)例形式分析了Python多進(jìn)程抓取網(wǎng)站內(nèi)容相關(guān)實(shí)現(xiàn)技巧與操作注意事項(xiàng),需要的朋友可以參考下
    2019-06-06
  • Pycharm 解決自動(dòng)格式化沖突的設(shè)置操作

    Pycharm 解決自動(dòng)格式化沖突的設(shè)置操作

    這篇文章主要介紹了Pycharm 解決自動(dòng)格式化沖突的設(shè)置操作,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧
    2021-01-01
  • Python文本預(yù)處理學(xué)習(xí)指南

    Python文本預(yù)處理學(xué)習(xí)指南

    文本預(yù)處理是指在進(jìn)行自然語(yǔ)言處理(NLP)任務(wù)之前,對(duì)原始文本數(shù)據(jù)進(jìn)行清洗、轉(zhuǎn)換和標(biāo)準(zhǔn)化的過(guò)程,本文主要為大家介紹了文本預(yù)處理的使用,需要的可以參考下
    2023-07-07
  • Python面試之os.system()和os.popen()的區(qū)別詳析

    Python面試之os.system()和os.popen()的區(qū)別詳析

    Python調(diào)用Shell,有兩種方法:os.system(cmd)或os.popen(cmd)腳本執(zhí)行過(guò)程中的輸出內(nèi)容,下面這篇文章主要給大家介紹了關(guān)于Python面試之os.system()和os.popen()區(qū)別的相關(guān)資料,需要的朋友可以參考下
    2022-06-06

最新評(píng)論