pytorch torch.gather函數(shù)的使用
pytorch torch.gather函數(shù)
torch.gather
是 PyTorch 中的一個用于從給定維度上按索引取值的函數(shù)。
它根據(jù)一個索引張量 index
,從源張量 input
中收集值,并返回一個新的張量。
torch.gather
常用于需要從張量的特定位置抽取元素的操作。
1. 函數(shù)簽名
torch.gather(input, dim, index, *, sparse_grad=False, out=None)
input
:輸入張量,表示要從中收集元素的源張量。dim
:要收集的維度索引。例如,對于一個二維張量,0 表示沿著行的維度,1 表示沿著列的維度。index
:索引張量,其形狀應(yīng)與input
張量在除了dim
維度之外的其他維度上保持一致。索引張量中的值表示在input
張量對應(yīng)維度上要收集的元素的索引。out
(可選):輸出張量,如果提供,結(jié)果將存儲在這個張量中。
2. 工作原理
torch.gather
在 dim
維度上,通過 index
指定的索引,從 input
中選取元素。
返回的張量的形狀與 index
的形狀相同。
3. 示例代碼
以下是一個簡單的示例代碼,演示如何使用 torch.gather
函數(shù):
import torch # 創(chuàng)建一個源張量 input = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) # 創(chuàng)建一個索引張量 index = torch.tensor([[0, 2, 1], [2, 0, 1], [1, 2, 0]]) # 在 dim=1 維度上使用 gather 函數(shù) result = torch.gather(input, dim=1, index=index) print("Input Tensor:") print(input) print("\nIndex Tensor:") print(index) print("\nResult Tensor:") print(result)
4. 輸出結(jié)果
Input Tensor:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])Index Tensor:
tensor([[0, 2, 1],
[2, 0, 1],
[1, 2, 0]])Result Tensor:
tensor([[1, 3, 2],
[6, 4, 5],
[8, 9, 7]])
5. 解釋
- 輸入張量 (
input
) 是一個3x3
的矩陣,每個元素代表一個值。 - 索引張量 (
index
) 指定了要從input
中提取的元素的索引。 - 結(jié)果張量 (
result
) 是根據(jù)index
從input
中提取的元素形成的張量。
在這個例子中:
- 對于
input
的第一行,index
提取了索引0, 2, 1
對應(yīng)的元素1, 3, 2
。 - 對于
input
的第二行,index
提取了索引2, 0, 1
對應(yīng)的元素6, 4, 5
。 - 對于
input
的第三行,index
提取了索引1, 2, 0
對應(yīng)的元素8, 9, 7
。
總結(jié)
torch.gather
通過索引在指定維度上提取張量中的元素,是用于基于索引選擇數(shù)據(jù)的有用工具。
函數(shù)對批處理數(shù)據(jù)特別有用,例如在分類任務(wù)中提取對應(yīng)類別的概率或得分。
索引張量的形狀必須與源張量在指定維度的形狀相匹配,以確保正確的取值操作。
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
python中元組創(chuàng)建、索引訪問和元組作用詳解
在Python中,元組是一種內(nèi)置的不可變序列,使用圓括號定義,元組的創(chuàng)建可以通過直接使用圓括號或者逗號分隔的方式進(jìn)行,文中通過代碼介紹的非常詳細(xì),需要的朋友可以參考下2024-11-11小學(xué)生也能看懂的python語法之循環(huán)語句精解
這篇文章主要介紹了詳解Python中的條件,循環(huán)語句,包括while循環(huán)for循環(huán),循環(huán)語句是學(xué)習(xí)各個編程語言的最基本的基礎(chǔ)知識,需要的朋友可以參考下2021-09-09Python的Socket編程過程中實現(xiàn)UDP端口復(fù)用的實例分享
這篇文章主要介紹了Python的Socket編程過程中實現(xiàn)UDP端口復(fù)用的實例分享,文中作者用到了Python的twisted異步框架,需要的朋友可以參考下2016-03-03用python3 urllib破解有道翻譯反爬蟲機(jī)制詳解
這篇文章主要介紹了python破解網(wǎng)易反爬蟲機(jī)制詳解,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下2019-08-08使用Python將Mysql的查詢數(shù)據(jù)導(dǎo)出到文件的方法
今天小編就為大家分享一篇關(guān)于使用Python將Mysql的查詢數(shù)據(jù)導(dǎo)出到文件的方法,小編覺得內(nèi)容挺不錯的,現(xiàn)在分享給大家,具有很好的參考價值,需要的朋友一起跟隨小編來看看吧2019-02-02