pytorch更新tensor中指定index位置的值scatter_add_問(wèn)題
使用scatter_add_更新tensor張量中指定index位置的值
例子
import torch a = torch.zeros((3, 4)) print(a) """ tensor([[0., 0., 0., 0.], ? ? ? ? [0., 0., 0., 0.], ? ? ? ? [0., 0., 0., 0.]]) """ b = torch.rand((2, 4)) print(b) """ tensor([[0.6293, 0.3050, 0.9608, 0.5577], ? ? ? ? [0.3469, 0.1025, 0.8185, 0.5085]]) """ # 將a中第0行和第2行的值修改為b a = a.scatter_add_(0, torch.tensor([[0, 0, 0], [2, 2, 2]]), b) print(a) """ tensor([[0.6293, 0.3050, 0.9608, 0.0000], ? ? ? ? [0.0000, 0.0000, 0.0000, 0.0000], ? ? ? ? [0.3469, 0.1025, 0.8185, 0.0000]]) """
torch_scatter.scatter_add、Tensor.scatter_add_ 、Tensor.scatter_、Tensor.scatter_add 、Tensor.scatter
torch_scatter.scatter_add
官方文檔:
torch_scatter.scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0)
Sums all values from the src tensor into out at the indices specified in the index tensor along a given axis dim. For each value in src, its output index is specified by its index in input for dimensions outside of dim and by the corresponding value in index for dimension dim. If multiple indices reference the same location, their contributions add.
看著挺疑惑的,自己試了一把:
src = torch.tensor([10, 20, 30, 40, 1, 2, 2, 2, 9]) index = torch.tensor([2, 1, 1, 1, 1, 1, 1, 1, 0]) out=scatter_add(src, index) print(out)
輸出結(jié)果為:tensor([ 9, 97, 10])
說(shuō)白了就是:index就是out的下標(biāo),將src所有和此下標(biāo)對(duì)應(yīng)的值加起來(lái),就是out的值。
例如上面的例子:index中等于1的,對(duì)應(yīng)于src是【20, 30, 40, 1, 2, 2, 2】,將這些值加起來(lái)是97,于是,out[1]=97
同理:out[0]=src[8]=9 out[2]=src[0]=10
另一個(gè)函數(shù)
Tensor.scatter_add_
官方文檔:
scatter_add_(self, dim, index, other):
For a 3-D tensor, :attr:`self` is updated as:: ? ? self[index[i][j][k]][j][k] += other[i][j][k] ?# if dim == 0 ? ? self[i][index[i][j][k]][k] += other[i][j][k] ?# if dim == 1 ? ? self[i][j][index[i][j][k]] += other[i][j][k] ?# if dim == 2
官方例子:
? ? ? ? ? ? >>> x = torch.rand(2, 5) ? ? ? ? ? ? >>> x ? ? ? ? ? ? tensor([[0.7404, 0.0427, 0.6480, 0.3806, 0.8328], ? ? ? ? ? ? ? ? ? ? [0.7953, 0.2009, 0.9154, 0.6782, 0.9620]]) ? ? ? ? ? ? >>> torch.ones(3, 5).scatter_add_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x) ? ? ? ? ? ? tensor([[1.7404, 1.2009, 1.9154, 1.3806, 1.8328], ? ? ? ? ? ? ? ? ? ? [1.0000, 1.0427, 1.0000, 1.6782, 1.0000], ? ? ? ? ? ? ? ? ? ? [1.7953, 1.0000, 1.6480, 1.0000, 1.9620]])
以index來(lái)遍歷,就比較容易看懂。self中并不是每個(gè)值都要改變的。
以上面為例
index[0][0]=0 ?self[index[0][0]][0]=self[0][0] =self[0][0]+ x[0][0]=1 +0.7404=1.7404 index[0][1]=1 ?self[index[0][1]][1]=self[1][1] =self[1][1]+ x[0][1] =1 +0.0427 =1.0427
。。。
以此類推,將index遍歷一遍,就得到最終的結(jié)果
所以,self中需要改變的是index中列出的坐標(biāo),其他的是不動(dòng)的。
Tensor.scatter_
scatter_(self, dim, index, src)
和Tensor.scatter_add_的區(qū)別是直接將src中的值填充到self中,不做相加
例子:
>>> x = torch.rand(2, 5) ? ? ? ? ? ? >>> x ? ? ? ? ? ? tensor([[ 0.3992, ?0.2908, ?0.9044, ?0.4850, ?0.6004], ? ? ? ? ? ? ? ? ? ? [ 0.5735, ?0.9006, ?0.6797, ?0.4152, ?0.1732]]) ? ? ? ? ? ? >>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x) ? ? ? ? ? ? tensor([[ 0.3992, ?0.9006, ?0.6797, ?0.4850, ?0.6004], ? ? ? ? ? ? ? ? ? ? [ 0.0000, ?0.2908, ?0.0000, ?0.4152, ?0.0000], ? ? ? ? ? ? ? ? ? ? [ 0.5735, ?0.0000, ?0.9044, ?0.0000, ?0.1732]]) ? ? ? ? ? ? >>> z = torch.zeros(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23) ? ? ? ? ? ? >>> z ? ? ? ? ? ? tensor([[ 0.0000, ?0.0000, ?1.2300, ?0.0000], ? ? ? ? ? ? ? ? ? ? [ 0.0000, ?0.0000, ?0.0000, ?1.2300]])
另外,pytorch中還有
scatter_add和scatter函數(shù),和上面兩個(gè)函數(shù)不同的是這個(gè)兩個(gè)函數(shù)不改變self,會(huì)返回結(jié)果值;上面兩個(gè)函數(shù)(scatter_add_和scatter_)是直接在原數(shù)據(jù)self上進(jìn)行修改
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
利用Pyhton中的requests包進(jìn)行網(wǎng)頁(yè)訪問(wèn)測(cè)試的方法
今天小編就為大家分享一篇利用Pyhton中的requests包進(jìn)行網(wǎng)頁(yè)訪問(wèn)測(cè)試的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-12-12python學(xué)習(xí)筆記之調(diào)用eval函數(shù)出現(xiàn)invalid syntax錯(cuò)誤問(wèn)題
python是一門多種用途的編程語(yǔ)言,時(shí)常扮演腳本語(yǔ)言的角色。一般來(lái)說(shuō),python可以定義為面向?qū)ο蟮哪_本語(yǔ)言,這個(gè)定義把面向?qū)ο蟮闹С趾兔嫦蚰_本語(yǔ)言的角色融合在一起。很多時(shí)候,人們常常喜歡用“腳本”和不是語(yǔ)言來(lái)描述python的代碼文件。2015-10-10PyTorch?使用torchvision進(jìn)行圖片數(shù)據(jù)增廣
本文主要介紹了PyTorch?使用torchvision進(jìn)行圖片數(shù)據(jù)增廣,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2022-05-05pytorch簡(jiǎn)單實(shí)現(xiàn)神經(jīng)網(wǎng)絡(luò)功能
這篇文章主要介紹了pytorch簡(jiǎn)單實(shí)現(xiàn)神經(jīng)網(wǎng)絡(luò),本文通過(guò)實(shí)例代碼給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2022-09-09