​Iris数据集是常用的分类实验数据集,早在1936年,模式识别的先驱Fisher就在论文中使用了它 (直至今日该论文仍然被频繁引用)。

在这里插入图片描述

Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性:花萼长度(sepal length),花萼宽度(sepal width),花瓣长度(petal length),花瓣宽度(petal width),可通过4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。在三个类别中,其中有一个类别和其他两个类别是线性可分的。

在sklearn中已内置了此数据集。

核心算法思想就是:

每到一个节点,对于现在数据集中四个属性的所有值,根据小于和大于此属性值分成两个数据集,并计算香农熵,取所有香农熵最大增益的那个特征及值作为划分标准;划分直到所有数据均为一类数据。

为了说明这一思想,举个例子:

现在数据集中有三个数据:

1:0.1 0.2 0.3 0.4 0

2:0.2 0.3 0.4 0.5 1

3:0.3 0.4 0.4 0.6 2

此时,数据集不是同一类数据,所以要根据某个标准进行划分

为了找到那个标准,先考察pos=0,即第一个特征。将第一个特征的三个值0.1,0.2,0.3分别作为标准,比如:

将0.1作为标准,pos=0小于等于0.1的数据,即数据1划分为一个数据集,大于0.1的数据集,即数据2,3划分为一个数据集

则得到两个数据集:{1},{2,3}

计算此时的香农熵,接下来计算第一个特征的其余值作为标准时的香农熵,之后计算出第2,3,4个属性所有值的香农熵,这些香农熵中最大的那个pos和值即为我们决策树的节点

下面是这一逻辑的代码:

#选择最好的特征值进行分类
def choose_best_split(data_set):
    base_Ent=calculate_Ent(data_set)
    best_increase=0.0
    best_feature=[-1,-1]
    for i in range(4):
        features=[j[i] for j in data_set]
        unique=set(features)
        for feature in unique:
            less_Set,more_Set=spliit_Set(data_set, i, feature)
            tmp=len(less_Set)/float(len(data_set))
            new_Ent=tmp*calculate_Ent(less_Set)
            new_Ent+=(1-tmp)*calculate_Ent(more_Set)
            increase=base_Ent-new_Ent
            if increase>best_increase:
                best_increase=increase
                best_feature=[i,feature]
    return best_feature,best_increase

一、数据集初始化

将标签附到特征值之后:

#初始化数据集
def init_data_set():
    iris = load_iris()  #导入数据集iris
    iris_feature = iris.data.tolist()    #特征数据
    iris_target = iris.target.tolist()   #分类数据
    for i in range(len(iris_feature)):
        iris_feature[i].append(iris_target[i])
    return iris_feature

二、划分数据集

将数据集划分成训练集和测试集

#划分数据集
def create_set(data_set,split_rate=0.8):
    #0的是测试集,1的是训练集
    length=len(data_set)
    train_num=int(length*split_rate)
    test_num=length-train_num
    random_list=[1]*train_num
    random_list.extend([0]*test_num)
    random.shuffle(random_list)
    test_set=[]
    train_set=[]
    for i in range(length):
        if random_list[i]==0:
            test_set.append(data_set[i])
        else:
            train_set.append(data_set[i])
    return test_set,train_set

三、计算信息熵

计算香农熵

#计算信息熵
def calculate_Ent(data_set):
    label={}
    for i in data_set:
        if i[-1] not in label.keys():
            label[i[-1]]=1
        else:
            label[i[-1]]+=1
    Ent=0.0
    for i in label:
        tmp=float(label[i])/len(data_set)
        Ent-=tmp*log(tmp,2)
    return Ent

四、根据pos和value划分成两个数据集

#根据pos和value划分成两个数据集
def spliit_Set(data_set,pos,value):
    less_Set=[]
    more_Set=[]
    for item in data_set:
        if item[pos]<value:
            less_Set.append(item)
        else:
            more_Set.append(item)
    return less_Set,more_Set

五、构造决策树

当数据集均为同一类别时停止构造:

#构造决策树
def create_tree(data_set):
    myTree={}
    label=[i[-1] for i in data_set]
    label_set=set(label)
    if len(label_set)==1:
        myTree['class']=label[0]
        return myTree
    best_feature,best_increase=choose_best_split(data_set)
    myTree['node']=best_feature
    less_Set,more_Set=spliit_Set(data_set, best_feature[0], best_feature[1])
    myTree['left']=create_tree(less_Set)
    myTree['right']=create_tree(more_Set)
    return myTree

六、可视化

若想要将决策树可视化,需要将决策树结果变成字典形式

#画决策树
def draw_tree(data_set):
    myTree={}
    label=[i[-1] for i in data_set]
    label_set=set(label)
    if len(label_set)==1:
        return 'type:'+str(label[0])+'\nsample:'+str(len(data_set))
    best_feature,best_increase=choose_best_split(data_set)
    string='X['+str(best_feature[0])+']<'+str(best_feature[1])
    string+='\nbest_increase='+str(round(best_increase,3))
    string+='\nsample:'+str(len(data_set))
    myTree[string]={}
    less_Set,more_Set=spliit_Set(data_set, best_feature[0], best_feature[1])
    myTree[string]['True']=draw_tree(less_Set)
    myTree[string]['False']=draw_tree(more_Set)
    return myTree 

结果:

在这里插入图片描述

以下是完整代码:

# -*- coding: utf-8 -*-
"""
@author: starry_sky
"""
​
​
from sklearn.datasets import load_iris
import random
from math import log 
import plot
​
#初始化数据集
def init_data_set():
    iris = load_iris()  #导入数据集iris
    iris_feature = iris.data.tolist()    #特征数据
    iris_target = iris.target.tolist()   #分类数据
    for i in range(len(iris_feature)):
        iris_feature[i].append(iris_target[i])
    return iris_feature
​
#划分数据集
def create_set(data_set,split_rate=0.8):
    #0的是测试集,1的是训练集
    length=len(data_set)
    train_num=int(length*split_rate)
    test_num=length-train_num
    random_list=[1]*train_num
    random_list.extend([0]*test_num)
    random.shuffle(random_list)
    test_set=[]
    train_set=[]
    for i in range(length):
        if random_list[i]==0:
            test_set.append(data_set[i])
        else:
            train_set.append(data_set[i])
    return test_set,train_set
​
#计算信息熵
def calculate_Ent(data_set):
    label={}
    for i in data_set:
        if i[-1] not in label.keys():
            label[i[-1]]=1
        else:
            label[i[-1]]+=1
    Ent=0.0
    for i in label:
        tmp=float(label[i])/len(data_set)
        Ent-=tmp*log(tmp,2)
    return Ent
​
#根据pos和value划分成两个数据集
def spliit_Set(data_set,pos,value):
    less_Set=[]
    more_Set=[]
    for item in data_set:
        if item[pos]<value:
            less_Set.append(item)
        else:
            more_Set.append(item)
    return less_Set,more_Set
​
#选择最好的特征值进行分类
def choose_best_split(data_set):
    base_Ent=calculate_Ent(data_set)
    best_increase=0.0
    best_feature=[-1,-1]
    for i in range(4):
        features=[j[i] for j in data_set]
        unique=set(features)
        for feature in unique:
            less_Set,more_Set=spliit_Set(data_set, i, feature)
            tmp=len(less_Set)/float(len(data_set))
            new_Ent=tmp*calculate_Ent(less_Set)
            new_Ent+=(1-tmp)*calculate_Ent(more_Set)
            increase=base_Ent-new_Ent
            if increase>best_increase:
                best_increase=increase
                best_feature=[i,feature]
    return best_feature,best_increase
    
#构造决策树
def create_tree(data_set):
    myTree={}
    label=[i[-1] for i in data_set]
    label_set=set(label)
    if len(label_set)==1:
        myTree['class']=label[0]
        return myTree
    best_feature,best_increase=choose_best_split(data_set)
    myTree['node']=best_feature
    less_Set,more_Set=spliit_Set(data_set, best_feature[0], best_feature[1])
    myTree['left']=create_tree(less_Set)
    myTree['right']=create_tree(more_Set)
    return myTree
​
#画决策树
def draw_tree(data_set):
    myTree={}
    label=[i[-1] for i in data_set]
    label_set=set(label)
    if len(label_set)==1:
        return 'type:'+str(label[0])+'\nsample:'+str(len(data_set))
    best_feature,best_increase=choose_best_split(data_set)
    string='X['+str(best_feature[0])+']<'+str(best_feature[1])
    string+='\nbest_increase='+str(round(best_increase,3))
    string+='\nsample:'+str(len(data_set))
    myTree[string]={}
    less_Set,more_Set=spliit_Set(data_set, best_feature[0], best_feature[1])
    myTree[string]['True']=draw_tree(less_Set)
    myTree[string]['False']=draw_tree(more_Set)
    return myTree 
​
data_set=init_data_set()
plot.createPlot(draw_tree(data_set))

将字典类型的决策树画出的plot函数和使用测试集来预剪枝的代码,在公众号中发送【python 决策树样例】可获取。
powerful code

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐