PyTorch中torch.tensor()和torch.to_tensor()的區(qū)別
前言
在跑模型的時候,遇到如下報錯
UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
網(wǎng)上查了一下,發(fā)現(xiàn)將 torch.tensor()
改寫成 torch.as_tensor()
就可以避免報錯了。
# 如下寫法報錯 feature = torch.tensor(image, dtype=torch.float32) # 改為 feature = torch.as_tensor(image, dtype=torch.float32)
然后就又仔細(xì)研究了下 torch.as_tensor()
和 torch.tensor()
的區(qū)別,在此記錄。
1、torch.as_tensor()
new_data = torch.as_tensor(data, dtype=None,device=None)->Tensor
作用:生成一個新的 tensor, 這個新生成的tensor 會根據(jù)原數(shù)據(jù)的實際情況,來決定是進(jìn)行淺拷貝,還是深拷貝。當(dāng)然,會優(yōu)先淺拷貝,淺拷貝會共享內(nèi)存,并共享 autograd 歷史記錄。
情況一:數(shù)據(jù)類型相同 且 device相同,會進(jìn)行淺拷貝,共享內(nèi)存
import numpy import torch a = numpy.array([1, 2, 3]) t = torch.as_tensor(a) t[0] = -1 print(a) # [-1 2 3] print(a.dtype) # int64 print(t) # tensor([-1, 2, 3]) print(t.dtype) # torch.int64
import numpy import torch a = torch.tensor([1, 2, 3], device=torch.device('cuda')) t = torch.as_tensor(a) t[0] = -1 print(a) # tensor([-1, 2, 3], device='cuda:0') print(t) # tensor([-1, 2, 3], device='cuda:0')
情況二: 數(shù)據(jù)類型相同,但是device不同,深拷貝,不再共享內(nèi)存
import numpy import torch import numpy a = numpy.array([1, 2, 3]) t = torch.as_tensor(a, device=torch.device('cuda')) t[0] = -1 print(a) # [1 2 3] print(a.dtype) # int64 print(t) # tensor([-1, 2, 3], device='cuda:0') print(t.dtype) # torch.int64
情況三:device相同,但數(shù)據(jù)類型不同,深拷貝,不再共享內(nèi)存
import numpy import torch a = numpy.array([1, 2, 3]) t = torch.as_tensor(a, dtype=torch.float32) t[0] = -1 print(a) # [1 2 3] print(a.dtype) # int64 print(t) # tensor([-1., 2., 3.]) print(t.dtype) # torch.float32
2、torch.tensor()
torch.tensor()
是深拷貝方式。
torch.tensor(data, dtype=None, device=None, requires_grad=False, pin_memory=False)
深拷貝:會拷貝 數(shù)據(jù)類型 和 device,不會記錄 autograd 歷史 (also known as a “leaf tensor” 葉子tensor)
重點(diǎn)是:
- 如果原數(shù)據(jù)的數(shù)據(jù)類型是:list, tuple, NumPy ndarray, scalar, and other types,不會 waring
- 如果原數(shù)據(jù)的數(shù)據(jù)類型是:tensor,使用 torch.tensor(data) 就會報waring
# 原數(shù)據(jù)類型是:tensor 會發(fā)出警告 import numpy import torch a = torch.tensor([1, 2, 3], device=torch.device('cuda')) t = torch.tensor(a) t[0] = -1 print(a) print(t) # 輸出: # tensor([1, 2, 3], device='cuda:0') # tensor([-1, 2, 3], device='cuda:0') # /opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
# 原數(shù)據(jù)類型是:list, tuple, NumPy ndarray, scalar, and other types, 沒警告 import torch import numpy a = numpy.array([1, 2, 3]) t = torch.tensor(a) b = [1,2,3] t= torch.tensor(b) c = (1,2,3) t= torch.tensor(c)
結(jié)論就是:以后盡量用 torch.as_tensor()
吧
總結(jié)
到此這篇關(guān)于PyTorch中torch.tensor()和torch.to_tensor()區(qū)別的文章就介紹到這了,更多相關(guān)torch.tensor()和torch.to_tensor()區(qū)別內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python生成任意范圍任意精度的隨機(jī)數(shù)方法
下面小編就為大家分享一篇Python生成任意范圍任意精度的隨機(jī)數(shù)方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-04-04Python實現(xiàn)網(wǎng)絡(luò)自動化eNSP
這篇文章主要介紹了Python實現(xiàn)網(wǎng)絡(luò)自動化eNSP,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2021-05-05Pyecharts 動態(tài)地圖 geo()和map()的安裝與用法詳解
這篇文章主要介紹了Pyecharts 動態(tài)地圖 geo()和map()的安裝與用法詳解,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-03-03