使用Python和scikit-learn創(chuàng)建混淆矩陣的示例詳解
一、混淆矩陣概述
在訓(xùn)練了有監(jiān)督的機(jī)器學(xué)習(xí)模型(例如分類(lèi)器)之后,您想知道它的工作情況。
這通常是通過(guò)將一小部分稱(chēng)為測(cè)試集的數(shù)據(jù)分開(kāi)來(lái)完成的,該數(shù)據(jù)用作模型以前從未見(jiàn)過(guò)的數(shù)據(jù)。
如果它在此數(shù)據(jù)集上表現(xiàn)良好,那么該模型很可能在其他數(shù)據(jù)上也表現(xiàn)良好 - 當(dāng)然,如果它是從與您的測(cè)試集相同的分布中采樣的。
現(xiàn)在,當(dāng)您測(cè)試您的模型時(shí),您向其提供數(shù)據(jù) - 并將預(yù)測(cè)與基本事實(shí)進(jìn)行比較,測(cè)量真陽(yáng)性、真陰性、假陽(yáng)性和假陰性的數(shù)量。這些隨后可以在視覺(jué)上吸引人的混淆矩陣中可視化。
在今天我們將學(xué)習(xí)如何使用 Scikit-learn 創(chuàng)建這樣的混淆矩陣,Scikit-learn 是當(dāng)今機(jī)器學(xué)習(xí)社區(qū)中使用最廣泛的機(jī)器學(xué)習(xí)框架之一。通過(guò)使用 Python 創(chuàng)建的示例,展示如何生成一個(gè)矩陣,您可以使用該矩陣輕松直觀地確定模型的性能。
1、示例1
一個(gè)混淆矩陣的例子
它是一個(gè)歸一化的混淆矩陣。它的描述了兩個(gè)度量:
True label,這是您的測(cè)試集所代表的基本事實(shí)。
Predicted label,即機(jī)器學(xué)習(xí)模型對(duì)與真實(shí)標(biāo)簽對(duì)應(yīng)的特征生成的預(yù)測(cè)。
例如,在上面的模型中,對(duì)于所有真實(shí)標(biāo)簽 1,預(yù)測(cè)標(biāo)簽為 1。這意味著來(lái)自第 1 類(lèi)的所有樣本都被正確分類(lèi)。
對(duì)于其他類(lèi),性能也不錯(cuò),但稍差一些。如您所見(jiàn),對(duì)于第 2 類(lèi),一些樣本被預(yù)測(cè)為第 0 類(lèi)和第 1 類(lèi)的一部分。
簡(jiǎn)而言之,它回答了“對(duì)于我的真實(shí)標(biāo)簽/基本事實(shí),模型的預(yù)測(cè)效果如何?”這個(gè)問(wèn)題。
2、示例2
也可以從預(yù)測(cè)的角度看,問(wèn)題將變?yōu)?ldquo;對(duì)于我的預(yù)測(cè)標(biāo)簽,有多少預(yù)測(cè)實(shí)際上是預(yù)測(cè)類(lèi)別的一部分?”。這是相反的觀點(diǎn),但在許多機(jī)器學(xué)習(xí)案例中可能是一個(gè)有意義的問(wèn)題。
最優(yōu)情況,是整個(gè)真實(shí)標(biāo)簽集等于預(yù)測(cè)標(biāo)簽集。在這些情況下,除了從左上角到右下角的線之外,您會(huì)在各處看到零。然而,在實(shí)踐中,這種情況并不經(jīng)常發(fā)生。很可能更加分散,例如下面這個(gè) SVM 分類(lèi)器,其中需要許多支持向量來(lái)繪制不能完美工作但足夠充分的決策邊界:
二、使用Scikit-learn 創(chuàng)建混淆矩陣
現(xiàn)在創(chuàng)建一個(gè)混淆矩陣。將使用 Python 和 Scikit-learn。
創(chuàng)建混淆矩陣涉及多個(gè)步驟:
1、生成示例數(shù)據(jù)集。需要數(shù)據(jù)來(lái)訓(xùn)練我們的模型。因此,我們將首先生成數(shù)據(jù),以便我們接下來(lái)可以為 ML 模型類(lèi)做出適當(dāng)?shù)倪x擇。
2、選擇機(jī)器學(xué)習(xí)模型類(lèi)。顯然,如果我們要評(píng)估一個(gè)模型,我們需要訓(xùn)練一個(gè)模型。我們將首先選擇適合我們數(shù)據(jù)特征的特定類(lèi)型的模型。
3、構(gòu)建和訓(xùn)練 ML 模型。前兩個(gè)步驟的結(jié)果是我們最終得到了一個(gè)訓(xùn)練有素的模型。
4、生成混淆矩陣。最后,基于訓(xùn)練好的模型,我們可以創(chuàng)建我們的混淆矩陣。
1、相應(yīng)軟件包
需要以下包,假定已經(jīng)安裝好了Python環(huán)境、Scikit-learn、Numpy、Matplotlib、Mlxtend
2、生成示例數(shù)據(jù)集
第一步是生成示例數(shù)據(jù)集。我們也將為此目的使用 Scikit-learn。首先,創(chuàng)建一個(gè)名為 的文件confusion-matrix.py
。
(1)導(dǎo)入相關(guān)的包
# Imports from sklearn.datasets import make_blobs from sklearn.model_selection import train_test_split import numpy as np import matplotlib.pyplot as plt
Scikit-learn的make_blobs
功能可以生成樣本的“blob”或集群。這些斑點(diǎn)以某個(gè)點(diǎn)為中心,并且樣本基于某個(gè)標(biāo)準(zhǔn)偏差分散在該點(diǎn)周?chē)_@使您可以靈活地確定生成的數(shù)據(jù)集的位置和結(jié)構(gòu),從而使您可以試驗(yàn)各種 ML 模型。
在評(píng)估模型時(shí),我們需要確保數(shù)據(jù)集在訓(xùn)練數(shù)據(jù)和測(cè)試數(shù)據(jù)之間進(jìn)行分割。Scikit-learn使用train_test_split
函數(shù)實(shí)現(xiàn)分割。
(2)相關(guān)配置
# Configuration options blobs_random_seed = 42 centers = [(0,0), (5,5), (0,5), (2,3)] cluster_std = 1.3 frac_test_split = 0.33 num_features_for_samples = 4 num_samples_total = 5000
隨機(jī)種子描述了用于生成數(shù)據(jù)塊的偽隨機(jī)數(shù)生成器的初始化。您可能知道,沒(méi)有隨機(jī)數(shù)生成器是真正隨機(jī)的。更重要的是,它們的初始化方式也不同。配置固定種子可確保每次運(yùn)行腳本時(shí),隨機(jī)數(shù)生成器都以相同的方式初始化。如果出現(xiàn)奇怪的行為,您就知道它可能不是隨機(jī)數(shù)生成器。
中心描述了我們數(shù)據(jù)塊的二維空間中的中心。如您所見(jiàn),我們今天有 4 個(gè) blob。
聚類(lèi)標(biāo)準(zhǔn)差描述了從隨機(jī)點(diǎn)生成器使用的抽樣分布中抽取樣本的標(biāo)準(zhǔn)差。我們將其設(shè)置為 1.3;較低的數(shù)字會(huì)產(chǎn)生更好分離的集群,反之亦然。
訓(xùn)練/測(cè)試拆分的比例決定了為了測(cè)試目的拆分了多少數(shù)據(jù)。在我們的例子中,這是 33% 的數(shù)據(jù)。
我們樣本的特征數(shù)量是 4,并且確實(shí)描述了我們有多少目標(biāo):4,因?yàn)槲覀冇?4 個(gè)數(shù)據(jù)塊。
最后,生成的樣本數(shù)量。我們將其設(shè)置為 5000 個(gè)樣本。
(3)生成數(shù)據(jù)
# Generate data inputs, targets = make_blobs(n_samples = num_samples_total, centers = centers, n_features = num_features_for_samples, cluster_std = cluster_std) X_train, X_test, y_train, y_test = train_test_split(inputs, targets, test_size=frac_test_split, random_state=blobs_random_seed)
(4)保存數(shù)據(jù)(可選)
# Save and load temporarily np.save('./data_cf.npy', (X_train, X_test, y_train, y_test)) X_train, X_test, y_train, y_test = np.load('./data_cf.npy', allow_pickle=True)
(5)可視化數(shù)據(jù)
# Generate scatter plot for training data plt.scatter(X_train[:,0], X_train[:,1]) plt.title('Linearly separable data') plt.xlabel('X1') plt.ylabel('X2') plt.show()
3、訓(xùn)練一個(gè)SVM
(1)導(dǎo)入相關(guān)包
from sklearn import svm from sklearn.metrics import plot_confusion_matrix from mlxtend.plotting import plot_decision_regions
(2)訓(xùn)練分類(lèi)器
# Initialize SVM classifier clf = svm.SVC(kernel='linear') # 擬合數(shù)據(jù) clf = clf.fit(X_train, y_train)
4、生成混淆矩陣
它是評(píng)估步驟的一部分,我們用它來(lái)可視化它在測(cè)試集上的預(yù)測(cè)和泛化能力。
使用plot_confusion_matrix
調(diào)用為我們解決了這個(gè)問(wèn)題,我們只需向它提供分類(lèi)器 (clf
)、測(cè)試集 (X_test
和y_test
)、顏色圖以及是否對(duì)數(shù)據(jù)進(jìn)行歸一化。
# Generate confusion matrix matrix = plot_confusion_matrix(clf, X_test, y_test, cmap=plt.cm.Blues, normalize='true') plt.title('Confusion matrix for our classifier') plt.show(matrix) plt.show()
5、可視化邊界
如果要生成邊界圖,需要安裝 Mlxtend
# Get support vectors support_vectors = clf.support_vectors_ # Visualize support vectors plt.scatter(X_train[:,0], X_train[:,1]) plt.scatter(support_vectors[:,0], support_vectors[:,1], color='red') plt.title('Linearly separable data with support vectors') plt.xlabel('X1') plt.ylabel('X2') plt.show() # Plot decision boundary plot_decision_regions(X_test, y_test, clf=clf, legend=2) plt.show()
唯一表現(xiàn)不佳的班級(jí)是第 3 類(lèi),得分為 0.68。這可以通過(guò)查看決策邊界圖中的類(lèi)來(lái)解釋。在這里,由于這些樣本被其他樣本包圍,很明顯模型在生成決策邊界時(shí)遇到了很大的困難。例如,我們可以通過(guò)使用考慮到這一點(diǎn)的不同內(nèi)核函數(shù)來(lái)解決這個(gè)問(wèn)題,從而確保更好的可分離性。
? 以上就是我們使用 Python 和 Scikit-learn 創(chuàng)建了一個(gè)混淆矩陣。在研究了混淆矩陣是什么,以及它如何顯示真陽(yáng)性、真陰性、假陽(yáng)性和假陰性之后,我們給出了一個(gè)自己創(chuàng)建示例。
該示例包括生成數(shù)據(jù)集、為數(shù)據(jù)集選擇合適的機(jī)器學(xué)習(xí)模型、構(gòu)建、配置和訓(xùn)練它,最后解釋結(jié)果,即混淆矩陣。
到此這篇關(guān)于使用Python和scikit-learn創(chuàng)建混淆矩陣的文章就介紹到這了,更多相關(guān)Python和scikit-learn混淆矩陣內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
python3+PyQt5圖形項(xiàng)的自定義和交互 python3實(shí)現(xiàn)page Designer應(yīng)用程序
這篇文章主要為大家詳細(xì)介紹了python3+PyQt5圖形項(xiàng)的自定義和交互,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-04-04python EasyOCR庫(kù)實(shí)例用法介紹
在本篇文章里小編給大家整理的是一篇關(guān)于python EasyOCR庫(kù)實(shí)例用法介紹,有需要的朋友們可以跟著學(xué)習(xí)下。2021-07-07Python多進(jìn)程方式抓取基金網(wǎng)站內(nèi)容的方法分析
這篇文章主要介紹了Python多進(jìn)程方式抓取基金網(wǎng)站內(nèi)容的方法,結(jié)合實(shí)例形式分析了Python多進(jìn)程抓取網(wǎng)站內(nèi)容相關(guān)實(shí)現(xiàn)技巧與操作注意事項(xiàng),需要的朋友可以參考下2019-06-06Python測(cè)試框架:pytest學(xué)習(xí)筆記
這篇文章主要介紹了Python測(cè)試框架:pytest的相關(guān)資料,幫助大家更好的利用python進(jìn)行單元測(cè)試,感興趣的朋友可以了解下2020-10-10