四種Python機(jī)器學(xué)習(xí)超參數(shù)搜索方法總結(jié)
在建模時(shí)模型的超參數(shù)對(duì)精度有一定的影響,而設(shè)置和調(diào)整超參數(shù)的取值,往往稱為調(diào)參。
在實(shí)踐中調(diào)參往往依賴人工來進(jìn)行設(shè)置調(diào)整范圍,然后使用機(jī)器在超參數(shù)范圍內(nèi)進(jìn)行搜素。本文將演示在sklearn中支持的四種基礎(chǔ)超參數(shù)搜索方法:
- GridSearch
- RandomizedSearch
- HalvingGridSearch
- HalvingRandomSearch
原始模型
作為精度對(duì)比,我們最開始使用隨機(jī)森林來訓(xùn)練初始化模型,并在測(cè)試集計(jì)算精度:
# 數(shù)據(jù)讀取 df = pd.read_csv('https://mirror.coggle.club/dataset/heart.csv') X = df.drop(columns=['output']) y = df['output'] # 數(shù)據(jù)劃分 x_train, x_test, y_train, y_test = train_test_split(X, y, stratify=y) # 模型訓(xùn)練與計(jì)算準(zhǔn)確率 clf = RandomForestClassifier(random_state=0) clf.fit(x_train, y_train) clf.score(x_test, y_test)
模型最終在測(cè)試集精度為:0.802。
GridSearch
GridSearch是比較基礎(chǔ)的超參數(shù)搜索方法,中文名字網(wǎng)格搜索。其原理是在計(jì)算的過程中遍歷所有的超參數(shù)組合,然后搜索到最優(yōu)的結(jié)果。
如下代碼所示,我們對(duì)4個(gè)超參數(shù)進(jìn)行搜索,搜索空間為 5 * 3 * 2 * 3 = 90組超參數(shù)。對(duì)于每組超參數(shù)還需要計(jì)算5折交叉驗(yàn)證,則需要訓(xùn)練450次。
parameters = { 'max_depth': [2,4,5,6,7], 'min_samples_leaf': [1,2,3], 'min_weight_fraction_leaf': [0, 0.1], 'min_impurity_decrease': [0, 0.1, 0.2] } # Fitting 5 folds for each of 90 candidates, totalling 450 fits clf = GridSearchCV( RandomForestClassifier(random_state=0), parameters, refit=True, verbose=1, ) clf.fit(x_train, y_train) clf.best_estimator_.score(x_test, y_test)
模型最終在測(cè)試集精度為:0.815。
RandomizedSearch
RandomizedSearch是在一定范圍內(nèi)進(jìn)行搜索,且需要設(shè)置搜索的次數(shù),其默認(rèn)不會(huì)對(duì)所有的組合進(jìn)行搜索。
n_iter代表超參數(shù)組合的個(gè)數(shù),默認(rèn)會(huì)設(shè)置比所有組合次數(shù)少的取值,如下面設(shè)置的為10,則只進(jìn)行50次訓(xùn)練。
parameters = { 'max_depth': [2,4,5,6,7], 'min_samples_leaf': [1,2,3], 'min_weight_fraction_leaf': [0, 0.1], 'min_impurity_decrease': [0, 0.1, 0.2] } clf = RandomizedSearchCV( RandomForestClassifier(random_state=0), parameters, refit=True, verbose=1, n_iter=10, ) clf.fit(x_train, y_train) clf.best_estimator_.score(x_test, y_test)
模型最終在測(cè)試集精度為:0.815。
HalvingGridSearch
HalvingGridSearch和GridSearch非常相似,但在迭代的過程中是有參數(shù)組合減半的操作。
最開始使用所有的超參數(shù)組合,但使用最少的數(shù)據(jù),篩選其中最優(yōu)的超參數(shù),增加數(shù)據(jù)再進(jìn)行篩選。
HalvingGridSearch的思路和hyperband的思路非常相似,但是最樸素的實(shí)現(xiàn)。先使用少量數(shù)據(jù)篩選超參數(shù)組合,然后使用更多的數(shù)據(jù)驗(yàn)證精度。
n_iterations: 3 n_required_iterations: 5 n_possible_iterations: 3 min_resources_: 20 max_resources_: 227 aggressive_elimination: False factor: 3 ---------- iter: 0 n_candidates: 90 n_resources: 20 Fitting 5 folds for each of 90 candidates, totalling 450 fits ---------- iter: 1 n_candidates: 30 n_resources: 60 Fitting 5 folds for each of 30 candidates, totalling 150 fits ---------- iter: 2 n_candidates: 10 n_resources: 180 Fitting 5 folds for each of 10 candidates, totalling 50 fits ----------
模型最終在測(cè)試集精度為:0.855。
HalvingRandomSearch
HalvingRandomSearch和HalvingGridSearch類似,都是逐步增加樣本,減少超參數(shù)組合。但每次生成超參數(shù)組合,都是隨機(jī)篩選的。
n_iterations: 3 n_required_iterations: 3 n_possible_iterations: 3 min_resources_: 20 max_resources_: 227 aggressive_elimination: False factor: 3 ---------- iter: 0 n_candidates: 11 n_resources: 20 Fitting 5 folds for each of 11 candidates, totalling 55 fits ---------- iter: 1 n_candidates: 4 n_resources: 60 Fitting 5 folds for each of 4 candidates, totalling 20 fits ---------- iter: 2 n_candidates: 2 n_resources: 180 Fitting 5 folds for each of 2 candidates, totalling 10 fits
模型最終在測(cè)試集精度為:0.828。
總結(jié)與對(duì)比
HalvingGridSearch和HalvingRandomSearch比較適合在數(shù)據(jù)量比較大的情況使用,可以提高訓(xùn)練速度。如果計(jì)算資源充足,GridSearch和HalvingGridSearch會(huì)得到更好的結(jié)果。
后續(xù)我們將分享其他的一些高階調(diào)參庫(kù)的實(shí)現(xiàn),其中也會(huì)有數(shù)據(jù)量改變的思路。如在Optuna中,核心是參數(shù)組合的生成和剪枝、訓(xùn)練的樣本增加等細(xì)節(jié)。
到此這篇關(guān)于四種Python機(jī)器學(xué)習(xí)超參數(shù)搜索方法總結(jié)的文章就介紹到這了,更多相關(guān)Python超參數(shù)搜索內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
pycharm打開長(zhǎng)代碼文件CPU占用率過高的解決
這篇文章主要介紹了pycharm打開長(zhǎng)代碼文件CPU占用率過高的解決方案,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-09-09Python 相對(duì)路徑和絕對(duì)路徑及寫法演示
這篇文章主要介紹了Python 相對(duì)路徑絕對(duì)路徑的相關(guān)知識(shí),結(jié)合實(shí)例代碼介紹了Python 相對(duì)路徑、絕對(duì)路徑的寫法實(shí)例演示,需要的朋友可以參考下2023-02-02Python開發(fā)的HTTP庫(kù)requests詳解
Requests是用Python語(yǔ)言編寫,基于urllib,采用Apache2 Licensed開源協(xié)議的HTTP庫(kù)。它比urllib更加方便,可以節(jié)約我們大量的工作,完全滿足HTTP測(cè)試需求。Requests的哲學(xué)是以PEP 20 的習(xí)語(yǔ)為中心開發(fā)的,所以它比urllib更加Pythoner。更重要的一點(diǎn)是它支持Python3哦!2017-08-08python Pexpect 實(shí)現(xiàn)輸密碼 scp 拷貝的方法
今天小編就為大家分享一篇python Pexpect 實(shí)現(xiàn)輸密碼 scp 拷貝的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2019-01-01Python?pandas按行、按列遍歷DataFrame的幾種方式
在python的DataFrame中,因?yàn)閿?shù)據(jù)中可以有多個(gè)行和列,而且每行代表一個(gè)數(shù)據(jù)樣本,我們可以將DataFrame看作數(shù)據(jù)表,那你知道如何按照數(shù)據(jù)表中的行遍歷嗎,下面這篇文章主要給大家介紹了關(guān)于Python?pandas按行、按列遍歷DataFrame的幾種方式,需要的朋友可以參考下2022-09-09