python中torch.load中的map_location參數(shù)使用
引言
在PyTorch中,torch.load()
函數(shù)是用于加載保存模型或張量數(shù)據(jù)的重要工具。當(dāng)我們訓(xùn)練好一個(gè)深度學(xué)習(xí)模型后,通常需要將模型的參數(shù)(或稱為狀態(tài)字典,state_dict)保存下來,以便后續(xù)進(jìn)行模型評(píng)估、繼續(xù)訓(xùn)練或部署到其他環(huán)境中。在加載這些保存的數(shù)據(jù)時(shí),map_location
參數(shù)為我們提供了極大的靈活性,以決定這些數(shù)據(jù)應(yīng)該被加載到哪個(gè)設(shè)備上。本文將詳細(xì)解析map_location
參數(shù)的功能和使用方法,并通過實(shí)戰(zhàn)案例來展示其在不同場(chǎng)景下的應(yīng)用。
map_location參數(shù)詳解
map_location
參數(shù)在torch.load()
函數(shù)中扮演著至關(guān)重要的角色。它決定了從保存的文件中加載數(shù)據(jù)時(shí)應(yīng)將它們映射到哪個(gè)設(shè)備上。在PyTorch中,設(shè)備可以是CPU或GPU,而GPU可以有多個(gè),每個(gè)都有其獨(dú)立的索引。map_location
的靈活使用能夠讓我們輕松地在不同設(shè)備之間遷移模型,從而充分利用不同設(shè)備的計(jì)算優(yōu)勢(shì)。
map_location參數(shù)的數(shù)據(jù)類型
map_location
參數(shù)的數(shù)據(jù)類型可以是:
參數(shù)類型 | 描述 | 示例 |
---|---|---|
字符串(str) | 預(yù)定義的設(shè)備字符串,指定目標(biāo)設(shè)備。 | 1. 'cpu' :加載到CPU上;2. 'cuda:X' :加載到索引為X的GPU上。 |
torch.device對(duì)象 | 一個(gè)表示目標(biāo)設(shè)備的torch.device 對(duì)象。 | 1.torch.device('cpu') :加載到CPU上;2. torch.device('cuda:1') :加載到索引為1的GPU上。 |
可調(diào)用對(duì)象(callable) | 一個(gè)接收存儲(chǔ)路徑并返回新位置的函數(shù)。 | lambda storage, loc: storage.cuda(1) :將每個(gè)存儲(chǔ)對(duì)象移動(dòng)到索引為1的GPU上。 |
字典(dict) | 一個(gè)將存儲(chǔ)路徑映射到新位置的字典。 | {'cuda:1':'cuda:0'} :將原本在GPU 1上的張量加載到GPU 0上。 |
map_location參數(shù)的使用場(chǎng)景
CPU加載:當(dāng)你想在CPU上加載模型時(shí),可以設(shè)置
map_location='cpu'
。這適用于那些不需要GPU加速的推理任務(wù),或者在沒有GPU的環(huán)境中部署模型。指定GPU加載:如果你有多個(gè)GPU,并且想將模型加載到特定的GPU上,可以使用
'cuda:X'
格式的字符串,其中X
是GPU的索引。這在多GPU環(huán)境中非常有用,可以確保模型加載到指定的設(shè)備上。自動(dòng)選擇GPU:如果你只想在GPU上加載模型,但不關(guān)心具體是哪一個(gè)GPU,可以設(shè)置
map_location=torch.device('cuda')
。這會(huì)自動(dòng)選擇第一個(gè)可用的GPU來加載模型。保持原始設(shè)備:如果你想保持模型在加載時(shí)的原始設(shè)備(即如果模型原先是在GPU上訓(xùn)練的,就仍然在GPU上加載;如果是在CPU上,就在CPU上加載),可以使用
map_location=None
或map_location=torch.device('cpu')
(對(duì)于CPU模型)和map_location=torch.device('cuda')
(對(duì)于GPU模型)。自定義映射邏輯:通過傳遞一個(gè)可調(diào)用對(duì)象,你可以實(shí)現(xiàn)更復(fù)雜的映射邏輯。例如,你可以編寫一個(gè)函數(shù),根據(jù)存儲(chǔ)路徑或模型結(jié)構(gòu)來決定將模型加載到哪個(gè)設(shè)備上。這在需要根據(jù)特定條件動(dòng)態(tài)選擇加載設(shè)備時(shí)非常有用。
代碼實(shí)戰(zhàn)(詳細(xì)注釋)
下面將通過幾個(gè)實(shí)戰(zhàn)案例來展示map_location
參數(shù)在不同場(chǎng)景下的應(yīng)用。
案例1:從文件加載張量到CPU
# 案例1:從文件加載張量到CPU # 使用torch.load()函數(shù)加載tensors.pt文件中的所有張量到CPU上 tensors = torch.load('tensors.pt')
案例2:指定設(shè)備加載張量
# 案例2:指定設(shè)備加載張量 # 使用torch.load()函數(shù)并指定map_location參數(shù)為CPU設(shè)備,加載tensors.pt文件中的所有張量到CPU上 tensors_on_cpu = torch.load('tensors.pt', map_location=torch.device('cpu'))
案例3:使用匿名函數(shù)指定加載位置
# 案例3:使用函數(shù)指定加載位置 # 使用torch.load()函數(shù)和map_location參數(shù)為一個(gè)lambda函數(shù),該函數(shù)不做任何改變,保持張量原始位置(通常是CPU) tensors_original_location = torch.load('tensors.pt', map_location=lambda storage, loc: storage)
案例4:將張量加載到指定GPU
# 案例4:將張量加載到指定GPU # 使用torch.load()函數(shù)和map_location參數(shù)為一個(gè)lambda函數(shù),該函數(shù)將張量移動(dòng)到索引為1的GPU上 tensors_on_gpu1 = torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
案例5:張量從一個(gè)GPU映射到另一個(gè)GPU
# 案例5:張量從一個(gè)GPU映射到另一個(gè)GPU # 使用torch.load()函數(shù)和map_location參數(shù)為一個(gè)字典,將原本在GPU 1上的張量映射到GPU 0上 tensors_mapped = torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
案例6:從io.BytesIO對(duì)象加載張量
# 案例6:從io.BytesIO對(duì)象加載張量 # 打開tensor.pt文件并讀取內(nèi)容到BytesIO緩沖區(qū) with open('tensor.pt', 'rb') as f: buffer = io.BytesIO(f.read()) # 使用torch.load()函數(shù)從BytesIO緩沖區(qū)加載張量 tensors_from_buffer = torch.load(buffer)
案例7:使用ASCII編碼加載模塊
# 案例7:使用ASCII編碼加載模塊 # 使用torch.load()函數(shù)和encoding參數(shù)為'ascii',加載module.pt文件中的模塊(如神經(jīng)網(wǎng)絡(luò)模型) model = torch.load('module.pt', encoding='ascii')
這些案例代碼和注釋展示了如何使用torch.load()
函數(shù)的不同map_location
參數(shù)和編碼設(shè)置來加載張量和模型。這些設(shè)置對(duì)于控制數(shù)據(jù)加載的位置和格式非常重要,特別是在跨設(shè)備或跨平臺(tái)加載數(shù)據(jù)時(shí)。
參考文檔
[1] PyTorch官方文檔
到此這篇關(guān)于python中torch.load中的map_location參數(shù)使用的文章就介紹到這了,更多相關(guān)python torch.load map_location參數(shù)內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
分析總結(jié)Python數(shù)據(jù)化運(yùn)營(yíng)KMeans聚類
本文主要以 Python 使用 Keans 進(jìn)行聚類分析的簡(jiǎn)單舉例應(yīng)用介紹聚類分析,它是探索性數(shù)據(jù)挖掘的主要任務(wù),也是統(tǒng)計(jì)數(shù)據(jù)分析的常用技術(shù),用于許多領(lǐng)域2021-08-08Python簡(jiǎn)單獲取自身外網(wǎng)IP的方法
這篇文章主要介紹了Python簡(jiǎn)單獲取自身外網(wǎng)IP的方法,涉及Python基于第三方平臺(tái)獲取本機(jī)外網(wǎng)IP的操作技巧,需要的朋友可以參考下2016-09-09python應(yīng)用程序在windows下不出現(xiàn)cmd窗口的辦法
這篇文章主要介紹了python應(yīng)用程序在windows下不出現(xiàn)cmd窗口的辦法,適用于python寫的GTK程序并用py2exe編譯的情況下,需要的朋友可以參考下2014-05-05Python根據(jù)文件名批量轉(zhuǎn)移圖片的方法
今天小編就為大家分享一篇Python根據(jù)文件名批量轉(zhuǎn)移圖片的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2018-10-10安裝python-docx后,無法在pycharm中導(dǎo)入的解決方案
這篇文章主要介紹了安裝python-docx后,無法在pycharm中導(dǎo)入的解決方案,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2021-03-03Python判斷和循環(huán)語句的分析與應(yīng)用
判斷語句是用來篩選條件,過濾條件的。循環(huán)語句是用來解決重復(fù)性代碼的問題,提高工作效率。今天的知識(shí)點(diǎn)不多,耐心看完吧2022-07-07python實(shí)現(xiàn)爬蟲統(tǒng)計(jì)學(xué)校BBS男女比例之多線程爬蟲(二)
這篇文章主要介紹了python實(shí)現(xiàn)爬蟲統(tǒng)計(jì)學(xué)校BBS男女比例之多線程爬蟲,感興趣的小伙伴們可以參考一下2015-12-12在Flask使用TensorFlow的幾個(gè)常見錯(cuò)誤及解決
這篇文章主要介紹了在Flask使用TensorFlow的幾個(gè)常見錯(cuò)誤及解決,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2024-01-01Python中flatten( )函數(shù)及函數(shù)用法詳解
flatten是numpy.ndarray.flatten的一個(gè)函數(shù),即返回一個(gè)一維數(shù)組。這篇文章主要介紹了Python中flatten( )函數(shù),需要的朋友可以參考下2018-11-11