python實現(xiàn)決策樹分類
上一篇博客主要介紹了決策樹的原理,這篇主要介紹他的實現(xiàn),代碼環(huán)境python 3.4,實現(xiàn)的是ID3算法,首先為了后面matplotlib的繪圖方便,我把原來的中文數(shù)據(jù)集變成了英文。
原始數(shù)據(jù)集:
變化后的數(shù)據(jù)集在程序代碼中體現(xiàn),這就不截圖了
構(gòu)建決策樹的代碼如下:
#coding :utf-8 ''' 2017.6.25 author :Erin function: "decesion tree" ID3 ''' import numpy as np import pandas as pd from math import log import operator def load_data(): #data=np.array(data) data=[['teenager' ,'high', 'no' ,'same', 'no'], ['teenager', 'high', 'no', 'good', 'no'], ['middle_aged' ,'high', 'no', 'same', 'yes'], ['old_aged', 'middle', 'no' ,'same', 'yes'], ['old_aged', 'low', 'yes', 'same' ,'yes'], ['old_aged', 'low', 'yes', 'good', 'no'], ['middle_aged', 'low' ,'yes' ,'good', 'yes'], ['teenager' ,'middle' ,'no', 'same', 'no'], ['teenager', 'low' ,'yes' ,'same', 'yes'], ['old_aged' ,'middle', 'yes', 'same', 'yes'], ['teenager' ,'middle', 'yes', 'good', 'yes'], ['middle_aged' ,'middle', 'no', 'good', 'yes'], ['middle_aged', 'high', 'yes', 'same', 'yes'], ['old_aged', 'middle', 'no' ,'good' ,'no']] features=['age','input','student','level'] return data,features def cal_entropy(dataSet): ''' 輸入data ,表示帶最后標(biāo)簽列的數(shù)據(jù)集 計算給定數(shù)據(jù)集總的信息熵 {'是': 9, '否': 5} 0.9402859586706309 ''' numEntries = len(dataSet) labelCounts = {} for featVec in dataSet: label = featVec[-1] if label not in labelCounts.keys(): labelCounts[label] = 0 labelCounts[label] += 1 entropy = 0.0 for key in labelCounts.keys(): p_i = float(labelCounts[key]/numEntries) entropy -= p_i * log(p_i,2)#log(x,10)表示以10 為底的對數(shù) return entropy def split_data(data,feature_index,value): ''' 劃分?jǐn)?shù)據(jù)集 feature_index:用于劃分特征的列數(shù),例如“年齡” value:劃分后的屬性值:例如“青少年” ''' data_split=[]#劃分后的數(shù)據(jù)集 for feature in data: if feature[feature_index]==value: reFeature=feature[:feature_index] reFeature.extend(feature[feature_index+1:]) data_split.append(reFeature) return data_split def choose_best_to_split(data): ''' 根據(jù)每個特征的信息增益,選擇最大的劃分?jǐn)?shù)據(jù)集的索引特征 ''' count_feature=len(data[0])-1#特征個數(shù)4 #print(count_feature)#4 entropy=cal_entropy(data)#原數(shù)據(jù)總的信息熵 #print(entropy)#0.9402859586706309 max_info_gain=0.0#信息增益最大 split_fea_index = -1#信息增益最大,對應(yīng)的索引號 for i in range(count_feature): feature_list=[fe_index[i] for fe_index in data]#獲取該列所有特征值 ####################################### ''' print('feature_list') ['青少年', '青少年', '中年', '老年', '老年', '老年', '中年', '青少年', '青少年', '老年', '青少年', '中年', '中年', '老年'] 0.3467680694480959 #對應(yīng)上篇博客中的公式 =(1)*5/14 0.3467680694480959 0.6935361388961918 ''' # print(feature_list) unqval=set(feature_list)#去除重復(fù) Pro_entropy=0.0#特征的熵 for value in unqval:#遍歷改特征下的所有屬性 sub_data=split_data(data,i,value) pro=len(sub_data)/float(len(data)) Pro_entropy+=pro*cal_entropy(sub_data) #print(Pro_entropy) info_gain=entropy-Pro_entropy if(info_gain>max_info_gain): max_info_gain=info_gain split_fea_index=i return split_fea_index ################################################## def most_occur_label(labels): #sorted_label_count[0][0] 次數(shù)最多的類標(biāo)簽 label_count={} for label in labels: if label not in label_count.keys(): label_count[label]=0 else: label_count[label]+=1 sorted_label_count = sorted(label_count.items(),key = operator.itemgetter(1),reverse = True) return sorted_label_count[0][0] def build_decesion_tree(dataSet,featnames): ''' 字典的鍵存放節(jié)點信息,分支及葉子節(jié)點存放值 ''' featname = featnames[:] ################ classlist = [featvec[-1] for featvec in dataSet] #此節(jié)點的分類情況 if classlist.count(classlist[0]) == len(classlist): #全部屬于一類 return classlist[0] if len(dataSet[0]) == 1: #分完了,沒有屬性了 return Vote(classlist) #少數(shù)服從多數(shù) # 選擇一個最優(yōu)特征進行劃分 bestFeat = choose_best_to_split(dataSet) bestFeatname = featname[bestFeat] del(featname[bestFeat]) #防止下標(biāo)不準(zhǔn) DecisionTree = {bestFeatname:{}} # 創(chuàng)建分支,先找出所有屬性值,即分支數(shù) allvalue = [vec[bestFeat] for vec in dataSet] specvalue = sorted(list(set(allvalue))) #使有一定順序 for v in specvalue: copyfeatname = featname[:] DecisionTree[bestFeatname][v] = build_decesion_tree(split_data(dataSet,bestFeat,v),copyfeatname) return DecisionTree
繪制可視化圖的代碼如下:
def getNumLeafs(myTree): '計算決策樹的葉子數(shù)' # 葉子數(shù) numLeafs = 0 # 節(jié)點信息 sides = list(myTree.keys()) firstStr =sides[0] # 分支信息 secondDict = myTree[firstStr] for key in secondDict.keys(): # 遍歷所有分支 # 子樹分支則遞歸計算 if type(secondDict[key]).__name__=='dict': numLeafs += getNumLeafs(secondDict[key]) # 葉子分支則葉子數(shù)+1 else: numLeafs +=1 return numLeafs def getTreeDepth(myTree): '計算決策樹的深度' # 最大深度 maxDepth = 0 # 節(jié)點信息 sides = list(myTree.keys()) firstStr =sides[0] # 分支信息 secondDict = myTree[firstStr] for key in secondDict.keys(): # 遍歷所有分支 # 子樹分支則遞歸計算 if type(secondDict[key]).__name__=='dict': thisDepth = 1 + getTreeDepth(secondDict[key]) # 葉子分支則葉子數(shù)+1 else: thisDepth = 1 # 更新最大深度 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth import matplotlib.pyplot as plt decisionNode = dict(boxstyle="sawtooth", fc="0.8") leafNode = dict(boxstyle="round4", fc="0.8") arrow_args = dict(arrowstyle="<-") # ================================================== # 輸入: # nodeTxt: 終端節(jié)點顯示內(nèi)容 # centerPt: 終端節(jié)點坐標(biāo) # parentPt: 起始節(jié)點坐標(biāo) # nodeType: 終端節(jié)點樣式 # 輸出: # 在圖形界面中顯示輸入?yún)?shù)指定樣式的線段(終端帶節(jié)點) # ================================================== def plotNode(nodeTxt, centerPt, parentPt, nodeType): '畫線(末端帶一個點)' createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args ) # ================================================================= # 輸入: # cntrPt: 終端節(jié)點坐標(biāo) # parentPt: 起始節(jié)點坐標(biāo) # txtString: 待顯示文本內(nèi)容 # 輸出: # 在圖形界面指定位置(cntrPt和parentPt中間)顯示文本內(nèi)容(txtString) # ================================================================= def plotMidText(cntrPt, parentPt, txtString): '在指定位置添加文本' # 中間位置坐標(biāo) xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30) # =================================== # 輸入: # myTree: 決策樹 # parentPt: 根節(jié)點坐標(biāo) # nodeTxt: 根節(jié)點坐標(biāo)信息 # 輸出: # 在圖形界面繪制決策樹 # =================================== def plotTree(myTree, parentPt, nodeTxt): '繪制決策樹' # 當(dāng)前樹的葉子數(shù) numLeafs = getNumLeafs(myTree) # 當(dāng)前樹的節(jié)點信息 sides = list(myTree.keys()) firstStr =sides[0] # 定位第一棵子樹的位置 cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) # 繪制當(dāng)前節(jié)點到子樹節(jié)點(含子樹節(jié)點)的信息 plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) # 獲取子樹信息 secondDict = myTree[firstStr] # 開始繪制子樹,縱坐標(biāo)-1。 plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in secondDict.keys(): # 遍歷所有分支 # 子樹分支則遞歸 if type(secondDict[key]).__name__=='dict': plotTree(secondDict[key],cntrPt,str(key)) # 葉子分支則直接繪制 else: plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) # 子樹繪制完畢,縱坐標(biāo)+1。 plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD # ============================== # 輸入: # myTree: 決策樹 # 輸出: # 在圖形界面顯示決策樹 # ============================== def createPlot(inTree): '顯示決策樹' # 創(chuàng)建新的圖像并清空 - 無橫縱坐標(biāo) fig = plt.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) # 樹的總寬度 高度 plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) # 當(dāng)前繪制節(jié)點的坐標(biāo) plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; # 繪制決策樹 plotTree(inTree, (0.5,1.0), '') plt.show() if __name__ == '__main__': data,features=load_data() split_fea_index=choose_best_to_split(data) newtree=build_decesion_tree(data,features) print(newtree) createPlot(newtree) ''' {'age': {'old_aged': {'level': {'same': 'yes', 'good': 'no'}}, 'teenager': {'student': {'no': 'no', 'yes': 'yes'}}, 'middle_aged': 'yes'}} '''
結(jié)果如下:
怎么用決策樹分類,將會在下一章。
以上就是本文的全部內(nèi)容,希望對大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。
相關(guān)文章
apache部署python程序出現(xiàn)503錯誤的解決方法
這篇文章主要給大家介紹了關(guān)于在apahce部署python程序出現(xiàn)503錯誤的解決方法,文中通過示例代碼介紹的非常詳細(xì),對同樣遇到這個問題的朋友們具有一定的參考學(xué)習(xí)價值,需要的朋友們下面來一起看看吧。2017-07-07解決pytorch下只打印tensor的數(shù)值不打印出device等信息的問題
這篇文章主要介紹了解決pytorch下只打印tensor的數(shù)值不打印出device等信息的問題,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2021-05-05linux系統(tǒng)使用python監(jiān)測網(wǎng)絡(luò)接口獲取網(wǎng)絡(luò)的輸入輸出
這篇文章主要介紹了linux系統(tǒng)使用python監(jiān)測網(wǎng)絡(luò)接口獲取網(wǎng)絡(luò)的輸入輸出信息,大家參考使用吧2014-01-01使用Python實現(xiàn)給企業(yè)微信發(fā)送消息功能
本文將介紹如何使用python3給企業(yè)微信發(fā)送消息,文中有詳細(xì)的圖文解說及代碼示例,對正在學(xué)習(xí)python的小伙伴很有幫助,需要的朋友可以參考下2021-12-12Python中關(guān)于列表的常規(guī)操作范例以及介紹
列表是一種有序的集合,可以隨時添加和刪除其中的元素。在python中使用的頻率非常高,本篇文章對大家的學(xué)習(xí)或工作具有一定的價值,需要的朋友可以參考下2021-09-09python學(xué)習(xí)-List移除某個值remove和統(tǒng)計值次數(shù)count
這篇文章主要介紹了?python學(xué)習(xí)-List移除某個值remove和統(tǒng)計值次數(shù)count,文章基于python的相關(guān)內(nèi)容展開詳細(xì)介紹,需要的小伙伴可以參考一下2022-04-04