關于torch.scatter與torch_scatter庫的使用整理
最近在做圖結構相關的算法,scatter能把鄰接矩陣里的信息修改,或者把鄰居分組算個sum或者reduce,挺方便的,簡單整理一下。
torch.scatter 與 tensor._scatter
Pytorch自帶的函數(shù),用來將作為 src 的tensor根據(jù) index 的描述填充到 input 中,
形式如下:
ouput = torch.scatter(input, dim, index, src) # 或者是 input.scatter_(dim, index, src)
兩個方法的功能是相同的,而帶下劃線的 _scatter 方法是將原tensor input 直接修改了,不帶的則會返回一個新的tensor output , input 不變。
其中 dim 決定 index 對應值是沿著哪個維度進行修改。而 src 為數(shù)據(jù)來源,當其為tensor張量時,shape要和index相同,這樣index中每個元素都能對應 src 中對應位置的信息。
理解 scatter 方法主要是要理解 index 實現(xiàn)的 src 和 input 之間的位置對應關系,舉個例子:
dim = 0 index = torch.tensor( [[0, 2, 2], [2, 1, 0]] )
dim 為0時,遵循的映射原則為: input[index[i][j]][j] = src[i][j] .
也就是說,將位置 (i, j) 中 dim 對應的位置改為 index[i][j] 的值。
如位置(1,0),index[1][0]為2,則映射后的位置為(2,0),意味著 input 中(2,0)的位置被更改為 src 中(1,0)位置的值。
我個人形象理解是這些值會沿著dim方向滑動,上面例子中src[1][0]位置的值滑到2,成為input中的新值,這樣理解起來更形象一點。
基本理解了上面這個例子,多維情況和不同dim的情況都可以類推了。
需要注意:src和input的dtype需要相同,不然會報
Expected self.dtype to be equal to src.dtype
不一樣就先轉(zhuǎn)換再使用。
t = torch.arange(6).view(2, 3) t = t.to(torch.float32) print(t) output = torch.scatter(torch.zeros((3, 3)), 0, torch.tensor([[0, 2, 2], [2, 1, 0]]), t) print(torch.zeros((3, 3)).scatter_(0, torch.tensor([[0, 2, 2], [2, 1, 0]]), t))
輸出:
tensor([[0., 1., 2.],
[3., 4., 5.]])
tensor([[0., 0., 5.],
[0., 4., 0.],
[3., 1., 2.]])
torch_scatter庫
這個第三方庫對矩陣的分組處理這個概念做了更進一步的封裝,通過index來指定分組信息,將元素分組后進行對應處理,
最基礎的scatter方法形式如下:
torch_scatter.scatter(src, index, dim, out, dim_size, reduce)
src: 數(shù)據(jù)源index:分組序列dim:分組遵循的維度out:輸出的tensor,可以不指定直接讓函數(shù)輸出dim_size:out不指定的時候,將輸出shape變?yōu)樵撝荡笮?;dim_size也不指定,就根據(jù)計算結果來reduce:分組的操作,包括sum,mul,mean,min和max操作
這個方法理解關鍵在 index 的分組方法,
舉個例子:
dim = 1 index = torch.tensor([[0, 1, 1]])
torch_scatter.scatter 對 index 的順序是沒有特定規(guī)定的,相同數(shù)字對應的元素即為一組。
比如例子中,維度1上的第0個元素為一組,第1和2元素為另一組。
這樣,按照分組進行reduce定義的計算即可獲得輸出。如:
t = torch.arange(12).view(4, 3) print(t) t_s = torch_scatter.scatter(t, torch.tensor([[0, 1, 1]]), dim=1, reduce='sum') print(t_s)
輸出:
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
tensor([[ 0, 3],
[ 3, 9],
[ 6, 15]])
可以看出,每行的后兩個元素求了和,與index定義相同。
要注意的是,index的 shape[0] 為1時,會自動對dim對應的維度上每一層進行相同的分組處理,如上例所示,index大小為(1, 3),即對src的三行數(shù)據(jù)都進行了分組處理。
而另一種分組方式,如需要每行分組不同,則需要index的shape和src的shape相同,如下例:
t = torch.arange(12).view(4, 3) print(t) t_s = torch_scatter.scatter(t, torch.tensor([[0, 1, 1], [1, 1, 0], [0, 1, 1], [1, 1, 0]]), dim=1, reduce='sum') print(t_s)
輸出:
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
tensor([[ 0, 3],
[ 5, 7],
[ 6, 15]])
shape不相同時,則會報錯提示:
RuntimeError: The expanded size of the tensor (3) must match the existing size (2) at non-singleton dimension 0 .
同時,該庫還給出了另外兩種方法,分別為 torch_scatter.segment_coo 和 torch_scatter.segment_csr .
torch_scatter.segment_coo
torch_scatter.segment_coo 和 scatter 的功能差不多,但它只支持index的shape[0]為1的狀態(tài),即每一行都為相同的分組方式。
同時,index中數(shù)值為順序排列,以提高計算速度。
torch_scatter.segment_csr
torch_scatter.segment_csr 的index格式不太相同,是一種區(qū)間格式,如[0, 2, 5],表示0,1為一組,2,3,4為一組,即取數(shù)值間的左閉右開區(qū)間。
這個方法是計算速度最快的。
官方文檔地址
torch_scatter庫doc
https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html
torch.scatter文檔
總結
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
- PyTorch中torch.load()的用法和應用
- python中torch.load中的map_location參數(shù)使用
- Pytorch中的torch.nn.Linear()方法用法解讀
- Pytorch中的torch.where函數(shù)使用
- python中的List sort()與torch.sort()
- PyTorch函數(shù)torch.cat與torch.stac的區(qū)別小結
- pytorch.range()和pytorch.arange()的區(qū)別及說明
- 使用with torch.no_grad():顯著減少測試時顯存占用
- PyTorch中torch.save()的用法和應用小結
相關文章
python數(shù)據(jù)結構之二叉樹的統(tǒng)計與轉(zhuǎn)換實例
這篇文章主要介紹了python數(shù)據(jù)結構之二叉樹的統(tǒng)計與轉(zhuǎn)換實例,例如統(tǒng)計二叉樹的葉子、分支節(jié)點,以及二叉樹的左右兩樹互換等,需要的朋友可以參考下2014-04-04
Python+OpenCV圖片去水印的多種方案實現(xiàn)
這篇文章主要為大家總結了Python結合OpenCV的幾種常見的水印去除方式,簡單圖片去水印效果良好,有需要的小伙伴可以跟隨小編一起了解下2025-02-02
python tensorflow學習之識別單張圖片的實現(xiàn)的示例
本篇文章主要介紹了python tensorflow學習之識別單張圖片的實現(xiàn)的示例,小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧2018-02-02
python實現(xiàn)AHP算法的方法實例(層次分析法)
這篇文章主要給大家介紹了關于python實現(xiàn)AHP算法(層次分析法)的相關資料,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2020-09-09
pandas數(shù)據(jù)清洗(缺失值和重復值的處理)
這篇文章主要介紹了pandas數(shù)據(jù)清洗(缺失值和重復值的處理),pandas對大數(shù)據(jù)有很多便捷的清洗用法,尤其針對缺失值和重復值,詳細介紹感興趣的小伙伴可以參考下面文章內(nèi)容2022-08-08

