欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

關(guān)于torch.scatter與torch_scatter庫的使用整理

 更新時(shí)間:2023年09月11日 14:36:18   作者:回爐重造P  
這篇文章主要介紹了關(guān)于torch.scatter與torch_scatter庫的使用整理,具有很好的參考價(jià)值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教

最近在做圖結(jié)構(gòu)相關(guān)的算法,scatter能把鄰接矩陣?yán)锏男畔⑿薷?,或者把鄰居分組算個sum或者reduce,挺方便的,簡單整理一下。

torch.scatter 與 tensor._scatter

Pytorch自帶的函數(shù),用來將作為 src 的tensor根據(jù) index 的描述填充到 input 中,

形式如下:

ouput = torch.scatter(input, dim, index, src)
# 或者是
input.scatter_(dim, index, src)

兩個方法的功能是相同的,而帶下劃線的 _scatter 方法是將原tensor input 直接修改了,不帶的則會返回一個新的tensor output input 不變。

其中 dim 決定 index 對應(yīng)值是沿著哪個維度進(jìn)行修改。而 src 為數(shù)據(jù)來源,當(dāng)其為tensor張量時(shí),shape要和index相同,這樣index中每個元素都能對應(yīng) src 中對應(yīng)位置的信息。

理解 scatter 方法主要是要理解 index 實(shí)現(xiàn)的 src input 之間的位置對應(yīng)關(guān)系,舉個例子:

dim = 0
index = torch.tensor(
	[[0, 2, 2], 
	[2, 1, 0]]
)

dim 為0時(shí),遵循的映射原則為: input[index[i][j]][j] = src[i][j] .

也就是說,將位置 (i, j) 中 dim 對應(yīng)的位置改為 index[i][j] 的值。

如位置(1,0),index[1][0]為2,則映射后的位置為(2,0),意味著 input 中(2,0)的位置被更改為 src 中(1,0)位置的值。

我個人形象理解是這些值會沿著dim方向滑動,上面例子中src[1][0]位置的值滑到2,成為input中的新值,這樣理解起來更形象一點(diǎn)。

基本理解了上面這個例子,多維情況和不同dim的情況都可以類推了。

需要注意:src和input的dtype需要相同,不然會報(bào)

Expected self.dtype to be equal to src.dtype

不一樣就先轉(zhuǎn)換再使用。

t = torch.arange(6).view(2, 3)
t = t.to(torch.float32)
print(t)
output = torch.scatter(torch.zeros((3, 3)), 0, torch.tensor([[0, 2, 2], [2, 1, 0]]), t)
print(torch.zeros((3, 3)).scatter_(0, torch.tensor([[0, 2, 2], [2, 1, 0]]), t))

輸出:

tensor([[0., 1., 2.],
        [3., 4., 5.]])
tensor([[0., 0., 5.],
        [0., 4., 0.],
        [3., 1., 2.]])

torch_scatter庫

這個第三方庫對矩陣的分組處理這個概念做了更進(jìn)一步的封裝,通過index來指定分組信息,將元素分組后進(jìn)行對應(yīng)處理,

最基礎(chǔ)的scatter方法形式如下:

torch_scatter.scatter(src, index, dim, out, dim_size, reduce)
  • src : 數(shù)據(jù)源
  • index :分組序列
  • dim :分組遵循的維度
  • out :輸出的tensor,可以不指定直接讓函數(shù)輸出
  • dim_size :out不指定的時(shí)候,將輸出shape變?yōu)樵撝荡笮?;dim_size也不指定,就根據(jù)計(jì)算結(jié)果來
  • reduce :分組的操作,包括sum,mul,mean,min和max操作

這個方法理解關(guān)鍵在 index 的分組方法,

舉個例子:

dim = 1
index = torch.tensor([[0, 1, 1]])

torch_scatter.scatter index 的順序是沒有特定規(guī)定的,相同數(shù)字對應(yīng)的元素即為一組。

比如例子中,維度1上的第0個元素為一組,第1和2元素為另一組。

這樣,按照分組進(jìn)行reduce定義的計(jì)算即可獲得輸出。如:

t = torch.arange(12).view(4, 3)
print(t)
t_s = torch_scatter.scatter(t, torch.tensor([[0, 1, 1]]), dim=1, reduce='sum')
print(t_s)

輸出:

tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])
tensor([[ 0,  3],
        [ 3,  9],
        [ 6, 15]])

可以看出,每行的后兩個元素求了和,與index定義相同。

要注意的是,index的 shape[0] 為1時(shí),會自動對dim對應(yīng)的維度上每一層進(jìn)行相同的分組處理,如上例所示,index大小為(1, 3),即對src的三行數(shù)據(jù)都進(jìn)行了分組處理。

而另一種分組方式,如需要每行分組不同,則需要index的shape和src的shape相同,如下例:

t = torch.arange(12).view(4, 3)
print(t)
t_s = torch_scatter.scatter(t, torch.tensor([[0, 1, 1], [1, 1, 0], [0, 1, 1], [1, 1, 0]]), dim=1, reduce='sum')
print(t_s)

輸出:

tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])
tensor([[ 0,  3],
        [ 5,  7],
        [ 6, 15]])

shape不相同時(shí),則會報(bào)錯提示:

RuntimeError: The expanded size of the tensor (3) must match the existing size (2) at non-singleton dimension 0 .

同時(shí),該庫還給出了另外兩種方法,分別為 torch_scatter.segment_coo torch_scatter.segment_csr .

torch_scatter.segment_coo

torch_scatter.segment_coo scatter 的功能差不多,但它只支持index的shape[0]為1的狀態(tài),即每一行都為相同的分組方式。

同時(shí),index中數(shù)值為順序排列,以提高計(jì)算速度。

torch_scatter.segment_csr

torch_scatter.segment_csr 的index格式不太相同,是一種區(qū)間格式,如[0, 2, 5],表示0,1為一組,2,3,4為一組,即取數(shù)值間的左閉右開區(qū)間。

這個方法是計(jì)算速度最快的。

官方文檔地址

torch_scatter庫doc

https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html

torch.scatter文檔

https://pytorch-cn.readthedocs.io/zh/latest/package_references/Tensor/#scatter_input-dim-index-src-tensor

總結(jié)

以上為個人經(jīng)驗(yàn),希望能給大家一個參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • python網(wǎng)絡(luò)爬蟲精解之XPath的使用說明

    python網(wǎng)絡(luò)爬蟲精解之XPath的使用說明

    XPath 是一門在 XML 文檔中查找信息的語言。XPath 可用來在 XML 文檔中對元素和屬性進(jìn)行遍歷。XPath 是 W3C XSLT 標(biāo)準(zhǔn)的主要元素,并且 XQuery 和 XPointer 都構(gòu)建于 XPath 表達(dá)之上
    2021-09-09
  • Python高效計(jì)算庫Joblib的入門教程

    Python高效計(jì)算庫Joblib的入門教程

    Joblib庫是一個用于在Python中進(jìn)行高效計(jì)算的開源庫,提供內(nèi)存映射和并行計(jì)算工具,本文就來介紹一下Joblib庫的使用,具有一定的參考價(jià)值,感興趣的可以了解一下
    2025-01-01
  • Python Tkinter GUI編程入門介紹

    Python Tkinter GUI編程入門介紹

    這篇文章主要介紹了Python Tkinter GUI編程入門介紹,本文講解了Tkinter介紹、Tkinter的使用、Tkinter的幾何管理器等內(nèi)容,并給出了一個完整示例,需要的朋友可以參考下
    2015-03-03
  • django使用xlwt導(dǎo)出excel文件實(shí)例代碼

    django使用xlwt導(dǎo)出excel文件實(shí)例代碼

    這篇文章主要介紹了django使用xlwt導(dǎo)出excel文件實(shí)例代碼,分享了相關(guān)代碼示例,小編覺得還是挺不錯的,具有一定借鑒價(jià)值,需要的朋友可以參考下
    2018-02-02
  • python數(shù)據(jù)結(jié)構(gòu)之二叉樹的統(tǒng)計(jì)與轉(zhuǎn)換實(shí)例

    python數(shù)據(jù)結(jié)構(gòu)之二叉樹的統(tǒng)計(jì)與轉(zhuǎn)換實(shí)例

    這篇文章主要介紹了python數(shù)據(jù)結(jié)構(gòu)之二叉樹的統(tǒng)計(jì)與轉(zhuǎn)換實(shí)例,例如統(tǒng)計(jì)二叉樹的葉子、分支節(jié)點(diǎn),以及二叉樹的左右兩樹互換等,需要的朋友可以參考下
    2014-04-04
  • Python+OpenCV圖片去水印的多種方案實(shí)現(xiàn)

    Python+OpenCV圖片去水印的多種方案實(shí)現(xiàn)

    這篇文章主要為大家總結(jié)了Python結(jié)合OpenCV的幾種常見的水印去除方式,簡單圖片去水印效果良好,有需要的小伙伴可以跟隨小編一起了解下
    2025-02-02
  • Python階乘求和的代碼詳解

    Python階乘求和的代碼詳解

    在本篇文章里小編給大家整理的是關(guān)于Python階乘求和的代碼實(shí)例,有需要的朋友們可以跟著學(xué)習(xí)下。
    2020-02-02
  • python tensorflow學(xué)習(xí)之識別單張圖片的實(shí)現(xiàn)的示例

    python tensorflow學(xué)習(xí)之識別單張圖片的實(shí)現(xiàn)的示例

    本篇文章主要介紹了python tensorflow學(xué)習(xí)之識別單張圖片的實(shí)現(xiàn)的示例,小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧
    2018-02-02
  • python實(shí)現(xiàn)AHP算法的方法實(shí)例(層次分析法)

    python實(shí)現(xiàn)AHP算法的方法實(shí)例(層次分析法)

    這篇文章主要給大家介紹了關(guān)于python實(shí)現(xiàn)AHP算法(層次分析法)的相關(guān)資料,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2020-09-09
  • pandas數(shù)據(jù)清洗(缺失值和重復(fù)值的處理)

    pandas數(shù)據(jù)清洗(缺失值和重復(fù)值的處理)

    這篇文章主要介紹了pandas數(shù)據(jù)清洗(缺失值和重復(fù)值的處理),pandas對大數(shù)據(jù)有很多便捷的清洗用法,尤其針對缺失值和重復(fù)值,詳細(xì)介紹感興趣的小伙伴可以參考下面文章內(nèi)容
    2022-08-08

最新評論