来自机器学习实战一书

# !/usr/bin/python
# -*- coding: utf-8 -*-

from math import log
import operator

def createDataSet():
    dataSet = [[1, 1, "yes"], 
               [1, 1, "yes"],
               [1, 0, "No"], 
               [0, 1, "No"],
               [0, 1, "No"]]
    labels = ['no surfacing', 'flippers']
    return dataSet, labels

def calcShannonEnt(dataSet):
    """
    # 作用:计算最后一列信息熵
    # Args: dataSet: list格式
    # return: shannonEnt:最后一列信息熵
    """

    numEntries = len(dataSet) 
    labelCounts = {} # 变量声明 list 格式
    for featVec in dataSet:  # 按行循环
        currentLabel = featVec[-1] # 最后一行,作为键值
        if currentLabel not in labelCounts.keys(): # 当键值不存在时创建Label 计数为0
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1 # 键值存在时计数+1
    shannonEnt = 0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob, 2) # 信息熵计算公式
    return shannonEnt

def splitDataSet(dataSet, axis, value):
    """
    # 作用:划分数据集,离散特征
    # Args: dataSet: list格式,待划分数据集
    #       axis:划分数据集特征
    #       value:返还特征值
    # return: shannonEnt:最后一列信息熵
    """

    retDataSet = [] # 变量声明 list格式
    for featVec in dataSet: # 循环列表
        if featVec[axis] == value:
           reduceFeatVec = featVec[:axis]
           reduceFeatVec.extend(featVec[axis+1:])
           retDataSet.append(reduceFeatVec)
    return retDataSet

def chooseBestFeatureToSplit(dataSet):
    """
    # 作用:选择最佳区分度最高的特征
    # Args: dataSet: list格式,待划分数据集
    #       axis:划分数据集特征
    #       value:返还特征值
    # return: shannonEnt:最后一列信息熵
    """

    numFeatures = len(dataSet[0]) - 1 
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0 # 预存信息
    bestFeatures = -1 # 预存
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet] # 提取列表第一列
        uniqueVals = set(featList) # 类似 R中 unique函数
        newEntropy = 0 # 变量声明

        # 计算每种划分方式的信息熵
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value) 
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy

        # 选择最好的信息增益
        if (infoGain > bestInfoGain):   
            bestInfoGain = infoGain
            bestFeatures = i
    return bestFeatures

def majorityCnt(classList):
    classCount = {} # 变量声明,字典格式
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount += 1
    sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgetter(1), reverse = True)
    return sortedClassCount[0][0]

def createTrees(dataSet, labels):
    """
    # 作用:递归创建决策树,字典格式存储
    # Args: dataSet: list格式,待划分数据集,最后一列为划分标签
    #       labels: dataSet 特征列名
    # return: myTree:决策树,字典格式
    """   

    classList = [example[-1] for example in dataSet]
    # 当类别完全相同时停止划分
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    # 遍历完所有特征时 返还出现次数最多的类别
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)

    bestFeture = chooseBestFeatureToSplit(dataSet) # 寻找最佳分类特征位列数
    bestFetureLabel = labels[bestFeture] # 分类特征标签
    myTree = {bestFetureLabel:{}}
    del(labels[bestFeture]) # del删除 分类标签
    featValues = [example[bestFeture] for example in dataSet] # 当前最佳分类特征列提取
    uniqueVals = set(featValues) # 去重,类似unique
    for value in uniqueVals:
        subLabels = labels[:] # copy
        myTree[bestFetureLabel][value] = createTrees(splitDataSet(dataSet, bestFeture, value), subLabels) # 递归创建下一叶结点
    return myTree

def classify(inputTree, featLabels, testVec):
    """
    # 作用:递归进行分类判断
    # Args:inputTree: 已学习的决策树模型, 字典格式
    #      featLabels:列名
    #      testVec:测试样本, 单条List格式
    # return:classLabel: 测试样本类别
    """

    firstStr = inputTree.keys()[0]    # 提取第一结点名 
    secondDict = inputTree[firstStr]  # 提取第一结点 
    featIndex = featLabels.index(firstStr) # firstStr在列名的下标 类似于which函数
    for key in secondDict.keys(): 
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__== 'dict': # 判断当前格式是否为字典,即是否还存在下一判断结点
                classLabel = classify(secondDict[key], featLabels, testVec) # 递归求标签
            else: classLabel = secondDict[key]
    return classLabel

def storeTree(inputTree, filename):
    """
    # 作用:保存决策树模型(字典的保存)
    # Args:inputTree: 已学习的决策树模型, 字典格式
    #      filename:存储文件名
    """
    import pickle
    fw = open(filename, 'w')
    pickle.dump(inputTree, filename)
    fw.close()

def grabTree(filename):
    """
    # 作用:读取保存的决策树(字典的读取)
    # Args: filename:存储文件名
    """
    import pickle
    fr = open(filename)
    return pickle.load(fr)

# 测试
import tree

dataSet, labels= tree.createDataSet()
myTree = tree.creatTrees(dataSet, labels)

myTree
{'no surfacing': {0: 'No', 1: {'flippers': {0: 'No', 1: 'yes'}}}}
Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐