在pytorch 中計(jì)算精度、回歸率、F1 score等指標(biāo)的實(shí)例
pytorch中訓(xùn)練完網(wǎng)絡(luò)后,需要對(duì)學(xué)習(xí)的結(jié)果進(jìn)行測(cè)試。官網(wǎng)上例程用的方法統(tǒng)統(tǒng)都是正確率,使用的是torch.eq()這個(gè)函數(shù)。
但是為了更精細(xì)的評(píng)價(jià)結(jié)果,我們還需要計(jì)算其他各個(gè)指標(biāo)。在把官網(wǎng)API翻了一遍之后發(fā)現(xiàn)并沒(méi)有用于計(jì)算TP,TN,F(xiàn)P,F(xiàn)N的函數(shù)。。。
在動(dòng)了無(wú)數(shù)歪腦筋之后,心想pytorch完全支持numpy,那能不能直接進(jìn)行判斷,試了一下果然可以,上代碼:
# TP predict 和 label 同時(shí)為1 TP += ((pred_choice == 1) & (target.data == 1)).cpu().sum() # TN predict 和 label 同時(shí)為0 TN += ((pred_choice == 0) & (target.data == 0)).cpu().sum() # FN predict 0 label 1 FN += ((pred_choice == 0) & (target.data == 1)).cpu().sum() # FP predict 1 label 0 FP += ((pred_choice == 1) & (target.data == 0)).cpu().sum() p = TP / (TP + FP) r = TP / (TP + FN) F1 = 2 * r * p / (r + p) acc = (TP + TN) / (TP + TN + FP + FN
這樣就能看到各個(gè)指標(biāo)了。
因?yàn)閠arget是Variable所以需要用target.data取到對(duì)應(yīng)的tensor,又因?yàn)槭窃趃pu上算的,需要用 .cpu() 移到cpu上。
因?yàn)檫@是一個(gè)batch的統(tǒng)計(jì),所以需要用+=累計(jì)出整個(gè)epoch的統(tǒng)計(jì)。當(dāng)然,在epoch開(kāi)始之前需要清零
以上這篇在pytorch 中計(jì)算精度、回歸率、F1 score等指標(biāo)的實(shí)例就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
PyTorch中torch.nn.Linear實(shí)例詳解
torch.nn是包含了構(gòu)筑神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)基本元素的包,在這個(gè)包中可以找到任意的神經(jīng)網(wǎng)絡(luò)層,下面這篇文章主要給大家介紹了關(guān)于PyTorch中torch.nn.Linear的相關(guān)資料,文中通過(guò)實(shí)例代碼介紹的非常詳細(xì),需要的朋友可以參考下2022-06-06python實(shí)現(xiàn)多線程及線程間通信的簡(jiǎn)單方法
這篇文章主要為大家介紹了python實(shí)現(xiàn)多線程及線程間通信的簡(jiǎn)單方法示例,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-07-07詳解通過(guò)API管理或定制開(kāi)發(fā)ECS實(shí)例
在本文里我們給大家整理了關(guān)于通過(guò)API管理或定制開(kāi)發(fā)ECS的相關(guān)實(shí)例內(nèi)容,有需要的朋友們參考學(xué)習(xí)下。2018-09-09python?import模塊時(shí)有錯(cuò)誤紅線的原因
這篇文章主要介紹了python?import模塊時(shí)有錯(cuò)誤紅線的原因及解決,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2022-02-02實(shí)例解析Python設(shè)計(jì)模式編程之橋接模式的運(yùn)用
這篇文章主要介紹了Python設(shè)計(jì)模式編程之橋接模式的運(yùn)用,橋接模式主張把抽象部分與它的實(shí)現(xiàn)部分分離,需要的朋友可以參考下2016-03-03