PyTorch中torch.tensor()和torch.to_tensor()的區(qū)別
前言
在跑模型的時候,遇到如下報錯
UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
網(wǎng)上查了一下,發(fā)現(xiàn)將 torch.tensor() 改寫成 torch.as_tensor() 就可以避免報錯了。
# 如下寫法報錯 feature = torch.tensor(image, dtype=torch.float32) # 改為 feature = torch.as_tensor(image, dtype=torch.float32)
然后就又仔細研究了下 torch.as_tensor() 和 torch.tensor() 的區(qū)別,在此記錄。
1、torch.as_tensor()
new_data = torch.as_tensor(data, dtype=None,device=None)->Tensor
作用:生成一個新的 tensor, 這個新生成的tensor 會根據(jù)原數(shù)據(jù)的實際情況,來決定是進行淺拷貝,還是深拷貝。當(dāng)然,會優(yōu)先淺拷貝,淺拷貝會共享內(nèi)存,并共享 autograd 歷史記錄。
情況一:數(shù)據(jù)類型相同 且 device相同,會進行淺拷貝,共享內(nèi)存
import numpy import torch a = numpy.array([1, 2, 3]) t = torch.as_tensor(a) t[0] = -1 print(a) # [-1 2 3] print(a.dtype) # int64 print(t) # tensor([-1, 2, 3]) print(t.dtype) # torch.int64
import numpy
import torch
a = torch.tensor([1, 2, 3], device=torch.device('cuda'))
t = torch.as_tensor(a)
t[0] = -1
print(a) # tensor([-1, 2, 3], device='cuda:0')
print(t) # tensor([-1, 2, 3], device='cuda:0')
情況二: 數(shù)據(jù)類型相同,但是device不同,深拷貝,不再共享內(nèi)存
import numpy
import torch
import numpy
a = numpy.array([1, 2, 3])
t = torch.as_tensor(a, device=torch.device('cuda'))
t[0] = -1
print(a) # [1 2 3]
print(a.dtype) # int64
print(t) # tensor([-1, 2, 3], device='cuda:0')
print(t.dtype) # torch.int64
情況三:device相同,但數(shù)據(jù)類型不同,深拷貝,不再共享內(nèi)存
import numpy import torch a = numpy.array([1, 2, 3]) t = torch.as_tensor(a, dtype=torch.float32) t[0] = -1 print(a) # [1 2 3] print(a.dtype) # int64 print(t) # tensor([-1., 2., 3.]) print(t.dtype) # torch.float32
2、torch.tensor()
torch.tensor() 是深拷貝方式。
torch.tensor(data, dtype=None, device=None, requires_grad=False, pin_memory=False)
深拷貝:會拷貝 數(shù)據(jù)類型 和 device,不會記錄 autograd 歷史 (also known as a “leaf tensor” 葉子tensor)
重點是:
- 如果原數(shù)據(jù)的數(shù)據(jù)類型是:list, tuple, NumPy ndarray, scalar, and other types,不會 waring
- 如果原數(shù)據(jù)的數(shù)據(jù)類型是:tensor,使用 torch.tensor(data) 就會報waring
# 原數(shù)據(jù)類型是:tensor 會發(fā)出警告
import numpy
import torch
a = torch.tensor([1, 2, 3], device=torch.device('cuda'))
t = torch.tensor(a)
t[0] = -1
print(a)
print(t)
# 輸出:
# tensor([1, 2, 3], device='cuda:0')
# tensor([-1, 2, 3], device='cuda:0')
# /opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).# 原數(shù)據(jù)類型是:list, tuple, NumPy ndarray, scalar, and other types, 沒警告 import torch import numpy a = numpy.array([1, 2, 3]) t = torch.tensor(a) b = [1,2,3] t= torch.tensor(b) c = (1,2,3) t= torch.tensor(c)
結(jié)論就是:以后盡量用 torch.as_tensor() 吧
總結(jié)
到此這篇關(guān)于PyTorch中torch.tensor()和torch.to_tensor()區(qū)別的文章就介紹到這了,更多相關(guān)torch.tensor()和torch.to_tensor()區(qū)別內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python實現(xiàn)網(wǎng)絡(luò)自動化eNSP
這篇文章主要介紹了Python實現(xiàn)網(wǎng)絡(luò)自動化eNSP,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2021-05-05
Pyecharts 動態(tài)地圖 geo()和map()的安裝與用法詳解
這篇文章主要介紹了Pyecharts 動態(tài)地圖 geo()和map()的安裝與用法詳解,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-03-03

