pytorch中torch.topk()函數(shù)的快速理解
函數(shù)作用:


該函數(shù)的作用即按字面意思理解,topk:取數(shù)組的前k個(gè)元素進(jìn)行排序。
通常該函數(shù)返回2個(gè)值,第一個(gè)值為排序的數(shù)組,第二個(gè)值為該數(shù)組中獲取到的元素在原數(shù)組中的位置標(biāo)號(hào)。
舉個(gè)栗子:
import numpy as np
import torch
import torch.utils.data.dataset as Dataset
from torch.utils.data import Dataset,DataLoader
####################準(zhǔn)備一個(gè)數(shù)組#########################
tensor1=torch.tensor([[10,1,2,1,1,1,1,1,1,1,10],
[3,4,5,1,1,1,1,1,1,1,1],
[7,8,9,1,1,1,1,1,1,1,1],
[1,4,7,1,1,1,1,1,1,1,1]],dtype=torch.float32)
####################打印這個(gè)原數(shù)組#########################
print('tensor1:')
print(tensor1)
#################使用torch.topk()這個(gè)函數(shù)##################
print('使用torch.topk()這個(gè)函數(shù)得到:')
'''k=3代表從原數(shù)組中取得3個(gè)元素,dim=1表示從原數(shù)組中的第一維獲取元素
(在本例中是分別從[10,1,2,1,1,1,1,1,1,1,10]、[3,4,5,1,1,1,1,1,1,1,1]、
[7,8,9,1,1,1,1,1,1,1,1]、[1,4,7,1,1,1,1,1,1,1,1]這四個(gè)數(shù)組中獲取3個(gè)元素)
其中l(wèi)argest=True表示從大到小取元素'''
print(torch.topk(tensor1, k=3, dim=1, largest=True))
#################打印這個(gè)函數(shù)第一個(gè)返回值####################
print('函數(shù)第一個(gè)返回值topk[0]如下')
print(torch.topk(tensor1, k=3, dim=1, largest=True)[0])
#################打印這個(gè)函數(shù)第二個(gè)返回值####################
print('函數(shù)第二個(gè)返回值topk[1]如下')
print(torch.topk(tensor1, k=3, dim=1, largest=True)[1])
'''
#######################運(yùn)行結(jié)果##########################
tensor1:
tensor([[10., 1., 2., 1., 1., 1., 1., 1., 1., 1., 10.],
[ 3., 4., 5., 1., 1., 1., 1., 1., 1., 1., 1.],
[ 7., 8., 9., 1., 1., 1., 1., 1., 1., 1., 1.],
[ 1., 4., 7., 1., 1., 1., 1., 1., 1., 1., 1.]])
使用torch.topk()這個(gè)函數(shù)得到:
'得到的values是原數(shù)組dim=1的四組從大到小的三個(gè)元素值;
得到的indices是獲取到的元素值在原數(shù)組dim=1中的位置。'
torch.return_types.topk(
values=tensor([[10., 10., 2.],
[ 5., 4., 3.],
[ 9., 8., 7.],
[ 7., 4., 1.]]),
indices=tensor([[ 0, 10, 2],
[ 2, 1, 0],
[ 2, 1, 0],
[ 2, 1, 0]]))
函數(shù)第一個(gè)返回值topk[0]如下
tensor([[10., 10., 2.],
[ 5., 4., 3.],
[ 9., 8., 7.],
[ 7., 4., 1.]])
函數(shù)第二個(gè)返回值topk[1]如下
tensor([[ 0, 10, 2],
[ 2, 1, 0],
[ 2, 1, 0],
[ 2, 1, 0]])
'''
該函數(shù)功能經(jīng)常用來(lái)獲取張量或者數(shù)組中最大或者最小的元素以及索引位置,是一個(gè)經(jīng)常用到的基本函數(shù)。
實(shí)例演示
任務(wù)一:
取top1(最大值):
pred = torch.tensor([[-0.5816, -0.3873, -1.0215, -1.0145, 0.4053],
[ 0.7265, 1.4164, 1.3443, 1.2035, 1.8823],
[-0.4451, 0.1673, 1.2590, -2.0757, 1.7255],
[ 0.2021, 0.3041, 0.1383, 0.3849, -1.6311]])
print(pred)
values, indices = pred.topk(1, dim=0, largest=True, sorted=True)
print(indices)
print(values)
# 用max得到的結(jié)果,設(shè)置keepdim為T(mén)rue,避免降維。因?yàn)閠opk函數(shù)返回的index不降維,shape和輸入一致。
_, indices_max = pred.max(dim=0, keepdim=True)
print(indices_max)
print(indices_max == indices)
輸出:
tensor([[-0.5816, -0.3873, -1.0215, -1.0145, 0.4053],
[ 0.7265, 1.4164, 1.3443, 1.2035, 1.8823],
[-0.4451, 0.1673, 1.2590, -2.0757, 1.7255],
[ 0.2021, 0.3041, 0.1383, 0.3849, -1.6311]])
tensor([[1, 1, 1, 1, 1]])
tensor([[0.7265, 1.4164, 1.3443, 1.2035, 1.8823]])
tensor([[1, 1, 1, 1, 1]])
tensor([[True, True, True, True, True]])
任務(wù)二:
按行取出topk,將小于topk的置為inf:
pred = torch.tensor([[-0.5816, -0.3873, -1.0215, -1.0145, 0.4053],
[ 0.7265, 1.4164, 1.3443, 1.2035, 1.8823],
[-0.4451, 0.1673, 1.2590, -2.0757, 1.7255],
[ 0.2021, 0.3041, 0.1383, 0.3849, -1.6311]])
print(pred)
top_k = 2 # 按行求出每一行的最大的前兩個(gè)值
filter_value=-float('Inf')
indices_to_remove = pred < torch.topk(pred, top_k)[0][..., -1, None]
print(indices_to_remove)
pred[indices_to_remove] = filter_value # 對(duì)于topk之外的其他元素的logits值設(shè)為負(fù)無(wú)窮
print(pred)
輸出:
tensor([[-0.5816, -0.3873, -1.0215, -1.0145, 0.4053],
[ 0.7265, 1.4164, 1.3443, 1.2035, 1.8823],
[-0.4451, 0.1673, 1.2590, -2.0757, 1.7255],
[ 0.2021, 0.3041, 0.1383, 0.3849, -1.6311]])
tensor([[4],
[4],
[4],
[3]])
tensor([[0.4053],
[1.8823],
[1.7255],
[0.3849]])
tensor([[ True, False, True, True, False],
[ True, False, True, True, False],
[ True, True, False, True, False],
[ True, False, True, False, True]])
tensor([[ -inf, -0.3873, -inf, -inf, 0.4053],
[ -inf, 1.4164, -inf, -inf, 1.8823],
[ -inf, -inf, 1.2590, -inf, 1.7255],
[ -inf, 0.3041, -inf, 0.3849, -inf]])任務(wù)三:
import numpy as np
import torch
import torch.utils.data.dataset as Dataset
from torch.utils.data import Dataset,DataLoader
tensor1=torch.tensor([[10,1,2,1,1,1,1,1,1,1,10],
[3,4,5,1,1,1,1,1,1,1,1],
[7,8,9,1,1,1,1,1,1,1,1],
[1,4,7,1,1,1,1,1,1,1,1]],dtype=torch.float32)
# tensor2=torch.tensor([[3,2,1],
# [6,5,4],
# [1,4,7],
# [9,8,7]],dtype=torch.float32)
#
print('tensor1:')
print(tensor1)
print('直接輸出topk,會(huì)得到兩個(gè)東西,我們需要的是第二個(gè)indices')
print(torch.topk(tensor1, k=3, dim=1, largest=True))
print('topk[0]如下')
print(torch.topk(tensor1, k=3, dim=1, largest=True)[0])
print('topk[1]如下')
print(torch.topk(tensor1, k=3, dim=1, largest=True)[1])
'''
tensor1:
tensor([[10., 1., 2., 1., 1., 1., 1., 1., 1., 1., 10.],
[ 3., 4., 5., 1., 1., 1., 1., 1., 1., 1., 1.],
[ 7., 8., 9., 1., 1., 1., 1., 1., 1., 1., 1.],
[ 1., 4., 7., 1., 1., 1., 1., 1., 1., 1., 1.]])
直接輸出topk,會(huì)得到兩個(gè)東西,我們需要的是第二個(gè)indices
torch.return_types.topk(
values=tensor([[10., 10., 2.],
[ 5., 4., 3.],
[ 9., 8., 7.],
[ 7., 4., 1.]]),
indices=tensor([[ 0, 10, 2],
[ 2, 1, 0],
[ 2, 1, 0],
[ 2, 1, 0]]))
topk[0]如下
tensor([[10., 10., 2.],
[ 5., 4., 3.],
[ 9., 8., 7.],
[ 7., 4., 1.]])
topk[1]如下
tensor([[ 0, 10, 2],
[ 2, 1, 0],
[ 2, 1, 0],
[ 2, 1, 0]])
'''
總結(jié)
到此這篇關(guān)于pytorch中torch.topk()函數(shù)快速理解的文章就介紹到這了,更多相關(guān)pytorch torch.topk()函數(shù)理解內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
基于Python和C++實(shí)現(xiàn)刪除鏈表的節(jié)點(diǎn)
這篇文章主要介紹了基于Python和C++實(shí)現(xiàn)刪除鏈表的節(jié)點(diǎn),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-07-07
Streamlit+Echarts實(shí)現(xiàn)繪制精美圖表
在數(shù)據(jù)分析和可視化的領(lǐng)域,選擇合適的工具可以讓我們事半功倍,本文主要為大家介紹兩個(gè)工具,Streamlit和ECharts,感興趣的小伙伴可以跟隨小編一起了解下2023-09-09
用python3 urllib破解有道翻譯反爬蟲(chóng)機(jī)制詳解
這篇文章主要介紹了python破解網(wǎng)易反爬蟲(chóng)機(jī)制詳解,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2019-08-08
詳解Django將秒轉(zhuǎn)換為xx天xx時(shí)xx分
這篇文章主要介紹了Django將秒轉(zhuǎn)換為xx天xx時(shí)xx分,本文通過(guò)實(shí)例代碼給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2019-09-09
python腳本實(shí)現(xiàn)音頻m4a格式轉(zhuǎn)成MP3格式的實(shí)例代碼
這篇文章主要介紹了python腳本實(shí)現(xiàn)音頻m4a格式轉(zhuǎn)成MP3格式的實(shí)例代碼,非常不錯(cuò),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2019-10-10
Python利用pyreadline模塊實(shí)現(xiàn)交互式命令行開(kāi)發(fā)
交互式命令行是一種方便用戶(hù)進(jìn)行交互的工具,能夠使用戶(hù)與計(jì)算機(jī)進(jìn)行快速的交互操作,提高工作效率。本文主要介紹了如何利用pyreadline模塊實(shí)現(xiàn)交互式命令行開(kāi)發(fā),需要的可以參考一下2023-05-05
使用keras實(shí)現(xiàn)densenet和Xception的模型融合
這篇文章主要介紹了使用keras實(shí)現(xiàn)densenet和Xception的模型融合,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-05-05

