pytorch fine-tune 預(yù)訓(xùn)練的模型操作
之一:
torchvision 中包含了很多預(yù)訓(xùn)練好的模型,這樣就使得 fine-tune 非常容易。本文主要介紹如何 fine-tune torchvision 中預(yù)訓(xùn)練好的模型。
安裝
pip install torchvision
如何 fine-tune
以 resnet18 為例:
from torchvision import models from torch import nn from torch import optim resnet_model = models.resnet18(pretrained=True) # pretrained 設(shè)置為 True,會(huì)自動(dòng)下載模型 所對(duì)應(yīng)權(quán)重,并加載到模型中 # 也可以自己下載 權(quán)重,然后 load 到 模型中,源碼中有 權(quán)重的地址。 # 假設(shè) 我們的 分類任務(wù)只需要 分 100 類,那么我們應(yīng)該做的是 # 1. 查看 resnet 的源碼 # 2. 看最后一層的 名字是啥 (在 resnet 里是 self.fc = nn.Linear(512 * block.expansion, num_classes)) # 3. 在外面替換掉這個(gè)層 resnet_model.fc= nn.Linear(in_features=..., out_features=100) # 這樣就 哦了,修改后的模型除了輸出層的參數(shù)是 隨機(jī)初始化的,其他層都是用預(yù)訓(xùn)練的參數(shù)初始化的。 # 如果只想訓(xùn)練 最后一層的話,應(yīng)該做的是: # 1. 將其它層的參數(shù) requires_grad 設(shè)置為 False # 2. 構(gòu)建一個(gè) optimizer, optimizer 管理的參數(shù)只有最后一層的參數(shù) # 3. 然后 backward, step 就可以了 # 這一步可以節(jié)省大量的時(shí)間,因?yàn)槎鄶?shù)的參數(shù)不需要計(jì)算梯度 for para in list(resnet_model.parameters())[:-2]: para.requires_grad=False optimizer = optim.SGD(params=[resnet_model.fc.weight, resnet_model.fc.bias], lr=1e-3) ...
為什么
這里介紹下 運(yùn)行resnet_model.fc= nn.Linear(in_features=..., out_features=100)時(shí) 框架內(nèi)發(fā)生了什么
這時(shí)應(yīng)該看 nn.Module 源碼的 __setattr__ 部分,因?yàn)?setattr 時(shí)都會(huì)調(diào)用這個(gè)方法:
def __setattr__(self, name, value): def remove_from(*dicts): for d in dicts: if name in d: del d[name]
首先映入眼簾就是 remove_from 這個(gè)函數(shù),這個(gè)函數(shù)的目的就是,如果出現(xiàn)了 同名的屬性,就將舊的屬性移除。 用剛才舉的例子就是:
預(yù)訓(xùn)練的模型中 有個(gè) 名字叫fc 的 Module。
在類定義外,我們 將另一個(gè) Module 重新 賦值給了 fc。
類定義內(nèi)的 fc 對(duì)應(yīng)的 Module 就會(huì)從 模型中 刪除。
之二:
前言
這篇文章算是論壇PyTorch Forums關(guān)于參數(shù)初始化和finetune的總結(jié),也是我在寫代碼中用的算是“最佳實(shí)踐”吧。最后希望大家沒事多逛逛論壇,有很多高質(zhì)量的回答。
參數(shù)初始化
參數(shù)的初始化其實(shí)就是對(duì)參數(shù)賦值。而我們需要學(xué)習(xí)的參數(shù)其實(shí)都是Variable,它其實(shí)是對(duì)Tensor的封裝,同時(shí)提供了data,grad等借口,這就意味著我們可以直接對(duì)這些參數(shù)進(jìn)行操作賦值了。這就是PyTorch簡潔高效所在。
所以我們可以進(jìn)行如下操作進(jìn)行初始化,當(dāng)然其實(shí)有其他的方法,但是這種方法是PyTorch作者所推崇的:
def weight_init(m): # 使用isinstance來判斷m屬于什么類型 if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): # m中的weight,bias其實(shí)都是Variable,為了能學(xué)習(xí)參數(shù)以及后向傳播 m.weight.data.fill_(1) m.bias.data.zero_()
Finetune
往往在加載了預(yù)訓(xùn)練模型的參數(shù)之后,我們需要finetune模型,可以使用不同的方式finetune。
局部微調(diào)
有時(shí)候我們加載了訓(xùn)練模型后,只想調(diào)節(jié)最后的幾層,其他層不訓(xùn)練。其實(shí)不訓(xùn)練也就意味著不進(jìn)行梯度計(jì)算,PyTorch中提供的requires_grad使得對(duì)訓(xùn)練的控制變得非常簡單。
model = torchvision.models.resnet18(pretrained=True) for param in model.parameters(): param.requires_grad = False # 替換最后的全連接層, 改為訓(xùn)練100類 # 新構(gòu)造的模塊的參數(shù)默認(rèn)requires_grad為True model.fc = nn.Linear(512, 100) # 只優(yōu)化最后的分類層 optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)
全局微調(diào)
有時(shí)候我們需要對(duì)全局都進(jìn)行finetune,只不過我們希望改換過的層和其他層的學(xué)習(xí)速率不一樣,這時(shí)候我們可以把其他層和新層在optimizer中單獨(dú)賦予不同的學(xué)習(xí)速率。比如:
ignored_params = list(map(id, model.fc.parameters())) base_params = filter(lambda p: id(p) not in ignored_params, model.parameters()) optimizer = torch.optim.SGD([ {'params': base_params}, {'params': model.fc.parameters(), 'lr': 1e-3} ], lr=1e-2, momentum=0.9)
其中base_params使用1e-3來訓(xùn)練,model.fc.parameters使用1e-2來訓(xùn)練,momentum是二者共有的。
之三:
pytorch finetune模型
文章主要講述如何在pytorch上讀取以往訓(xùn)練的模型參數(shù),在模型的名字已經(jīng)變更的情況下又如何讀取模型的部分參數(shù)等。
pytorch 模型的存儲(chǔ)與讀取
其中在模型的保存過程有存儲(chǔ)模型和參數(shù)一起的也有單獨(dú)存儲(chǔ)模型參數(shù)的
單獨(dú)存儲(chǔ)模型參數(shù)
存儲(chǔ)時(shí)使用:
torch.save(the_model.state_dict(), PATH)
讀取時(shí):
the_model = TheModelClass(*args, **kwargs) the_model.load_state_dict(torch.load(PATH))
存儲(chǔ)模型與參數(shù)
存儲(chǔ):
torch.save(the_model, PATH)
讀?。?/p>
the_model = torch.load(PATH)
模型的參數(shù)
fine-tune的過程是讀取原有模型的參數(shù),但是由于模型的所要處理的數(shù)據(jù)集不同,最后的一層class的總數(shù)不同,所以需要修改模型的最后一層,這樣模型讀取的參數(shù),和在大數(shù)據(jù)集上訓(xùn)練好下載的模型參數(shù)在形式上不一樣。需要我們自己去寫函數(shù)讀取參數(shù)。
pytorch模型參數(shù)的形式
模型的參數(shù)是以字典的形式存儲(chǔ)的。
model_dict = the_model.state_dict(), for k,v in model_dict.items(): print(k)
即可看到所有的鍵值
如果想修改模型的參數(shù),給相應(yīng)的鍵值賦值即可
model_dict[k] = new_value
最后更新模型的參數(shù)
the_model.load_state_dict(model_dict)
如果模型的key值和在大數(shù)據(jù)集上訓(xùn)練時(shí)的key值是一樣的
我們可以通過下列算法進(jìn)行讀取模型
model_dict = model.state_dict() pretrained_dict = torch.load(model_path) # 1. filter out unnecessary keys diff = {k: v for k, v in model_dict.items() if \ k in pretrained_dict and pretrained_dict[k].size() == v.size()} pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()} pretrained_dict.update(diff) # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict model.load_state_dict(model_dict)
如果模型的key值和在大數(shù)據(jù)集上訓(xùn)練時(shí)的key值是不一樣的,但是順序是一樣的
model_dict = model.state_dict() pretrained_dict = torch.load(model_path) keys = [] for k,v in pretrained_dict.items(): keys.append(k) i = 0 for k,v in model_dict.items(): if v.size() == pretrained_dict[keys[i]].size(): print(k, ',', keys[i]) model_dict[k]=pretrained_dict[keys[i]] i = i + 1 model.load_state_dict(model_dict)
如果模型的key值和在大數(shù)據(jù)集上訓(xùn)練時(shí)的key值是不一樣的,但是順序是也不一樣的
自己找對(duì)應(yīng)關(guān)系,一個(gè)key對(duì)應(yīng)一個(gè)key的賦值
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
使用Python實(shí)現(xiàn)提取PDF文件中指定頁面的內(nèi)容
在日常工作和學(xué)習(xí)中,我們經(jīng)常需要從PDF文件中提取特定頁面的內(nèi)容,本文主要為大家詳細(xì)介紹了如何使用Python編程語言和兩個(gè)強(qiáng)大的庫——pymupdf和wxPython來實(shí)現(xiàn)這個(gè)任務(wù),需要的可以了解下2023-12-12Python 快速把多個(gè)元素連接成一個(gè)字符串的操作方法
join() 方法一個(gè)用于將序列中的元素以指定的分隔符連接成一個(gè)字符串的方法,這個(gè)方法通常用于字符串操作,這篇文章主要介紹了Python 快速把多個(gè)元素連接成一個(gè)字符串的方法,需要的朋友可以參考下2024-06-06python3實(shí)現(xiàn)單目標(biāo)粒子群算法
這篇文章主要為大家詳細(xì)介紹了python3實(shí)現(xiàn)單目標(biāo)粒子群算法,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2019-11-11用ReactJS和Python的Flask框架編寫留言板的代碼示例
這篇文章主要介紹了用ReactJS和Python的Flask框架編寫留言板的代碼示例,其他的話用到了MongoDB這個(gè)方便使用JavaScript來操作的數(shù)據(jù)庫,需要的朋友可以參考下2015-12-12jupyter 中文亂碼設(shè)置編碼格式 避免控制臺(tái)輸出的解決
這篇文章主要介紹了jupyter 中文亂碼設(shè)置編碼格式 避免控制臺(tái)輸出的解決,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-04-04Django生成PDF文檔顯示在網(wǎng)頁上以及解決PDF中文顯示亂碼的問題
這篇文章主要介紹了Django生成PDF文檔顯示在網(wǎng)頁上以及解決PDF中文顯示亂碼的問題,小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧2019-07-07python3啟動(dòng)web服務(wù)引發(fā)的一系列問題匯總
由于行內(nèi)交付的機(jī)器已自帶python3 ,沒有采取自行安裝python3,但是運(yùn)行python腳本時(shí)報(bào)沒有tornado module,遇到這樣的問題如何處理呢,下面小編給大家介紹下python3啟動(dòng)web服務(wù)引發(fā)的一系列問題匯總,感興趣的朋友一起看看吧2023-02-02