決策樹(shù)剪枝算法的python實(shí)現(xiàn)方法詳解
本文實(shí)例講述了決策樹(shù)剪枝算法的python實(shí)現(xiàn)方法。分享給大家供大家參考,具體如下:
決策樹(shù)是一種依托決策而建立起來(lái)的一種樹(shù)。在機(jī)器學(xué)習(xí)中,決策樹(shù)是一種預(yù)測(cè)模型,代表的是一種對(duì)象屬性與對(duì)象值之間的一種映射關(guān)系,每一個(gè)節(jié)點(diǎn)代表某個(gè)對(duì)象,樹(shù)中的每一個(gè)分叉路徑代表某個(gè)可能的屬性值,而每一個(gè)葉子節(jié)點(diǎn)則對(duì)應(yīng)從根節(jié)點(diǎn)到該葉子節(jié)點(diǎn)所經(jīng)歷的路徑所表示的對(duì)象的值。決策樹(shù)僅有單一輸出,如果有多個(gè)輸出,可以分別建立獨(dú)立的決策樹(shù)以處理不同的輸出。
ID3算法:ID3算法是決策樹(shù)的一種,是基于奧卡姆剃刀原理的,即用盡量用較少的東西做更多的事。ID3算法,即Iterative Dichotomiser 3,迭代二叉樹(shù)3代,是Ross Quinlan發(fā)明的一種決策樹(shù)算法,這個(gè)算法的基礎(chǔ)就是上面提到的奧卡姆剃刀原理,越是小型的決策樹(shù)越優(yōu)于大的決策樹(shù),盡管如此,也不總是生成最小的樹(shù)型結(jié)構(gòu),而是一個(gè)啟發(fā)式算法。在信息論中,期望信息越小,那么信息增益就越大,從而純度就越高。ID3算法的核心思想就是以信息增益來(lái)度量屬性的選擇,選擇分裂后信息增益最大的屬性進(jìn)行分裂。該算法采用自頂向下的貪婪搜索遍歷可能的決策空間。
信息熵,將其定義為離散隨機(jī)事件出現(xiàn)的概率,一個(gè)系統(tǒng)越是有序,信息熵就越低,反之一個(gè)系統(tǒng)越是混亂,它的信息熵就越高。所以信息熵可以被認(rèn)為是系統(tǒng)有序化程度的一個(gè)度量。
基尼指數(shù):在CART里面劃分決策樹(shù)的條件是采用Gini Index,定義如下:gini(T)=1−sumnj=1p2j。其中,( p_j )是類(lèi)j在T中的相對(duì)頻率,當(dāng)類(lèi)在T中是傾斜的時(shí),gini(T)會(huì)最小。將T劃分為T(mén)1(實(shí)例數(shù)為N1)和T2(實(shí)例數(shù)為N2)兩個(gè)子集后,劃分?jǐn)?shù)據(jù)的Gini定義如下:ginisplit(T)=fracN1Ngini(T1)+fracN2Ngini(T2),然后選擇其中最小的(gini_{split}(T) )作為結(jié)點(diǎn)劃分決策樹(shù)
具體實(shí)現(xiàn)
首先用函數(shù)calcShanno計(jì)算數(shù)據(jù)集的香農(nóng)熵,給所有可能的分類(lèi)創(chuàng)建字典
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
# 給所有可能分類(lèi)創(chuàng)建字典
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
# 以2為底數(shù)計(jì)算香農(nóng)熵
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries
shannonEnt -= prob * log(prob, 2)
return shannonEnt
# 對(duì)離散變量劃分?jǐn)?shù)據(jù)集,取出該特征取值為value的所有樣本
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis + 1:])
retDataSet.append(reducedFeatVec)
return retDataSet
對(duì)連續(xù)變量劃分?jǐn)?shù)據(jù)集,direction規(guī)定劃分的方向, 決定是劃分出小于value的數(shù)據(jù)樣本還是大于value的數(shù)據(jù)樣本集
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeature = -1
bestSplitDict = {}
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
# 對(duì)連續(xù)型特征進(jìn)行處理
if type(featList[0]).__name__ == 'float' or type(featList[0]).__name__ == 'int':
# 產(chǎn)生n-1個(gè)候選劃分點(diǎn)
sortfeatList = sorted(featList)
splitList = []
for j in range(len(sortfeatList) - 1):
splitList.append((sortfeatList[j] + sortfeatList[j + 1]) / 2.0)
bestSplitEntropy = 10000
slen = len(splitList)
# 求用第j個(gè)候選劃分點(diǎn)劃分時(shí),得到的信息熵,并記錄最佳劃分點(diǎn)
for j in range(slen):
value = splitList[j]
newEntropy = 0.0
subDataSet0 = splitContinuousDataSet(dataSet, i, value, 0)
subDataSet1 = splitContinuousDataSet(dataSet, i, value, 1)
prob0 = len(subDataSet0) / float(len(dataSet))
newEntropy += prob0 * calcShannonEnt(subDataSet0)
prob1 = len(subDataSet1) / float(len(dataSet))
newEntropy += prob1 * calcShannonEnt(subDataSet1)
if newEntropy < bestSplitEntropy:
bestSplitEntropy = newEntropy
bestSplit = j
# 用字典記錄當(dāng)前特征的最佳劃分點(diǎn)
bestSplitDict[labels[i]] = splitList[bestSplit]
infoGain = baseEntropy - bestSplitEntropy
# 對(duì)離散型特征進(jìn)行處理
else:
uniqueVals = set(featList)
newEntropy = 0.0
# 計(jì)算該特征下每種劃分的信息熵
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if infoGain > bestInfoGain:
bestInfoGain = infoGain
bestFeature = i
# 若當(dāng)前節(jié)點(diǎn)的最佳劃分特征為連續(xù)特征,則將其以之前記錄的劃分點(diǎn)為界進(jìn)行二值化處理
# 即是否小于等于bestSplitValue
if type(dataSet[0][bestFeature]).__name__ == 'float' or type(dataSet[0][bestFeature]).__name__ == 'int':
bestSplitValue = bestSplitDict[labels[bestFeature]]
labels[bestFeature] = labels[bestFeature] + '<=' + str(bestSplitValue)
for i in range(shape(dataSet)[0]):
if dataSet[i][bestFeature] <= bestSplitValue:
dataSet[i][bestFeature] = 1
else:
dataSet[i][bestFeature] = 0
return bestFeature
def chooseBestFeatureToSplit(dataSet, labels):
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeature = -1
bestSplitDict = {}
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
# 對(duì)連續(xù)型特征進(jìn)行處理
if type(featList[0]).__name__ == 'float' or type(featList[0]).__name__ == 'int':
# 產(chǎn)生n-1個(gè)候選劃分點(diǎn)
sortfeatList = sorted(featList)
splitList = []
for j in range(len(sortfeatList) - 1):
splitList.append((sortfeatList[j] + sortfeatList[j + 1]) / 2.0)
bestSplitEntropy = 10000
slen = len(splitList)
# 求用第j個(gè)候選劃分點(diǎn)劃分時(shí),得到的信息熵,并記錄最佳劃分點(diǎn)
for j in range(slen):
value = splitList[j]
newEntropy = 0.0
subDataSet0 = splitContinuousDataSet(dataSet, i, value, 0)
subDataSet1 = splitContinuousDataSet(dataSet, i, value, 1)
prob0 = len(subDataSet0) / float(len(dataSet))
newEntropy += prob0 * calcShannonEnt(subDataSet0)
prob1 = len(subDataSet1) / float(len(dataSet))
newEntropy += prob1 * calcShannonEnt(subDataSet1)
if newEntropy < bestSplitEntropy:
bestSplitEntropy = newEntropy
bestSplit = j
# 用字典記錄當(dāng)前特征的最佳劃分點(diǎn)
bestSplitDict[labels[i]] = splitList[bestSplit]
infoGain = baseEntropy - bestSplitEntropy
# 對(duì)離散型特征進(jìn)行處理
else:
uniqueVals = set(featList)
newEntropy = 0.0
# 計(jì)算該特征下每種劃分的信息熵
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if infoGain > bestInfoGain:
bestInfoGain = infoGain
bestFeature = i
# 若當(dāng)前節(jié)點(diǎn)的最佳劃分特征為連續(xù)特征,則將其以之前記錄的劃分點(diǎn)為界進(jìn)行二值化處理
# 即是否小于等于bestSplitValue
if type(dataSet[0][bestFeature]).__name__ == 'float' or type(dataSet[0][bestFeature]).__name__ == 'int':
bestSplitValue = bestSplitDict[labels[bestFeature]]
labels[bestFeature] = labels[bestFeature] + '<=' + str(bestSplitValue)
for i in range(shape(dataSet)[0]):
if dataSet[i][bestFeature] <= bestSplitValue:
dataSet[i][bestFeature] = 1
else:
dataSet[i][bestFeature] = 0
return bestFeature
``def classify(inputTree, featLabels, testVec):
firstStr = inputTree.keys()[0]
if u'<=' in firstStr:
featvalue = float(firstStr.split(u"<=")[1])
featkey = firstStr.split(u"<=")[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(featkey)
if testVec[featIndex] <= featvalue:
judge = 1
else:
judge = 0
for key in secondDict.keys():
if judge == int(key):
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else:
classLabel = secondDict[key]
else:
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else:
classLabel = secondDict[key]
return classLabel
def majorityCnt(classList):
classCount={}
for vote in classList:
if vote not in classCount.keys():
classCount[vote]=0
classCount[vote]+=1
return max(classCount)
def testing_feat(feat, train_data, test_data, labels):
class_list = [example[-1] for example in train_data]
bestFeatIndex = labels.index(feat)
train_data = [example[bestFeatIndex] for example in train_data]
test_data = [(example[bestFeatIndex], example[-1]) for example in test_data]
all_feat = set(train_data)
error = 0.0
for value in all_feat:
class_feat = [class_list[i] for i in range(len(class_list)) if train_data[i] == value]
major = majorityCnt(class_feat)
for data in test_data:
if data[0] == value and data[1] != major:
error += 1.0
# print 'myTree %d' % error
return error
測(cè)試
error = 0.0
for i in range(len(data_test)):
if classify(myTree, labels, data_test[i]) != data_test[i][-1]:
error += 1
# print 'myTree %d' % error
return float(error)
def testingMajor(major, data_test):
error = 0.0
for i in range(len(data_test)):
if major != data_test[i][-1]:
error += 1
# print 'major %d' % error
return float(error)
**遞歸產(chǎn)生決策樹(shù)**
```def createTree(dataSet,labels,data_full,labels_full,test_data,mode):
classList=[example[-1] for example in dataSet]
if classList.count(classList[0])==len(classList):
return classList[0]
if len(dataSet[0])==1:
return majorityCnt(classList)
labels_copy = copy.deepcopy(labels)
bestFeat=chooseBestFeatureToSplit(dataSet,labels)
bestFeatLabel=labels[bestFeat]
if mode == "unpro" or mode == "post":
myTree = {bestFeatLabel: {}}
elif mode == "prev":
if testing_feat(bestFeatLabel, dataSet, test_data, labels_copy) < testingMajor(majorityCnt(classList),
test_data):
myTree = {bestFeatLabel: {}}
else:
return majorityCnt(classList)
featValues=[example[bestFeat] for example in dataSet]
uniqueVals=set(featValues)
if type(dataSet[0][bestFeat]).__name__ == 'unicode':
currentlabel = labels_full.index(labels[bestFeat])
featValuesFull = [example[currentlabel] for example in data_full]
uniqueValsFull = set(featValuesFull)
del (labels[bestFeat])
for value in uniqueVals:
subLabels = labels[:]
if type(dataSet[0][bestFeat]).__name__ == 'unicode':
uniqueValsFull.remove(value)
myTree[bestFeatLabel][value] = createTree(splitDataSet \
(dataSet, bestFeat, value), subLabels, data_full, labels_full,
splitDataSet \
(test_data, bestFeat, value), mode=mode)
if type(dataSet[0][bestFeat]).__name__ == 'unicode':
for value in uniqueValsFull:
myTree[bestFeatLabel][value] = majorityCnt(classList)
if mode == "post":
if testing(myTree, test_data, labels_copy) > testingMajor(majorityCnt(classList), test_data):
return majorityCnt(classList)
return myTree
<div class="se-preview-section-delimiter"></div>
```**讀入數(shù)據(jù)**
```def load_data(file_name):
with open(r"dd.csv", 'rb') as f:
df = pd.read_csv(f,sep=",")
print(df)
train_data = df.values[:11, 1:].tolist()
print(train_data)
test_data = df.values[11:, 1:].tolist()
labels = df.columns.values[1:-1].tolist()
return train_data, test_data, labels
<div class="se-preview-section-delimiter"></div>
```測(cè)試并繪制樹(shù)圖
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="round4", color='red') # 定義判斷結(jié)點(diǎn)形態(tài)
leafNode = dict(boxstyle="circle", color='grey') # 定義葉結(jié)點(diǎn)形態(tài)
arrow_args = dict(arrowstyle="<-", color='blue') # 定義箭頭
# 計(jì)算樹(shù)的葉子節(jié)點(diǎn)數(shù)量
def getNumLeafs(myTree):
numLeafs = 0
firstSides = list(myTree.keys())
firstStr = firstSides[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
# 計(jì)算樹(shù)的最大深度
def getTreeDepth(myTree):
maxDepth = 0
firstSides = list(myTree.keys())
firstStr = firstSides[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
# 畫(huà)節(jié)點(diǎn)
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', \
xytext=centerPt, textcoords='axes fraction', va="center", ha="center", \
bbox=nodeType, arrowprops=arrow_args)
# 畫(huà)箭頭上的文字
def plotMidText(cntrPt, parentPt, txtString):
lens = len(txtString)
xMid = (parentPt[0] + cntrPt[0]) / 2.0 - lens * 0.002
yMid = (parentPt[1] + cntrPt[1]) / 2.0
createPlot.ax1.text(xMid, yMid, txtString)
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstSides = list(myTree.keys())
firstStr = firstSides[0]
cntrPt = (plotTree.x0ff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.y0ff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.y0ff = plotTree.y0ff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.x0ff = plotTree.x0ff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.x0ff, plotTree.y0ff), cntrPt, leafNode)
plotMidText((plotTree.x0ff, plotTree.y0ff), cntrPt, str(key))
plotTree.y0ff = plotTree.y0ff + 1.0 / plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.x0ff = -0.5 / plotTree.totalW
plotTree.y0ff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show()
if __name__ == "__main__":
train_data, test_data, labels = load_data("dd.csv")
data_full = train_data[:]
labels_full = labels[:]
mode="post"
mode = "prev"
mode="post"
myTree = createTree(train_data, labels, data_full, labels_full, test_data, mode=mode)
createPlot(myTree)
print(json.dumps(myTree, ensure_ascii=False, indent=4))
選擇mode就可以分別得到三種樹(shù)圖
if __name__ == "__main__":
train_data, test_data, labels = load_data("dd.csv")
data_full = train_data[:]
labels_full = labels[:]
mode="post"
mode = "prev"
mode="post"
myTree = createTree(train_data, labels, data_full, labels_full, test_data, mode=mode)
createPlot(myTree)
print(json.dumps(myTree, ensure_ascii=False, indent=4))
選擇mode就可以分別得到三種樹(shù)圖



更多關(guān)于Python相關(guān)內(nèi)容感興趣的讀者可查看本站專(zhuān)題:《Python數(shù)據(jù)結(jié)構(gòu)與算法教程》、《Python加密解密算法與技巧總結(jié)》、《Python編碼操作技巧總結(jié)》、《Python函數(shù)使用技巧總結(jié)》、《Python字符串操作技巧匯總》及《Python入門(mén)與進(jìn)階經(jīng)典教程》
希望本文所述對(duì)大家Python程序設(shè)計(jì)有所幫助。
相關(guān)文章
基于python requests庫(kù)中的代理實(shí)例講解
今天小編就為大家分享一篇基于python requests庫(kù)中的代理實(shí)例講解,具有很好的參考價(jià)值。希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-05-05
Python實(shí)現(xiàn)比較撲克牌大小程序代碼示例
這篇文章主要介紹了Python實(shí)現(xiàn)比較撲克牌大小程序代碼示例,具有一定借鑒價(jià)值,需要的朋友可以了解下。2017-12-12
Python利用arcpy模塊實(shí)現(xiàn)柵格的創(chuàng)建與拼接
這篇文章主要為大家詳細(xì)介紹了如何基于Python語(yǔ)言arcpy模塊,實(shí)現(xiàn)柵格影像圖層建立與多幅遙感影像數(shù)據(jù)批量拼接(Mosaic)的操作,感興趣的可以了解一下2023-02-02
python 經(jīng)緯度求兩點(diǎn)距離、三點(diǎn)面積操作
這篇文章主要介紹了python 經(jīng)緯度求兩點(diǎn)距離、三點(diǎn)面積操作,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2021-06-06
Python數(shù)據(jù)結(jié)構(gòu)樹(shù)與算法分析
這篇文章主要介紹了Python數(shù)據(jù)結(jié)構(gòu)樹(shù)與算法分析,文章圍繞主題展開(kāi)詳細(xì)的內(nèi)容介紹,具有一定的參考價(jià)值,需要的小伙伴可以參考一下2022-07-07
PyQt5 實(shí)現(xiàn)百度圖片下載器GUI界面
本文主要介紹了通過(guò) Pyqt5 實(shí)現(xiàn)一個(gè)界面化的下載器,在通過(guò)網(wǎng)絡(luò)請(qǐng)求實(shí)現(xiàn)各種類(lèi)型的圖片的下載。文中的示例代碼講解詳細(xì),感興趣的小伙伴可以了解一下2021-12-12

