python實(shí)現(xiàn)基于信息增益的決策樹歸納
本文實(shí)例為大家分享了基于信息增益的決策樹歸納的Python實(shí)現(xiàn)代碼,供大家參考,具體內(nèi)容如下
# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.mlab as mlab
import matplotlib.pyplot as plt
from copy import copy
#加載訓(xùn)練數(shù)據(jù)
#文件格式:屬性標(biāo)號(hào),是否連續(xù)【yes|no】,屬性說明
attribute_file_dest = 'F:\\bayes_categorize\\attribute.dat'
attribute_file = open(attribute_file_dest)
#文件格式:rec_id,attr1_value,attr2_value,...,attrn_value,class_id
trainning_data_file_dest = 'F:\\bayes_categorize\\trainning_data.dat'
trainning_data_file = open(trainning_data_file_dest)
#文件格式:class_id,class_desc
class_desc_file_dest = 'F:\\bayes_categorize\\class_desc.dat'
class_desc_file = open(class_desc_file_dest)
root_attr_dict = {}
for line in attribute_file :
line = line.strip()
fld_list = line.split(',')
root_attr_dict[int(fld_list[0])] = tuple(fld_list[1:])
class_dict = {}
for line in class_desc_file :
line = line.strip()
fld_list = line.split(',')
class_dict[int(fld_list[0])] = fld_list[1]
trainning_data_dict = {}
class_member_set_dict = {}
for line in trainning_data_file :
line = line.strip()
fld_list = line.split(',')
rec_id = int(fld_list[0])
a1 = int(fld_list[1])
a2 = int(fld_list[2])
a3 = float(fld_list[3])
c_id = int(fld_list[4])
if c_id not in class_member_set_dict :
class_member_set_dict[c_id] = set()
class_member_set_dict[c_id].add(rec_id)
trainning_data_dict[rec_id] = (a1 , a2 , a3 , c_id)
attribute_file.close()
class_desc_file.close()
trainning_data_file.close()
class_possibility_dict = {}
for c_id in class_member_set_dict :
class_possibility_dict[c_id] = (len(class_member_set_dict[c_id]) + 0.0)/len(trainning_data_dict)
#等待分類的數(shù)據(jù)
data_to_classify_file_dest = 'F:\\bayes_categorize\\trainning_data_new.dat'
data_to_classify_file = open(data_to_classify_file_dest)
data_to_classify_dict = {}
for line in data_to_classify_file :
line = line.strip()
fld_list = line.split(',')
rec_id = int(fld_list[0])
a1 = int(fld_list[1])
a2 = int(fld_list[2])
a3 = float(fld_list[3])
c_id = int(fld_list[4])
data_to_classify_dict[rec_id] = (a1 , a2 , a3 , c_id)
data_to_classify_file.close()
'''
決策樹的表達(dá)
結(jié)點(diǎn)的需求:
1、指示出是哪一種分區(qū) 一共3種 一是離散窮舉 二是連續(xù)有分裂點(diǎn) 三是離散有判別集合 零是葉子結(jié)點(diǎn)
2、保存分類所需信息
3、子結(jié)點(diǎn)列表
每個(gè)結(jié)點(diǎn)用Tuple類型表示
元素一是整形,取值123 分別對(duì)應(yīng)兩種分裂類型
元素二是集合類型 對(duì)于1保存所有的離散值 對(duì)于2保存分裂點(diǎn) 對(duì)于3保存判別集合 對(duì)于0保存分類結(jié)果類標(biāo)號(hào)
元素三是dict key對(duì)于1來說是某個(gè)的離散值 對(duì)于23來說只有12兩種 對(duì)于2來說1代表小于等于分裂點(diǎn)
對(duì)于3來說1代表屬于判別集合
'''
#對(duì)于一個(gè)成員列表,計(jì)算其熵
#公式為 Info_D = - sum(pi * log2 (pi)) pi為一個(gè)元素屬于Ci的概率,用|Ci|/|D|計(jì)算 ,對(duì)所有分類求和
def get_entropy( member_list ) :
#成員總數(shù)
mem_cnt = len(member_list)
#首先找出member中所包含的分類
class_dict = {}
for mem_id in member_list :
c_id = trainning_data_dict[mem_id][3]
if c_id not in class_dict :
class_dict[c_id] = set()
class_dict[c_id].add(mem_id)
tmp_sum = 0.0
for c_id in class_dict :
pi = ( len(class_dict[c_id]) + 0.0 ) / mem_cnt
tmp_sum += pi * mlab.log2(pi)
tmp_sum = -tmp_sum
return tmp_sum
def attribute_selection_method( member_list , attribute_dict ) :
#先計(jì)算原始的熵
info_D = get_entropy(member_list)
max_info_Gain = 0.0
attr_get = 0
split_point = 0.0
for attr_id in attribute_dict :
#對(duì)于每一個(gè)屬性計(jì)算劃分后的熵
#信息增益等于原始的熵減去劃分后的熵
info_D_new = 0
#如果是連續(xù)屬性
if attribute_dict[attr_id][0] == 'yes' :
#先得到memberlist中此屬性的取值序列,把序列中每一對(duì)相鄰項(xiàng)的中值作為劃分點(diǎn)計(jì)算熵
#找出其中最小的,作為此連續(xù)屬性的劃分點(diǎn)
value_list = []
for mem_id in member_list :
value_list.append(trainning_data_dict[mem_id][attr_id - 1])
#獲取相鄰元素的中值序列
mid_value_list = []
value_list.sort()
#print value_list
last_value = None
for value in value_list :
if value == last_value :
continue
if last_value is not None :
mid_value_list.append((last_value+value)/2)
last_value = value
#print mid_value_list
#對(duì)于中值序列做循環(huán)
#計(jì)算以此值做為劃分點(diǎn)的熵
#總的熵等于兩個(gè)劃分的熵乘以兩個(gè)劃分的比重
min_info = 1000000000.0
total_mens = len(member_list) + 0.0
for mid_value in mid_value_list :
#小于mid_value的mem
less_list = []
#大于
more_list = []
for tmp_mem_id in member_list :
if trainning_data_dict[tmp_mem_id][attr_id - 1] <= mid_value :
less_list.append(tmp_mem_id)
else :
more_list.append(tmp_mem_id)
sum_info = len(less_list)/total_mens * get_entropy(less_list) \
+ len(more_list)/total_mens * get_entropy(more_list)
if sum_info < min_info :
min_info = sum_info
split_point = mid_value
info_D_new = min_info
#如果是離散屬性
else :
#計(jì)算劃分后的熵
#采用循環(huán)累加的方式
attr_value_member_dict = {} #鍵為attribute value , 值為memberlist
for tmp_mem_id in member_list :
attr_value = trainning_data_dict[tmp_mem_id][attr_id - 1]
if attr_value not in attr_value_member_dict :
attr_value_member_dict[attr_value] = []
attr_value_member_dict[attr_value].append(tmp_mem_id)
#將每個(gè)離散值的熵乘以比重加到這上面
total_mens = len(member_list) + 0.0
sum_info = 0.0
for a_value in attr_value_member_dict :
sum_info += len(attr_value_member_dict[a_value])/total_mens \
* get_entropy(attr_value_member_dict[a_value])
info_D_new = sum_info
info_Gain = info_D - info_D_new
if info_Gain > max_info_Gain :
max_info_Gain = info_Gain
attr_get = attr_id
#如果是離散的
#print 'attr_get ' + str(attr_get)
if attribute_dict[attr_get][0] == 'no' :
return (1 , attr_get , split_point)
else :
return (2 , attr_get , split_point)
#第三類先不考慮
def get_decision_tree(father_node , key , member_list , attr_dict ) :
#最終的結(jié)果是新建一個(gè)結(jié)點(diǎn),并且添加到father_node的sub_node_dict,對(duì)key為鍵
#檢查memberlist 如果都是同類的,則生成一個(gè)葉子結(jié)點(diǎn),set里面保存類標(biāo)號(hào)
class_set = set()
for mem_id in member_list :
class_set.add(trainning_data_dict[mem_id][3])
if len(class_set) == 1 :
father_node[2][key] = (0 , (1 , class_set) , {} )
return
#檢查attribute_list,如果為空,產(chǎn)生葉子結(jié)點(diǎn),類標(biāo)號(hào)為memberlist中多數(shù)元素的類標(biāo)號(hào)
#如果幾個(gè)類的成員等量,則打印提示,并且全部添加到set里面
if not attr_dict :
class_cnt_dict = {}
for mem_id in member_list :
c_id = trainning_data_dict[mem_id][3]
if c_id not in class_cnt_dict :
class_cnt_dict[c_id] = 1
else :
class_cnt_dict[c_id] += 1
class_set = set()
max_cnt = 0
for c_id in class_cnt_dict :
if class_cnt_dict[c_id] > max_cnt :
max_cnt = class_cnt_dict[c_id]
class_set.clear()
class_set.add(c_id)
elif class_cnt_dict[c_id] == max_cnt :
class_set.add(c_id)
if len(class_set) > 1 :
print 'more than one class !'
father_node[2][key] = (0 , (1 , class_set ) , {} )
return
#找出最好的分區(qū)方案 , 暫不考慮第三種劃分方法
#比較所有離散屬性和所有連續(xù)屬性的所有中值點(diǎn)劃分的信息增益
split_criterion = attribute_selection_method(member_list , attr_dict)
#print split_criterion
selected_plan_id = split_criterion[0]
selected_attr_id = split_criterion[1]
#如果采用的是離散屬性做為分區(qū)方案,刪除這個(gè)屬性
new_attr_dict = copy(attr_dict)
if attr_dict[selected_attr_id][0] == 'no' :
del new_attr_dict[selected_attr_id]
#建立一個(gè)結(jié)點(diǎn)new_node,father_node[2][key] = new_node
#然后對(duì)new node的每一個(gè)key , sub_member_list,
#調(diào)用 get_decision_tree(new_node , new_key , sub_member_list , new_attribute_dict)
#實(shí)現(xiàn)遞歸
ele2 = ( selected_attr_id , set() )
#如果是1 , ele2保存所有離散值
if selected_plan_id == 1 :
for mem_id in member_list :
ele2[1].add(trainning_data_dict[mem_id][selected_attr_id - 1])
#如果是2,ele2保存分裂點(diǎn)
elif selected_plan_id == 2 :
ele2[1].add(split_criterion[2])
#如果是3則保存判別集合,先不管
else :
print 'not completed'
pass
new_node = ( selected_plan_id , ele2 , {} )
father_node[2][key] = new_node
#生成KEY,并遞歸調(diào)用
if selected_plan_id == 1 :
#每個(gè)attr_value是一個(gè)key
attr_value_member_dict = {}
for mem_id in member_list :
attr_value = trainning_data_dict[mem_id][selected_attr_id - 1 ]
if attr_value not in attr_value_member_dict :
attr_value_member_dict[attr_value] = []
attr_value_member_dict[attr_value].append(mem_id)
for attr_value in attr_value_member_dict :
get_decision_tree(new_node , attr_value , attr_value_member_dict[attr_value] , new_attr_dict)
pass
elif selected_plan_id == 2 :
#key 只有12 , 小于等于分裂點(diǎn)的是1 , 大于的是2
less_list = []
more_list = []
for mem_id in member_list :
attr_value = trainning_data_dict[mem_id][selected_attr_id - 1 ]
if attr_value <= split_criterion[2] :
less_list.append(mem_id)
else :
more_list.append(mem_id)
#if len(less_list) != 0 :
get_decision_tree(new_node , 1 , less_list , new_attr_dict)
#if len(more_list) != 0 :
get_decision_tree(new_node , 2 , more_list , new_attr_dict)
pass
#如果是3則保存判別集合,先不管
else :
print 'not completed'
pass
def get_class_sub(node , tp ) :
#
attr_id = node[1][0]
plan_id = node[0]
key = 0
if plan_id == 0 :
return node[1][1]
elif plan_id == 1 :
key = tp[attr_id - 1]
elif plan_id == 2 :
split_point = tuple(node[1][1])[0]
attr_value = tp[attr_id - 1]
if attr_value <= split_point :
key = 1
else :
key = 2
else :
print 'error'
return set()
return get_class_sub(node[2][key] , tp )
def get_class(r_node , tp) :
#tp為一組屬性值
if r_node[0] != -1 :
print 'error'
return set()
if 1 in r_node[2] :
return get_class_sub(r_node[2][1] , tp)
else :
print 'error'
return set()
if __name__ == '__main__' :
root_node = ( -1 , set() , {} )
mem_list = trainning_data_dict.keys()
get_decision_tree(root_node , 1 , mem_list , root_attr_dict )
#測(cè)試分類器的準(zhǔn)確率
diff_cnt = 0
for mem_id in data_to_classify_dict :
c_id = get_class(root_node , data_to_classify_dict[mem_id][0:3])
if tuple(c_id)[0] != data_to_classify_dict[mem_id][3] :
print tuple(c_id)[0]
print data_to_classify_dict[mem_id][3]
print 'different'
diff_cnt += 1
print diff_cnt
以上就是本文的全部?jī)?nèi)容,希望對(duì)大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。
相關(guān)文章
解決Python requests 報(bào)錯(cuò)方法集錦
這篇文章主要介紹了解決Python requests 報(bào)錯(cuò)方法集錦的相關(guān)資料,需要的朋友可以參考下2017-03-03
批量將ppt轉(zhuǎn)換為pdf的Python代碼 只要27行!
這篇文章主要為大家詳細(xì)介紹了批量將ppt轉(zhuǎn)換為pdf的Python代碼,只要27行,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-02-02
Python-pip配置國內(nèi)鏡像源快速下載包的方法詳解
pip如果不配置國內(nèi)鏡像源的話,下載包的速度非常慢,畢竟默認(rèn)的源在國外呢,這篇文章主要介紹了Python-pip配置國內(nèi)鏡像源快速下載包的方法詳解,需要的朋友可以參考下2024-01-01
Python的pytest測(cè)試框架中fixture的使用詳解
這篇文章主要介紹了pytest中fixture的使用詳解,pytest是一個(gè)非常成熟的全功能的Python測(cè)試框架,能夠支持簡(jiǎn)單的單元測(cè)試和復(fù)雜的功能測(cè)試,還可以用來做selenium/appnium等自動(dòng)化測(cè)試、接口自動(dòng)化測(cè)試,需要的朋友可以參考下2023-07-07
Python使用pptx實(shí)現(xiàn)復(fù)制頁面到其他PPT中
這篇文章主要為大家詳細(xì)介紹了python如何使用pptx庫實(shí)現(xiàn)從一個(gè)ppt復(fù)制頁面到另一個(gè)ppt里面,文中的示例代碼講解詳細(xì),感興趣的可以嘗試一下2023-02-02
opencv中圖像疊加/圖像融合/按位操作的實(shí)現(xiàn)
這篇文章主要介紹了opencv中圖像疊加/圖像融合/按位操作的實(shí)現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-04-04
Python?PaddleGAN實(shí)現(xiàn)調(diào)整照片人物年齡
這篇文章主要介紹了通過PaddleGAN實(shí)現(xiàn)照片人物的老年化和年輕化處理,文中的示例代碼講解有效,對(duì)我們學(xué)習(xí)或工作有一定的幫助,感興趣的可以學(xué)習(xí)一下2021-12-12

