返回最大值的index pytorch方式
返回最大值的index
import torch a=torch.tensor([[.1,.2,.3], ? ? ? ? ? ? ? ? [1.1,1.2,1.3], ? ? ? ? ? ? ? ? [2.1,2.2,2.3], ? ? ? ? ? ? ? ? [3.1,3.2,3.3]]) print(a.argmax(dim=1)) print(a.argmax())
輸出:
tensor([ 2, 2, 2, 2])
tensor(11)
pytorch 找最大值
題意:使用神經(jīng)網(wǎng)絡(luò)實(shí)現(xiàn),從數(shù)組中找出最大值。
提供數(shù)據(jù):兩個(gè) csv 文件,一個(gè)存訓(xùn)練集:n 個(gè) m 維特征自然數(shù)數(shù)據(jù),另一個(gè)存每條數(shù)據(jù)對(duì)應(yīng)的 label ,就是每條數(shù)據(jù)中的最大值。
這里將隨機(jī)構(gòu)建訓(xùn)練集:
#%% import numpy as np import pandas as pd import torch import random import torch.utils.data as Data import torch.nn as nn import torch.optim as optim def GetData(m, n): dataset = [] for j in range(m): max_v = random.randint(0, 9) data = [random.randint(0, 9) for i in range(n)] dataset.append(data) label = [max(dataset[i]) for i in range(len(dataset))] data_list = np.column_stack((dataset, label)) data_list = data_list.astype(np.float32) return data_list #%% # 數(shù)據(jù)集封裝 重載函數(shù)len, getitem class GetMaxEle(Data.Dataset): def __init__(self, trainset): self.data = trainset def __getitem__(self, index): item = self.data[index] x = item[:-1] y = item[-1] return x, y def __len__(self): return len(self.data) # %% 定義網(wǎng)絡(luò)模型 class SingleNN(nn.Module): def __init__(self, n_feature, n_hidden, n_output): super(SingleNN, self).__init__() self.hidden = nn.Linear(n_feature, n_hidden) self.relu = nn.ReLU() self.predict = nn.Linear(n_hidden, n_output) def forward(self, x): x = self.hidden(x) x = self.relu(x) x = self.predict(x) return x def train(m, n, batch_size, PATH): # 隨機(jī)生成 m 個(gè) n 個(gè)維度的訓(xùn)練樣本 data_list =GetData(m, n) dataset = GetMaxEle(data_list) trainset = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) net = SingleNN(n_feature=10, n_hidden=100, n_output=10) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) # total_epoch = 100 for epoch in range(total_epoch): for index, data in enumerate(trainset): input_x, labels = data labels = labels.long() optimizer.zero_grad() output = net(input_x) # print(output) # print(labels) loss = criterion(output, labels) loss.backward() optimizer.step() # scheduled_optimizer.step() print(f"Epoch {epoch}, loss:{loss.item()}") # %% 保存參數(shù) torch.save(net.state_dict(), PATH) #測(cè)試 def test(m, n, batch_size, PATH): data_list = GetData(m, n) dataset = GetMaxEle(data_list) testloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size) dataiter = iter(testloader) input_x, labels = dataiter.next() net = SingleNN(n_feature=10, n_hidden=100, n_output=10) net.load_state_dict(torch.load(PATH)) outputs = net(input_x) _, predicted = torch.max(outputs, 1) print("Ground_truth:",labels.numpy()) print("predicted:",predicted.numpy()) if __name__ == "__main__": m = 1000 n = 10 batch_size = 64 PATH = './max_list.pth' train(m, n, batch_size, PATH) test(m, n, batch_size, PATH)
初始的想法是使用全連接網(wǎng)絡(luò)+分類來(lái)實(shí)現(xiàn), 但是結(jié)果不盡人意,主要原因:不同類別之間的樣本量差太大,幾乎90%都是最大值。
比如代碼中隨機(jī)構(gòu)建 10 個(gè) 0~9 的數(shù)字構(gòu)成一個(gè)樣本[2, 3, 5, 8, 9, 5, 3, 9, 3, 6], 該樣本標(biāo)簽是9。
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
python批量處理PDF文檔輸出自定義關(guān)鍵詞的出現(xiàn)次數(shù)
這篇文章主要介紹了python批量處理PDF文檔,輸出自定義關(guān)鍵詞的出現(xiàn)次數(shù),文中有詳細(xì)的代碼示例,需要的朋友可以參考閱讀2023-04-04python學(xué)習(xí)之matplotlib繪制散點(diǎn)圖實(shí)例
這篇文章主要介紹了python學(xué)習(xí)之matplotlib繪制散點(diǎn)圖實(shí)例,具有一定借鑒價(jià)值,需要的朋友可以參考下。2017-12-12python數(shù)據(jù)可視化之條形圖畫(huà)法
這篇文章主要為大家詳細(xì)介紹了python數(shù)據(jù)可視化之條形圖畫(huà)法,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2022-04-04python使用Streamlit庫(kù)制作Web可視化頁(yè)面
一談到Web頁(yè)面,可能大家首先想到就是HTML,CSS或JavaScript。 本次小F就給大家介紹一下如何用Python制作一個(gè)數(shù)據(jù)可視化網(wǎng)頁(yè),使用到的是Streamlit庫(kù)。輕松的將一個(gè)Excel數(shù)據(jù)文件轉(zhuǎn)換為一個(gè)Web頁(yè)面,提供給所有人在線查看。2021-05-05關(guān)于對(duì)python中進(jìn)程的幾個(gè)概念理解
進(jìn)程由程序,數(shù)據(jù)和進(jìn)程控制塊組成,是正在執(zhí)行的程,程序的一次執(zhí)行過(guò)程,是資源調(diào)度的基本單位,下面這篇文章主要給大家介紹了關(guān)于對(duì)python中進(jìn)程的幾個(gè)概念理解,需要的朋友可以參考下2021-10-10python標(biāo)準(zhǔn)庫(kù)壓縮包模塊zipfile和tarfile詳解(常用標(biāo)準(zhǔn)庫(kù))
在我們常用的系統(tǒng)windows和Linux系統(tǒng)中有很多支持的壓縮包格式,包括但不限于以下種類:rar、zip、tar,這篇文章主要介紹了python標(biāo)準(zhǔn)庫(kù)壓縮包模塊zipfile和tarfile詳解(常用標(biāo)準(zhǔn)庫(kù)),需要的朋友可以參考下2022-06-06老生常談python函數(shù)參數(shù)的區(qū)別(必看篇)
下面小編就為大家?guī)?lái)一篇老生常談python函數(shù)參數(shù)的區(qū)別(必看篇)。小編覺(jué)得挺不錯(cuò)的,現(xiàn)在就分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧2017-05-05Python數(shù)據(jù)結(jié)構(gòu)與算法之圖結(jié)構(gòu)(Graph)實(shí)例分析
這篇文章主要介紹了Python數(shù)據(jù)結(jié)構(gòu)與算法之圖結(jié)構(gòu)(Graph),結(jié)合實(shí)例形式分析了圖結(jié)構(gòu)的概念、原理、使用方法及相關(guān)操作技巧,需要的朋友可以參考下2017-09-09Python3按一定數(shù)據(jù)位數(shù)格式處理bin文件的方法
今天小編就為大家分享一篇Python3按一定數(shù)據(jù)位數(shù)格式處理bin文件的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-01-01