别再死记硬背了!用Python手撸ID3决策树,搞懂信息增益到底怎么算
·
别再死记硬背了!用Python手撸ID3决策树,搞懂信息增益到底怎么算
决策树算法是机器学习中最直观、最易解释的模型之一,但很多初学者在学习ID3算法时,往往被"信息增益"、"熵"这些概念卡住,陷入公式推导的泥潭。本文将通过Python代码实战,带你把抽象的信息论概念转化为可运行的代码逻辑,真正理解决策树是如何"思考"的。
1. 从生活案例理解决策树的本质
想象你在周末决定是否去爬山,可能会考虑以下因素:
- 天气如何?(晴朗/多云/下雨)
- 空气质量指数是否低于50?
- 昨晚睡眠是否充足?
这个决策过程本质上就是一个树形结构:
如果 天气==晴朗:
如果 空气质量<50:
去爬山
否则:
不去
否则如果 天气==多云:
如果 睡眠充足:
去爬山
否则:
不去
否则:
不去
决策树的核心优势 在于:
- 模型可解释性强,决策过程如同白盒
- 对数据预处理要求较低,能同时处理数值和类别特征
- 非参数方法,不需要假设数据分布
在Python中,我们可以用字典嵌套的方式表示这棵树:
decision_tree = {
'weather': {
'sunny': {
'aqi<50': 'go',
'aqi>=50': 'no'
},
'cloudy': {
'sleep_well': 'go',
'sleep_bad': 'no'
},
'rainy': 'no'
}
}
2. 信息熵:混乱度的数学度量
熵的概念源于热力学,表示系统的混乱程度。在信息论中,香农借用了这个概念来描述信息的不确定性。举个例子:
- 情况A:一个袋子有4红球和4蓝球
- 情况B:一个袋子有7红球和1蓝球
显然,情况A的"混乱度"更高,因为更难预测随机摸出一个球的颜色。这就是熵的直观理解。
计算熵的数学公式为: $$ H(X) = -\sum_{i=1}^{n} p(x_i) \log_2 p(x_i) $$
Python实现如下:
import numpy as np
def entropy(labels):
"""计算标签的信息熵"""
_, counts = np.unique(labels, return_counts=True)
probabilities = counts / len(labels)
return -np.sum(probabilities * np.log2(probabilities))
测试不同分布的熵值:
print(entropy([0,0,0,0,1,1,1,1])) # 1.0 (最大混乱度)
print(entropy([0,0,0,0,0,0,1,1])) # 0.543
print(entropy([0,0,0,0,0,0,0,1])) # 0.168
3. 信息增益:决策树的分裂准则
信息增益表示某个特征能为分类带来多少信息量。计算步骤:
- 计算原始数据集的熵(父节点熵)
- 按特征分割数据集,计算各子集的熵(子节点熵)
- 计算子节点熵的加权平均
- 信息增益 = 父节点熵 - 子节点熵
Python实现:
def information_gain(data, labels, feature_idx):
"""计算指定特征的信息增益"""
# 父节点熵
parent_entropy = entropy(labels)
# 获取该特征的所有唯一值
feature_values = np.unique(data[:, feature_idx])
# 计算子节点加权熵
child_entropy = 0
for value in feature_values:
subset_mask = data[:, feature_idx] == value
subset_labels = labels[subset_mask]
weight = len(subset_labels) / len(labels)
child_entropy += weight * entropy(subset_labels)
return parent_entropy - child_entropy
实际案例演示:
# 天气数据集
data = np.array([
['sunny', 'high', 'weak'],
['sunny', 'high', 'strong'],
['overcast', 'high', 'weak'],
['rain', 'high', 'weak'],
['rain', 'normal', 'weak'],
['rain', 'normal', 'strong'],
['overcast', 'normal', 'strong'],
['sunny', 'high', 'weak']
])
labels = np.array(['no', 'no', 'yes', 'yes', 'yes', 'no', 'yes', 'no'])
# 计算各特征的信息增益
for i in range(data.shape[1]):
print(f"Feature {i} gain: {information_gain(data, labels, i):.3f}")
输出结果:
Feature 0 gain: 0.246
Feature 1 gain: 0.029
Feature 2 gain: 0.151
这表明"天气"(Feature 0)是最佳分裂特征。
4. 构建完整的ID3决策树
有了信息增益的计算方法,我们可以递归构建决策树:
def id3(data, labels, feature_names, used_features=set()):
# 终止条件1:所有标签相同
if len(np.unique(labels)) == 1:
return labels[0]
# 终止条件2:没有可用特征
if len(used_features) == len(feature_names):
return np.bincount(labels).argmax()
# 选择最佳特征
best_gain = -1
best_feature_idx = None
for i in range(len(feature_names)):
if i not in used_features:
gain = information_gain(data, labels, i)
if gain > best_gain:
best_gain = gain
best_feature_idx = i
# 构建子树
tree = {feature_names[best_feature_idx]: {}}
used_features.add(best_feature_idx)
# 递归处理每个特征值
for value in np.unique(data[:, best_feature_idx]):
subset_mask = data[:, best_feature_idx] == value
subset_data = data[subset_mask]
subset_labels = labels[subset_mask]
if len(subset_labels) == 0:
tree[feature_names[best_feature_idx]][value] = np.bincount(labels).argmax()
else:
tree[feature_names[best_feature_idx]][value] = id3(
subset_data, subset_labels, feature_names, used_features.copy())
return tree
使用示例:
feature_names = ['outlook', 'humidity', 'wind']
tree = id3(data, labels, feature_names)
print(tree)
输出结果:
{
'outlook': {
'sunny': 'no',
'overcast': 'yes',
'rain': {
'wind': {
'weak': 'yes',
'strong': 'no'
}
}
}
}
5. 决策树的预测与可视化
构建好决策树后,我们可以实现预测函数:
def predict(tree, sample, feature_names):
if not isinstance(tree, dict):
return tree
feature = list(tree.keys())[0]
feature_idx = feature_names.index(feature)
value = sample[feature_idx]
subtree = tree[feature][value]
return predict(subtree, sample, feature_names)
测试预测:
test_sample = ['rain', 'normal', 'weak']
print(predict(tree, test_sample, feature_names)) # 输出: 'yes'
为了更直观理解,我们可以用graphviz可视化决策树:
from graphviz import Digraph
def visualize_tree(tree, feature_names, dot=None):
if dot is None:
dot = Digraph(comment='Decision Tree')
root = list(tree.keys())[0]
dot.node(root, root)
for value, subtree in tree[root].items():
if isinstance(subtree, dict):
child = list(subtree.keys())[0]
dot.node(child, child)
dot.edge(root, child, label=str(value))
visualize_tree(subtree, feature_names, dot)
else:
leaf = f"{value}_{subtree}"
dot.node(leaf, str(subtree))
dot.edge(root, leaf, label=str(value))
return dot
visualize_tree(tree, feature_names).render('decision_tree', view=True)
6. 决策树的局限与改进
虽然ID3算法简单直观,但存在几个明显问题:
- 倾向于选择取值多的特征 :比如"用户ID"这种唯一标识符会获得最大信息增益,但毫无意义
- 无法处理连续值特征 :需要先离散化处理
- 容易过拟合 :树可能生长得太深,捕捉到噪声
改进方案对比:
| 算法 | 改进点 | 适用场景 |
|---|---|---|
| C4.5 | 使用信息增益比,支持连续值 | 类别型特征为主 |
| CART | 使用基尼系数,支持回归任务 | 需要二叉树结构 |
| 随机森林 | 集成多棵树,减少过拟合 | 高维数据 |
基尼系数实现示例:
def gini(labels):
_, counts = np.unique(labels, return_counts=True)
probabilities = counts / len(labels)
return 1 - np.sum(probabilities ** 2)
在实际项目中,建议直接使用scikit-learn的实现:
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
import matplotlib.pyplot as plt
clf = DecisionTreeClassifier(criterion='entropy', max_depth=3)
clf.fit(data_encoded, labels)
plt.figure(figsize=(12,8))
tree.plot_tree(clf, feature_names=feature_names,
class_names=['no', 'yes'], filled=True)
plt.show()
更多推荐
所有评论(0)