pytorch之torch_scatter.scatter_max()用法
torch_scatter.scatter_max()
torch_scatter.scatter_max(src, index, dim=-1, out=None, dim_size=None, fill_value=None)
- 根據(jù)index將src分組,求每一組中的最大值輸出到out
- dim是維度
from torch_scatter import scatter_max src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]) index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) out = src.new_zeros((2, 6)) '''src根據(jù)index進(jìn)行分組''' out, argmax = scatter_max(src, index, out=out) print(out) print(argmax)
輸出
tensor([[0., 0., 4., 3., 2., 0.],
[2., 4., 3., 0., 0., 0.]])
tensor([[-1, -1, 3, 4, 0, 1],
[ 1, 4, 3, -1, -1, -1]])
解釋
torch_scatter.scatter()使用
1. 參數(shù)
具體來講,scatter函數(shù)的作用就是將index中相同索引對應(yīng)位置的src元素進(jìn)行某種方式的操作,例如 sum
、 mean
等,然后將這些操作結(jié)果按照索引順序進(jìn)行拼接。
下面我用具體的例子來進(jìn)行講解。
2. 示例
2.1 簡單示例
首先初始化src和index:
src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) # (3, 3) index = torch.tensor([0, 0, 1], dtype=torch.int64)
接著使用scatter函數(shù):
out = scatter(src, index, dim=0, reduce='mean')
我們觀察 index=[0, 0, 1]
,第0個(gè)位置和第1個(gè)位置都為0,第2個(gè)位置為1。也就是說,我們需要將src中第0個(gè)元素和第1個(gè)元素求平均變成一個(gè)元素,然后第2個(gè)元素求mean也就是本身為一個(gè)元素。如果 index=[1, 0, 0]
,則意味著我們需要將src中第1個(gè)元素和第2個(gè)元素求平均變成一個(gè)元素,而第0個(gè)元素保持不變。
那么src中第幾個(gè)元素到底是如何定義的呢?這就需要用到 dim
參數(shù)了。
dim=0
意味著我們需要對src的維度0進(jìn)行操作:
tensor([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]])
即src中第0個(gè)元素為 [1, 2, 3]
,第1個(gè)元素為 [4, 5, 6]
,第2個(gè)元素為 [7, 8, 9]
。
而如果 dim=1
,則第0個(gè)元素為 [1, 4, 7]
,第1個(gè)元素為 [2, 5, 8]
,第2個(gè)元素為 [3, 6, 9]
。
因此,如果有以下代碼:
src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) # (3, 3) index = torch.tensor([0, 0, 1], dtype=torch.int64) out = scatter(src, index, dim=0, reduce='mean')
那么我們就應(yīng)該將src中的第0個(gè)元素為 [1, 2, 3]
和第1個(gè)元素為 [4, 5, 6]
求平均為 [2.5, 3.5, 4.5]
,然后第2個(gè)元素 [7, 8, 9]
保持不變,即:
tensor([[2.5000, 3.5000, 4.5000], [7.0000, 8.0000, 9.0000]])
2.2 順序問題
上面的例子中 index=[0, 0, 1]
,最后結(jié)果是將src中第0個(gè)元素和第1個(gè)元素求平均放到了位置0,然后src中第2個(gè)元素保持不變放到了位置1。
如果 index=[1, 1, 0]
,結(jié)果為:
tensor([[7.0000, 8.0000, 9.0000], [2.5000, 3.5000, 4.5000]])
可以發(fā)現(xiàn),上述結(jié)果是將src中第2個(gè)元素 [7, 8, 9]
保持不變放到了位置0,然后將src中第0個(gè)元素 [1, 2, 3]
和第1個(gè)元素 [4, 5, 6]
求平均保持不變放到了位置1。
也就是說,無論index怎么變化,都是優(yōu)先將index中0對應(yīng)位置的操作結(jié)果進(jìn)行放置。
2.3 維度問題
如果src的維度為(4, 3),而我們需要對 dim=0
操作,也就是一共有四個(gè)元素,那么index的長度應(yīng)該為4,即以下操作是不合法的:
src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]) # (4, 3) index = torch.tensor([1, 1, 0], dtype=torch.int64) out = scatter(src, index, dim=0, reduce='mean') print(out)
報(bào)錯(cuò)為:
RuntimeError: The expanded size of the tensor (4) must match the existing size (3) at non-singleton dimension 0. Target sizes: [4, 3]. Tensor sizes: [3, 1]
正確做法應(yīng)該是:
src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]) # (4, 3) index = torch.tensor([1, 1, 0, 2], dtype=torch.int64) out = scatter(src, index, dim=0, reduce='mean') print(out)
輸出為:
tensor([[ 7.0000, 8.0000, 9.0000],
[ 2.5000, 3.5000, 4.5000],
[10.0000, 11.0000, 12.0000]])
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
基于Python+Matplotlib實(shí)現(xiàn)直方圖的繪制
Matplotlib是Python的繪圖庫,它能讓使用者很輕松地將數(shù)據(jù)圖形化,并且提供多樣化的輸出格式。本文將為大家介紹如何用matplotlib繪制直方圖,感興趣的朋友可以學(xué)習(xí)一下2022-04-04利用Python實(shí)現(xiàn)數(shù)值積分的方法
這篇文章主要介紹了利用Python實(shí)現(xiàn)數(shù)值積分。本文主要用于對比使用Python來實(shí)現(xiàn)數(shù)學(xué)中積分的幾種計(jì)算方式,并和真值進(jìn)行對比,加深大家對積分運(yùn)算實(shí)現(xiàn)方式的理解2022-02-02Python多線程中阻塞(join)與鎖(Lock)使用誤區(qū)解析
這篇文章主要為大家詳細(xì)介紹了Python多線程中阻塞join與鎖Lock的使用誤區(qū),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-04-04pycharm 實(shí)現(xiàn)光標(biāo)快速移動(dòng)到括號外或行尾的操作
這篇文章主要介紹了pycharm 實(shí)現(xiàn)光標(biāo)快速移動(dòng)到括號外或行尾的操作,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2021-02-02python?實(shí)現(xiàn)?redis?數(shù)據(jù)庫的操作
這篇文章主要介紹了python?包?redis?數(shù)據(jù)庫的操作教程,redis?是一個(gè)?Key-Value?數(shù)據(jù)庫,下文基于python的相關(guān)資料展開對redis?數(shù)據(jù)庫操作的詳細(xì)介紹,需要的小伙伴可以參考一下2022-04-04python對一個(gè)數(shù)向上取整的實(shí)例方法
在本篇文章中小編給大家整理了關(guān)于python對一個(gè)數(shù)向上取整的實(shí)例方法,需要的朋友們可以跟著學(xué)習(xí)下。2020-06-06