python實現(xiàn)AdaBoost算法的示例
代碼
'''
數(shù)據(jù)集:Mnist
訓練集數(shù)量:60000(實際使用:10000)
測試集數(shù)量:10000(實際使用:1000)
層數(shù):40
------------------------------
運行結(jié)果:
正確率:97%
運行時長:65m
'''
import time
import numpy as np
def loadData(fileName):
'''
加載文件
:param fileName:要加載的文件路徑
:return: 數(shù)據(jù)集和標簽集
'''
# 存放數(shù)據(jù)及標記
dataArr = []
labelArr = []
# 讀取文件
fr = open(fileName)
# 遍歷文件中的每一行
for line in fr.readlines():
# 獲取當前行,并按“,”切割成字段放入列表中
# strip:去掉每行字符串首尾指定的字符(默認空格或換行符)
# split:按照指定的字符將字符串切割成每個字段,返回列表形式
curLine = line.strip().split(',')
# 將每行中除標記外的數(shù)據(jù)放入數(shù)據(jù)集中(curLine[0]為標記信息)
# 在放入的同時將原先字符串形式的數(shù)據(jù)轉(zhuǎn)換為整型
# 此外將數(shù)據(jù)進行了二值化處理,大于128的轉(zhuǎn)換成1,小于的轉(zhuǎn)換成0,方便后續(xù)計算
dataArr.append([int(int(num) > 128) for num in curLine[1:]])
# 將標記信息放入標記集中
# 放入的同時將標記轉(zhuǎn)換為整型
# 轉(zhuǎn)換成二分類任務
# 標簽0設置為1,反之為-1
if int(curLine[0]) == 0:
labelArr.append(1)
else:
labelArr.append(-1)
# 返回數(shù)據(jù)集和標記
return dataArr, labelArr
def calc_e_Gx(trainDataArr, trainLabelArr, n, div, rule, D):
'''
計算分類錯誤率
:param trainDataArr:訓練數(shù)據(jù)集數(shù)字
:param trainLabelArr: 訓練標簽集數(shù)組
:param n: 要操作的特征
:param div:劃分點
:param rule:正反例標簽
:param D:權(quán)值分布D
:return:預測結(jié)果, 分類誤差率
'''
# 初始化分類誤差率為0
e = 0
# 將訓練數(shù)據(jù)矩陣中特征為n的那一列單獨剝出來做成數(shù)組。因為其他元素我們并不需要,
# 直接對龐大的訓練集進行操作的話會很慢
x = trainDataArr[:, n]
# 同樣將標簽也轉(zhuǎn)換成數(shù)組格式,x和y的轉(zhuǎn)換只是單純?yōu)榱颂岣哌\行速度
# 測試過相對直接操作而言性能提升很大
y = trainLabelArr
predict = []
# 依據(jù)小于和大于的標簽依據(jù)實際情況會不同,在這里直接進行設置
if rule == 'LisOne':
L = 1
H = -1
else:
L = -1
H = 1
# 遍歷所有樣本的特征m
for i in range(trainDataArr.shape[0]):
if x[i] < div:
# 如果小于劃分點,則預測為L
# 如果設置小于div為1,那么L就是1,
# 如果設置小于div為-1,L就是-1
predict.append(L)
# 如果預測錯誤,分類錯誤率要加上該分錯的樣本的權(quán)值(8.1式)
if y[i] != L:
e += D[i]
elif x[i] >= div:
# 與上面思想一樣
predict.append(H)
if y[i] != H:
e += D[i]
# 返回預測結(jié)果和分類錯誤率e
# 預測結(jié)果其實是為了后面做準備的,在算法8.1第四步式8.4中exp內(nèi)部有個Gx,要用在那個地方
# 以此來更新新的D
return np.array(predict), e
def createSigleBoostingTree(trainDataArr, trainLabelArr, D):
'''
創(chuàng)建單層提升樹
:param trainDataArr:訓練數(shù)據(jù)集數(shù)組
:param trainLabelArr: 訓練標簽集數(shù)組
:param D: 算法8.1中的D
:return: 創(chuàng)建的單層提升樹
'''
# 獲得樣本數(shù)目及特征數(shù)量
m, n = np.shape(trainDataArr)
# 單層樹的字典,用于存放當前層提升樹的參數(shù)
# 也可以認為該字典代表了一層提升樹
sigleBoostTree = {}
# 初始化分類誤差率,分類誤差率在算法8.1步驟(2)(b)有提到
# 誤差率最高也只能100%,因此初始化為1
sigleBoostTree['e'] = 1
# 對每一個特征進行遍歷,尋找用于劃分的最合適的特征
for i in range(n):
# 因為特征已經(jīng)經(jīng)過二值化,只能為0和1,因此分切分時分為-0.5, 0.5, 1.5三擋進行切割
for div in [-0.5, 0.5, 1.5]:
# 在單個特征內(nèi)對正反例進行劃分時,有兩種情況:
# 可能是小于某值的為1,大于某值得為-1,也可能小于某值得是-1,反之為1
# 因此在尋找最佳提升樹的同時對于兩種情況也需要遍歷運行
# LisOne:Low is one:小于某值得是1
# HisOne:High is one:大于某值得是1
for rule in ['LisOne', 'HisOne']:
# 按照第i個特征,以值div進行切割,進行當前設置得到的預測和分類錯誤率
Gx, e = calc_e_Gx(trainDataArr, trainLabelArr, i, div, rule, D)
# 如果分類錯誤率e小于當前最小的e,那么將它作為最小的分類錯誤率保存
if e < sigleBoostTree['e']:
sigleBoostTree['e'] = e
# 同時也需要存儲最優(yōu)劃分點、劃分規(guī)則、預測結(jié)果、特征索引
# 以便進行D更新和后續(xù)預測使用
sigleBoostTree['div'] = div
sigleBoostTree['rule'] = rule
sigleBoostTree['Gx'] = Gx
sigleBoostTree['feature'] = i
# 返回單層的提升樹
return sigleBoostTree
def createBosstingTree(trainDataList, trainLabelList, treeNum=50):
'''
創(chuàng)建提升樹
創(chuàng)建算法依據(jù)“8.1.2 AdaBoost算法” 算法8.1
:param trainDataList:訓練數(shù)據(jù)集
:param trainLabelList: 訓練測試集
:param treeNum: 樹的層數(shù)
:return: 提升樹
'''
# 將數(shù)據(jù)和標簽轉(zhuǎn)化為數(shù)組形式
trainDataArr = np.array(trainDataList)
trainLabelArr = np.array(trainLabelList)
# 沒增加一層數(shù)后,當前最終預測結(jié)果列表
finallpredict = [0] * len(trainLabelArr)
# 獲得訓練集數(shù)量以及特征個數(shù)
m, n = np.shape(trainDataArr)
# 依據(jù)算法8.1步驟(1)初始化D為1/N
D = [1 / m] * m
# 初始化提升樹列表,每個位置為一層
tree = []
# 循環(huán)創(chuàng)建提升樹
for i in range(treeNum):
# 得到當前層的提升樹
curTree = createSigleBoostingTree(trainDataArr, trainLabelArr, D)
# 根據(jù)式8.2計算當前層的alpha
alpha = 1 / 2 * np.log((1 - curTree['e']) / curTree['e'])
# 獲得當前層的預測結(jié)果,用于下一步更新D
Gx = curTree['Gx']
# 依據(jù)式8.4更新D
# 考慮到該式每次只更新D中的一個w,要循環(huán)進行更新知道所有w更新結(jié)束會很復雜(其實
# 不是時間上的復雜,只是讓人感覺每次單獨更新一個很累),所以該式以向量相乘的形式,
# 一個式子將所有w全部更新完。
# 該式需要線性代數(shù)基礎,如果不太熟練建議補充相關知識,當然了,單獨更新w也一點問題
# 沒有
# np.multiply(trainLabelArr, Gx):exp中的y*Gm(x),結(jié)果是一個行向量,內(nèi)部為yi*Gm(xi)
# np.exp(-1 * alpha * np.multiply(trainLabelArr, Gx)):上面求出來的行向量內(nèi)部全體
# 成員再乘以-αm,然后取對數(shù),和書上式子一樣,只不過書上式子內(nèi)是一個數(shù),這里是一個向量
# D是一個行向量,取代了式中的wmi,然后D求和為Zm
# 書中的式子最后得出來一個數(shù)w,所有數(shù)w組合形成新的D
# 這里是直接得到一個向量,向量內(nèi)元素是所有的w
# 本質(zhì)上結(jié)果是相同的
D = np.multiply(D, np.exp(-1 * alpha * np.multiply(trainLabelArr, Gx))) / sum(D)
# 在當前層參數(shù)中增加alpha參數(shù),預測的時候需要用到
curTree['alpha'] = alpha
# 將當前層添加到提升樹索引中。
tree.append(curTree)
# -----以下代碼用來輔助,可以去掉---------------
# 根據(jù)8.6式將結(jié)果加上當前層乘以α,得到目前的最終輸出預測
finallpredict += alpha * Gx
# 計算當前最終預測輸出與實際標簽之間的誤差
error = sum([1 for i in range(len(trainDataList)) if np.sign(finallpredict[i]) != trainLabelArr[i]])
# 計算當前最終誤差率
finallError = error / len(trainDataList)
# 如果誤差為0,提前退出即可,因為沒有必要再計算算了
if finallError == 0:
return tree
# 打印一些信息
print('iter:%d:%d, sigle error:%.4f, finall error:%.4f' % (i, treeNum, curTree['e'], finallError))
# 返回整個提升樹
return tree
def predict(x, div, rule, feature):
'''
輸出單獨層預測結(jié)果
:param x: 預測樣本
:param div: 劃分點
:param rule: 劃分規(guī)則
:param feature: 進行操作的特征
:return:
'''
# 依據(jù)劃分規(guī)則定義小于及大于劃分點的標簽
if rule == 'LisOne':
L = 1
H = -1
else:
L = -1
H = 1
# 判斷預測結(jié)果
if x[feature] < div:
return L
else:
return H
def test(testDataList, testLabelList, tree):
'''
測試
:param testDataList:測試數(shù)據(jù)集
:param testLabelList: 測試標簽集
:param tree: 提升樹
:return: 準確率
'''
# 錯誤率計數(shù)值
errorCnt = 0
# 遍歷每一個測試樣本
for i in range(len(testDataList)):
# 預測結(jié)果值,初始為0
result = 0
# 依據(jù)算法8.1式8.6
# 預測式子是一個求和式,對于每一層的結(jié)果都要進行一次累加
# 遍歷每層的樹
for curTree in tree:
# 獲取該層參數(shù)
div = curTree['div']
rule = curTree['rule']
feature = curTree['feature']
alpha = curTree['alpha']
# 將當前層結(jié)果加入預測中
result += alpha * predict(testDataList[i], div, rule, feature)
# 預測結(jié)果取sign值,如果大于0 sign為1,反之為0
if np.sign(result) != testLabelList[i]:
errorCnt += 1
# 返回準確率
return 1 - errorCnt / len(testDataList)
if __name__ == '__main__':
# 開始時間
start = time.time()
# 獲取訓練集
print('start read transSet')
trainDataList, trainLabelList = loadData('../Mnist/mnist_train.csv')
# 獲取測試集
print('start read testSet')
testDataList, testLabelList = loadData('../Mnist/mnist_test.csv')
# 創(chuàng)建提升樹
print('start init train')
tree = createBosstingTree(trainDataList[:10000], trainLabelList[:10000], 40)
# 測試
print('start to test')
accuracy = test(testDataList[:1000], testLabelList[:1000], tree)
print('the accuracy is:%d' % (accuracy * 100), '%')
# 結(jié)束時間
end = time.time()
print('time span:', end - start)
程序運行結(jié)果
start read transSet
start read testSet
start init train
iter:0:40, sigle error:0.0804, finall error:0.0804
iter:1:40, sigle error:0.1448, finall error:0.0804
iter:2:40, sigle error:0.1362, finall error:0.0585
iter:3:40, sigle error:0.1864, finall error:0.0667
iter:4:40, sigle error:0.2249, finall error:0.0474
iter:5:40, sigle error:0.2634, finall error:0.0437
iter:6:40, sigle error:0.2626, finall error:0.0377
iter:7:40, sigle error:0.2935, finall error:0.0361
iter:8:40, sigle error:0.3230, finall error:0.0333
iter:9:40, sigle error:0.3034, finall error:0.0361
iter:10:40, sigle error:0.3375, finall error:0.0325
iter:11:40, sigle error:0.3364, finall error:0.0340
iter:12:40, sigle error:0.3473, finall error:0.0309
iter:13:40, sigle error:0.3006, finall error:0.0294
iter:14:40, sigle error:0.3267, finall error:0.0275
iter:15:40, sigle error:0.3584, finall error:0.0288
iter:16:40, sigle error:0.3492, finall error:0.0257
iter:17:40, sigle error:0.3506, finall error:0.0256
iter:18:40, sigle error:0.3665, finall error:0.0240
iter:19:40, sigle error:0.3769, finall error:0.0251
iter:20:40, sigle error:0.3828, finall error:0.0213
iter:21:40, sigle error:0.3733, finall error:0.0229
iter:22:40, sigle error:0.3785, finall error:0.0218
iter:23:40, sigle error:0.3867, finall error:0.0219
iter:24:40, sigle error:0.3850, finall error:0.0208
iter:25:40, sigle error:0.3823, finall error:0.0201
iter:26:40, sigle error:0.3825, finall error:0.0204
iter:27:40, sigle error:0.3874, finall error:0.0188
iter:28:40, sigle error:0.3952, finall error:0.0186
iter:29:40, sigle error:0.4018, finall error:0.0193
iter:30:40, sigle error:0.3889, finall error:0.0177
iter:31:40, sigle error:0.3939, finall error:0.0183
iter:32:40, sigle error:0.3838, finall error:0.0182
iter:33:40, sigle error:0.4021, finall error:0.0171
iter:34:40, sigle error:0.4119, finall error:0.0164
iter:35:40, sigle error:0.4093, finall error:0.0164
iter:36:40, sigle error:0.4135, finall error:0.0167
iter:37:40, sigle error:0.4099, finall error:0.0171
iter:38:40, sigle error:0.3871, finall error:0.0163
iter:39:40, sigle error:0.4085, finall error:0.0154
start to test
the accuracy is:97 %
time span: 3777.730945825577
以上就是python實現(xiàn)AdaBoost算法的示例的詳細內(nèi)容,更多關于python實現(xiàn)AdaBoost算法的資料請關注腳本之家其它相關文章!
相關文章
Python Pygame實戰(zhàn)之超級炸彈人游戲的實現(xiàn)
如今的玩家們在無聊的時候會玩些什么游戲呢?王者還是吃雞是最多的選擇。但在80、90年代的時候多是一些很簡單的游戲:《超級瑪麗》、《魂斗羅》等。本文將利用Pygame制作另一個經(jīng)典游戲—炸彈人,感興趣的可以了解一下2022-03-03
pip 20.3 新版本發(fā)布!即將拋棄 Python 2.x(推薦)
這篇文章主要介紹了pip 20.3 新版本發(fā)布!即將拋棄 Python 2.x,本文給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下2020-12-12
python3對接mysql數(shù)據(jù)庫實例詳解
這篇文章主要介紹了python3對接mysql數(shù)據(jù)庫,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2019-04-04
tensorflow 固定部分參數(shù)訓練,只訓練部分參數(shù)的實例
今天小編就為大家分享一篇tensorflow 固定部分參數(shù)訓練,只訓練部分參數(shù)的實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-01-01

