python 決策樹(shù)算法的實(shí)現(xiàn)
''' 數(shù)據(jù)集:Mnist 訓(xùn)練集數(shù)量:60000 測(cè)試集數(shù)量:10000 ------------------------------ 運(yùn)行結(jié)果:ID3(未剪枝) 正確率:85.9% 運(yùn)行時(shí)長(zhǎng):356s ''' import time import numpy as np def loadData(fileName): ''' 加載文件 :param fileName:要加載的文件路徑 :return: 數(shù)據(jù)集和標(biāo)簽集 ''' # 存放數(shù)據(jù)及標(biāo)記 dataArr = []; labelArr = [] # 讀取文件 fr = open(fileName) # 遍歷文件中的每一行 for line in fr.readlines(): # 獲取當(dāng)前行,并按“,”切割成字段放入列表中 # strip:去掉每行字符串首尾指定的字符(默認(rèn)空格或換行符) # split:按照指定的字符將字符串切割成每個(gè)字段,返回列表形式 curLine = line.strip().split(',') # 將每行中除標(biāo)記外的數(shù)據(jù)放入數(shù)據(jù)集中(curLine[0]為標(biāo)記信息) # 在放入的同時(shí)將原先字符串形式的數(shù)據(jù)轉(zhuǎn)換為整型 # 此外將數(shù)據(jù)進(jìn)行了二值化處理,大于128的轉(zhuǎn)換成1,小于的轉(zhuǎn)換成0,方便后續(xù)計(jì)算 dataArr.append([int(int(num) > 128) for num in curLine[1:]]) # 將標(biāo)記信息放入標(biāo)記集中 # 放入的同時(shí)將標(biāo)記轉(zhuǎn)換為整型 labelArr.append(int(curLine[0])) # 返回?cái)?shù)據(jù)集和標(biāo)記 return dataArr, labelArr def majorClass(labelArr): ''' 找到當(dāng)前標(biāo)簽集中占數(shù)目最大的標(biāo)簽 :param labelArr: 標(biāo)簽集 :return: 最大的標(biāo)簽 ''' # 建立字典,用于不同類(lèi)別的標(biāo)簽技術(shù) classDict = {} # 遍歷所有標(biāo)簽 for i in range(len(labelArr)): # 當(dāng)?shù)谝淮斡龅紸標(biāo)簽時(shí),字典內(nèi)還沒(méi)有A標(biāo)簽,這時(shí)候直接幅值加1是錯(cuò)誤的, # 所以需要判斷字典中是否有該鍵,沒(méi)有則創(chuàng)建,有就直接自增 if labelArr[i] in classDict.keys(): # 若在字典中存在該標(biāo)簽,則直接加1 classDict[labelArr[i]] += 1 else: # 若無(wú)該標(biāo)簽,設(shè)初值為1,表示出現(xiàn)了1次了 classDict[labelArr[i]] = 1 # 對(duì)字典依據(jù)值進(jìn)行降序排序 classSort = sorted(classDict.items(), key=lambda x: x[1], reverse=True) # 返回最大一項(xiàng)的標(biāo)簽,即占數(shù)目最多的標(biāo)簽 return classSort[0][0] def calc_H_D(trainLabelArr): ''' 計(jì)算數(shù)據(jù)集D的經(jīng)驗(yàn)熵,參考公式5.7 經(jīng)驗(yàn)熵的計(jì)算 :param trainLabelArr:當(dāng)前數(shù)據(jù)集的標(biāo)簽集 :return: 經(jīng)驗(yàn)熵 ''' # 初始化為0 H_D = 0 # 將當(dāng)前所有標(biāo)簽放入集合中,這樣只要有的標(biāo)簽都會(huì)在集合中出現(xiàn),且出現(xiàn)一次。 # 遍歷該集合就可以遍歷所有出現(xiàn)過(guò)的標(biāo)記并計(jì)算其Ck # 這么做有一個(gè)很重要的原因:首先假設(shè)一個(gè)背景,當(dāng)前標(biāo)簽集中有一些標(biāo)記已經(jīng)沒(méi)有了,比如說(shuō)標(biāo)簽集中 # 沒(méi)有0(這是很正常的,說(shuō)明當(dāng)前分支不存在這個(gè)標(biāo)簽)。 式5.7中有一項(xiàng)Ck,那按照式中的針對(duì)不同標(biāo)簽k # 計(jì)算Cl和D并求和時(shí),由于沒(méi)有0,那么C0=0,此時(shí)C0/D0=0,log2(C0/D0) = log2(0),事實(shí)上0并不在log的 # 定義區(qū)間內(nèi),出現(xiàn)了問(wèn)題 # 所以使用集合的方式先知道當(dāng)前標(biāo)簽中都出現(xiàn)了那些標(biāo)簽,隨后對(duì)每個(gè)標(biāo)簽進(jìn)行計(jì)算,如果沒(méi)出現(xiàn)的標(biāo)簽?zāi)且豁?xiàng)就 # 不在經(jīng)驗(yàn)熵中出現(xiàn)(未參與,對(duì)經(jīng)驗(yàn)熵?zé)o影響),保證log的計(jì)算能一直有定義 trainLabelSet = set([label for label in trainLabelArr]) # 遍歷每一個(gè)出現(xiàn)過(guò)的標(biāo)簽 for i in trainLabelSet: # 計(jì)算|Ck|/|D| # trainLabelArr == i:當(dāng)前標(biāo)簽集中為該標(biāo)簽的的位置 # 例如a = [1, 0, 0, 1], c = (a == 1): c == [True, false, false, True] # trainLabelArr[trainLabelArr == i]:獲得為指定標(biāo)簽的樣本 # trainLabelArr[trainLabelArr == i].size:獲得為指定標(biāo)簽的樣本的大小,即標(biāo)簽為i的樣本 # 數(shù)量,就是|Ck| # trainLabelArr.size:整個(gè)標(biāo)簽集的數(shù)量(也就是樣本集的數(shù)量),即|D| p = trainLabelArr[trainLabelArr == i].size / trainLabelArr.size # 對(duì)經(jīng)驗(yàn)熵的每一項(xiàng)累加求和 H_D += -1 * p * np.log2(p) # 返回經(jīng)驗(yàn)熵 return H_D def calcH_D_A(trainDataArr_DevFeature, trainLabelArr): ''' 計(jì)算經(jīng)驗(yàn)條件熵 :param trainDataArr_DevFeature:切割后只有feature那列數(shù)據(jù)的數(shù)組 :param trainLabelArr: 標(biāo)簽集數(shù)組 :return: 經(jīng)驗(yàn)條件熵 ''' # 初始為0 H_D_A = 0 # 在featue那列放入集合中,是為了根據(jù)集合中的數(shù)目知道該feature目前可取值數(shù)目是多少 trainDataSet = set([label for label in trainDataArr_DevFeature]) # 對(duì)于每一個(gè)特征取值遍歷計(jì)算條件經(jīng)驗(yàn)熵的每一項(xiàng) for i in trainDataSet: # 計(jì)算H(D|A) # trainDataArr_DevFeature[trainDataArr_DevFeature == i].size / trainDataArr_DevFeature.size:|Di| / |D| # calc_H_D(trainLabelArr[trainDataArr_DevFeature == i]):H(Di) H_D_A += trainDataArr_DevFeature[trainDataArr_DevFeature == i].size / trainDataArr_DevFeature.size \ * calc_H_D(trainLabelArr[trainDataArr_DevFeature == i]) # 返回得出的條件經(jīng)驗(yàn)熵 return H_D_A def calcBestFeature(trainDataList, trainLabelList): ''' 計(jì)算信息增益最大的特征 :param trainDataList: 當(dāng)前數(shù)據(jù)集 :param trainLabelList: 當(dāng)前標(biāo)簽集 :return: 信息增益最大的特征及最大信息增益值 ''' # 將數(shù)據(jù)集和標(biāo)簽集轉(zhuǎn)換為數(shù)組形式 # trainLabelArr轉(zhuǎn)換后需要轉(zhuǎn)置,這樣在取數(shù)時(shí)方便 # 例如a = np.array([1, 2, 3]); b = np.array([1, 2, 3]).T # 若不轉(zhuǎn)置,a[0] = [1, 2, 3],轉(zhuǎn)置后b[0] = 1, b[1] = 2 # 對(duì)于標(biāo)簽集來(lái)說(shuō),能夠很方便地取到每一位是很重要的 trainDataArr = np.array(trainDataList) trainLabelArr = np.array(trainLabelList).T # 獲取當(dāng)前特征數(shù)目,也就是數(shù)據(jù)集的橫軸大小 featureNum = trainDataArr.shape[1] # 初始化最大信息增益 maxG_D_A = -1 # 初始化最大信息增益的特征 maxFeature = -1 # 對(duì)每一個(gè)特征進(jìn)行遍歷計(jì)算 for feature in range(featureNum): # “5.2.2 信息增益”中“算法5.1(信息增益的算法)”第一步: # 1.計(jì)算數(shù)據(jù)集D的經(jīng)驗(yàn)熵H(D) H_D = calc_H_D(trainLabelArr) # 2.計(jì)算條件經(jīng)驗(yàn)熵H(D|A) # 由于條件經(jīng)驗(yàn)熵的計(jì)算過(guò)程中只涉及到標(biāo)簽以及當(dāng)前特征,為了提高運(yùn)算速度(全部樣本 # 做成的矩陣運(yùn)算速度太慢,需要剔除不需要的部分),將數(shù)據(jù)集矩陣進(jìn)行切割 # 數(shù)據(jù)集在初始時(shí)刻是一個(gè)Arr = 60000*784的矩陣,針對(duì)當(dāng)前要計(jì)算的feature,在訓(xùn)練集中切割下 # Arr[:, feature]這么一條來(lái),因?yàn)楹罄m(xù)計(jì)算中數(shù)據(jù)集中只用到這個(gè)(沒(méi)明白的跟著算一遍例5.2) # trainDataArr[:, feature]:在數(shù)據(jù)集中切割下這么一條 # trainDataArr[:, feature].flat:將這么一條轉(zhuǎn)換成豎著的列表 # np.array(trainDataArr[:, feature].flat):再轉(zhuǎn)換成一條豎著的矩陣,大小為60000*1(只是初始是 # 這么大,運(yùn)行過(guò)程中是依據(jù)當(dāng)前數(shù)據(jù)集大小動(dòng)態(tài)變的) trainDataArr_DevideByFeature = np.array(trainDataArr[:, feature].flat) # 3.計(jì)算信息增益G(D|A) G(D|A) = H(D) - H(D | A) G_D_A = H_D - calcH_D_A(trainDataArr_DevideByFeature, trainLabelArr) # 不斷更新最大的信息增益以及對(duì)應(yīng)的feature if G_D_A > maxG_D_A: maxG_D_A = G_D_A maxFeature = feature return maxFeature, maxG_D_A def getSubDataArr(trainDataArr, trainLabelArr, A, a): ''' 更新數(shù)據(jù)集和標(biāo)簽集 :param trainDataArr:要更新的數(shù)據(jù)集 :param trainLabelArr: 要更新的標(biāo)簽集 :param A: 要去除的特征索引 :param a: 當(dāng)data[A]== a時(shí),說(shuō)明該行樣本時(shí)要保留的 :return: 新的數(shù)據(jù)集和標(biāo)簽集 ''' # 返回的數(shù)據(jù)集 retDataArr = [] # 返回的標(biāo)簽集 retLabelArr = [] # 對(duì)當(dāng)前數(shù)據(jù)的每一個(gè)樣本進(jìn)行遍歷 for i in range(len(trainDataArr)): # 如果當(dāng)前樣本的特征為指定特征值a if trainDataArr[i][A] == a: # 那么將該樣本的第A個(gè)特征切割掉,放入返回的數(shù)據(jù)集中 retDataArr.append(trainDataArr[i][0:A] + trainDataArr[i][A + 1:]) # 將該樣本的標(biāo)簽放入返回標(biāo)簽集中 retLabelArr.append(trainLabelArr[i]) # 返回新的數(shù)據(jù)集和標(biāo)簽集 return retDataArr, retLabelArr def createTree(*dataSet): ''' 遞歸創(chuàng)建決策樹(shù) :param dataSet:(trainDataList, trainLabelList) <<-- 元祖形式 :return:新的子節(jié)點(diǎn)或該葉子節(jié)點(diǎn)的值 ''' # 設(shè)置Epsilon,“5.3.1 ID3算法”第4步提到需要將信息增益與閾值Epsilon比較,若小于則直接處理后返回T Epsilon = 0.1 # 從參數(shù)中獲取trainDataList和trainLabelList trainDataList = dataSet[0][0] trainLabelList = dataSet[0][1] # 打印信息:開(kāi)始一個(gè)子節(jié)點(diǎn)創(chuàng)建,打印當(dāng)前特征向量數(shù)目及當(dāng)前剩余樣本數(shù)目 print('start a node', len(trainDataList[0]), len(trainLabelList)) # 將標(biāo)簽放入一個(gè)字典中,當(dāng)前樣本有多少類(lèi),在字典中就會(huì)有多少項(xiàng) # 也相當(dāng)于去重,多次出現(xiàn)的標(biāo)簽就留一次。舉個(gè)例子,假如處理結(jié)束后字典的長(zhǎng)度為1,那說(shuō)明所有的樣本 # 都是同一個(gè)標(biāo)簽,那就可以直接返回該標(biāo)簽了,不需要再生成子節(jié)點(diǎn)了。 classDict = {i for i in trainLabelList} # 如果D中所有實(shí)例屬于同一類(lèi)Ck,則置T為單節(jié)點(diǎn)數(shù),并將Ck作為該節(jié)點(diǎn)的類(lèi),返回T # 即若所有樣本的標(biāo)簽一致,也就不需要再分化,返回標(biāo)記作為該節(jié)點(diǎn)的值,返回后這就是一個(gè)葉子節(jié)點(diǎn) if len(classDict) == 1: # 因?yàn)樗袠颖径际且恢碌模跇?biāo)簽集中隨便拿一個(gè)標(biāo)簽返回都行,這里用的第0個(gè)(因?yàn)槟悴⒉恢? # 當(dāng)前標(biāo)簽集的長(zhǎng)度是多少,但運(yùn)行中所有標(biāo)簽只要有長(zhǎng)度都會(huì)有第0位。 return trainLabelList[0] # 如果A為空集,則置T為單節(jié)點(diǎn)數(shù),并將D中實(shí)例數(shù)最大的類(lèi)Ck作為該節(jié)點(diǎn)的類(lèi),返回T # 即如果已經(jīng)沒(méi)有特征可以用來(lái)再分化了,就返回占大多數(shù)的類(lèi)別 if len(trainDataList[0]) == 0: # 返回當(dāng)前標(biāo)簽集中占數(shù)目最大的標(biāo)簽 return majorClass(trainLabelList) # 否則,按式5.10計(jì)算A中個(gè)特征值的信息增益,選擇信息增益最大的特征Ag Ag, EpsilonGet = calcBestFeature(trainDataList, trainLabelList) # 如果Ag的信息增益比小于閾值Epsilon,則置T為單節(jié)點(diǎn)樹(shù),并將D中實(shí)例數(shù)最大的類(lèi)Ck # 作為該節(jié)點(diǎn)的類(lèi),返回T if EpsilonGet < Epsilon: return majorClass(trainLabelList) # 否則,對(duì)Ag的每一可能值ai,依Ag=ai將D分割為若干非空子集Di,將Di中實(shí)例數(shù)最大的 # 類(lèi)作為標(biāo)記,構(gòu)建子節(jié)點(diǎn),由節(jié)點(diǎn)及其子節(jié)點(diǎn)構(gòu)成樹(shù)T,返回T treeDict = {Ag: {}} # 特征值為0時(shí),進(jìn)入0分支 # getSubDataArr(trainDataList, trainLabelList, Ag, 0):在當(dāng)前數(shù)據(jù)集中切割當(dāng)前feature,返回新的數(shù)據(jù)集和標(biāo)簽集 treeDict[Ag][0] = createTree(getSubDataArr(trainDataList, trainLabelList, Ag, 0)) treeDict[Ag][1] = createTree(getSubDataArr(trainDataList, trainLabelList, Ag, 1)) return treeDict def predict(testDataList, tree): ''' 預(yù)測(cè)標(biāo)簽 :param testDataList:樣本 :param tree: 決策樹(shù) :return: 預(yù)測(cè)結(jié)果 ''' # treeDict = copy.deepcopy(tree) # 死循環(huán),直到找到一個(gè)有效地分類(lèi) while True: # 因?yàn)橛袝r(shí)候當(dāng)前字典只有一個(gè)節(jié)點(diǎn) # 例如{73: {0: {74:6}}}看起來(lái)節(jié)點(diǎn)很多,但是對(duì)于字典的最頂層來(lái)說(shuō),只有73一個(gè)key,其余都是value # 若還是采用for來(lái)讀取的話不太合適,所以使用下行這種方式讀取key和value (key, value), = tree.items() # 如果當(dāng)前的value是字典,說(shuō)明還需要遍歷下去 if type(tree[key]).__name__ == 'dict': # 獲取目前所在節(jié)點(diǎn)的feature值,需要在樣本中刪除該feature # 因?yàn)樵趧?chuàng)建樹(shù)的過(guò)程中,feature的索引值永遠(yuǎn)是對(duì)于當(dāng)時(shí)剩余的feature來(lái)設(shè)置的 # 所以需要不斷地刪除已經(jīng)用掉的特征,保證索引相對(duì)位置的一致性 dataVal = testDataList[key] del testDataList[key] # 將tree更新為其子節(jié)點(diǎn)的字典 tree = value[dataVal] # 如果當(dāng)前節(jié)點(diǎn)的子節(jié)點(diǎn)的值是int,就直接返回該int值 # 例如{403: {0: 7, 1: {297:7}},dataVal=0 # 此時(shí)上一行tree = value[dataVal],將tree定位到了7,而7不再是一個(gè)字典了, # 這里就可以直接返回7了,如果tree = value[1],那就是一個(gè)新的子節(jié)點(diǎn),需要繼續(xù)遍歷下去 if type(tree).__name__ == 'int': # 返回該節(jié)點(diǎn)值,也就是分類(lèi)值 return tree else: # 如果當(dāng)前value不是字典,那就返回分類(lèi)值 return value def accuracy(testDataList, testLabelList, tree): ''' 測(cè)試準(zhǔn)確率 :param testDataList:待測(cè)試數(shù)據(jù)集 :param testLabelList: 待測(cè)試標(biāo)簽集 :param tree: 訓(xùn)練集生成的樹(shù) :return: 準(zhǔn)確率 ''' # 錯(cuò)誤次數(shù)計(jì)數(shù) errorCnt = 0 # 遍歷測(cè)試集中每一個(gè)測(cè)試樣本 for i in range(len(testDataList)): # 判斷預(yù)測(cè)與標(biāo)簽中結(jié)果是否一致 if testLabelList[i] != predict(testDataList[i], tree): errorCnt += 1 # 返回準(zhǔn)確率 return 1 - errorCnt / len(testDataList) if __name__ == '__main__': # 開(kāi)始時(shí)間 start = time.time() # 獲取訓(xùn)練集 trainDataList, trainLabelList = loadData('../Mnist/mnist_train.csv') # 獲取測(cè)試集 testDataList, testLabelList = loadData('../Mnist/mnist_test.csv') # 創(chuàng)建決策樹(shù) print('start create tree') tree = createTree((trainDataList, trainLabelList)) print('tree is:', tree) # 測(cè)試準(zhǔn)確率 print('start test') accur = accuracy(testDataList, testLabelList, tree) print('the accur is:', accur) # 結(jié)束時(shí)間 end = time.time() print('time span:', end - start)
以上就是python 決策樹(shù)算法的實(shí)現(xiàn)的詳細(xì)內(nèi)容,更多關(guān)于python 決策樹(shù)算法的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
- python實(shí)現(xiàn)LRU熱點(diǎn)緩存及原理
- LRUCache的實(shí)現(xiàn)原理及利用python實(shí)現(xiàn)的方法
- Python實(shí)現(xiàn)LRU算法的2種方法
- Python實(shí)現(xiàn)的一個(gè)簡(jiǎn)單LRU cache
- python 實(shí)現(xiàn)非極大值抑制算法(Non-maximum suppression, NMS)
- python實(shí)現(xiàn)粒子群算法
- Python實(shí)現(xiàn)七個(gè)基本算法的實(shí)例代碼
- Python實(shí)現(xiàn)EM算法實(shí)例代碼
- 工程師必須了解的LRU緩存淘汰算法以及python實(shí)現(xiàn)過(guò)程
相關(guān)文章
opencv+pyQt5實(shí)現(xiàn)圖片閾值編輯器/尋色塊閾值利器
這篇文章主要介紹了opencv+pyQt5實(shí)現(xiàn)圖片閾值編輯器/尋色塊閾值利器,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-11-11Python 語(yǔ)法錯(cuò)誤:"SyntaxError: invalid charac
本文給大家分享Python 語(yǔ)法錯(cuò)誤:“SyntaxError: invalid character in identifier“,原因及解決方法,文末給大家補(bǔ)充介紹了Python出現(xiàn)SyntaxError: invalid syntax的原因總結(jié),感興趣的朋友跟隨小編一起學(xué)習(xí)吧2023-02-02基于python全局設(shè)置id 自動(dòng)化測(cè)試元素定位過(guò)程解析
這篇文章主要介紹了基于python全局設(shè)置id 自動(dòng)化測(cè)試元素定位過(guò)程解析,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2019-09-09使用Python實(shí)現(xiàn)企業(yè)微信通知功能案例分析
這篇文章主要介紹了使用Python實(shí)現(xiàn)企業(yè)微信通知功能,主要目的是通過(guò)企業(yè)微信應(yīng)用給企業(yè)成員發(fā)消息,通過(guò)案例分析給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2022-04-04requests.post()方法中data和json參數(shù)的使用
這篇文章主要介紹了requests.post()方法中data和json參數(shù)的使用方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2022-02-02python讀寫(xiě)文件write和flush的實(shí)現(xiàn)方式
今天小編就為大家分享一篇python讀寫(xiě)文件write和flush的實(shí)現(xiàn)方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-02-02python 數(shù)據(jù)清洗之?dāng)?shù)據(jù)合并、轉(zhuǎn)換、過(guò)濾、排序
這篇文章主要介紹了python 數(shù)據(jù)清洗之?dāng)?shù)據(jù)合并、轉(zhuǎn)換、過(guò)濾、排序的相關(guān)資料,需要的朋友可以參考下2017-02-02pytorch快速搭建神經(jīng)網(wǎng)絡(luò)_Sequential操作
這篇文章主要介紹了pytorch快速搭建神經(jīng)網(wǎng)絡(luò)_Sequential操作,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-06-06