sklearn中的交叉驗(yàn)證的實(shí)現(xiàn)(Cross-Validation)
sklearn
是利用python進(jìn)行機(jī)器學(xué)習(xí)中一個非常全面和好用的第三方庫,用過的都說好。今天主要記錄一下sklearn
中關(guān)于交叉驗(yàn)證的各種用法,主要是對sklearn
官方文檔 Cross-validation: evaluating estimator performance進(jìn)行講解,英文水平好的建議讀官方文檔,里面的知識點(diǎn)很詳細(xì)。
先導(dǎo)入需要的庫及數(shù)據(jù)集
In [1]: import numpy as np In [2]: from sklearn.model_selection import train_test_split In [3]: from sklearn.datasets import load_iris In [4]: from sklearn import svm In [5]: iris = load_iris() In [6]: iris.data.shape, iris.target.shape Out[6]: ((150, 4), (150,))
1.train_test_split
對數(shù)據(jù)集進(jìn)行快速打亂(分為訓(xùn)練集和測試集)
這里相當(dāng)于對數(shù)據(jù)集進(jìn)行了shuffle后按照給定的test_size
進(jìn)行數(shù)據(jù)集劃分。
In [7]: X_train, X_test, y_train, y_test = train_test_split( ...: iris.data, iris.target, test_size=.4, random_state=0) #這里是按照6:4對訓(xùn)練集測試集進(jìn)行劃分 In [8]: X_train.shape, y_train.shape Out[8]: ((90, 4), (90,)) In [9]: X_test.shape, y_test.shape Out[9]: ((60, 4), (60,)) In [10]: iris.data[:5] Out[10]: array([[ 5.1, 3.5, 1.4, 0.2], [ 4.9, 3. , 1.4, 0.2], [ 4.7, 3.2, 1.3, 0.2], [ 4.6, 3.1, 1.5, 0.2], [ 5. , 3.6, 1.4, 0.2]]) In [11]: X_train[:5] Out[11]: array([[ 6. , 3.4, 4.5, 1.6], [ 4.8, 3.1, 1.6, 0.2], [ 5.8, 2.7, 5.1, 1.9], [ 5.6, 2.7, 4.2, 1.3], [ 5.6, 2.9, 3.6, 1.3]]) In [12]: clf = svm.SVC(kernel='linear', C=1).fit(X_train, y_train) In [13]: clf.score(X_test, y_test) Out[13]: 0.96666666666666667
2.cross_val_score
對數(shù)據(jù)集進(jìn)行指定次數(shù)的交叉驗(yàn)證并為每次驗(yàn)證效果評測
其中,score
默認(rèn)是以 scoring='f1_macro'進(jìn)行評測的,余外針對分類或回歸還有:
這需要from sklearn import metrics
,通過在cross_val_score
指定參數(shù)來設(shè)定評測標(biāo)準(zhǔn);
當(dāng)cv
指定為int
類型時,默認(rèn)使用KFold
或StratifiedKFold
進(jìn)行數(shù)據(jù)集打亂,下面會對KFold
和StratifiedKFold
進(jìn)行介紹。
In [15]: from sklearn.model_selection import cross_val_score In [16]: clf = svm.SVC(kernel='linear', C=1) In [17]: scores = cross_val_score(clf, iris.data, iris.target, cv=5) In [18]: scores Out[18]: array([ 0.96666667, 1. , 0.96666667, 0.96666667, 1. ]) In [19]: scores.mean() Out[19]: 0.98000000000000009
除使用默認(rèn)交叉驗(yàn)證方式外,可以對交叉驗(yàn)證方式進(jìn)行指定,如驗(yàn)證次數(shù),訓(xùn)練集測試集劃分比例等
In [20]: from sklearn.model_selection import ShuffleSplit In [21]: n_samples = iris.data.shape[0] In [22]: cv = ShuffleSplit(n_splits=3, test_size=.3, random_state=0) In [23]: cross_val_score(clf, iris.data, iris.target, cv=cv) Out[23]: array([ 0.97777778, 0.97777778, 1. ])
在cross_val_score
中同樣可使用pipeline
進(jìn)行流水線操作
In [24]: from sklearn import preprocessing In [25]: from sklearn.pipeline import make_pipeline In [26]: clf = make_pipeline(preprocessing.StandardScaler(), svm.SVC(C=1)) In [27]: cross_val_score(clf, iris.data, iris.target, cv=cv) Out[27]: array([ 0.97777778, 0.93333333, 0.95555556])
3.cross_val_predict
cross_val_predict
與cross_val_score
很相像,不過不同于返回的是評測效果,cross_val_predict
返回的是estimator
的分類結(jié)果(或回歸值),這個對于后期模型的改善很重要,可以通過該預(yù)測輸出對比實(shí)際目標(biāo)值,準(zhǔn)確定位到預(yù)測出錯的地方,為我們參數(shù)優(yōu)化及問題排查十分的重要。
In [28]: from sklearn.model_selection import cross_val_predict In [29]: from sklearn import metrics In [30]: predicted = cross_val_predict(clf, iris.data, iris.target, cv=10) In [31]: predicted Out[31]: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]) In [32]: metrics.accuracy_score(iris.target, predicted) Out[32]: 0.96666666666666667
4.KFold
K折交叉驗(yàn)證,這是將數(shù)據(jù)集分成K份的官方給定方案,所謂K折就是將數(shù)據(jù)集通過K次分割,使得所有數(shù)據(jù)既在訓(xùn)練集出現(xiàn)過,又在測試集出現(xiàn)過,當(dāng)然,每次分割中不會有重疊。相當(dāng)于無放回抽樣。
In [33]: from sklearn.model_selection import KFold In [34]: X = ['a','b','c','d'] In [35]: kf = KFold(n_splits=2) In [36]: for train, test in kf.split(X): ...: print train, test ...: print np.array(X)[train], np.array(X)[test] ...: print '\n' ...: [2 3] [0 1] ['c' 'd'] ['a' 'b'] [0 1] [2 3] ['a' 'b'] ['c' 'd']
5.LeaveOneOut
LeaveOneOut
其實(shí)就是KFold
的一個特例,因?yàn)槭褂么螖?shù)比較多,因此獨(dú)立的定義出來,完全可以通過KFold
實(shí)現(xiàn)。
In [37]: from sklearn.model_selection import LeaveOneOut In [38]: X = [1,2,3,4] In [39]: loo = LeaveOneOut() In [41]: for train, test in loo.split(X): ...: print train, test ...: [1 2 3] [0] [0 2 3] [1] [0 1 3] [2] [0 1 2] [3] #使用KFold實(shí)現(xiàn)LeaveOneOtut In [42]: kf = KFold(n_splits=len(X)) In [43]: for train, test in kf.split(X): ...: print train, test ...: [1 2 3] [0] [0 2 3] [1] [0 1 3] [2] [0 1 2] [3]
6.LeavePOut
這個也是KFold
的一個特例,用KFold
實(shí)現(xiàn)起來稍麻煩些,跟LeaveOneOut
也很像。
In [44]: from sklearn.model_selection import LeavePOut In [45]: X = np.ones(4) In [46]: lpo = LeavePOut(p=2) In [47]: for train, test in lpo.split(X): ...: print train, test ...: [2 3] [0 1] [1 3] [0 2] [1 2] [0 3] [0 3] [1 2] [0 2] [1 3] [0 1] [2 3]
7.ShuffleSplit
ShuffleSplit
咋一看用法跟LeavePOut
很像,其實(shí)兩者完全不一樣,LeavePOut
是使得數(shù)據(jù)集經(jīng)過數(shù)次分割后,所有的測試集出現(xiàn)的元素的集合即是完整的數(shù)據(jù)集,即無放回的抽樣,而ShuffleSplit
則是有放回的抽樣,只能說經(jīng)過一個足夠大的抽樣次數(shù)后,保證測試集出現(xiàn)了完成的數(shù)據(jù)集的倍數(shù)。
In [48]: from sklearn.model_selection import ShuffleSplit In [49]: X = np.arange(5) In [50]: ss = ShuffleSplit(n_splits=3, test_size=.25, random_state=0) In [51]: for train_index, test_index in ss.split(X): ...: print train_index, test_index ...: [1 3 4] [2 0] [1 4 3] [0 2] [4 0 2] [1 3]
8.StratifiedKFold
這個就比較好玩了,通過指定分組,對測試集進(jìn)行無放回抽樣。
In [52]: from sklearn.model_selection import StratifiedKFold In [53]: X = np.ones(10) In [54]: y = [0,0,0,0,1,1,1,1,1,1] In [55]: skf = StratifiedKFold(n_splits=3) In [56]: for train, test in skf.split(X,y): ...: print train, test ...: [2 3 6 7 8 9] [0 1 4 5] [0 1 3 4 5 8 9] [2 6 7] [0 1 2 4 5 6 7] [3 8 9]
9.GroupKFold
這個跟StratifiedKFold
比較像,不過測試集是按照一定分組進(jìn)行打亂的,即先分堆,然后把這些堆打亂,每個堆里的順序還是固定不變的。
In [57]: from sklearn.model_selection import GroupKFold In [58]: X = [.1, .2, 2.2, 2.4, 2.3, 4.55, 5.8, 8.8, 9, 10] In [59]: y = ['a','b','b','b','c','c','c','d','d','d'] In [60]: groups = [1,1,1,2,2,2,3,3,3,3] In [61]: gkf = GroupKFold(n_splits=3) In [62]: for train, test in gkf.split(X,y,groups=groups): ...: print train, test ...: [0 1 2 3 4 5] [6 7 8 9] [0 1 2 6 7 8 9] [3 4 5] [3 4 5 6 7 8 9] [0 1 2]
10.LeaveOneGroupOut
這個是在GroupKFold
上的基礎(chǔ)上混亂度又減小了,按照給定的分組方式將測試集分割下來。
In [63]: from sklearn.model_selection import LeaveOneGroupOut In [64]: X = [1, 5, 10, 50, 60, 70, 80] In [65]: y = [0, 1, 1, 2, 2, 2, 2] In [66]: groups = [1, 1, 2, 2, 3, 3, 3] In [67]: logo = LeaveOneGroupOut() In [68]: for train, test in logo.split(X, y, groups=groups): ...: print train, test ...: [2 3 4 5 6] [0 1] [0 1 4 5 6] [2 3] [0 1 2 3] [4 5 6]
11.LeavePGroupsOut
這個沒啥可說的,跟上面那個一樣,只是一個是單組,一個是多組
from sklearn.model_selection import LeavePGroupsOut X = np.arange(6) y = [1, 1, 1, 2, 2, 2] groups = [1, 1, 2, 2, 3, 3] lpgo = LeavePGroupsOut(n_groups=2) for train, test in lpgo.split(X, y, groups=groups): print train, test [4 5] [0 1 2 3] [2 3] [0 1 4 5] [0 1] [2 3 4 5]
12.GroupShuffleSplit
這個是有放回抽樣
In [75]: from sklearn.model_selection import GroupShuffleSplit In [76]: X = [.1, .2, 2.2, 2.4, 2.3, 4.55, 5.8, .001] In [77]: y = ['a', 'b','b', 'b', 'c','c', 'c', 'a'] In [78]: groups = [1,1,2,2,3,3,4,4] In [79]: gss = GroupShuffleSplit(n_splits=4, test_size=.5, random_state=0) In [80]: for train, test in gss.split(X, y, groups=groups): ...: print train, test ...: [0 1 2 3] [4 5 6 7] [2 3 6 7] [0 1 4 5] [2 3 4 5] [0 1 6 7] [4 5 6 7] [0 1 2 3]
13.TimeSeriesSplit
針對時間序列的處理,防止未來數(shù)據(jù)的使用,分割時是將數(shù)據(jù)進(jìn)行從前到后切割(這個說法其實(shí)不太恰當(dāng),因?yàn)榍懈钍茄永m(xù)性的。。)
In [81]: from sklearn.model_selection import TimeSeriesSplit In [82]: X = np.array([[1,2],[3,4],[1,2],[3,4],[1,2],[3,4]]) In [83]: tscv = TimeSeriesSplit(n_splits=3) In [84]: for train, test in tscv.split(X): ...: print train, test ...: [0 1 2] [3] [0 1 2 3] [4] [0 1 2 3 4] [5]
這個repo
用來記錄一些python技巧、書籍、學(xué)習(xí)鏈接等,歡迎star
github地址
相關(guān)文章
python簡單驗(yàn)證碼識別的實(shí)現(xiàn)過程
很多網(wǎng)站登錄都需要輸入驗(yàn)證碼,如果要實(shí)現(xiàn)自動登錄就不可避免的要識別驗(yàn)證碼,這篇文章主要給大家介紹了關(guān)于python簡單驗(yàn)證碼識別的實(shí)現(xiàn)過程,需要的朋友可以參考下2021-06-06Python+PyQt5實(shí)現(xiàn)滅霸響指功能
這篇文章主要介紹了Python+PyQt5實(shí)現(xiàn)滅霸響指功能,本文通過實(shí)例代碼給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友可以參考下2020-05-05