pytorch GPU和CPU模型相互加載方式
1 pytorch保存模型的兩種方式
1.1 直接保存模型并讀取
# 創(chuàng)建你的模型實(shí)例對(duì)象: model
model = net()
## 保存模型
torch.save(model, 'model_name.pth')
## 讀取模型
model = torch.load('model_name.pth')1.2 只保存模型中的參數(shù)并讀取
## 保存模型
torch.save({'model': model.state_dict()}, 'model_name.pth')
## 讀取模型
model = net()
state_dict = torch.load('model_name.pth')
model.load_state_dict(state_dict['model'])- 第一種方法可以直接保存模型,加載模型的時(shí)候直接把讀取的模型給一個(gè)參數(shù)就行。
- 第二種方法則只是保存參數(shù),在讀取模型參數(shù)前要先定義一個(gè)模型(模型必須與原模型相同的構(gòu)造),然后對(duì)這個(gè)模型導(dǎo)入?yún)?shù)。雖然麻煩,但是可以同時(shí)保存多個(gè)模型的參數(shù),而第一種方法則不能,而且第一種方法有時(shí)不能保證模型的相同性(你讀取的模型并不是你想要的)。
如何保存模型決定了如何讀取模型,一般來(lái)選擇第二種來(lái)保存和讀取。
2 GPU / CPU模型相互加載
2.1 單個(gè)CPU和單個(gè)GPU模型加載
pytorch 允許把在GPU上訓(xùn)練的模型加載到CPU上,也允許把在CPU上訓(xùn)練的模型加載到GPU上。
加載模型參數(shù)的時(shí)候,在GPU和CPU訓(xùn)練的模型是不一樣的,這兩種模型是不能混為一談的,下面分情況進(jìn)行操作說明。
情況一:CPU -> CPU, GPU -> GPU
- GPU訓(xùn)練的模型,在GPU上使用;
- CPU訓(xùn)練的模型,在CPU上使用,
這種情況下我們都只用直接用下面的語(yǔ)句即可:
torch.load('model_dict.pth')情況二:GPU -> CPG/GPU
GPU訓(xùn)練的模型,不知道放在CPU還是GPU運(yùn)行,兩種情況都要考慮
import torch
from torchvision import models
# 加載預(yù)訓(xùn)練的GPU模型權(quán)重文件
weights_path = 'model_gpu.pth'
# 定義一個(gè)與原模型結(jié)構(gòu)相同的新模型
model = models.resnet50()
# 檢查是否有可用的CUDA設(shè)備
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 將權(quán)重映射到相應(yīng)的設(shè)備內(nèi)存并加載到模型中
weights = torch.load(weights_path, map_location=device)
model.load_state_dict(weights)
# 設(shè)置為評(píng)估模式
model.eval()
print("Model is successfully loaded and can be used on a", device.type, "!")情況三:CPU -> CPG/GPU
模型是在CPU上訓(xùn)練的,但不確定要在CPU還是GPU上運(yùn)行時(shí),兩種情況都要考慮
import torch
from torchvision import models
# 加載預(yù)訓(xùn)練的CPU模型權(quán)重文件
weights_path = 'model_cpu.pth'
# 定義一個(gè)與原模型結(jié)構(gòu)相同的新模型
model = models.resnet50()
# 檢查是否有可用的CUDA設(shè)備
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 將權(quán)重映射到相應(yīng)的設(shè)備內(nèi)存并加載到模型中
if device.type == 'cuda':
model.to(device)
weights = torch.load(weights_path, map_location=device)
else:
weights = torch.load(weights_path, map_location='cpu')
model.load_state_dict(weights)
# 設(shè)置為評(píng)估模式
model.eval()
print("Model is successfully loaded and can be used on a", device.type, "!")總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
詳解用python -m http.server搭一個(gè)簡(jiǎn)易的本地局域網(wǎng)
這篇文章主要介紹了詳解用python -m http.server搭一個(gè)簡(jiǎn)易的本地局域網(wǎng),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-09-09
用selenium解決滑塊驗(yàn)證碼的實(shí)現(xiàn)步驟
驗(yàn)證碼作為一種自然人的機(jī)器人的判別工具,被廣泛的用于各種防止程序做自動(dòng)化的場(chǎng)景中,下面這篇文章主要給大家介紹了關(guān)于用selenium解決滑塊驗(yàn)證碼的實(shí)現(xiàn)步驟,需要的朋友可以參考下2023-02-02
Python利用卡方Chi特征檢驗(yàn)實(shí)現(xiàn)提取關(guān)鍵文本特征
卡方檢驗(yàn)最基本的思想就是通過觀察實(shí)際值與理論值的偏差來(lái)確定理論的正確與否。本文將利用卡方Chi特征檢驗(yàn)實(shí)現(xiàn)提取關(guān)鍵文本特征功能,感興趣的可以了解一下2022-12-12
Pygame實(shí)現(xiàn)簡(jiǎn)易版趣味小游戲之反彈球
這篇文章主要為大家詳細(xì)介紹了python實(shí)現(xiàn)簡(jiǎn)易版趣味反彈球游戲,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2022-03-03
python linecache讀取行更新的實(shí)現(xiàn)
本文主要介紹了python linecache讀取行更新的實(shí)現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2023-03-03

