關(guān)于Pytorch中模型的保存與遷移問題
1 引言
各位朋友大家好,歡迎來到月來客棧。今天要和大家介紹的內(nèi)容是如何在Pytorch框架中對模型進行保存和載入、以及模型的遷移和再訓(xùn)練。一般來說,最常見的場景就是模型完成訓(xùn)練后的推斷過程。一個網(wǎng)絡(luò)模型在完成訓(xùn)練后通常都需要對新樣本進行預(yù)測,此時就只需要構(gòu)建模型的前向傳播過程,然后載入已訓(xùn)練好的參數(shù)初始化網(wǎng)絡(luò)即可。
第2個場景就是模型的再訓(xùn)練過程。一個模型在一批數(shù)據(jù)上訓(xùn)練完成之后需要將其保存到本地,并且可能過了一段時間后又收集到了一批新的數(shù)據(jù),因此這個時候就需要將之前的模型載入進行在新數(shù)據(jù)上進行增量訓(xùn)練(或者是在整個數(shù)據(jù)上進行全量訓(xùn)練)。
第3個應(yīng)用場景就是模型的遷移學(xué)習。這個時候就是將別人已經(jīng)訓(xùn)練好的預(yù)模型拿過來,作為你自己網(wǎng)絡(luò)模型參數(shù)的一部分進行初始化。例如:你自己在Bert模型的基礎(chǔ)上加了幾個全連接層來做分類任務(wù),那么你就需要將原始BERT模型中的參數(shù)載入并以此來初始化你的網(wǎng)絡(luò)中的BERT部分的權(quán)重參數(shù)。
在接下來的這篇文章中,筆者就以上述3個場景為例來介紹如何利用Pytorch框架來完成上述過程。
2 模型的保存與復(fù)用
在Pytorch中,我們可以通過torch.save()
和torch.load()
來完成上述場景中的主要步驟。下面,筆者將以之前介紹的LeNet5網(wǎng)絡(luò)模型為例來分別進行介紹。不過在這之前,我們先來看看Pytorch中模型參數(shù)的保存形式。
2.1 查看網(wǎng)絡(luò)模型參數(shù)
(1)查看參數(shù)
首先定義好LeNet5的網(wǎng)絡(luò)模型結(jié)構(gòu),如下代碼所示:
class LeNet5(nn.Module): def __init__(self, ): super(LeNet5, self).__init__() self.conv = nn.Sequential( # [n,1,28,28] nn.Conv2d(1, 6, 5, padding=2), # in_channels, out_channels, kernel_size nn.ReLU(), # [n,6,24,24] nn.MaxPool2d(2, 2), # kernel_size, stride [n,6,14,14] nn.Conv2d(6, 16, 5), # [n,16,10,10] nn.ReLU(), nn.MaxPool2d(2, 2)) # [n,16,5,5] self.fc = nn.Sequential( nn.Flatten(), nn.Linear(16 * 5 * 5, 120), nn.ReLU(), nn.Linear(120, 84), nn.ReLU(), nn.Linear(84, 10)) def forward(self, img): output = self.conv(img) output = self.fc(output) return output
在定義好LeNet5這個網(wǎng)絡(luò)結(jié)構(gòu)的類之后,只要我們完成了這個類的實例化操作,那么網(wǎng)絡(luò)中對應(yīng)的權(quán)重參數(shù)也都完成了初始化的工作,即有了一個初始值。同時,我們可以通過如下方式來訪問:
# Print model's state_dict print("Model's state_dict:") for param_tensor in model.state_dict(): print(param_tensor, "\t", model.state_dict()[param_tensor].size())
其輸出的結(jié)果為:
conv.0.weight torch.Size([6, 1, 5, 5])
conv.0.bias torch.Size([6])
conv.3.weight torch.Size([16, 6, 5, 5])
....
....
可以發(fā)現(xiàn),網(wǎng)絡(luò)模型中的參數(shù)model.state_dict()
其實是以字典的形式(實質(zhì)上是collections
模塊中的OrderedDict
)保存下來的:
print(model.state_dict().keys()) # odict_keys(['conv.0.weight', 'conv.0.bias', 'conv.3.weight', 'conv.3.bias', 'fc.1.weight', 'fc.1.bias', 'fc.3.weight', 'fc.3.bias', 'fc.5.weight', 'fc.5.bias'])
(2)自定義參數(shù)前綴
同時,這里值得注意的地方有兩點:①參數(shù)名中的fc
和conv
前綴是根據(jù)你在上面定義nn.Sequential()
時的名字所確定的;②參數(shù)名中的數(shù)字表示每個Sequential()
中網(wǎng)絡(luò)層所在的位置。例如將網(wǎng)絡(luò)結(jié)構(gòu)定義成如下形式:
class LeNet5(nn.Module): def __init__(self, ): super(LeNet5, self).__init__() self.moon = nn.Sequential( # [n,1,28,28] nn.Conv2d(1, 6, 5, padding=2), # in_channels, out_channels, kernel_size nn.ReLU(), # [n,6,24,24] nn.MaxPool2d(2, 2), # kernel_size, stride [n,6,14,14] nn.Conv2d(6, 16, 5), # [n,16,10,10] nn.ReLU(), nn.MaxPool2d(2, 2), nn.Flatten(), nn.Linear(16 * 5 * 5, 120), nn.ReLU(), nn.Linear(120, 84), nn.ReLU(), nn.Linear(84, 10))
那么其參數(shù)名則為:
print(model.state_dict().keys()) odict_keys(['moon.0.weight', 'moon.0.bias', 'moon.3.weight', 'moon.3.bias', 'moon.7.weight', 'moon.7.bias', 'moon.9.weight', 'moon.9.bias', 'moon.11.weight', 'moon.11.bias'])
理解了這一點對于后續(xù)我們?nèi)ソ馕龊洼d入一些預(yù)訓(xùn)練模型很有幫助。
除此之外,對于中的優(yōu)化器等,其同樣有對應(yīng)的state_dict()
方法來獲取對于的參數(shù),例如:
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) print("Optimizer's state_dict:") for var_name in optimizer.state_dict(): print(var_name, "\t", optimizer.state_dict()[var_name]) # Optimizer's state_dict: state {} param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140239245300504, 140239208339784, 140239245311360, 140239245310856, 140239266942480, 140239266942552, 140239266942624, 140239266942696, 140239266942912, 140239267041352]}]
在介紹完模型參數(shù)的查看方法后,就可以進入到模型復(fù)用階段的內(nèi)容介紹了。
2.2 載入模型進行推斷
(1) 模型保存
在Pytorch中,對于模型的保存來說是非常簡單的,通常來說通過如下兩行代碼便可以實現(xiàn):
model_save_path = os.path.join(model_save_dir, 'model.pt') torch.save(model.state_dict(), model_save_path)
在指定保存的模型名稱時Pytorch官方建議的后綴為.pt
或者.pth
(當然也不是強制的)。最后,只需要在合適的地方加入第2行代碼即可完成模型的保存。
同時,如果想要在訓(xùn)練過程中保存某個條件下的最優(yōu)模型,那么應(yīng)該通過如下方式:
best_model_state = deepcopy(model.state_dict()) torch.save(best_model_state, model_save_path)
而不是:
best_model_state = model.state_dict() torch.save(best_model_state, model_save_path)
因為后者best_model_state
得到只是model.state_dict()
的引用,它依舊會隨著訓(xùn)練過程而發(fā)生改變。
(2)復(fù)用模型進行推斷
在推斷過程中,首先需要完成網(wǎng)絡(luò)的初始化,然后再載入已有的模型參數(shù)來覆蓋網(wǎng)絡(luò)中的權(quán)重參數(shù)即可,示例代碼如下:
def inference(data_iter, device, model_save_dir='./MODEL'): model = LeNet5() # 初始化現(xiàn)有模型的權(quán)重參數(shù) model.to(device) model_save_path = os.path.join(model_save_dir, 'model.pt') if os.path.exists(model_save_path): loaded_paras = torch.load(model_save_path) model.load_state_dict(loaded_paras) # 用本地已有模型來重新初始化網(wǎng)絡(luò)權(quán)重參數(shù) model.eval() # 注意不要忘記 with torch.no_grad(): acc_sum, n = 0.0, 0 for x, y in data_iter: x, y = x.to(device), y.to(device) logits = model(x) acc_sum += (logits.argmax(1) == y).float().sum().item() n += len(y) print("Accuracy in test data is :", acc_sum / n)
在上述代碼中,4-7行便是用來載入本地模型參數(shù),并用其覆蓋網(wǎng)絡(luò)模型中原有的參數(shù)。這樣,便可以進行后續(xù)的推斷工作:
Accuracy in test data is : 0.8851
2.3 載入模型進行訓(xùn)練
在介紹完模型的保存與復(fù)用之后,對于網(wǎng)絡(luò)的追加訓(xùn)練就很簡單了。最簡便的一種方式就是在訓(xùn)練過程中只保存網(wǎng)絡(luò)權(quán)重,然后在后續(xù)進行追加訓(xùn)練時只載入網(wǎng)絡(luò)權(quán)重參數(shù)初始化網(wǎng)絡(luò)進行訓(xùn)練即可,示例如下(完整代碼參見[2]):
def train(self): #...... model_save_path = os.path.join(self.model_save_dir, 'model.pt') if os.path.exists(model_save_path): loaded_paras = torch.load(model_save_path) self.model.load_state_dict(loaded_paras) print("#### 成功載入已有模型,進行追加訓(xùn)練...") optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) # 定義優(yōu)化器 #...... for epoch in range(self.epochs): for i, (x, y) in enumerate(train_iter): x, y = x.to(device), y.to(device) logits = self.model(x) # ...... print("Epochs[{}/{}]--acc on test {:.4}".format(epoch, self.epochs, self.evaluate(test_iter, self.model, device))) torch.save(self.model.state_dict(), model_save_path)
這樣,便完成了模型的追加訓(xùn)練:
#### 成功載入已有模型,進行追加訓(xùn)練... Epochs[0/5]---batch[938/0]---acc 0.9062---loss 0.2926 Epochs[0/5]---batch[938/100]---acc 0.9375---loss 0.1598 ......
除此之外,你也可以在保存參數(shù)的時候,將優(yōu)化器參數(shù)、損失值等一同保存下來,然后在恢復(fù)模型的時候連同其它參數(shù)一起恢復(fù),示例如下:
model_save_path = os.path.join(model_save_dir, 'model.pt') torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, ... }, model_save_path)
載入方式如下:
checkpoint = torch.load(model_save_path) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss']
2.4 載入模型進行遷移
(1)定義新模型
到目前為止,對于前面兩種應(yīng)用場景的介紹就算完成了,可以發(fā)現(xiàn)總體上并不復(fù)雜。但是對于第3中場景的應(yīng)用來說就會略微復(fù)雜一點。
假設(shè)現(xiàn)在有一個LeNet6網(wǎng)絡(luò)模型,它是在LeNet5的基礎(chǔ)最后多加了一個全連接層,其定義如下:
class LeNet6(nn.Module): def __init__(self, ): super(LeNet6, self).__init__() self.conv = nn.Sequential( # [n,1,28,28] nn.Conv2d(1, 6, 5, padding=2), # in_channels, out_channels, kernel_size nn.ReLU(), # [n,6,24,24] nn.MaxPool2d(2, 2), # kernel_size, stride [n,6,14,14] nn.Conv2d(6, 16, 5), # [n,16,10,10] nn.ReLU(), nn.MaxPool2d(2, 2)) # [n,16,5,5] self.fc = nn.Sequential( nn.Flatten(), nn.Linear(16 * 5 * 5, 120), nn.ReLU(), nn.Linear(120, 84), nn.ReLU(), nn.Linear(84, 64), nn.ReLU(), nn.Linear(64, 10) ) # 新加入的全連接層
接下來,我們需要將在LeNet5上訓(xùn)練得到的權(quán)重參數(shù)遷移到LeNet6網(wǎng)絡(luò)中去。從上面LeNet6的定義可以發(fā)現(xiàn),此時盡管只是多加了一個全連接層,但是倒數(shù)第2層參數(shù)的維度也發(fā)生了變換。因此,對于LeNet6來說只能復(fù)用LeNet5網(wǎng)絡(luò)前面4層的權(quán)重參數(shù)。
(2)查看模型參數(shù)
在拿到一個模型參數(shù)后,首先我們可以將其載入,然查看相關(guān)參數(shù)的信息:
model_save_path = os.path.join('./MODEL', 'model.pt') loaded_paras = torch.load(model_save_path) for param_tensor in loaded_paras: print(param_tensor, "\t", loaded_paras[param_tensor].size()) #---- 可復(fù)用部分 conv.0.weight torch.Size([6, 1, 5, 5]) conv.0.bias torch.Size([6]) conv.3.weight torch.Size([16, 6, 5, 5]) conv.3.bias torch.Size([16]) fc.1.weight torch.Size([120, 400]) fc.1.bias torch.Size([120]) fc.3.weight torch.Size([84, 120]) fc.3.bias torch.Size([84]) #----- 不可復(fù)用部分 fc.5.weight torch.Size([10, 84]) fc.5.bias torch.Size([10])
同時,對于LeNet6網(wǎng)絡(luò)的參數(shù)信息為:
model = LeNet6() for param_tensor in model.state_dict(): print(param_tensor, "\t", model.state_dict()[param_tensor].size()) # conv.0.weight torch.Size([6, 1, 5, 5]) conv.0.bias torch.Size([6]) conv.3.weight torch.Size([16, 6, 5, 5]) conv.3.bias torch.Size([16]) fc.1.weight torch.Size([120, 400]) fc.1.bias torch.Size([120]) fc.3.weight torch.Size([84, 120]) fc.3.bias torch.Size([84]) #------ 新加入部分 fc.5.weight torch.Size([64, 84]) fc.5.bias torch.Size([64]) fc.7.weight torch.Size([10, 64]) fc.7.bias torch.Size([10])
在理清楚了新舊模型的參數(shù)后,下面就可以將LeNet5中我們需要的參數(shù)給取出來,然后再換到LeNet6的網(wǎng)絡(luò)中。
(3)模型遷移
雖然本地載入的模型參數(shù)(上面的loaded_paras
)和模型初始化后的參數(shù)(上面的model.state_dict()
)都是一個字典的形式,但是我們并不能夠直接改變model.state_dict()
中的權(quán)重參數(shù)。這里需要先構(gòu)造一個state_dict
然后通過model.load_state_dict()
方法來重新初始化網(wǎng)絡(luò)中的參數(shù)。
同時,在這個過程中我們需要篩選掉本地模型中不可復(fù)用的部分,具體代碼如下:
def para_state_dict(model, model_save_dir): state_dict = deepcopy(model.state_dict()) model_save_path = os.path.join(model_save_dir, 'model.pt') if os.path.exists(model_save_path): loaded_paras = torch.load(model_save_path) for key in state_dict: # 在新的網(wǎng)絡(luò)模型中遍歷對應(yīng)參數(shù) if key in loaded_paras and state_dict[key].size() == loaded_paras[key].size(): print("成功初始化參數(shù):", key) state_dict[key] = loaded_paras[key] return state_dict
在上述代碼中,第2行的作用是先拷貝網(wǎng)絡(luò)中(LeNet6)原有的參數(shù);第6-9行則是用本地的模型參數(shù)(LeNet5)中可以復(fù)用的替換掉LeNet6中的對應(yīng)部分,其中第7行就是判斷可用的條件。同時需要注意的是在不同的情況下篩選的方式可能不一樣,因此具體情況需要具體分析,但是整體邏輯是一樣的。
最后,我們只需要在模型訓(xùn)練之前調(diào)用該函數(shù),然后重新初始化LeNet6中的部分權(quán)重參數(shù)即可[2]:
state_dict = para_state_dict(self.model, self.model_save_dir) self.model.load_state_dict(state_dict)
訓(xùn)練結(jié)果如下:
成功初始化參數(shù): conv.0.weight
成功初始化參數(shù): conv.0.bias
成功初始化參數(shù): conv.3.weight
成功初始化參數(shù): conv.3.bias
成功初始化參數(shù): fc.1.weight
成功初始化參數(shù): fc.1.bias
成功初始化參數(shù): fc.3.weight
成功初始化參數(shù): fc.3.bias
#### 成功載入已有模型,進行追加訓(xùn)練...
Epochs[0/5]---batch[938/0]---acc 0.1094---loss 2.512
Epochs[0/5]---batch[938/100]---acc 0.9375---loss 0.2141
Epochs[0/5]---batch[938/200]---acc 0.9219---loss 0.2729
Epochs[0/5]---batch[938/300]---acc 0.8906---loss 0.2958
......
Epochs[0/5]---batch[938/900]---acc 0.8906---loss 0.2828
Epochs[0/5]--acc on test 0.8808
可以發(fā)現(xiàn),在大約100個batch之后,模型的準確率就提升上來了。
3 總結(jié)
在本篇文章中,筆者首先介紹了模型復(fù)用的幾種典型場景;然后介紹了如何查看Pytorch模型中的相關(guān)參數(shù)信息;接著介紹了如何載入模型、如何進行追加訓(xùn)練以及進行模型的遷移學(xué)習等。有了這部分內(nèi)容的鋪墊,在后續(xù)介紹類似BERT這樣的預(yù)訓(xùn)練模型載入時就會容易很多了。
到此這篇關(guān)于Pytorch中模型的保存與遷移的文章就介紹到這了,更多相關(guān)Pytorch模型內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python Socket多線程并發(fā)原理及實現(xiàn)
這篇文章主要介紹了Python Socket多線程并發(fā)原理及實現(xiàn),幫助大家更好的理解和使用python,感興趣的朋友可以了解下2020-12-12python備份文件以及mysql數(shù)據(jù)庫的腳本代碼
最近正在學(xué)習python,看了幾天了,,所以寫個小腳本練習練習,沒什么含金量,只當練手2013-06-06Python定時任務(wù)APScheduler原理及實例解析
這篇文章主要介紹了Python定時任務(wù)APScheduler原理及實例解析,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習或者工作具有一定的參考學(xué)習價值,需要的朋友可以參考下2020-05-05Python Tkinter GUI編程實現(xiàn)Frame切換
本文主要介紹了Python Tkinter GUI編程實現(xiàn)Frame切換,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習或者工作具有一定的參考學(xué)習價值,需要的朋友們下面隨著小編來一起學(xué)習學(xué)習吧2022-04-04