欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Python利用Pytorch實現繪制ROC與PR曲線圖

 更新時間:2022年12月30日 09:14:20   作者:Vertira  
這篇文章主要和大家分享一下Python利用Pytorch實現繪制ROC與PR曲線圖的相關代碼,文中的示例代碼講解詳細,具有一定的借鑒價值,需要的可以參考一下

Pytorch 多分類模型繪制 ROC, PR 曲線(代碼 親測 可用)

ROC曲線

示例代碼

import torch
import torch.nn as nn
import os
import numpy as np
from torchvision.datasets import ImageFolder
from utils.transform import get_transform_for_test
from senet.se_resnet import FineTuneSEResnet50
from scipy import interp
import matplotlib.pyplot as plt
from itertools import cycle
from sklearn.metrics import roc_curve, auc, f1_score, precision_recall_curve, average_precision_score
 
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
 
data_root = r'D:\TJU\GBDB\set113\set113_images\test1'    # 測試集路徑
test_weights_path = r"C:\Users\admin\Desktop\fsdownload\epoch_0278_top1_70.565_'checkpoint.pth.tar'"    # 預訓練模型參數
num_class = 113    # 類別數量
gpu = "cuda:0"  
 
 
# mean=[0.948078, 0.93855226, 0.9332005], var=[0.14589554, 0.17054074, 0.18254866]
def test(model, test_path):
    # 加載測試集和預訓練模型參數
    test_dir = os.path.join(data_root, 'test_images')
    class_list = list(os.listdir(test_dir))
    class_list.sort()
    transform_test = get_transform_for_test(mean=[0.948078, 0.93855226, 0.9332005],
                                            var=[0.14589554, 0.17054074, 0.18254866])
    test_dataset = ImageFolder(test_dir, transform=transform_test)
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=1, shuffle=False, drop_last=False, pin_memory=True, num_workers=1)
    checkpoint = torch.load(test_path)
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
 
    score_list = []     # 存儲預測得分
    label_list = []     # 存儲真實標簽
    for i, (inputs, labels) in enumerate(test_loader):
        inputs = inputs.cuda()
        labels = labels.cuda()
 
        outputs = model(inputs)
        # prob_tmp = torch.nn.Softmax(dim=1)(outputs) # (batchsize, nclass)
        score_tmp = outputs  # (batchsize, nclass)
 
        score_list.extend(score_tmp.detach().cpu().numpy())
        label_list.extend(labels.cpu().numpy())
 
    score_array = np.array(score_list)
    # 將label轉換成onehot形式
    label_tensor = torch.tensor(label_list)
    label_tensor = label_tensor.reshape((label_tensor.shape[0], 1))
    label_onehot = torch.zeros(label_tensor.shape[0], num_class)
    label_onehot.scatter_(dim=1, index=label_tensor, value=1)
    label_onehot = np.array(label_onehot)
 
    print("score_array:", score_array.shape)  # (batchsize, classnum)
    print("label_onehot:", label_onehot.shape)  # torch.Size([batchsize, classnum])
 
    # 調用sklearn庫,計算每個類別對應的fpr和tpr
    fpr_dict = dict()
    tpr_dict = dict()
    roc_auc_dict = dict()
    for i in range(num_class):
        fpr_dict[i], tpr_dict[i], _ = roc_curve(label_onehot[:, i], score_array[:, i])
        roc_auc_dict[i] = auc(fpr_dict[i], tpr_dict[i])
    # micro
    fpr_dict["micro"], tpr_dict["micro"], _ = roc_curve(label_onehot.ravel(), score_array.ravel())
    roc_auc_dict["micro"] = auc(fpr_dict["micro"], tpr_dict["micro"])
 
    # macro
    # First aggregate all false positive rates
    all_fpr = np.unique(np.concatenate([fpr_dict[i] for i in range(num_class)]))
    # Then interpolate all ROC curves at this points
    mean_tpr = np.zeros_like(all_fpr)
    for i in range(num_class):
        mean_tpr += interp(all_fpr, fpr_dict[i], tpr_dict[i])
    # Finally average it and compute AUC
    mean_tpr /= num_class
    fpr_dict["macro"] = all_fpr
    tpr_dict["macro"] = mean_tpr
    roc_auc_dict["macro"] = auc(fpr_dict["macro"], tpr_dict["macro"])
 
    # 繪制所有類別平均的roc曲線
    plt.figure()
    lw = 2
    plt.plot(fpr_dict["micro"], tpr_dict["micro"],
             label='micro-average ROC curve (area = {0:0.2f})'
                   ''.format(roc_auc_dict["micro"]),
             color='deeppink', linestyle=':', linewidth=4)
 
    plt.plot(fpr_dict["macro"], tpr_dict["macro"],
             label='macro-average ROC curve (area = {0:0.2f})'
                   ''.format(roc_auc_dict["macro"]),
             color='navy', linestyle=':', linewidth=4)
 
    colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
    for i, color in zip(range(num_class), colors):
        plt.plot(fpr_dict[i], tpr_dict[i], color=color, lw=lw,
                 label='ROC curve of class {0} (area = {1:0.2f})'
                       ''.format(i, roc_auc_dict[i]))
    plt.plot([0, 1], [0, 1], 'k--', lw=lw)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Some extension of Receiver operating characteristic to multi-class')
    plt.legend(loc="lower right")
    plt.savefig('set113_roc.jpg')
    plt.show()
 
 
if __name__ == '__main__':
    # 加載模型
    seresnet = FineTuneSEResnet50(num_class=num_class)
    device = torch.device(gpu)
    seresnet = seresnet.to(device)
    test(seresnet, test_weights_path)

運行結果:

PR曲線

示例代碼

import torch
import torch.nn as nn
import os
import numpy as np
from torchvision.datasets import ImageFolder
from utils.transform import get_transform_for_test
from senet.se_resnet import FineTuneSEResnet50
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, f1_score, precision_recall_curve, average_precision_score
 
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
 
data_root = r'D:\TJU\GBDB\set113\set113_images\test1'    # 測試集路徑
test_weights_path = r"C:\Users\admin\Desktop\fsdownload\epoch_0278_top1_70.565_'checkpoint.pth.tar'"    # 預訓練模型參數
num_class = 113    # 類別數量
gpu = "cuda:0"    
 
 
# mean=[0.948078, 0.93855226, 0.9332005], var=[0.14589554, 0.17054074, 0.18254866]
def test(model, test_path):
    # 加載測試集和預訓練模型參數
    test_dir = os.path.join(data_root, 'test_images')
    class_list = list(os.listdir(test_dir))
    class_list.sort()
    transform_test = get_transform_for_test(mean=[0.948078, 0.93855226, 0.9332005],
                                            var=[0.14589554, 0.17054074, 0.18254866])
    test_dataset = ImageFolder(test_dir, transform=transform_test)
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=1, shuffle=False, drop_last=False, pin_memory=True, num_workers=1)
    checkpoint = torch.load(test_path)
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
 
    score_list = []     # 存儲預測得分
    label_list = []     # 存儲真實標簽
    for i, (inputs, labels) in enumerate(test_loader):
        inputs = inputs.cuda()
        labels = labels.cuda()
 
        outputs = model(inputs)
        # prob_tmp = torch.nn.Softmax(dim=1)(outputs) # (batchsize, nclass)
        score_tmp = outputs  # (batchsize, nclass)
 
        score_list.extend(score_tmp.detach().cpu().numpy())
        label_list.extend(labels.cpu().numpy())
 
    score_array = np.array(score_list)
    # 將label轉換成onehot形式
    label_tensor = torch.tensor(label_list)
    label_tensor = label_tensor.reshape((label_tensor.shape[0], 1))
    label_onehot = torch.zeros(label_tensor.shape[0], num_class)
    label_onehot.scatter_(dim=1, index=label_tensor, value=1)
    label_onehot = np.array(label_onehot)
    print("score_array:", score_array.shape)  # (batchsize, classnum) softmax
    print("label_onehot:", label_onehot.shape)  # torch.Size([batchsize, classnum]) onehot
 
    # 調用sklearn庫,計算每個類別對應的precision和recall
    precision_dict = dict()
    recall_dict = dict()
    average_precision_dict = dict()
    for i in range(num_class):
        precision_dict[i], recall_dict[i], _ = precision_recall_curve(label_onehot[:, i], score_array[:, i])
        average_precision_dict[i] = average_precision_score(label_onehot[:, i], score_array[:, i])
        print(precision_dict[i].shape, recall_dict[i].shape, average_precision_dict[i])
 
    # micro
    precision_dict["micro"], recall_dict["micro"], _ = precision_recall_curve(label_onehot.ravel(),
                                                                              score_array.ravel())
    average_precision_dict["micro"] = average_precision_score(label_onehot, score_array, average="micro")
    print('Average precision score, micro-averaged over all classes: {0:0.2f}'.format(average_precision_dict["micro"]))
 
    # 繪制所有類別平均的pr曲線
    plt.figure()
    plt.step(recall_dict['micro'], precision_dict['micro'], where='post')
 
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.ylim([0.0, 1.05])
    plt.xlim([0.0, 1.0])
    plt.title(
        'Average precision score, micro-averaged over all classes: AP={0:0.2f}'
        .format(average_precision_dict["micro"]))
    plt.savefig("set113_pr_curve.jpg")
    # plt.show()
 
 
if __name__ == '__main__':
    # 加載模型
    seresnet = FineTuneSEResnet50(num_class=num_class)
    device = torch.device(gpu)
    seresnet = seresnet.to(device)
    test(seresnet, test_weights_path)

運行結果:

到此這篇關于Python利用Pytorch實現繪制ROC與PR曲線圖的文章就介紹到這了,更多相關Python繪制ROC PR曲線圖內容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持腳本之家!

相關文章

  • Django實現表單驗證

    Django實現表單驗證

    這篇文章主要為大家詳細介紹了Django實現表單驗證的相關資料,具有一定的參考價值,感興趣的小伙伴們可以參考一下
    2018-09-09
  • 使用Django框架中ORM系統(tǒng)實現對數據庫數據增刪改查

    使用Django框架中ORM系統(tǒng)實現對數據庫數據增刪改查

    這篇文章主要介紹了使用Django的ORM實現對數據庫數據增刪改查方法,文中附含詳細示例代碼以及過程詳解,有需要的朋友可以借鑒參考下
    2021-09-09
  • Python復制文件操作實例詳解

    Python復制文件操作實例詳解

    這篇文章主要介紹了Python復制文件操作的方法,涉及Python針對文件與目錄的復制及刪除操作相關技巧,具有一定參考借鑒價值,需要的朋友可以參考下
    2015-11-11
  • numpy實現RNN原理實現

    numpy實現RNN原理實現

    這篇文章主要介紹了numpy實現RNN原理實現,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧
    2021-03-03
  • 使用Python的Zato發(fā)送AMQP消息的教程

    使用Python的Zato發(fā)送AMQP消息的教程

    這篇文章主要介紹了使用Python的Zato發(fā)送AMQP消息的教程,主要是基于一些Zato的圖形化界面進行操作,需要的朋友可以參考下
    2015-04-04
  • Pandas 內置的十種畫圖方法

    Pandas 內置的十種畫圖方法

    這篇文章主要介紹了Pandas 內置的十種畫圖方法,Pandas是非常常見的數據分析工具,我們一般都會處理好處理數據然后使用searbon或matplotlib來進行繪制
    2022-09-09
  • Python迭代器和生成器介紹

    Python迭代器和生成器介紹

    這篇文章主要介紹了Python迭代器和生成器介紹,本文分別用代碼實例講解了Python的迭代器和生成器,需要的朋友可以參考下
    2015-03-03
  • python3+selenium實現qq郵箱登陸并發(fā)送郵件功能

    python3+selenium實現qq郵箱登陸并發(fā)送郵件功能

    這篇文章主要為大家詳細介紹了python3+selenium實現qq郵箱登陸,并發(fā)送郵件功能,文中示例代碼介紹的非常詳細,具有一定的參考價值,感興趣的小伙伴們可以參考一下
    2019-01-01
  • Python實現字符串中某個字母的替代功能

    Python實現字符串中某個字母的替代功能

    小編想實現這樣一個功能:將輸入字符串中的字母 “i” 變成字母 “p”。想著很簡單,怎么實現呢?下面小編給大家?guī)砹薖ython實現字符串中某個字母的替代功能,感興趣的朋友一起看看吧
    2019-10-10
  • 如何查看Mac本機的Python3安裝路徑

    如何查看Mac本機的Python3安裝路徑

    這篇文章主要介紹了如何查看Mac本機的Python3安裝路徑問題,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教
    2023-03-03

最新評論