深入理解Pytorch微調(diào)torchvision模型
一、簡介
在本小節(jié),深入探討如何對torchvision進(jìn)行微調(diào)和特征提取。所有模型都已經(jīng)預(yù)先在1000類的magenet數(shù)據(jù)集上訓(xùn)練完成。 本節(jié)將深入介紹如何使用幾個(gè)現(xiàn)代的CNN架構(gòu),并將直觀展示如何微調(diào)任意的PyTorch模型。
本節(jié)將執(zhí)行兩種類型的遷移學(xué)習(xí):
- 微調(diào):從預(yù)訓(xùn)練模型開始,更新我們新任務(wù)的所有模型參數(shù),實(shí)質(zhì)上是重新訓(xùn)練整個(gè)模型。
- 特征提?。簭念A(yù)訓(xùn)練模型開始,僅更新從中導(dǎo)出預(yù)測的最終圖層權(quán)重。它被稱為特征提取,因?yàn)槲覀兪褂妙A(yù)訓(xùn)練的CNN作為固定 的特征提取器,并且僅改變輸出層。
通常這兩種遷移學(xué)習(xí)方法都會(huì)遵循一下步驟:
- 初始化預(yù)訓(xùn)練模型
- 重組最后一層,使其具有與新數(shù)據(jù)集類別數(shù)相同的輸出數(shù)
- 為優(yōu)化算法定義想要的訓(xùn)練期間更新的參數(shù)
- 運(yùn)行訓(xùn)練步驟
二、導(dǎo)入相關(guān)包
from __future__ import print_function from __future__ import division import torch import torch.nn as nn import torch.optim as optim import numpy as np import torchvision from torchvision import datasets,models,transforms import matplotlib.pyplot as plt import time import os import copy print("Pytorch version:",torch.__version__) print("torchvision version:",torchvision.__version__)
運(yùn)行結(jié)果
三、數(shù)據(jù)輸入
數(shù)據(jù)集——>我在這里
鏈接:https://pan.baidu.com/s/1G3yRfKTQf9sIq1iCSoymWQ
提取碼:1234
#%%輸入 data_dir="D:\Python\Pytorch\data\hymenoptera_data" # 從[resnet,alexnet,vgg,squeezenet,desenet,inception] model_name='squeezenet' # 數(shù)據(jù)集中類別數(shù)量 num_classes=2 # 訓(xùn)練的批量大小 batch_size=8 # 訓(xùn)練epoch數(shù) num_epochs=15 # 用于特征提取的標(biāo)志。為FALSE,微調(diào)整個(gè)模型,為TRUE只更新圖層參數(shù) feature_extract=True
四、輔助函數(shù)
1、模型訓(xùn)練和驗(yàn)證
- train_model函數(shù)處理給定模型的訓(xùn)練和驗(yàn)證。作為輸入,它需要PyTorch模型、數(shù)據(jù)加載器字典、損失函數(shù)、優(yōu)化器、用于訓(xùn)練和驗(yàn) 證epoch數(shù),以及當(dāng)模型是初始模型時(shí)的布爾標(biāo)志。
- is_inception標(biāo)志用于容納 Inception v3 模型,因?yàn)樵擉w系結(jié)構(gòu)使用輔助輸出, 并且整體模型損失涉及輔助輸出和最終輸出,如此處所述。 這個(gè)函數(shù)訓(xùn)練指定數(shù)量的epoch,并且在每個(gè)epoch之后運(yùn)行完整的驗(yàn)證步驟。它還跟蹤最佳性能的模型(從驗(yàn)證準(zhǔn)確率方面),并在訓(xùn)練 結(jié)束時(shí)返回性能最好的模型。在每個(gè)epoch之后,打印訓(xùn)練和驗(yàn)證正確率。
#%%模型訓(xùn)練和驗(yàn)證 device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def train_model(model,dataloaders,criterion,optimizer,num_epochs=25,is_inception=False): since=time.time() val_acc_history=[] best_model_wts=copy.deepcopy(model.state_dict()) best_acc=0.0 for epoch in range(num_epochs): print('Epoch{}/{}'.format(epoch, num_epochs-1)) print('-'*10) # 每個(gè)epoch都有一個(gè)訓(xùn)練和驗(yàn)證階段 for phase in['train','val']: if phase=='train': model.train() else: model.eval() running_loss=0.0 running_corrects=0 # 迭代數(shù)據(jù) for inputs,labels in dataloaders[phase]: inputs=inputs.to(device) labels=labels.to(device) # 梯度置零 optimizer.zero_grad() # 向前傳播 with torch.set_grad_enabled(phase=='train'): # 獲取模型輸出并計(jì)算損失,開始的特殊情況在訓(xùn)練中他有一個(gè)輔助輸出 # 在訓(xùn)練模式下,通過將最終輸出和輔助輸出相加來計(jì)算損耗,在測試中值考慮最終輸出 if is_inception and phase=='train': outputs,aux_outputs=model(inputs) loss1=criterion(outputs,labels) loss2=criterion(aux_outputs,labels) loss=loss1+0.4*loss2 else: outputs=model(inputs) loss=criterion(outputs,labels) _,preds=torch.max(outputs,1) if phase=='train': loss.backward() optimizer.step() # 添加 running_loss+=loss.item()*inputs.size(0) running_corrects+=torch.sum(preds==labels.data) epoch_loss=running_loss/len(dataloaders[phase].dataset) epoch_acc=running_corrects.double()/len(dataloaders[phase].dataset) print('{}loss : {:.4f} acc:{:.4f}'.format(phase, epoch_loss,epoch_acc)) if phase=='train' and epoch_acc>best_acc: best_acc=epoch_acc best_model_wts=copy.deepcopy(model.state_dict()) if phase=='val': val_acc_history.append(epoch_acc) print() time_elapsed=time.time()-since print('training complete in {:.0f}s'.format(time_elapsed//60, time_elapsed%60)) print('best val acc:{:.4f}'.format(best_acc)) model.load_state_dict(best_model_wts) return model,val_acc_history
2、設(shè)置模型參數(shù)的'.requires_grad屬性'
當(dāng)我們進(jìn)行特征提取時(shí),此輔助函數(shù)將模型中參數(shù)的 .requires_grad 屬性設(shè)置為False。
默認(rèn)情況下,當(dāng)我們加載一個(gè)預(yù)訓(xùn)練模型時(shí),所有參數(shù)都是 .requires_grad = True
,如果我們從頭開始訓(xùn)練或微調(diào),這種設(shè)置就沒問題。
但是,如果我們要運(yùn)行特征提取并且只想為新初始化的層計(jì)算梯度,那么我們希望所有其他參數(shù)不需要梯度變化。
#%%設(shè)置模型參數(shù)的.require——grad屬性 def set_parameter_requires_grad(model,feature_extracting): if feature_extracting: for param in model.parameters(): param.require_grad=False
靚仔今天先去跑步了,再不跑來不及了,先更這么多,后續(xù)明天繼續(xù)~(感謝有人沒有催更!感謝監(jiān)督!希望繼續(xù)監(jiān)督?。?/p>
以上就是深入理解Pytorch微調(diào)torchvision模型的詳細(xì)內(nèi)容,更多關(guān)于Pytorch torchvision模型的資料請關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
python-opencv 中值濾波{cv2.medianBlur(src, ksize)}的用法
這篇文章主要介紹了python-opencv 中值濾波{cv2.medianBlur(src, ksize)}的用法,具有很好的參考價(jià)值,希望對大家有所幫助。2021-06-06詳解如何使用Pandas處理時(shí)間序列數(shù)據(jù)
時(shí)間序列數(shù)據(jù)在數(shù)據(jù)分析建模中很常見,例如天氣預(yù)報(bào),空氣狀態(tài)監(jiān)測,股票交易等金融場景,本文給大家詳細(xì)介紹了如何使用Pandas處理時(shí)間序列數(shù)據(jù),文中通過代碼示例講解的非常詳細(xì),需要的朋友可以參考下2024-01-01Python序列化與反序列化相關(guān)知識總結(jié)
今天給大家?guī)黻P(guān)于python的相關(guān)知識,文章圍繞著Python序列化與反序列展開,文中有非常詳細(xì)的介紹,需要的朋友可以參考下2021-06-06Python pandas RFM模型應(yīng)用實(shí)例詳解
這篇文章主要介紹了Python pandas RFM模型應(yīng)用,結(jié)合實(shí)例形式詳細(xì)分析了pandas RFM模型的概念、原理、應(yīng)用及相關(guān)操作注意事項(xiàng),需要的朋友可以參考下2019-11-11使用matplotlib繪制并排柱狀圖的實(shí)戰(zhàn)案例
堆積柱狀圖有堆積柱狀圖的好處,比如說我們可以很方便地看到多分類總和的趨勢,下面這篇文章主要給大家介紹了關(guān)于使用matplotlib繪制并排柱狀圖的相關(guān)資料,需要的朋友可以參考下2022-07-07python實(shí)現(xiàn)web方式logview的方法
這篇文章主要介紹了python實(shí)現(xiàn)web方式logview的方法,涉及Python基于web模塊操作Linux命令的技巧,具有一定參考借鑒價(jià)值,需要的朋友可以參考下2015-08-08python安裝dlib庫報(bào)錯(cuò)問題及解決方法
這篇文章主要介紹了python安裝dlib庫報(bào)錯(cuò)問題及解決方法,本文通過實(shí)例代碼給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-03-03Flask與數(shù)據(jù)庫的交互插件Flask-Sqlalchemy的使用
在構(gòu)建Web應(yīng)用時(shí),與數(shù)據(jù)庫的交互是必不可少的部分,本文主要介紹了Flask與數(shù)據(jù)庫的交互插件Flask-Sqlalchemy的使用,具有一定的參考價(jià)值,感興趣的可以了解一下2024-03-03python實(shí)現(xiàn)登錄與注冊系統(tǒng)
這篇文章主要為大家詳細(xì)介紹了python實(shí)現(xiàn)登錄與注冊系統(tǒng),文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2020-11-11