Pytorch中torch.argmax()函數(shù)使用及說明
torch.argmax()函數(shù)解析
1. 官網(wǎng)鏈接
torch.argmax(),如下圖所示:
2. torch.argmax(input)函數(shù)解析
torch.argmax(input) → LongTensor
將輸入input張量,無論有幾維,首先將其reshape排列成一個(gè)一維向量,然后找出這個(gè)一維向量里面最大值的索引。
3. 代碼舉例
import torch x = torch.randn(3,4) y = torch.argmax(x)#對應(yīng)于x中最大元素的索引值 x,y
輸出結(jié)果如下:
import torch
x = torch.randn(3,4)
y = torch.argmax(x)#對應(yīng)于x中最大元素的索引值
x,y
4. torch.argmax(input,dim) 函數(shù)解析
torch.argmax(input, dim, keepdim=False) → LongTensor
函數(shù)返回其他所有維在這個(gè)維度上面張量最大值的索引。
torch.argmax()函數(shù)中dim表示該維度會(huì)消失,可以理解為最終結(jié)果該維度大小是1,表示將該維度壓縮成維度大小為1。
舉例理解:
對于一個(gè)維度為(d0,d1) 的矩陣來說,dim=1表示求每一行中最大數(shù)的在該行中的列號(hào),最后得到的就是一個(gè)維度為(d0,1) 的二維矩陣,最終列這一維度大小為1就要消失了,最終結(jié)果變成一維張量(d0);
dim=0表示求每一列中最大數(shù)的在該列中的行號(hào),最后我們得到的就是一個(gè)維度為(1,d1) 的二維矩陣,結(jié)果行這一維度大小為1就要消失了,最終結(jié)果變成一維張量(d1)。
因此,我們想要求每一行最大的列標(biāo)號(hào),我們就要指定dim=1,表示我們不要列了,保留行的size就可以了。
假如我們想求每一列的最大行標(biāo),就可以指定dim=0,表示我們不要行了,求出每一列的最大值的下標(biāo),最后得到(1,d1)的一維矩陣。
5. 代碼舉例
5.1 輸入二維張量torch.Size([3, 4]),dim=0表示將dim=0這個(gè)維度大小由3壓縮成1,然后找到dim=0這三個(gè)值中最大值的索引,這個(gè)索引表示dim=0行索引標(biāo)號(hào),結(jié)果張量維度變?yōu)閠orch.Size([4])。
import torch x = torch.randn(3,4) y = torch.argmax(x,dim=0)#dim=0表示將dim=0這個(gè)維度大小由3壓縮成1,然后找到dim=0這三個(gè)值中最大值的索引,這個(gè)索引表示dim=0行索引標(biāo)號(hào) x,x.shape,y,y.shape
輸出結(jié)果如下:
(tensor([[ 2.6347, 0.6456, -1.0461, -1.5154],
[-1.3955, -1.2618, -0.5886, -0.5947],
[-1.5272, -2.0960, 0.9428, -0.9532]]),
torch.Size([3, 4]),
tensor([0, 0, 2, 1]),
torch.Size([4]))
5.2 輸入二維張量torch.Size([3, 4]),dim=1表示將dim=1這個(gè)維度大小由4壓縮成1,然后找到dim=1這四個(gè)值中最大值的索引,這個(gè)索引表示dim=1列索引標(biāo)號(hào),結(jié)果張量維度變?yōu)閠orch.Size([3])。
import torch x = torch.randn(3,4) y = torch.argmax(x,dim=1)#dim=1表示將dim=1這個(gè)維度大小由4壓縮成1,然后找到dim=1這四個(gè)值中最大值的索引,這個(gè)索引表示dim=1列索引標(biāo)號(hào) x,x.shape,y,y.shape
輸出結(jié)果如下:
(tensor([[ 0.1549, 0.4331, 0.3575, 1.1077],
[ 2.0233, 2.0085, -0.6101, -1.8547],
[-0.5101, -0.4052, 0.3458, -0.7802]]),
torch.Size([3, 4]),
tensor([3, 0, 2]),
torch.Size([3]))
5.3 輸入三維張量torch.Size([2, 3, 4]),dim=0表示將dim=0這個(gè)維度大小由2壓縮成1,然后找到dim=0這兩個(gè)值中最大值的索引,這個(gè)索引表示dim=0維索引標(biāo)號(hào)。
dim=0,即將第一個(gè)維度消除,也就是將兩個(gè)[34]矩陣只保留一個(gè),因此要在兩組中作比較,即將上下兩個(gè)[34]的矩陣分別在對應(yīng)的位置上比較大小,結(jié)果矩陣張量維度變?yōu)閠orch.Size([3, 4])。
import torch x = torch.randn(2,3,4) y = torch.argmax(x,dim=0)#dim=0表示將dim=0這個(gè)維度大小由2壓縮成1,然后找到dim=0這兩個(gè)值中最大值的索引,這個(gè)索引表示dim=0維索引標(biāo)號(hào) x,x.shape,y,y.shape
輸出結(jié)果如下:
(tensor([[[-1.4430, 0.0306, -1.0396, 0.1219],
[ 0.1016, 0.0889, 0.8005, 0.3320],
[-1.0518, -1.4526, -0.4586, -0.1474]],
[[ 1.2274, 1.5806, 0.5444, -0.3088],
[-0.8672, 0.3843, 1.2377, 2.1596],
[ 0.0671, 0.0847, 0.5607, -0.7492]]]),
torch.Size([2, 3, 4]),
tensor([[1, 1, 1, 0],
[0, 1, 1, 1],
[1, 1, 1, 0]]),
torch.Size([3, 4]))
5.4 輸入三維張量torch.Size([2, 3, 4]),dim=1表示將dim=1這個(gè)維度大小由3壓縮成1,然后找到dim=1這三個(gè)值中最大值的索引,這個(gè)索引表示dim=1維索引標(biāo)號(hào)。
dim=1,即將第二個(gè)維度消除(縱向壓縮成一維),結(jié)果矩陣張量維度變?yōu)閠orch.Size([2, 4])。
import torch x = torch.randn(2,3,4) y = torch.argmax(x,dim=1)#dim=1表示將dim=1這個(gè)維度大小由3壓縮成1,然后找到dim=1這三個(gè)值中最大值的索引,這個(gè)索引表示dim=1維索引標(biāo)號(hào) x,x.shape,y,y.shape
輸出結(jié)果如下:
(tensor([[[-1.7136, 0.5528, 0.5171, 1.2978],
[ 1.0250, -0.2687, 0.6727, -0.2013],
[ 0.1366, -1.0563, 0.1965, 1.5303]],
[[-0.0048, 1.6265, -1.0341, -0.3994],
[ 1.5536, 0.9739, -0.0913, 0.0889],
[-0.6703, -0.9099, -0.6400, -0.1807]]]),
torch.Size([2, 3, 4]),
tensor([[1, 0, 1, 2],
[1, 0, 1, 1]]),
torch.Size([2, 4]))
5.5 輸入三維張量torch.Size([2, 3, 4]),dim=2表示將dim=2這個(gè)維度大小由4壓縮成1,然后找到dim=2這四個(gè)值中最大值的索引,這個(gè)索引表示dim=2維索引標(biāo)號(hào)。dim=2,即將第三個(gè)維度消除(橫向壓縮成一維),結(jié)果矩陣張量維度變?yōu)閠orch.Size([2, 3])。
import torch x = torch.randn(2,3,4) y = torch.argmax(x,dim=2)#dim=2表示將dim=2這個(gè)維度大小由4壓縮成1,然后找到dim=2這四個(gè)值中最大值的索引,這個(gè)索引表示dim=2維索引標(biāo)號(hào) x,x.shape,y,y.shape
輸出結(jié)果如下:
(tensor([[[-0.3493, 0.8838, 0.5876, -0.3967],
[-1.5795, 2.6964, 0.7266, 0.3517],
[-0.6949, -1.4385, -0.0993, 0.1679]],
[[-0.4924, -0.8955, 0.5511, 0.6287],
[ 0.2338, -0.5787, -0.2081, -1.3032],
[ 0.6429, 0.0949, 0.3319, -0.8551]]]),
torch.Size([2, 3, 4]),
tensor([[1, 1, 3],
[3, 0, 0]]),
torch.Size([2, 3]))
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python Numpy中數(shù)據(jù)的常用保存與讀取方法
這篇文章主要介紹了Python Numpy中數(shù)據(jù)的常用保存與讀取方法,本文給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-04-04python二分查找算法的遞歸實(shí)現(xiàn)方法
這篇文章主要介紹了python二分查找算法的遞歸實(shí)現(xiàn)方法,結(jié)合實(shí)例形式分析了Python二分查找算法的相關(guān)實(shí)現(xiàn)技巧,需要的朋友可以參考下2016-05-05Python虛擬環(huán)境Virtualenv使用教程
這篇文章主要介紹了Python虛擬環(huán)境Virtualenv簡明教程,本文整合了兩篇關(guān)于Virtualenv的使用教程,相信大家有通過本文一定可以學(xué)會(huì)如何使用Virtualenv,需要的朋友可以參考下2015-05-05PyTorch中tensor.backward()函數(shù)的詳細(xì)介紹及功能實(shí)現(xiàn)
backward()?函數(shù)是PyTorch框架中自動(dòng)求梯度功能的一部分,它負(fù)責(zé)執(zhí)行反向傳播算法以計(jì)算模型參數(shù)的梯度,這篇文章主要介紹了PyTorch中tensor.backward()函數(shù)的詳細(xì)介紹,需要的朋友可以參考下2024-02-02Python存儲(chǔ)json數(shù)據(jù)發(fā)生亂碼的解決方法
當(dāng)使用json.dump()把python對象轉(zhuǎn)換為json后存儲(chǔ)到文件中時(shí),文件可能會(huì)出現(xiàn)亂碼的問題,本篇文章可以幫助您解決亂碼問題,文中通過圖文介紹的非常詳細(xì),需要的朋友可以參考下2023-09-09anaconda jupyter不能導(dǎo)入安裝的lightgbm解決方案
這篇文章主要介紹了anaconda jupyter不能導(dǎo)入安裝的lightgbm解決方案,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2021-03-03使用python-opencv讀取視頻,計(jì)算視頻總幀數(shù)及FPS的實(shí)現(xiàn)
今天小編就為大家分享一篇使用python-opencv讀取視頻,計(jì)算視頻總幀數(shù)及FPS的實(shí)現(xiàn)方式,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-12-12