pytorch分類模型繪制混淆矩陣以及可視化詳解
Step 1. 獲取混淆矩陣
#首先定義一個(gè) 分類數(shù)*分類數(shù) 的空混淆矩陣 conf_matrix = torch.zeros(Emotion_kinds, Emotion_kinds) # 使用torch.no_grad()可以顯著降低測試用例的GPU占用 with torch.no_grad(): for step, (imgs, targets) in enumerate(test_loader): # imgs: torch.Size([50, 3, 200, 200]) torch.FloatTensor # targets: torch.Size([50, 1]), torch.LongTensor 多了一維,所以我們要把其去掉 targets = targets.squeeze() # [50,1] -----> [50] # 將變量轉(zhuǎn)為gpu targets = targets.cuda() imgs = imgs.cuda() # print(step,imgs.shape,imgs.type(),targets.shape,targets.type()) out = model(imgs) #記錄混淆矩陣參數(shù) conf_matrix = confusion_matrix(out, targets, conf_matrix) conf_matrix=conf_matrix.cpu()
混淆矩陣的求取用到了confusion_matrix函數(shù),其定義如下:
def confusion_matrix(preds, labels, conf_matrix): preds = torch.argmax(preds, 1) for p, t in zip(preds, labels): conf_matrix[p, t] += 1 return conf_matrix
在當(dāng)我們的程序執(zhí)行結(jié)束 test_loader 后,我們可以得到本次數(shù)據(jù)的 混淆矩陣,接下來就要計(jì)算其 識別正確的個(gè)數(shù)以及混淆矩陣可視化:
conf_matrix=np.array(conf_matrix.cpu())# 將混淆矩陣從gpu轉(zhuǎn)到cpu再轉(zhuǎn)到np corrects=conf_matrix.diagonal(offset=0)#抽取對角線的每種分類的識別正確個(gè)數(shù) per_kinds=conf_matrix.sum(axis=1)#抽取每個(gè)分類數(shù)據(jù)總的測試條數(shù) print("混淆矩陣總元素個(gè)數(shù):{0},測試集總個(gè)數(shù):{1}".format(int(np.sum(conf_matrix)),test_num)) print(conf_matrix) # 獲取每種Emotion的識別準(zhǔn)確率 print("每種情感總個(gè)數(shù):",per_kinds) print("每種情感預(yù)測正確的個(gè)數(shù):",corrects) print("每種情感的識別準(zhǔn)確率為:{0}".format([rate*100 for rate in corrects/per_kinds]))
執(zhí)行此步的輸出結(jié)果如下所示:
Step 2. 混淆矩陣可視化
對上邊求得的混淆矩陣可視化
# 繪制混淆矩陣 Emotion=8#這個(gè)數(shù)值是具體的分類數(shù),大家可以自行修改 labels = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprised']#每種類別的標(biāo)簽 # 顯示數(shù)據(jù) plt.imshow(conf_matrix, cmap=plt.cm.Blues) # 在圖中標(biāo)注數(shù)量/概率信息 thresh = conf_matrix.max() / 2 #數(shù)值顏色閾值,如果數(shù)值超過這個(gè),就顏色加深。 for x in range(Emotion_kinds): for y in range(Emotion_kinds): # 注意這里的matrix[y, x]不是matrix[x, y] info = int(conf_matrix[y, x]) plt.text(x, y, info, verticalalignment='center', horizontalalignment='center', color="white" if info > thresh else "black") plt.tight_layout()#保證圖不重疊 plt.yticks(range(Emotion_kinds), labels) plt.xticks(range(Emotion_kinds), labels,rotation=45)#X軸字體傾斜45° plt.show() plt.close()
好了,以下就是最終的可視化的混淆矩陣?yán)玻?/p>
其它分類指標(biāo)的獲取
例如 F1分?jǐn)?shù)、TP、TN、FP、FN、精確率、召回率 等指標(biāo), 待補(bǔ)充哈(因?yàn)闀簳r(shí)還沒用到)~
總結(jié)
到此這篇關(guān)于pytorch分類模型繪制混淆矩陣以及可視化詳?shù)奈恼戮徒榻B到這了,更多相關(guān)pytorch繪制混淆矩陣內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
淺談盤點(diǎn)5種基于Python生成的個(gè)性化語音方法
這篇文章主要介紹了淺談盤點(diǎn)5種基于Python生成的個(gè)性化語音方法,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2021-02-02用python基于appium模塊開發(fā)一個(gè)自動收取能量的小助手
大家都有了解過螞蟻森林吧,本篇文章帶給你自動收取螞蟻森林能量的思路與方法,基于appium模塊開發(fā)一個(gè)自動收取能量的小助手,本文給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的價(jià)值2021-09-09Python深度學(xué)習(xí)實(shí)戰(zhàn)PyQt5布局管理項(xiàng)目示例詳解
本文具體介紹基本的水平布局、垂直布局、柵格布局、表格布局和進(jìn)階的嵌套布局和容器布局,最后通過案例帶小白創(chuàng)建一個(gè)有型的圖形布局窗口2021-10-10詳解python函數(shù)的閉包問題(內(nèi)部函數(shù)與外部函數(shù)詳述)
這篇文章主要介紹了python函數(shù)的閉包問題,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-05-05python使用pypdf2實(shí)現(xiàn)pdf文檔解密
利用pypdf2完成pdf的解密,這里的事例是python3環(huán)境下的,當(dāng)然python2下也可以運(yùn)行,只需要修改名稱即可,文中通過代碼示例給大家介紹的非常詳細(xì),需要的朋友可以參考下2023-12-12python 生成器協(xié)程運(yùn)算實(shí)例
下面小編就為大家?guī)硪黄猵ython 生成器協(xié)程運(yùn)算實(shí)例。小編覺得挺不錯(cuò)的,現(xiàn)在就分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧2017-09-09