欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch之torch_scatter.scatter_max()用法

 更新時(shí)間:2023年09月11日 11:45:10   作者:A2333fun  
這篇文章主要介紹了pytorch之torch_scatter.scatter_max()用法,具有很好的參考價(jià)值,希望對大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教

torch_scatter.scatter_max()

torch_scatter.scatter_max(src, index, dim=-1, out=None, dim_size=None, fill_value=None)

  • 根據(jù)index將src分組,求每一組中的最大值輸出到out
  • dim是維度

from torch_scatter import scatter_max
src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))
'''src根據(jù)index進(jìn)行分組'''
out, argmax = scatter_max(src, index, out=out)
print(out)
print(argmax)

輸出

tensor([[0., 0., 4., 3., 2., 0.],
        [2., 4., 3., 0., 0., 0.]])
tensor([[-1, -1,  3,  4,  0,  1],
        [ 1,  4,  3, -1, -1, -1]])

解釋

torch_scatter.scatter()使用

1. 參數(shù)

具體來講,scatter函數(shù)的作用就是將index中相同索引對應(yīng)位置的src元素進(jìn)行某種方式的操作,例如 sum mean 等,然后將這些操作結(jié)果按照索引順序進(jìn)行拼接。

下面我用具體的例子來進(jìn)行講解。

2. 示例

2.1 簡單示例

首先初始化src和index:

src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])  # (3, 3)
index = torch.tensor([0, 0, 1], dtype=torch.int64)

接著使用scatter函數(shù):

out = scatter(src, index, dim=0, reduce='mean')

我們觀察 index=[0, 0, 1] ,第0個(gè)位置和第1個(gè)位置都為0,第2個(gè)位置為1。也就是說,我們需要將src中第0個(gè)元素和第1個(gè)元素求平均變成一個(gè)元素,然后第2個(gè)元素求mean也就是本身為一個(gè)元素。如果 index=[1, 0, 0] ,則意味著我們需要將src中第1個(gè)元素和第2個(gè)元素求平均變成一個(gè)元素,而第0個(gè)元素保持不變。

那么src中第幾個(gè)元素到底是如何定義的呢?這就需要用到 dim 參數(shù)了。

dim=0 意味著我們需要對src的維度0進(jìn)行操作:

tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])

即src中第0個(gè)元素為 [1, 2, 3] ,第1個(gè)元素為 [4, 5, 6] ,第2個(gè)元素為 [7, 8, 9] 。

而如果 dim=1 ,則第0個(gè)元素為 [1, 4, 7] ,第1個(gè)元素為 [2, 5, 8] ,第2個(gè)元素為 [3, 6, 9]

因此,如果有以下代碼:

src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])  # (3, 3)
index = torch.tensor([0, 0, 1], dtype=torch.int64)
out = scatter(src, index, dim=0, reduce='mean')

那么我們就應(yīng)該將src中的第0個(gè)元素為 [1, 2, 3] 和第1個(gè)元素為 [4, 5, 6] 求平均為 [2.5, 3.5, 4.5] ,然后第2個(gè)元素 [7, 8, 9] 保持不變,即:

tensor([[2.5000, 3.5000, 4.5000],
        [7.0000, 8.0000, 9.0000]])

2.2 順序問題

上面的例子中 index=[0, 0, 1] ,最后結(jié)果是將src中第0個(gè)元素和第1個(gè)元素求平均放到了位置0,然后src中第2個(gè)元素保持不變放到了位置1。

如果 index=[1, 1, 0] ,結(jié)果為:

tensor([[7.0000, 8.0000, 9.0000],
        [2.5000, 3.5000, 4.5000]])

可以發(fā)現(xiàn),上述結(jié)果是將src中第2個(gè)元素 [7, 8, 9] 保持不變放到了位置0,然后將src中第0個(gè)元素 [1, 2, 3] 和第1個(gè)元素 [4, 5, 6] 求平均保持不變放到了位置1。

也就是說,無論index怎么變化,都是優(yōu)先將index中0對應(yīng)位置的操作結(jié)果進(jìn)行放置。

2.3 維度問題

如果src的維度為(4, 3),而我們需要對 dim=0 操作,也就是一共有四個(gè)元素,那么index的長度應(yīng)該為4,即以下操作是不合法的:

src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])  # (4, 3)
index = torch.tensor([1, 1, 0], dtype=torch.int64)
out = scatter(src, index, dim=0, reduce='mean')
print(out)

報(bào)錯(cuò)為:

RuntimeError: The expanded size of the tensor (4) must match the existing size (3) at non-singleton dimension 0.  Target sizes: [4, 3].  Tensor sizes: [3, 1]

正確做法應(yīng)該是:

src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])  # (4, 3)
index = torch.tensor([1, 1, 0, 2], dtype=torch.int64)
out = scatter(src, index, dim=0, reduce='mean')
print(out)

輸出為:

tensor([[ 7.0000,  8.0000,  9.0000],
        [ 2.5000,  3.5000,  4.5000],
        [10.0000, 11.0000, 12.0000]])

總結(jié)

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • 基于Python+Matplotlib實(shí)現(xiàn)直方圖的繪制

    基于Python+Matplotlib實(shí)現(xiàn)直方圖的繪制

    Matplotlib是Python的繪圖庫,它能讓使用者很輕松地將數(shù)據(jù)圖形化,并且提供多樣化的輸出格式。本文將為大家介紹如何用matplotlib繪制直方圖,感興趣的朋友可以學(xué)習(xí)一下
    2022-04-04
  • 利用Python實(shí)現(xiàn)數(shù)值積分的方法

    利用Python實(shí)現(xiàn)數(shù)值積分的方法

    這篇文章主要介紹了利用Python實(shí)現(xiàn)數(shù)值積分。本文主要用于對比使用Python來實(shí)現(xiàn)數(shù)學(xué)中積分的幾種計(jì)算方式,并和真值進(jìn)行對比,加深大家對積分運(yùn)算實(shí)現(xiàn)方式的理解
    2022-02-02
  • Python實(shí)現(xiàn)石頭剪刀布游戲

    Python實(shí)現(xiàn)石頭剪刀布游戲

    這篇文章主要為大家詳細(xì)介紹了Python實(shí)現(xiàn)石頭剪刀布游戲,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下
    2021-01-01
  • 使用python實(shí)現(xiàn)ANN

    使用python實(shí)現(xiàn)ANN

    這篇文章主要為大家詳細(xì)介紹了使用python實(shí)現(xiàn)ANN的相關(guān)資料,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下
    2017-12-12
  • Python多線程中阻塞(join)與鎖(Lock)使用誤區(qū)解析

    Python多線程中阻塞(join)與鎖(Lock)使用誤區(qū)解析

    這篇文章主要為大家詳細(xì)介紹了Python多線程中阻塞join與鎖Lock的使用誤區(qū),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下
    2018-04-04
  • python連接mysql有哪些方法

    python連接mysql有哪些方法

    在本篇文章里小編給大家分享的是一篇關(guān)于python連接mysql的方法,有興趣的朋友們可以學(xué)習(xí)下。
    2020-06-06
  • pycharm 實(shí)現(xiàn)光標(biāo)快速移動(dòng)到括號外或行尾的操作

    pycharm 實(shí)現(xiàn)光標(biāo)快速移動(dòng)到括號外或行尾的操作

    這篇文章主要介紹了pycharm 實(shí)現(xiàn)光標(biāo)快速移動(dòng)到括號外或行尾的操作,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2021-02-02
  • python?實(shí)現(xiàn)?redis?數(shù)據(jù)庫的操作

    python?實(shí)現(xiàn)?redis?數(shù)據(jù)庫的操作

    這篇文章主要介紹了python?包?redis?數(shù)據(jù)庫的操作教程,redis?是一個(gè)?Key-Value?數(shù)據(jù)庫,下文基于python的相關(guān)資料展開對redis?數(shù)據(jù)庫操作的詳細(xì)介紹,需要的小伙伴可以參考一下
    2022-04-04
  • python對一個(gè)數(shù)向上取整的實(shí)例方法

    python對一個(gè)數(shù)向上取整的實(shí)例方法

    在本篇文章中小編給大家整理了關(guān)于python對一個(gè)數(shù)向上取整的實(shí)例方法,需要的朋友們可以跟著學(xué)習(xí)下。
    2020-06-06
  • pyqt5 QListWidget的用法解析

    pyqt5 QListWidget的用法解析

    這篇文章主要介紹了pyqt5 QListWidget的用法解析,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2021-03-03

最新評論