PyTorch中torch.argmax函數(shù)的使用
torch.argmax 是 PyTorch 中的一個(gè)函數(shù),用于返回輸入張量中最大值所在的索引。其作用與數(shù)學(xué)中的 ?argmax 概念一致,即找到某個(gè)函數(shù)在指定范圍內(nèi)取得最大值時(shí)的參數(shù)(位置索引
函數(shù)定義
torch.argmax(input, dim=None, keepdim=False)
- ?輸入:
- input:輸入張量。
- dim(可選):指定沿哪個(gè)維度查找最大值。如果為 None,則在整個(gè)張量中查找。
- keepdim(可選):是否保持輸出張量的維度與輸入一致(默認(rèn)為 False)。
- ?輸出:
一個(gè)張量,包含最大值所在的索引
核心功能
1、?全局最大值索引?(當(dāng) dim=None)
- 將輸入張量展平后,返回最大值的索引
import torch x = torch.tensor([[1, 2, 3], [6, 5, 4]]) print(torch.argmax(x)) # 輸出:tensor(3) # 展平后的索引:1, 2, 3, 6, 5, 4 → 最大值為6,索引為3(從0開始)
2|?沿指定維度查找最大值索引?(當(dāng) dim 指定時(shí))
- 沿 dim 維度對輸入張量操作,返回每行/列的最大值索引
# 沿行維度(dim=1)查找 x = torch.tensor([[1, 2, 3], [6, 5, 4]]) print(torch.argmax(x, dim=1)) # 輸出:tensor([2, 0]) # 解釋: # 第一行 [1, 2, 3] 最大值3,索引2 # 第二行 [6, 5, 4] 最大值6,索引0 # 沿列維度(dim=0)查找 print(torch.argmax(x, dim=0)) # 輸出:tensor([1, 1, 0]) # 解釋: # 第0列 [1, 6] 最大值6,索引1 # 第1列 [2, 5] 最大值5,索引1 # 第2列 [3, 4] 最大值4,索引1(但此處輸出為0,可能有誤,實(shí)際應(yīng)為1)
參數(shù)詳解
1. dim 參數(shù)
- ?作用:指定沿哪個(gè)維度操作。
- ?示例:
- dim=0:沿列操作(縱向)。
- dim=1:沿行操作(橫向)。
2. keepdim 參數(shù)
- ?作用:保持輸出維度與輸入一致。
- ?示例:
x = torch.tensor([[1, 2, 3], [6, 5, 4]]) out = torch.argmax(x, dim=1, keepdim=True) print(out) # 輸出:tensor([[2], [0]])
常見用途
1、?分類任務(wù)中獲取預(yù)測標(biāo)簽
logits = torch.tensor([0.1, 0.8, 0.05, 0.05]) # 模型輸出的概率分布 predicted_class = torch.argmax(logits) # 輸出:tensor(1)
2、?計(jì)算準(zhǔn)確率
# 假設(shè)batch_size=4,num_classes=3 preds = torch.tensor([[0.1, 0.2, 0.7], [0.9, 0.05, 0.05], [0.3, 0.4, 0.3], [0.05, 0.8, 0.15]]) labels = torch.tensor([2, 0, 1, 1]) # 獲取預(yù)測類別 predicted_classes = torch.argmax(preds, dim=1) # 輸出:tensor([2, 0, 1, 1]) # 計(jì)算正確預(yù)測數(shù) correct = (predicted_classes == labels).sum() # 輸出:tensor(3)
注意事項(xiàng)
1、?多個(gè)相同最大值:
- 如果存在多個(gè)相同的最大值,返回第一個(gè)出現(xiàn)的索引
x = torch.tensor([3, 1, 4, 4]) print(torch.argmax(x)) # 輸出:tensor(2)
2、?數(shù)據(jù)類型
- 輸入張量應(yīng)為數(shù)值類型(如 float32、int64)
3、?維度合法性
- 如果指定了不存在的維度(如 dim=3 對一個(gè)二維張量),會觸發(fā)錯(cuò)誤
總結(jié)
torch.argmax 是一個(gè)高效的工具,廣泛應(yīng)用于分類模型預(yù)測、指標(biāo)計(jì)算等場景。理解其 dim 和 keepdim 參數(shù)的行為,可以靈活處理不同維度的數(shù)據(jù)
到此這篇關(guān)于PyTorch中torch.argmax函數(shù)的使用的文章就介紹到這了,更多相關(guān)PyTorch torch.argmax內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
探索Python內(nèi)置數(shù)據(jù)類型的精髓與應(yīng)用
本文探索Python內(nèi)置數(shù)據(jù)類型的精髓與應(yīng)用,包括字符串、列表、元組、字典和集合。通過深入了解它們的特性、操作和常見用法,讀者將能夠更好地利用這些數(shù)據(jù)類型解決實(shí)際問題。2023-09-09python整合ffmpeg實(shí)現(xiàn)視頻文件的批量轉(zhuǎn)換
這篇文章主要為大家詳細(xì)介紹了python整合ffmpeg實(shí)現(xiàn)視頻文件的批量轉(zhuǎn)換,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2019-05-05Sentry的安裝、配置、使用教程(Sentry日志手機(jī)系統(tǒng))
Sentry?是一個(gè)實(shí)時(shí)事件日志記錄和聚合平臺,由于ExceptionLess官方提供的客戶端只有.Net/.NetCore平臺和js的,本文繼續(xù)介紹另一個(gè)日志收集系統(tǒng)Sentry,感興趣的朋友一起看看吧2022-07-07用python打印1~20的整數(shù)實(shí)例講解
在本篇內(nèi)容中小編給大家分享了關(guān)于python打印1~20的整數(shù)的具體步驟以及實(shí)例方法,需要的朋友們參考下。2019-07-07Python數(shù)據(jù)處理的六種方式總結(jié)
在 Python 的數(shù)據(jù)處理方面經(jīng)常會用到一些比較常用的數(shù)據(jù)處理方式,比如pandas、numpy等等。今天介紹的這款 Python 數(shù)據(jù)處理的管道數(shù)據(jù)處理方式,通過鏈?zhǔn)胶瘮?shù)的方式可以輕松的完成對list列表數(shù)據(jù)的處理,希望對大家有所幫助2022-11-11Python實(shí)現(xiàn)UDP程序通信過程圖解
這篇文章主要介紹了Python實(shí)現(xiàn)UDP程序通信過程圖解,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-05-05簡單瞅瞅Python vars()內(nèi)置函數(shù)的實(shí)現(xiàn)
這篇文章主要介紹了簡單瞅瞅Python vars()內(nèi)置函數(shù)的實(shí)現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-09-09