解決torch.to(device)是否賦值的坑
torch.to(device)是否賦值的坑
在我們用GPU跑程序時(shí),需要在程序中把變量和模型放到GPU里面。
有一些坑需要注意,本文用RNN模型實(shí)例
首先,定義device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
對(duì)于變量,需要進(jìn)行賦值操作才能真正轉(zhuǎn)到GPU上:
all_input_batch=all_input_batch.to(device)
對(duì)于模型,不需要進(jìn)行賦值:
model = TextRNN() model.to(device)
對(duì)模型進(jìn)行to(device),還有一種方法,就是在定義模型的時(shí)候全部對(duì)模型網(wǎng)絡(luò)參數(shù)to(device),這樣就可以不需要model.to(device)這句話。
class TextRNN(nn.Module): def __init__(self): super(TextRNN, self).__init__() #self.cnt = 0 self.C = nn.Embedding(n_class, embedding_dim=emb_size,device=device) self.rnn = nn.RNN(input_size=emb_size, hidden_size=n_hidden,device=device) self.W = nn.Linear(n_hidden, n_class, bias=False,device=device) self.b = nn.Parameter(torch.ones([n_class])).to(device) def forward(self, X): X = self.C(X) #print(X.is_cuda) X = X.transpose(0, 1) # X : [n_step, batch_size, embeding size] outputs, hidden = self.rnn(X) # outputs : [n_step, batch_size, num_directions(=1) * n_hidden] # hidden : [num_layers(=1) * num_directions(=1), batch_size, n_hidden] outputs = outputs[-1] # [batch_size, num_directions(=1) * n_hidden] model = self.W(outputs) + self.b # model : [batch_size, n_class] return model
pytorch中model=model.to(device)用法
這代表將模型加載到指定設(shè)備上。
其中,device=torch.device("cpu")
代表的使用cpu,而device=torch.device("cuda")
則代表的使用GPU。
當(dāng)我們指定了設(shè)備之后,就需要將模型加載到相應(yīng)設(shè)備中,此時(shí)需要使用model=model.to(device)
,將模型加載到相應(yīng)的設(shè)備中。
將由GPU保存的模型加載到CPU上
將torch.load()
函數(shù)中的map_location
參數(shù)設(shè)置為torch.device('cpu')
device = torch.device('cpu') model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH, map_location=device))
將由GPU保存的模型加載到GPU上。確保對(duì)輸入的tensors
調(diào)用input = input.to(device)
方法。
device = torch.device("cuda") model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.to(device)
將由CPU保存的模型加載到GPU上
確保對(duì)輸入的tensors
調(diào)用input = input.to(device)
方法。
map_location
是將模型加載到GPU上,model.to(torch.device('cuda'))
是將模型參數(shù)加載為CUDA的tensor。
最后保證使用.to(torch.device('cuda'))
方法將需要使用的參數(shù)放入CUDA。
device = torch.device("cuda") model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want model.to(device)
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
基于Python數(shù)據(jù)可視化利器Matplotlib,繪圖入門篇,Pyplot詳解
下面小編就為大家?guī)硪黄赑ython數(shù)據(jù)可視化利器Matplotlib,繪圖入門篇,Pyplot詳解。小編覺得挺不錯(cuò)的,現(xiàn)在就分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧2017-10-10使用python opencv對(duì)目錄下圖片進(jìn)行去重的方法
今天小編就為大家分享一篇使用python opencv對(duì)目錄下圖片進(jìn)行去重的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2019-01-01python中小數(shù)點(diǎn)后的位數(shù)問題
這篇文章主要介紹了python中小數(shù)點(diǎn)后的位數(shù)問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-03-03python list.sort()根據(jù)多個(gè)關(guān)鍵字排序的方法實(shí)現(xiàn)
Python list內(nèi)置sort()方法用來排序,也可以用python內(nèi)置的全局sorted()方法來對(duì)可迭代的序列排序生成新的序列,本文詳細(xì)的介紹了python list.sort()根據(jù)多個(gè)關(guān)鍵字排序,感興趣的可以了解一下2021-12-12python 微信好友特征數(shù)據(jù)分析及可視化
這篇文章主要介紹了python 微信好友特征數(shù)據(jù)分析及可視化,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-01-01Python三元運(yùn)算實(shí)現(xiàn)方法
這篇文章主要介紹了Python三元運(yùn)算實(shí)現(xiàn)方法,通過if else語句實(shí)現(xiàn)了三元運(yùn)算的功能,具有一定參考借鑒價(jià)值,需要的朋友可以參考下2015-01-01實(shí)例詳解Python中的numpy.abs和abs函數(shù)
Numpy是python中最有用的工具之一,它可以有效地處理大容量數(shù)據(jù),下面這篇文章主要給大家介紹了關(guān)于Python中numpy.abs和abs函數(shù)的相關(guān)資料,文中通過實(shí)例代碼介紹的非常詳細(xì),需要的朋友可以參考下2022-08-08