Python 手写实现鸢尾花决策树
Iris数据集是常用的分类实验数据集,早在1936年,模式识别的先驱Fisher就在论文中使用了它 (直至今日该论文仍然被频繁引用)。Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性:花萼长度(sepal length),花萼宽度(sepal width),花瓣长度(petal length),花瓣宽度(petal
Iris数据集是常用的分类实验数据集,早在1936年,模式识别的先驱Fisher就在论文中使用了它 (直至今日该论文仍然被频繁引用)。
Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性:花萼长度(sepal length),花萼宽度(sepal width),花瓣长度(petal length),花瓣宽度(petal width),可通过4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。在三个类别中,其中有一个类别和其他两个类别是线性可分的。
在sklearn中已内置了此数据集。
核心算法思想就是:
每到一个节点,对于现在数据集中四个属性的所有值,根据小于和大于此属性值分成两个数据集,并计算香农熵,取所有香农熵最大增益的那个特征及值作为划分标准;划分直到所有数据均为一类数据。
为了说明这一思想,举个例子:
现在数据集中有三个数据:
1:0.1 0.2 0.3 0.4 0
2:0.2 0.3 0.4 0.5 1
3:0.3 0.4 0.4 0.6 2
此时,数据集不是同一类数据,所以要根据某个标准进行划分
为了找到那个标准,先考察pos=0,即第一个特征。将第一个特征的三个值0.1,0.2,0.3分别作为标准,比如:
将0.1作为标准,pos=0小于等于0.1的数据,即数据1划分为一个数据集,大于0.1的数据集,即数据2,3划分为一个数据集
则得到两个数据集:{1},{2,3}
计算此时的香农熵,接下来计算第一个特征的其余值作为标准时的香农熵,之后计算出第2,3,4个属性所有值的香农熵,这些香农熵中最大的那个pos和值即为我们决策树的节点
下面是这一逻辑的代码:
#选择最好的特征值进行分类
def choose_best_split(data_set):
base_Ent=calculate_Ent(data_set)
best_increase=0.0
best_feature=[-1,-1]
for i in range(4):
features=[j[i] for j in data_set]
unique=set(features)
for feature in unique:
less_Set,more_Set=spliit_Set(data_set, i, feature)
tmp=len(less_Set)/float(len(data_set))
new_Ent=tmp*calculate_Ent(less_Set)
new_Ent+=(1-tmp)*calculate_Ent(more_Set)
increase=base_Ent-new_Ent
if increase>best_increase:
best_increase=increase
best_feature=[i,feature]
return best_feature,best_increase
一、数据集初始化
将标签附到特征值之后:
#初始化数据集
def init_data_set():
iris = load_iris() #导入数据集iris
iris_feature = iris.data.tolist() #特征数据
iris_target = iris.target.tolist() #分类数据
for i in range(len(iris_feature)):
iris_feature[i].append(iris_target[i])
return iris_feature
二、划分数据集
将数据集划分成训练集和测试集
#划分数据集
def create_set(data_set,split_rate=0.8):
#0的是测试集,1的是训练集
length=len(data_set)
train_num=int(length*split_rate)
test_num=length-train_num
random_list=[1]*train_num
random_list.extend([0]*test_num)
random.shuffle(random_list)
test_set=[]
train_set=[]
for i in range(length):
if random_list[i]==0:
test_set.append(data_set[i])
else:
train_set.append(data_set[i])
return test_set,train_set
三、计算信息熵
计算香农熵
#计算信息熵
def calculate_Ent(data_set):
label={}
for i in data_set:
if i[-1] not in label.keys():
label[i[-1]]=1
else:
label[i[-1]]+=1
Ent=0.0
for i in label:
tmp=float(label[i])/len(data_set)
Ent-=tmp*log(tmp,2)
return Ent
四、根据pos和value划分成两个数据集
#根据pos和value划分成两个数据集
def spliit_Set(data_set,pos,value):
less_Set=[]
more_Set=[]
for item in data_set:
if item[pos]<value:
less_Set.append(item)
else:
more_Set.append(item)
return less_Set,more_Set
五、构造决策树
当数据集均为同一类别时停止构造:
#构造决策树
def create_tree(data_set):
myTree={}
label=[i[-1] for i in data_set]
label_set=set(label)
if len(label_set)==1:
myTree['class']=label[0]
return myTree
best_feature,best_increase=choose_best_split(data_set)
myTree['node']=best_feature
less_Set,more_Set=spliit_Set(data_set, best_feature[0], best_feature[1])
myTree['left']=create_tree(less_Set)
myTree['right']=create_tree(more_Set)
return myTree
六、可视化
若想要将决策树可视化,需要将决策树结果变成字典形式
#画决策树
def draw_tree(data_set):
myTree={}
label=[i[-1] for i in data_set]
label_set=set(label)
if len(label_set)==1:
return 'type:'+str(label[0])+'\nsample:'+str(len(data_set))
best_feature,best_increase=choose_best_split(data_set)
string='X['+str(best_feature[0])+']<'+str(best_feature[1])
string+='\nbest_increase='+str(round(best_increase,3))
string+='\nsample:'+str(len(data_set))
myTree[string]={}
less_Set,more_Set=spliit_Set(data_set, best_feature[0], best_feature[1])
myTree[string]['True']=draw_tree(less_Set)
myTree[string]['False']=draw_tree(more_Set)
return myTree
结果:
以下是完整代码:
# -*- coding: utf-8 -*-
"""
@author: starry_sky
"""
from sklearn.datasets import load_iris
import random
from math import log
import plot
#初始化数据集
def init_data_set():
iris = load_iris() #导入数据集iris
iris_feature = iris.data.tolist() #特征数据
iris_target = iris.target.tolist() #分类数据
for i in range(len(iris_feature)):
iris_feature[i].append(iris_target[i])
return iris_feature
#划分数据集
def create_set(data_set,split_rate=0.8):
#0的是测试集,1的是训练集
length=len(data_set)
train_num=int(length*split_rate)
test_num=length-train_num
random_list=[1]*train_num
random_list.extend([0]*test_num)
random.shuffle(random_list)
test_set=[]
train_set=[]
for i in range(length):
if random_list[i]==0:
test_set.append(data_set[i])
else:
train_set.append(data_set[i])
return test_set,train_set
#计算信息熵
def calculate_Ent(data_set):
label={}
for i in data_set:
if i[-1] not in label.keys():
label[i[-1]]=1
else:
label[i[-1]]+=1
Ent=0.0
for i in label:
tmp=float(label[i])/len(data_set)
Ent-=tmp*log(tmp,2)
return Ent
#根据pos和value划分成两个数据集
def spliit_Set(data_set,pos,value):
less_Set=[]
more_Set=[]
for item in data_set:
if item[pos]<value:
less_Set.append(item)
else:
more_Set.append(item)
return less_Set,more_Set
#选择最好的特征值进行分类
def choose_best_split(data_set):
base_Ent=calculate_Ent(data_set)
best_increase=0.0
best_feature=[-1,-1]
for i in range(4):
features=[j[i] for j in data_set]
unique=set(features)
for feature in unique:
less_Set,more_Set=spliit_Set(data_set, i, feature)
tmp=len(less_Set)/float(len(data_set))
new_Ent=tmp*calculate_Ent(less_Set)
new_Ent+=(1-tmp)*calculate_Ent(more_Set)
increase=base_Ent-new_Ent
if increase>best_increase:
best_increase=increase
best_feature=[i,feature]
return best_feature,best_increase
#构造决策树
def create_tree(data_set):
myTree={}
label=[i[-1] for i in data_set]
label_set=set(label)
if len(label_set)==1:
myTree['class']=label[0]
return myTree
best_feature,best_increase=choose_best_split(data_set)
myTree['node']=best_feature
less_Set,more_Set=spliit_Set(data_set, best_feature[0], best_feature[1])
myTree['left']=create_tree(less_Set)
myTree['right']=create_tree(more_Set)
return myTree
#画决策树
def draw_tree(data_set):
myTree={}
label=[i[-1] for i in data_set]
label_set=set(label)
if len(label_set)==1:
return 'type:'+str(label[0])+'\nsample:'+str(len(data_set))
best_feature,best_increase=choose_best_split(data_set)
string='X['+str(best_feature[0])+']<'+str(best_feature[1])
string+='\nbest_increase='+str(round(best_increase,3))
string+='\nsample:'+str(len(data_set))
myTree[string]={}
less_Set,more_Set=spliit_Set(data_set, best_feature[0], best_feature[1])
myTree[string]['True']=draw_tree(less_Set)
myTree[string]['False']=draw_tree(more_Set)
return myTree
data_set=init_data_set()
plot.createPlot(draw_tree(data_set))
将字典类型的决策树画出的plot函数和使用测试集来预剪枝的代码,在公众号中发送【python 决策树样例】可获取。
更多推荐
所有评论(0)