欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

解決torch.to(device)是否賦值的坑

 更新時(shí)間:2024年06月27日 14:45:39   作者:不會(huì)卷積  
這篇文章主要介紹了解決torch.to(device)是否賦值的坑,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教

torch.to(device)是否賦值的坑

在我們用GPU跑程序時(shí),需要在程序中把變量和模型放到GPU里面。

有一些坑需要注意,本文用RNN模型實(shí)例

首先,定義device

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

對(duì)于變量,需要進(jìn)行賦值操作才能真正轉(zhuǎn)到GPU上:

all_input_batch=all_input_batch.to(device)

對(duì)于模型,不需要進(jìn)行賦值:

 model = TextRNN()
 model.to(device)

對(duì)模型進(jìn)行to(device),還有一種方法,就是在定義模型的時(shí)候全部對(duì)模型網(wǎng)絡(luò)參數(shù)to(device),這樣就可以不需要model.to(device)這句話。

class TextRNN(nn.Module):

    def __init__(self):
        super(TextRNN, self).__init__()
        #self.cnt = 0
        self.C = nn.Embedding(n_class, embedding_dim=emb_size,device=device)
        self.rnn = nn.RNN(input_size=emb_size, hidden_size=n_hidden,device=device)
        self.W = nn.Linear(n_hidden, n_class, bias=False,device=device)
        self.b = nn.Parameter(torch.ones([n_class])).to(device)


    def forward(self, X):
        X = self.C(X)
        #print(X.is_cuda)
        X = X.transpose(0, 1) # X : [n_step, batch_size, embeding size]
        outputs, hidden = self.rnn(X)
        # outputs : [n_step, batch_size, num_directions(=1) * n_hidden]
        # hidden : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]
        outputs = outputs[-1] # [batch_size, num_directions(=1) * n_hidden]
        model = self.W(outputs) + self.b # model : [batch_size, n_class]
        return model

pytorch中model=model.to(device)用法

這代表將模型加載到指定設(shè)備上。

其中,device=torch.device("cpu")代表的使用cpu,而device=torch.device("cuda")則代表的使用GPU。

當(dāng)我們指定了設(shè)備之后,就需要將模型加載到相應(yīng)設(shè)備中,此時(shí)需要使用model=model.to(device),將模型加載到相應(yīng)的設(shè)備中。

將由GPU保存的模型加載到CPU上

torch.load()函數(shù)中的map_location參數(shù)設(shè)置為torch.device('cpu')

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

將由GPU保存的模型加載到GPU上。確保對(duì)輸入的tensors調(diào)用input = input.to(device)方法。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

將由CPU保存的模型加載到GPU上

確保對(duì)輸入的tensors調(diào)用input = input.to(device)方法。

map_location是將模型加載到GPU上,model.to(torch.device('cuda'))是將模型參數(shù)加載為CUDA的tensor。

最后保證使用.to(torch.device('cuda'))方法將需要使用的參數(shù)放入CUDA。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)

總結(jié)

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • 基于Python數(shù)據(jù)可視化利器Matplotlib,繪圖入門篇,Pyplot詳解

    基于Python數(shù)據(jù)可視化利器Matplotlib,繪圖入門篇,Pyplot詳解

    下面小編就為大家?guī)硪黄赑ython數(shù)據(jù)可視化利器Matplotlib,繪圖入門篇,Pyplot詳解。小編覺得挺不錯(cuò)的,現(xiàn)在就分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧
    2017-10-10
  • 使用python opencv對(duì)目錄下圖片進(jìn)行去重的方法

    使用python opencv對(duì)目錄下圖片進(jìn)行去重的方法

    今天小編就為大家分享一篇使用python opencv對(duì)目錄下圖片進(jìn)行去重的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧
    2019-01-01
  • Python GUI布局尺寸適配方法

    Python GUI布局尺寸適配方法

    今天小編就為大家分享一篇Python GUI布局尺寸適配方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧
    2018-10-10
  • python中小數(shù)點(diǎn)后的位數(shù)問題

    python中小數(shù)點(diǎn)后的位數(shù)問題

    這篇文章主要介紹了python中小數(shù)點(diǎn)后的位數(shù)問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教
    2023-03-03
  • python list.sort()根據(jù)多個(gè)關(guān)鍵字排序的方法實(shí)現(xiàn)

    python list.sort()根據(jù)多個(gè)關(guān)鍵字排序的方法實(shí)現(xiàn)

    Python list內(nèi)置sort()方法用來排序,也可以用python內(nèi)置的全局sorted()方法來對(duì)可迭代的序列排序生成新的序列,本文詳細(xì)的介紹了python list.sort()根據(jù)多個(gè)關(guān)鍵字排序,感興趣的可以了解一下
    2021-12-12
  • python實(shí)現(xiàn)簡易云音樂播放器

    python實(shí)現(xiàn)簡易云音樂播放器

    這篇文章主要介紹了python實(shí)現(xiàn)簡易云音樂播放器,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下
    2018-01-01
  • Python字典,函數(shù),全局變量代碼解析

    Python字典,函數(shù),全局變量代碼解析

    這篇文章主要介紹了Python字典,函數(shù),全局變量代碼解析,具有一定借鑒價(jià)值,需要的朋友可以參考下。
    2017-12-12
  • python 微信好友特征數(shù)據(jù)分析及可視化

    python 微信好友特征數(shù)據(jù)分析及可視化

    這篇文章主要介紹了python 微信好友特征數(shù)據(jù)分析及可視化,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2020-01-01
  • Python三元運(yùn)算實(shí)現(xiàn)方法

    Python三元運(yùn)算實(shí)現(xiàn)方法

    這篇文章主要介紹了Python三元運(yùn)算實(shí)現(xiàn)方法,通過if else語句實(shí)現(xiàn)了三元運(yùn)算的功能,具有一定參考借鑒價(jià)值,需要的朋友可以參考下
    2015-01-01
  • 實(shí)例詳解Python中的numpy.abs和abs函數(shù)

    實(shí)例詳解Python中的numpy.abs和abs函數(shù)

    Numpy是python中最有用的工具之一,它可以有效地處理大容量數(shù)據(jù),下面這篇文章主要給大家介紹了關(guān)于Python中numpy.abs和abs函數(shù)的相關(guān)資料,文中通過實(shí)例代碼介紹的非常詳細(xì),需要的朋友可以參考下
    2022-08-08

最新評(píng)論