pytorch GPU和CPU模型相互加載方式
1 pytorch保存模型的兩種方式
1.1 直接保存模型并讀取
# 創(chuàng)建你的模型實例對象: model model = net() ## 保存模型 torch.save(model, 'model_name.pth') ## 讀取模型 model = torch.load('model_name.pth')
1.2 只保存模型中的參數(shù)并讀取
## 保存模型 torch.save({'model': model.state_dict()}, 'model_name.pth') ## 讀取模型 model = net() state_dict = torch.load('model_name.pth') model.load_state_dict(state_dict['model'])
- 第一種方法可以直接保存模型,加載模型的時候直接把讀取的模型給一個參數(shù)就行。
- 第二種方法則只是保存參數(shù),在讀取模型參數(shù)前要先定義一個模型(模型必須與原模型相同的構(gòu)造),然后對這個模型導(dǎo)入?yún)?shù)。雖然麻煩,但是可以同時保存多個模型的參數(shù),而第一種方法則不能,而且第一種方法有時不能保證模型的相同性(你讀取的模型并不是你想要的)。
如何保存模型決定了如何讀取模型,一般來選擇第二種來保存和讀取。
2 GPU / CPU模型相互加載
2.1 單個CPU和單個GPU模型加載
pytorch 允許把在GPU上訓(xùn)練的模型加載到CPU上,也允許把在CPU上訓(xùn)練的模型加載到GPU上。
加載模型參數(shù)的時候,在GPU和CPU訓(xùn)練的模型是不一樣的,這兩種模型是不能混為一談的,下面分情況進行操作說明。
情況一:CPU -> CPU, GPU -> GPU
- GPU訓(xùn)練的模型,在GPU上使用;
- CPU訓(xùn)練的模型,在CPU上使用,
這種情況下我們都只用直接用下面的語句即可:
torch.load('model_dict.pth')
情況二:GPU -> CPG/GPU
GPU訓(xùn)練的模型,不知道放在CPU還是GPU運行,兩種情況都要考慮
import torch from torchvision import models # 加載預(yù)訓(xùn)練的GPU模型權(quán)重文件 weights_path = 'model_gpu.pth' # 定義一個與原模型結(jié)構(gòu)相同的新模型 model = models.resnet50() # 檢查是否有可用的CUDA設(shè)備 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 將權(quán)重映射到相應(yīng)的設(shè)備內(nèi)存并加載到模型中 weights = torch.load(weights_path, map_location=device) model.load_state_dict(weights) # 設(shè)置為評估模式 model.eval() print("Model is successfully loaded and can be used on a", device.type, "!")
情況三:CPU -> CPG/GPU
模型是在CPU上訓(xùn)練的,但不確定要在CPU還是GPU上運行時,兩種情況都要考慮
import torch from torchvision import models # 加載預(yù)訓(xùn)練的CPU模型權(quán)重文件 weights_path = 'model_cpu.pth' # 定義一個與原模型結(jié)構(gòu)相同的新模型 model = models.resnet50() # 檢查是否有可用的CUDA設(shè)備 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 將權(quán)重映射到相應(yīng)的設(shè)備內(nèi)存并加載到模型中 if device.type == 'cuda': model.to(device) weights = torch.load(weights_path, map_location=device) else: weights = torch.load(weights_path, map_location='cpu') model.load_state_dict(weights) # 設(shè)置為評估模式 model.eval() print("Model is successfully loaded and can be used on a", device.type, "!")
總結(jié)
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
詳解用python -m http.server搭一個簡易的本地局域網(wǎng)
這篇文章主要介紹了詳解用python -m http.server搭一個簡易的本地局域網(wǎng),文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-09-09Python利用卡方Chi特征檢驗實現(xiàn)提取關(guān)鍵文本特征
卡方檢驗最基本的思想就是通過觀察實際值與理論值的偏差來確定理論的正確與否。本文將利用卡方Chi特征檢驗實現(xiàn)提取關(guān)鍵文本特征功能,感興趣的可以了解一下2022-12-12python linecache讀取行更新的實現(xiàn)
本文主要介紹了python linecache讀取行更新的實現(xiàn),文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2023-03-03