淺談keras中自定義二分類(lèi)任務(wù)評(píng)價(jià)指標(biāo)metrics的方法以及代碼
對(duì)于二分類(lèi)任務(wù),keras現(xiàn)有的評(píng)價(jià)指標(biāo)只有binary_accuracy,即二分類(lèi)準(zhǔn)確率,但是評(píng)估模型的性能有時(shí)需要一些其他的評(píng)價(jià)指標(biāo),例如精確率,召回率,F(xiàn)1-score等等,因此需要使用keras提供的自定義評(píng)價(jià)函數(shù)功能構(gòu)建出針對(duì)二分類(lèi)任務(wù)的各類(lèi)評(píng)價(jià)指標(biāo)。
keras提供的自定義評(píng)價(jià)函數(shù)功能需要以如下兩個(gè)張量作為輸入,并返回一個(gè)張量作為輸出。
y_true:數(shù)據(jù)集真實(shí)值組成的一階張量。
y_pred:數(shù)據(jù)集輸出值組成的一階張量。
tf.round()可對(duì)張量四舍五入,因此tf.round(y_pred)即是預(yù)測(cè)值張量。
1-tf.round(y_pred)即是預(yù)測(cè)值張量取反。
1-y_true即是真實(shí)值張量取反。
tf.reduce_sum()可對(duì)張量求和。
由此可以根據(jù)定義構(gòu)建出四個(gè)基礎(chǔ)指標(biāo)TP、TN、FP、FN,然后進(jìn)一步構(gòu)建出進(jìn)階指標(biāo)precision、recall、F1score,最后在編譯階段引用上述自定義評(píng)價(jià)指標(biāo)即可。
keras中自定義二分類(lèi)任務(wù)常用評(píng)價(jià)指標(biāo)及其引用的代碼如下
import tensorflow as tf #精確率評(píng)價(jià)指標(biāo) def metric_precision(y_true,y_pred): TP=tf.reduce_sum(y_true*tf.round(y_pred)) TN=tf.reduce_sum((1-y_true)*(1-tf.round(y_pred))) FP=tf.reduce_sum((1-y_true)*tf.round(y_pred)) FN=tf.reduce_sum(y_true*(1-tf.round(y_pred))) precision=TP/(TP+FP) return precision #召回率評(píng)價(jià)指標(biāo) def metric_recall(y_true,y_pred): TP=tf.reduce_sum(y_true*tf.round(y_pred)) TN=tf.reduce_sum((1-y_true)*(1-tf.round(y_pred))) FP=tf.reduce_sum((1-y_true)*tf.round(y_pred)) FN=tf.reduce_sum(y_true*(1-tf.round(y_pred))) recall=TP/(TP+FN) return recall #F1-score評(píng)價(jià)指標(biāo) def metric_F1score(y_true,y_pred): TP=tf.reduce_sum(y_true*tf.round(y_pred)) TN=tf.reduce_sum((1-y_true)*(1-tf.round(y_pred))) FP=tf.reduce_sum((1-y_true)*tf.round(y_pred)) FN=tf.reduce_sum(y_true*(1-tf.round(y_pred))) precision=TP/(TP+FP) recall=TP/(TP+FN) F1score=2*precision*recall/(precision+recall) return F1score #編譯階段引用自定義評(píng)價(jià)指標(biāo)示例 model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy', metric_precision, metric_recall, metric_F1score])
補(bǔ)充知識(shí):keras sklearn下兩分類(lèi)/多分類(lèi)的技術(shù)雜談(交叉驗(yàn)證和評(píng)價(jià)指標(biāo))
一.前言
這篇博客是為了記錄論文補(bǔ)充實(shí)驗(yàn)中所遇到的問(wèn)題,以及解決方法,主要以程序的形式呈現(xiàn)。
二.對(duì)象
深度學(xué)習(xí)框架:keras
研究對(duì)象:兩分類(lèi)/多分類(lèi)
三.技術(shù)雜談
1.K-FOLD交叉驗(yàn)證
1.概念
對(duì)一個(gè)模型進(jìn)行K次訓(xùn)練,每次訓(xùn)練將整個(gè)數(shù)據(jù)集分為隨機(jī)的K份,K-1作為訓(xùn)練集,剩余的1份作為驗(yàn)證集,每次訓(xùn)練結(jié)束將驗(yàn)證集上的性能指標(biāo)保存下來(lái),最后對(duì)K個(gè)結(jié)果進(jìn)行平均得到最終的模型性能指標(biāo)。
2.優(yōu)缺點(diǎn)
優(yōu)點(diǎn):模型評(píng)估更加魯棒
缺點(diǎn):訓(xùn)練時(shí)間加大
3.代碼
① sklearn與keras獨(dú)立使用
from sklearn.model_selection import StratifiedKFold
import numpy
seed = 7 # 隨機(jī)種子
numpy.random.seed(seed) # 生成固定的隨機(jī)數(shù)
num_k = 5 # 多少折
# 整個(gè)數(shù)據(jù)集(自己定義)
X =
Y =
kfold = StratifiedKFold(n_splits=num_k, shuffle=True, random_state=seed) # 分層K折,保證類(lèi)別比例一致
cvscores = []
for train, test in kfold.split(X, Y):
# 可以用sequential或者function的方式建模(自己定義)
model =
model.compile() # 自定義
# 模型訓(xùn)練
model.fit(X[train], Y[train], epochs=150, batch_size=10, verbose=0)
# 模型測(cè)試
scores = model.evaluate(X[test], Y[test], verbose=0)
print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100)) # 打印出驗(yàn)證集準(zhǔn)確率
cvscores.append(scores[1] * 100)
print("%.2f%% (+/- %.2f%%)" % (numpy.mean(cvscores), numpy.std(cvscores))) # 輸出k-fold的模型平均和標(biāo)準(zhǔn)差結(jié)果
② sklearn與keras結(jié)合使用
from keras.wrappers.scikit_learn import KerasClassifier # 使用keras下的sklearn API from sklearn.cross_validation import StratifiedKFold, cross_val_score import numpy as np seed = 7 # 隨機(jī)種子 numpy.random.seed(seed) # 生成固定的隨機(jī)數(shù) num_k = 5 # 多少折 # 整個(gè)數(shù)據(jù)集(自己定義) X = Y = # 創(chuàng)建模型 def model(): # 可以用sequential或者function的方式建模(自己定義) model = return model model = KerasClassifier(build_fn=model, epochs=150, batch_size=10) kfold = StratifiedKFold(Y, n_folds=num_k, shuffle=True, random_state=seed) results = cross_val_score(model, X, Y, cv=kfold) print(np.average(results)) # 輸出k-fold的模型平均結(jié)果
補(bǔ)充:引入keras的callbacks
只需要在①②中的model.fit中加入一個(gè)arg:callbacks=[keras.callbacks.ModelCheckpoint()] # 這樣可以保存下模型的權(quán)重,當(dāng)然了你也可以使用callbacks.TensorBoard保存下訓(xùn)練過(guò)程
2.二分類(lèi)/多分類(lèi)評(píng)價(jià)指標(biāo)
1.概念
二分類(lèi)就是說(shuō),一個(gè)目標(biāo)的標(biāo)簽只有兩種之一(例如:0或1,對(duì)應(yīng)的one-hot標(biāo)簽為[1,0]或[0,1])。對(duì)于這種問(wèn)題,一般可以采用softmax或者logistic回歸來(lái)完成,分別采用cross-entropy和mse損失函數(shù)來(lái)進(jìn)行網(wǎng)絡(luò)訓(xùn)練,分別輸出概率分布和單個(gè)的sigmoid預(yù)測(cè)值(0,1)。
多分類(lèi)就是說(shuō),一個(gè)目標(biāo)的標(biāo)簽是幾種之一(如:0,1,2…)
2.評(píng)價(jià)指標(biāo)
主要包含了:準(zhǔn)確率(accuracy),錯(cuò)誤率(error rate),精確率(precision),召回率(recall)= 真陽(yáng)率(TPR)= 靈敏度(sensitivity),F(xiàn)1-measure(包含了micro和macro兩種),假陽(yáng)率(FPR),特異度(specificity),ROC(receiver operation characteristic curve)(包含了micro和macro兩種),AUC(area under curve),P-R曲線(xiàn)(precision-recall),混淆矩陣
① 準(zhǔn)確率和錯(cuò)誤率
accuracy = (TP+TN)/ (P+N)或者accuracy = (TP+TN)/ (T+F)
error rate = (FP+FN) / (P+N)或者(FP+FN) / (T+F)
accuracy = 1 - error rate
可見(jiàn):準(zhǔn)確率、錯(cuò)誤率是對(duì)分類(lèi)器在整體數(shù)據(jù)上的評(píng)價(jià)指標(biāo)。
② 精確率
precision=TP /(TP+FP)
可見(jiàn):精確率是對(duì)分類(lèi)器在預(yù)測(cè)為陽(yáng)性的數(shù)據(jù)上的評(píng)價(jià)指標(biāo)。
③ 召回率/真陽(yáng)率/靈敏度
recall = TPR = sensitivity = TP/(TP+FN)
可見(jiàn):召回率/真陽(yáng)率/靈敏度是對(duì)分類(lèi)器在整個(gè)陽(yáng)性數(shù)據(jù)上的評(píng)價(jià)指標(biāo)。
④ F1-measure
F1-measure = 2 * (recall * precision / (recall + precision))
包含兩種:micro和macro(對(duì)于多類(lèi)別分類(lèi)問(wèn)題,注意區(qū)別于多標(biāo)簽分類(lèi)問(wèn)題)
1)micro
計(jì)算出所有類(lèi)別總的precision和recall,然后計(jì)算F1-measure
2)macro
計(jì)算出每一個(gè)類(lèi)的precison和recall后計(jì)算F1-measure,最后將F1-measure平均
可見(jiàn):F1-measure是對(duì)兩個(gè)矛盾指標(biāo)precision和recall的一種調(diào)和。
⑤ 假陽(yáng)率
FPR=FP / (FP+TN)
可見(jiàn):假陽(yáng)率是對(duì)分類(lèi)器在整個(gè)陰性數(shù)據(jù)上的評(píng)價(jià)指標(biāo),針對(duì)的是假陽(yáng)。
⑥ 特異度
specificity = 1- FPR
可見(jiàn):特異度是對(duì)分類(lèi)器在整個(gè)陰性數(shù)據(jù)上的評(píng)價(jià)指標(biāo),針對(duì)的是真陰。
⑦ ROC曲線(xiàn)和AUC
作用:靈敏度與特異度的綜合指標(biāo)
橫坐標(biāo):FPR/1-specificity
縱坐標(biāo):TPR/sensitivity/recall
AUC是ROC右下角的面積,越大,表示分類(lèi)器的性能越好
包含兩種:micro和macro(對(duì)于多類(lèi)別分類(lèi)問(wèn)題,注意區(qū)別于多標(biāo)簽分類(lèi)問(wèn)題)
假設(shè)一共有M個(gè)樣本,N個(gè)類(lèi)別。預(yù)測(cè)出來(lái)的概率矩陣P(M,N),標(biāo)簽矩陣L (M,N)
1)micro
根據(jù)P和L中的每一列(對(duì)整個(gè)數(shù)據(jù)集而言),計(jì)算出各閾值下的TPR和FPR,總共可以得到N組數(shù)據(jù),分別畫(huà)出N個(gè)ROC曲線(xiàn),最后取平均
2)macro
將P和L按行展開(kāi),然后轉(zhuǎn)置為兩列,最后畫(huà)出一個(gè)ROC曲線(xiàn)
⑧ P-R曲線(xiàn)
橫軸:recall
縱軸:precision
評(píng)判:1)直觀(guān)看,P-R包圍的面積越大越好,P=R的點(diǎn)越大越好;2)通過(guò)F1-measure來(lái)看
比較ROC和P-R: 當(dāng)樣本中的正、負(fù)比例不平衡的時(shí)候,ROC曲線(xiàn)基本保持不變,而P-R曲線(xiàn)變化很大,原因如下:
當(dāng)負(fù)樣本的比例增大時(shí),在召回率一定的情況下,那么表現(xiàn)較差的模型必然會(huì)召回更多的負(fù)樣本,TP降低,F(xiàn)P迅速增加(對(duì)于性能差的分類(lèi)器而言),precision就會(huì)降低,所以P-R曲線(xiàn)包圍的面積會(huì)變小。
⑨ 混淆矩陣
行表示的是樣本中的一種真類(lèi)別被預(yù)測(cè)的結(jié)果,列表示的是一種被預(yù)測(cè)的標(biāo)簽所對(duì)應(yīng)的真類(lèi)別。
3.代碼
注意:以下的代碼是合在一起寫(xiě)的,有注釋。
from sklearn import datasets
import numpy as np
from sklearn.preprocessing import label_binarize
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix, precision_score, accuracy_score,recall_score, f1_score,roc_auc_score, precision_recall_fscore_support, roc_curve, classification_report
import matplotlib.pyplot as plt
iris = datasets.load_iris()
x, y = iris.data, iris.target
print("label:", y)
n_class = len(set(iris.target))
y_one_hot = label_binarize(y, np.arange(n_class))
# alpha = np.logspace(-2, 2, 20) #設(shè)置超參數(shù)范圍
# model = LogisticRegressionCV(Cs = alpha, cv = 3, penalty = 'l2') #使用L2正則化
model = LogisticRegression() # 內(nèi)置了最大迭代次數(shù)了,可修改
model.fit(x, y)
y_score = model.predict(x) # 輸出的是整數(shù)標(biāo)簽
mean_accuracy = model.score(x, y)
print("mean_accuracy: ", mean_accuracy)
print("predict label:", y_score)
print(y_score==y)
print(y_score.shape)
y_score_pro = model.predict_proba(x) # 輸出概率
print(y_score_pro)
print(y_score_pro.shape)
y_score_one_hot = label_binarize(y_score, np.arange(n_class)) # 這個(gè)函數(shù)的輸入必須是整數(shù)的標(biāo)簽哦
print(y_score_one_hot.shape)
obj1 = confusion_matrix(y, y_score) # 注意輸入必須是整數(shù)型的,shape=(n_samples, )
print('confusion_matrix\n', obj1)
print(y)
print('accuracy:{}'.format(accuracy_score(y, y_score))) # 不存在average
print('precision:{}'.format(precision_score(y, y_score,average='micro')))
print('recall:{}'.format(recall_score(y, y_score,average='micro')))
print('f1-score:{}'.format(f1_score(y, y_score,average='micro')))
print('f1-score-for-each-class:{}'.format(precision_recall_fscore_support(y, y_score))) # for macro
# print('AUC y_pred = one-hot:{}\n'.format(roc_auc_score(y_one_hot, y_score_one_hot,average='micro'))) # 對(duì)于multi-class輸入必須是proba,所以這種是錯(cuò)誤的
# AUC值
auc = roc_auc_score(y_one_hot, y_score_pro,average='micro') # 使用micro,會(huì)計(jì)算n_classes個(gè)roc曲線(xiàn),再取平均
print("AUC y_pred = proba:", auc)
# 畫(huà)ROC曲線(xiàn)
print("one-hot label ravelled shape:", y_one_hot.ravel().shape)
fpr, tpr, thresholds = roc_curve(y_one_hot.ravel(),y_score_pro.ravel()) # ravel()表示平鋪開(kāi)來(lái),因?yàn)檩斎氲膕hape必須是(n_samples,)
print("threshold: ", thresholds)
plt.plot(fpr, tpr, linewidth = 2,label='AUC=%.3f' % auc)
plt.plot([0,1],[0,1], 'k--') # 畫(huà)一條y=x的直線(xiàn),線(xiàn)條的顏色和類(lèi)型
plt.axis([0,1.0,0,1.0]) # 限制坐標(biāo)范圍
plt.xlabel('False Postivie Rate')
plt.ylabel('True Positive Rate')
plt.legend()
plt.show()
# p-r曲線(xiàn)針對(duì)的是二分類(lèi),這里就不描述了
ans = classification_report(y, y_score,digits=5) # 小數(shù)點(diǎn)后保留5位有效數(shù)字
print(ans)
以上這篇淺談keras中自定義二分類(lèi)任務(wù)評(píng)價(jià)指標(biāo)metrics的方法以及代碼就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Django項(xiàng)目如何給數(shù)據(jù)庫(kù)添加約束
這篇文章主要介紹了Django項(xiàng)目如何給數(shù)據(jù)庫(kù)添加約束,幫助大家更好的理解和學(xué)習(xí)使用Django框架,感興趣的朋友可以了解下2021-04-04
Pyinstaller 打包發(fā)布經(jīng)驗(yàn)總結(jié)
這篇文章主要介紹了Pyinstaller 打包發(fā)布經(jīng)驗(yàn)總結(jié),使用Pyinstaller打包Python項(xiàng)目包含了大量的坑,感興趣的可以一起來(lái)了解一下2020-06-06
Python全景系列之控制流程盤(pán)點(diǎn)及進(jìn)階技巧
這篇文章主要為大家介紹了Python全景系列之控制流程盤(pán)點(diǎn)及進(jìn)階技巧詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-05-05
PHP webshell檢查工具 python實(shí)現(xiàn)代碼
Web安全應(yīng)急響應(yīng)中,不免要檢查下服務(wù)器上是否被上傳了webshell,手工檢查比較慢,就寫(xiě)了個(gè)腳本來(lái)檢查了。Windows平臺(tái)下已經(jīng)有了lake2寫(xiě)的雷克圖的了,一般的檢查也夠用了,寫(xiě)了個(gè)Linux下面的,用python寫(xiě)的。2009-09-09
python中如何利用matplotlib畫(huà)多個(gè)并列的柱狀圖
python是一個(gè)很有趣的語(yǔ)言,可以在命令行窗口運(yùn)行,下面這篇文章主要給大家介紹了關(guān)于python中如何利用matplotlib畫(huà)多個(gè)并列的柱狀圖的相關(guān)資料,需要的朋友可以參考下2022-01-01
教你使用Python畫(huà)圣誕樹(shù)做浪漫的程序員
這不是圣誕節(jié)快到了,還不用Python繪制個(gè)圣誕樹(shù)和煙花讓女朋友開(kāi)心開(kāi)心,也算是親手做的,稍稍花了點(diǎn)心思,學(xué)會(huì)了趕緊畫(huà)給你的那個(gè)她吧2022-12-12
Python3.7基于hashlib和Crypto實(shí)現(xiàn)加簽驗(yàn)簽功能(實(shí)例代碼)
這篇文章主要介紹了Python3.7基于hashlib和Crypto實(shí)現(xiàn)加簽驗(yàn)簽功能,環(huán)境是基于python3.7,本文通過(guò)實(shí)例代碼給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2019-12-12

