關(guān)于torch.scatter與torch_scatter庫的使用整理
最近在做圖結(jié)構(gòu)相關(guān)的算法,scatter能把鄰接矩陣?yán)锏男畔⑿薷?,或者把鄰居分組算個sum或者reduce,挺方便的,簡單整理一下。
torch.scatter 與 tensor._scatter
Pytorch自帶的函數(shù),用來將作為 src
的tensor根據(jù) index
的描述填充到 input
中,
形式如下:
ouput = torch.scatter(input, dim, index, src) # 或者是 input.scatter_(dim, index, src)
兩個方法的功能是相同的,而帶下劃線的 _scatter
方法是將原tensor input
直接修改了,不帶的則會返回一個新的tensor output
, input
不變。
其中 dim
決定 index
對應(yīng)值是沿著哪個維度進(jìn)行修改。而 src
為數(shù)據(jù)來源,當(dāng)其為tensor張量時(shí),shape要和index相同,這樣index中每個元素都能對應(yīng) src
中對應(yīng)位置的信息。
理解 scatter
方法主要是要理解 index
實(shí)現(xiàn)的 src
和 input
之間的位置對應(yīng)關(guān)系,舉個例子:
dim = 0 index = torch.tensor( [[0, 2, 2], [2, 1, 0]] )
dim
為0時(shí),遵循的映射原則為: input[index[i][j]][j] = src[i][j]
.
也就是說,將位置 (i, j) 中 dim
對應(yīng)的位置改為 index[i][j] 的值。
如位置(1,0),index[1][0]為2,則映射后的位置為(2,0),意味著 input
中(2,0)的位置被更改為 src
中(1,0)位置的值。
我個人形象理解是這些值會沿著dim方向滑動,上面例子中src[1][0]位置的值滑到2,成為input中的新值,這樣理解起來更形象一點(diǎn)。
基本理解了上面這個例子,多維情況和不同dim的情況都可以類推了。
需要注意:src和input的dtype需要相同,不然會報(bào)
Expected self.dtype to be equal to src.dtype
不一樣就先轉(zhuǎn)換再使用。
t = torch.arange(6).view(2, 3) t = t.to(torch.float32) print(t) output = torch.scatter(torch.zeros((3, 3)), 0, torch.tensor([[0, 2, 2], [2, 1, 0]]), t) print(torch.zeros((3, 3)).scatter_(0, torch.tensor([[0, 2, 2], [2, 1, 0]]), t))
輸出:
tensor([[0., 1., 2.],
[3., 4., 5.]])
tensor([[0., 0., 5.],
[0., 4., 0.],
[3., 1., 2.]])
torch_scatter庫
這個第三方庫對矩陣的分組處理這個概念做了更進(jìn)一步的封裝,通過index來指定分組信息,將元素分組后進(jìn)行對應(yīng)處理,
最基礎(chǔ)的scatter方法形式如下:
torch_scatter.scatter(src, index, dim, out, dim_size, reduce)
src
: 數(shù)據(jù)源index
:分組序列dim
:分組遵循的維度out
:輸出的tensor,可以不指定直接讓函數(shù)輸出dim_size
:out不指定的時(shí)候,將輸出shape變?yōu)樵撝荡笮?;dim_size也不指定,就根據(jù)計(jì)算結(jié)果來reduce
:分組的操作,包括sum,mul,mean,min和max操作
這個方法理解關(guān)鍵在 index
的分組方法,
舉個例子:
dim = 1 index = torch.tensor([[0, 1, 1]])
torch_scatter.scatter
對 index
的順序是沒有特定規(guī)定的,相同數(shù)字對應(yīng)的元素即為一組。
比如例子中,維度1上的第0個元素為一組,第1和2元素為另一組。
這樣,按照分組進(jìn)行reduce定義的計(jì)算即可獲得輸出。如:
t = torch.arange(12).view(4, 3) print(t) t_s = torch_scatter.scatter(t, torch.tensor([[0, 1, 1]]), dim=1, reduce='sum') print(t_s)
輸出:
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
tensor([[ 0, 3],
[ 3, 9],
[ 6, 15]])
可以看出,每行的后兩個元素求了和,與index定義相同。
要注意的是,index的 shape[0]
為1時(shí),會自動對dim對應(yīng)的維度上每一層進(jìn)行相同的分組處理,如上例所示,index大小為(1, 3),即對src的三行數(shù)據(jù)都進(jìn)行了分組處理。
而另一種分組方式,如需要每行分組不同,則需要index的shape和src的shape相同,如下例:
t = torch.arange(12).view(4, 3) print(t) t_s = torch_scatter.scatter(t, torch.tensor([[0, 1, 1], [1, 1, 0], [0, 1, 1], [1, 1, 0]]), dim=1, reduce='sum') print(t_s)
輸出:
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
tensor([[ 0, 3],
[ 5, 7],
[ 6, 15]])
shape不相同時(shí),則會報(bào)錯提示:
RuntimeError: The expanded size of the tensor (3) must match the existing size (2) at non-singleton dimension 0 .
同時(shí),該庫還給出了另外兩種方法,分別為 torch_scatter.segment_coo
和 torch_scatter.segment_csr
.
torch_scatter.segment_coo
torch_scatter.segment_coo
和 scatter
的功能差不多,但它只支持index的shape[0]為1的狀態(tài),即每一行都為相同的分組方式。
同時(shí),index中數(shù)值為順序排列,以提高計(jì)算速度。
torch_scatter.segment_csr
torch_scatter.segment_csr
的index格式不太相同,是一種區(qū)間格式,如[0, 2, 5],表示0,1為一組,2,3,4為一組,即取數(shù)值間的左閉右開區(qū)間。
這個方法是計(jì)算速度最快的。
官方文檔地址
torch_scatter庫doc
https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html
torch.scatter文檔
總結(jié)
以上為個人經(jīng)驗(yàn),希望能給大家一個參考,也希望大家多多支持腳本之家。
- PyTorch中torch.load()的用法和應(yīng)用
- python中torch.load中的map_location參數(shù)使用
- Pytorch中的torch.nn.Linear()方法用法解讀
- Pytorch中的torch.where函數(shù)使用
- python中的List sort()與torch.sort()
- PyTorch函數(shù)torch.cat與torch.stac的區(qū)別小結(jié)
- pytorch.range()和pytorch.arange()的區(qū)別及說明
- 使用with torch.no_grad():顯著減少測試時(shí)顯存占用
- PyTorch中torch.save()的用法和應(yīng)用小結(jié)
相關(guān)文章
python網(wǎng)絡(luò)爬蟲精解之XPath的使用說明
XPath 是一門在 XML 文檔中查找信息的語言。XPath 可用來在 XML 文檔中對元素和屬性進(jìn)行遍歷。XPath 是 W3C XSLT 標(biāo)準(zhǔn)的主要元素,并且 XQuery 和 XPointer 都構(gòu)建于 XPath 表達(dá)之上2021-09-09django使用xlwt導(dǎo)出excel文件實(shí)例代碼
這篇文章主要介紹了django使用xlwt導(dǎo)出excel文件實(shí)例代碼,分享了相關(guān)代碼示例,小編覺得還是挺不錯的,具有一定借鑒價(jià)值,需要的朋友可以參考下2018-02-02python數(shù)據(jù)結(jié)構(gòu)之二叉樹的統(tǒng)計(jì)與轉(zhuǎn)換實(shí)例
這篇文章主要介紹了python數(shù)據(jù)結(jié)構(gòu)之二叉樹的統(tǒng)計(jì)與轉(zhuǎn)換實(shí)例,例如統(tǒng)計(jì)二叉樹的葉子、分支節(jié)點(diǎn),以及二叉樹的左右兩樹互換等,需要的朋友可以參考下2014-04-04Python+OpenCV圖片去水印的多種方案實(shí)現(xiàn)
這篇文章主要為大家總結(jié)了Python結(jié)合OpenCV的幾種常見的水印去除方式,簡單圖片去水印效果良好,有需要的小伙伴可以跟隨小編一起了解下2025-02-02python tensorflow學(xué)習(xí)之識別單張圖片的實(shí)現(xiàn)的示例
本篇文章主要介紹了python tensorflow學(xué)習(xí)之識別單張圖片的實(shí)現(xiàn)的示例,小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧2018-02-02python實(shí)現(xiàn)AHP算法的方法實(shí)例(層次分析法)
這篇文章主要給大家介紹了關(guān)于python實(shí)現(xiàn)AHP算法(層次分析法)的相關(guān)資料,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-09-09pandas數(shù)據(jù)清洗(缺失值和重復(fù)值的處理)
這篇文章主要介紹了pandas數(shù)據(jù)清洗(缺失值和重復(fù)值的處理),pandas對大數(shù)據(jù)有很多便捷的清洗用法,尤其針對缺失值和重復(fù)值,詳細(xì)介紹感興趣的小伙伴可以參考下面文章內(nèi)容2022-08-08