python-ID3(理解)
来自机器学习实战一书# !/usr/bin/python# -*- coding: utf-8 -*-from math import logimport operatordef createDataSet():dataSet = [[1, 1, "yes"],[1, 1, "yes"],[1, 0, "No"],
·
来自机器学习实战一书
# !/usr/bin/python
# -*- coding: utf-8 -*-
from math import log
import operator
def createDataSet():
dataSet = [[1, 1, "yes"],
[1, 1, "yes"],
[1, 0, "No"],
[0, 1, "No"],
[0, 1, "No"]]
labels = ['no surfacing', 'flippers']
return dataSet, labels
def calcShannonEnt(dataSet):
"""
# 作用:计算最后一列信息熵
# Args: dataSet: list格式
# return: shannonEnt:最后一列信息熵
"""
numEntries = len(dataSet)
labelCounts = {} # 变量声明 list 格式
for featVec in dataSet: # 按行循环
currentLabel = featVec[-1] # 最后一行,作为键值
if currentLabel not in labelCounts.keys(): # 当键值不存在时创建Label 计数为0
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1 # 键值存在时计数+1
shannonEnt = 0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob * log(prob, 2) # 信息熵计算公式
return shannonEnt
def splitDataSet(dataSet, axis, value):
"""
# 作用:划分数据集,离散特征
# Args: dataSet: list格式,待划分数据集
# axis:划分数据集特征
# value:返还特征值
# return: shannonEnt:最后一列信息熵
"""
retDataSet = [] # 变量声明 list格式
for featVec in dataSet: # 循环列表
if featVec[axis] == value:
reduceFeatVec = featVec[:axis]
reduceFeatVec.extend(featVec[axis+1:])
retDataSet.append(reduceFeatVec)
return retDataSet
def chooseBestFeatureToSplit(dataSet):
"""
# 作用:选择最佳区分度最高的特征
# Args: dataSet: list格式,待划分数据集
# axis:划分数据集特征
# value:返还特征值
# return: shannonEnt:最后一列信息熵
"""
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0 # 预存信息
bestFeatures = -1 # 预存
for i in range(numFeatures):
featList = [example[i] for example in dataSet] # 提取列表第一列
uniqueVals = set(featList) # 类似 R中 unique函数
newEntropy = 0 # 变量声明
# 计算每种划分方式的信息熵
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
# 选择最好的信息增益
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeatures = i
return bestFeatures
def majorityCnt(classList):
classCount = {} # 变量声明,字典格式
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount += 1
sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgetter(1), reverse = True)
return sortedClassCount[0][0]
def createTrees(dataSet, labels):
"""
# 作用:递归创建决策树,字典格式存储
# Args: dataSet: list格式,待划分数据集,最后一列为划分标签
# labels: dataSet 特征列名
# return: myTree:决策树,字典格式
"""
classList = [example[-1] for example in dataSet]
# 当类别完全相同时停止划分
if classList.count(classList[0]) == len(classList):
return classList[0]
# 遍历完所有特征时 返还出现次数最多的类别
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeture = chooseBestFeatureToSplit(dataSet) # 寻找最佳分类特征位列数
bestFetureLabel = labels[bestFeture] # 分类特征标签
myTree = {bestFetureLabel:{}}
del(labels[bestFeture]) # del删除 分类标签
featValues = [example[bestFeture] for example in dataSet] # 当前最佳分类特征列提取
uniqueVals = set(featValues) # 去重,类似unique
for value in uniqueVals:
subLabels = labels[:] # copy
myTree[bestFetureLabel][value] = createTrees(splitDataSet(dataSet, bestFeture, value), subLabels) # 递归创建下一叶结点
return myTree
def classify(inputTree, featLabels, testVec):
"""
# 作用:递归进行分类判断
# Args:inputTree: 已学习的决策树模型, 字典格式
# featLabels:列名
# testVec:测试样本, 单条List格式
# return:classLabel: 测试样本类别
"""
firstStr = inputTree.keys()[0] # 提取第一结点名
secondDict = inputTree[firstStr] # 提取第一结点
featIndex = featLabels.index(firstStr) # firstStr在列名的下标 类似于which函数
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__== 'dict': # 判断当前格式是否为字典,即是否还存在下一判断结点
classLabel = classify(secondDict[key], featLabels, testVec) # 递归求标签
else: classLabel = secondDict[key]
return classLabel
def storeTree(inputTree, filename):
"""
# 作用:保存决策树模型(字典的保存)
# Args:inputTree: 已学习的决策树模型, 字典格式
# filename:存储文件名
"""
import pickle
fw = open(filename, 'w')
pickle.dump(inputTree, filename)
fw.close()
def grabTree(filename):
"""
# 作用:读取保存的决策树(字典的读取)
# Args: filename:存储文件名
"""
import pickle
fr = open(filename)
return pickle.load(fr)
# 测试
import tree
dataSet, labels= tree.createDataSet()
myTree = tree.creatTrees(dataSet, labels)
myTree
{'no surfacing': {0: 'No', 1: {'flippers': {0: 'No', 1: 'yes'}}}}
更多推荐
已为社区贡献2条内容
所有评论(0)