pytorch torch.gather函數(shù)的使用
pytorch torch.gather函數(shù)
torch.gather
是 PyTorch 中的一個用于從給定維度上按索引取值的函數(shù)。
它根據(jù)一個索引張量 index
,從源張量 input
中收集值,并返回一個新的張量。
torch.gather
常用于需要從張量的特定位置抽取元素的操作。
1. 函數(shù)簽名
torch.gather(input, dim, index, *, sparse_grad=False, out=None)
input
:輸入張量,表示要從中收集元素的源張量。dim
:要收集的維度索引。例如,對于一個二維張量,0 表示沿著行的維度,1 表示沿著列的維度。index
:索引張量,其形狀應與input
張量在除了dim
維度之外的其他維度上保持一致。索引張量中的值表示在input
張量對應維度上要收集的元素的索引。out
(可選):輸出張量,如果提供,結果將存儲在這個張量中。
2. 工作原理
torch.gather
在 dim
維度上,通過 index
指定的索引,從 input
中選取元素。
返回的張量的形狀與 index
的形狀相同。
3. 示例代碼
以下是一個簡單的示例代碼,演示如何使用 torch.gather
函數(shù):
import torch # 創(chuàng)建一個源張量 input = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) # 創(chuàng)建一個索引張量 index = torch.tensor([[0, 2, 1], [2, 0, 1], [1, 2, 0]]) # 在 dim=1 維度上使用 gather 函數(shù) result = torch.gather(input, dim=1, index=index) print("Input Tensor:") print(input) print("\nIndex Tensor:") print(index) print("\nResult Tensor:") print(result)
4. 輸出結果
Input Tensor:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])Index Tensor:
tensor([[0, 2, 1],
[2, 0, 1],
[1, 2, 0]])Result Tensor:
tensor([[1, 3, 2],
[6, 4, 5],
[8, 9, 7]])
5. 解釋
- 輸入張量 (
input
) 是一個3x3
的矩陣,每個元素代表一個值。 - 索引張量 (
index
) 指定了要從input
中提取的元素的索引。 - 結果張量 (
result
) 是根據(jù)index
從input
中提取的元素形成的張量。
在這個例子中:
- 對于
input
的第一行,index
提取了索引0, 2, 1
對應的元素1, 3, 2
。 - 對于
input
的第二行,index
提取了索引2, 0, 1
對應的元素6, 4, 5
。 - 對于
input
的第三行,index
提取了索引1, 2, 0
對應的元素8, 9, 7
。
總結
torch.gather
通過索引在指定維度上提取張量中的元素,是用于基于索引選擇數(shù)據(jù)的有用工具。
函數(shù)對批處理數(shù)據(jù)特別有用,例如在分類任務中提取對應類別的概率或得分。
索引張量的形狀必須與源張量在指定維度的形狀相匹配,以確保正確的取值操作。
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
python中元組創(chuàng)建、索引訪問和元組作用詳解
在Python中,元組是一種內(nèi)置的不可變序列,使用圓括號定義,元組的創(chuàng)建可以通過直接使用圓括號或者逗號分隔的方式進行,文中通過代碼介紹的非常詳細,需要的朋友可以參考下2024-11-11Python的Socket編程過程中實現(xiàn)UDP端口復用的實例分享
這篇文章主要介紹了Python的Socket編程過程中實現(xiàn)UDP端口復用的實例分享,文中作者用到了Python的twisted異步框架,需要的朋友可以參考下2016-03-03使用Python將Mysql的查詢數(shù)據(jù)導出到文件的方法
今天小編就為大家分享一篇關于使用Python將Mysql的查詢數(shù)據(jù)導出到文件的方法,小編覺得內(nèi)容挺不錯的,現(xiàn)在分享給大家,具有很好的參考價值,需要的朋友一起跟隨小編來看看吧2019-02-02