pytorch更新tensor中指定index位置的值scatter_add_問題
使用scatter_add_更新tensor張量中指定index位置的值
例子
import torch a = torch.zeros((3, 4)) print(a) """ tensor([[0., 0., 0., 0.], ? ? ? ? [0., 0., 0., 0.], ? ? ? ? [0., 0., 0., 0.]]) """ b = torch.rand((2, 4)) print(b) """ tensor([[0.6293, 0.3050, 0.9608, 0.5577], ? ? ? ? [0.3469, 0.1025, 0.8185, 0.5085]]) """ # 將a中第0行和第2行的值修改為b a = a.scatter_add_(0, torch.tensor([[0, 0, 0], [2, 2, 2]]), b) print(a) """ tensor([[0.6293, 0.3050, 0.9608, 0.0000], ? ? ? ? [0.0000, 0.0000, 0.0000, 0.0000], ? ? ? ? [0.3469, 0.1025, 0.8185, 0.0000]]) """
torch_scatter.scatter_add、Tensor.scatter_add_ 、Tensor.scatter_、Tensor.scatter_add 、Tensor.scatter
torch_scatter.scatter_add
官方文檔:
torch_scatter.scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0)
Sums all values from the src tensor into out at the indices specified in the index tensor along a given axis dim. For each value in src, its output index is specified by its index in input for dimensions outside of dim and by the corresponding value in index for dimension dim. If multiple indices reference the same location, their contributions add.
看著挺疑惑的,自己試了一把:
src = torch.tensor([10, 20, 30, 40, 1, 2, 2, 2, 9]) index = torch.tensor([2, 1, 1, 1, 1, 1, 1, 1, 0]) out=scatter_add(src, index) print(out)
輸出結果為:tensor([ 9, 97, 10])
說白了就是:index就是out的下標,將src所有和此下標對應的值加起來,就是out的值。
例如上面的例子:index中等于1的,對應于src是【20, 30, 40, 1, 2, 2, 2】,將這些值加起來是97,于是,out[1]=97
同理:out[0]=src[8]=9 out[2]=src[0]=10
另一個函數(shù)
Tensor.scatter_add_
官方文檔:
scatter_add_(self, dim, index, other):
For a 3-D tensor, :attr:`self` is updated as:: ? ? self[index[i][j][k]][j][k] += other[i][j][k] ?# if dim == 0 ? ? self[i][index[i][j][k]][k] += other[i][j][k] ?# if dim == 1 ? ? self[i][j][index[i][j][k]] += other[i][j][k] ?# if dim == 2
官方例子:
? ? ? ? ? ? >>> x = torch.rand(2, 5) ? ? ? ? ? ? >>> x ? ? ? ? ? ? tensor([[0.7404, 0.0427, 0.6480, 0.3806, 0.8328], ? ? ? ? ? ? ? ? ? ? [0.7953, 0.2009, 0.9154, 0.6782, 0.9620]]) ? ? ? ? ? ? >>> torch.ones(3, 5).scatter_add_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x) ? ? ? ? ? ? tensor([[1.7404, 1.2009, 1.9154, 1.3806, 1.8328], ? ? ? ? ? ? ? ? ? ? [1.0000, 1.0427, 1.0000, 1.6782, 1.0000], ? ? ? ? ? ? ? ? ? ? [1.7953, 1.0000, 1.6480, 1.0000, 1.9620]])
以index來遍歷,就比較容易看懂。self中并不是每個值都要改變的。
以上面為例
index[0][0]=0 ?self[index[0][0]][0]=self[0][0] =self[0][0]+ x[0][0]=1 +0.7404=1.7404 index[0][1]=1 ?self[index[0][1]][1]=self[1][1] =self[1][1]+ x[0][1] =1 +0.0427 =1.0427
。。。
以此類推,將index遍歷一遍,就得到最終的結果
所以,self中需要改變的是index中列出的坐標,其他的是不動的。
Tensor.scatter_
scatter_(self, dim, index, src)
和Tensor.scatter_add_的區(qū)別是直接將src中的值填充到self中,不做相加
例子:
>>> x = torch.rand(2, 5) ? ? ? ? ? ? >>> x ? ? ? ? ? ? tensor([[ 0.3992, ?0.2908, ?0.9044, ?0.4850, ?0.6004], ? ? ? ? ? ? ? ? ? ? [ 0.5735, ?0.9006, ?0.6797, ?0.4152, ?0.1732]]) ? ? ? ? ? ? >>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x) ? ? ? ? ? ? tensor([[ 0.3992, ?0.9006, ?0.6797, ?0.4850, ?0.6004], ? ? ? ? ? ? ? ? ? ? [ 0.0000, ?0.2908, ?0.0000, ?0.4152, ?0.0000], ? ? ? ? ? ? ? ? ? ? [ 0.5735, ?0.0000, ?0.9044, ?0.0000, ?0.1732]]) ? ? ? ? ? ? >>> z = torch.zeros(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23) ? ? ? ? ? ? >>> z ? ? ? ? ? ? tensor([[ 0.0000, ?0.0000, ?1.2300, ?0.0000], ? ? ? ? ? ? ? ? ? ? [ 0.0000, ?0.0000, ?0.0000, ?1.2300]])
另外,pytorch中還有
scatter_add和scatter函數(shù),和上面兩個函數(shù)不同的是這個兩個函數(shù)不改變self,會返回結果值;上面兩個函數(shù)(scatter_add_和scatter_)是直接在原數(shù)據(jù)self上進行修改
總結
以上為個人經驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
利用Pyhton中的requests包進行網頁訪問測試的方法
今天小編就為大家分享一篇利用Pyhton中的requests包進行網頁訪問測試的方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-12-12python學習筆記之調用eval函數(shù)出現(xiàn)invalid syntax錯誤問題
python是一門多種用途的編程語言,時常扮演腳本語言的角色。一般來說,python可以定義為面向對象的腳本語言,這個定義把面向對象的支持和面向腳本語言的角色融合在一起。很多時候,人們常常喜歡用“腳本”和不是語言來描述python的代碼文件。2015-10-10PyTorch?使用torchvision進行圖片數(shù)據(jù)增廣
本文主要介紹了PyTorch?使用torchvision進行圖片數(shù)據(jù)增廣,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2022-05-05