别再死记硬背了!用Python手撸一个ID3决策树,从熵到分类器一次搞懂
别再死记硬背了!用Python手撸一个ID3决策树,从熵到分类器一次搞懂
决策树算法作为机器学习中最直观的模型之一,常被比作"机器学习界的Hello World"。但很多初学者在学习时容易陷入两个极端:要么被数学公式吓退,要么只会调用sklearn的DecisionTreeClassifier而不知其所以然。今天,我们将打破这种局面——用不到200行Python代码,从熵的计算开始,一步步构建完整的ID3决策树分类器。
1. 决策树与信息熵:从概念到代码
1.1 为什么需要信息熵
想象你在玩20个问题的游戏:每次提问都希望能最大程度地缩小答案范围。信息熵正是量化这种"不确定性"的数学工具。在决策树中,我们通过计算信息增益来决定每个节点的分裂特征,而这一切都建立在信息熵的基础上。
熵的计算公式看似简单:
H(X) = -Σ p(x)log₂p(x)
但如何将其转化为可运行的Python代码?下面是一个直观的实现:
import numpy as np
def calc_entropy(labels):
"""计算信息熵"""
unique_labels, counts = np.unique(labels, return_counts=True)
probabilities = counts / len(labels)
return -np.sum(probabilities * np.log2(probabilities))
这个函数的核心逻辑:
- 使用
np.unique统计每个类别出现的次数 - 计算每个类别的概率
- 应用熵公式求和
注意:在实际应用中,当概率为0时,log2(0)无定义,但NumPy会正确处理这种情况返回0
1.2 信息增益的计算
信息增益衡量的是特征对分类不确定性的减少程度。计算步骤分为:
- 计算原始数据集的信息熵(父节点熵)
- 按特征分割数据集后,计算加权平均熵(子节点熵)
- 两者相减得到信息增益
对应的Python实现:
def information_gain(data, labels, feature_idx):
"""计算指定特征的信息增益"""
parent_entropy = calc_entropy(labels)
# 获取该特征的所有唯一值
feature_values = np.unique(data[:, feature_idx])
# 计算加权子节点熵
child_entropy = 0
for value in feature_values:
mask = data[:, feature_idx] == value
subset_labels = labels[mask]
weight = len(subset_labels) / len(labels)
child_entropy += weight * calc_entropy(subset_labels)
return parent_entropy - child_entropy
2. 构建决策树的核心逻辑
2.1 递归构建树结构
决策树的构建本质上是一个递归过程,包含三个关键步骤:
- 选择最佳分裂特征 :计算所有特征的信息增益,选择增益最大的
- 创建分支节点 :根据选定特征的不同取值创建分支
- 递归处理子集 :对每个分支对应的数据子集重复上述过程
实现这一逻辑的Python代码如下:
def build_tree(data, labels, feature_names):
# 终止条件1:所有样本属于同一类别
if len(np.unique(labels)) == 1:
return labels[0]
# 终止条件2:没有更多特征可用于分裂
if data.shape[1] == 0:
return np.bincount(labels).argmax()
# 选择最佳分裂特征
best_feature = select_best_feature(data, labels)
best_feature_name = feature_names[best_feature]
# 创建树节点
tree = {best_feature_name: {}}
# 获取该特征的所有唯一值
feature_values = np.unique(data[:, best_feature])
# 递归构建子树
for value in feature_values:
mask = data[:, best_feature] == value
subset_data = np.delete(data[mask], best_feature, axis=1)
subset_labels = labels[mask]
subset_feature_names = np.delete(feature_names, best_feature)
tree[best_feature_name][value] = build_tree(
subset_data, subset_labels, subset_feature_names)
return tree
2.2 处理边界情况
在实际编码中,我们需要处理几种特殊情况:
- 连续特征处理 :ID3算法原本只处理离散特征,可通过二分法扩展
- 缺失值处理 :可采用常见值填充或概率分配
- 过拟合预防 :设置最大深度或最小样本数限制
以下是增强版的终止条件判断:
def should_stop(data, labels, max_depth, current_depth, min_samples):
# 所有样本属于同一类别
if len(np.unique(labels)) == 1:
return True
# 达到最大深度限制
if current_depth >= max_depth:
return True
# 样本数小于最小限制
if len(labels) < min_samples:
return True
# 没有更多特征可用于分裂
if data.shape[1] == 0:
return True
return False
3. 完整ID3决策树实现
3.1 决策树分类器类
将上述功能封装成一个完整的类,提高代码复用性:
class ID3DecisionTree:
def __init__(self, max_depth=None, min_samples_split=2):
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.tree = None
self.feature_names = None
def fit(self, data, labels, feature_names):
self.feature_names = feature_names
self.tree = self._build_tree(
data, labels, feature_names, current_depth=0)
def _build_tree(self, data, labels, feature_names, current_depth):
# 终止条件判断
if self._should_stop(data, labels, current_depth):
return self._make_leaf_node(labels)
# 选择最佳分裂特征
best_idx = self._select_best_feature(data, labels)
best_name = feature_names[best_idx]
# 创建树节点
node = {best_name: {}}
# 获取特征唯一值并递归构建子树
feature_values = np.unique(data[:, best_idx])
for value in feature_values:
mask = data[:, best_idx] == value
subset_data = np.delete(data[mask], best_idx, axis=1)
subset_labels = labels[mask]
subset_features = np.delete(feature_names, best_idx)
node[best_name][value] = self._build_tree(
subset_data, subset_labels, subset_features, current_depth+1)
return node
def predict(self, X):
return np.array([self._traverse_tree(x, self.tree) for x in X])
def _traverse_tree(self, sample, node):
if not isinstance(node, dict):
return node
feature_name = next(iter(node))
feature_idx = np.where(self.feature_names == feature_name)[0][0]
feature_value = sample[feature_idx]
if feature_value in node[feature_name]:
return self._traverse_tree(sample, node[feature_name][feature_value])
else:
# 处理未见过的特征值
return self._handle_unknown_value(node[feature_name])
3.2 可视化决策树
理解决策树的最好方式就是可视化其结构。我们可以使用简单的文本缩进来展示:
def print_tree(node, indent=""):
if not isinstance(node, dict):
print(indent + "预测: " + str(node))
return
feature_name = next(iter(node))
print(indent + feature_name)
for value, subtree in node[feature_name].items():
print(indent + "├── " + str(value) + ":")
print_tree(subtree, indent + "│ ")
4. 实战:用自制决策树解决真实问题
4.1 准备示例数据集
让我们创建一个简单的贷款审批数据集:
# 特征:年龄(0:青年,1:中年,2:老年),有工作(0:否,1:是),有房子(0:否,1:是)
data = np.array([
[0, 0, 0],
[0, 0, 1],
[0, 1, 1],
[1, 0, 0],
[1, 0, 1],
[1, 1, 1],
[2, 0, 0],
[2, 0, 1],
[2, 1, 1]
])
# 标签:0:拒绝贷款,1:批准贷款
labels = np.array([0, 1, 1, 0, 1, 1, 0, 1, 1])
feature_names = np.array(['年龄', '有工作', '有房子'])
4.2 训练并测试模型
# 初始化并训练模型
tree = ID3DecisionTree(max_depth=3)
tree.fit(data, labels, feature_names)
# 打印树结构
print_tree(tree.tree)
# 预测新样本
test_samples = np.array([
[0, 1, 0], # 青年,有工作,没房子
[1, 0, 1], # 中年,没工作,有房子
[2, 1, 0] # 老年,有工作,没房子
])
predictions = tree.predict(test_samples)
print("预测结果:", predictions)
4.3 性能优化技巧
- 特征预排序 :对连续特征提前排序,加速最优分割点查找
- 并行计算 :在多核CPU上并行计算不同特征的信息增益
- 缓存中间结果 :避免重复计算相同子集的信息熵
- 使用Cython加速 :将计算密集型部分用Cython重写
优化后的信息增益计算示例:
from joblib import Parallel, delayed
def parallel_information_gain(data, labels, feature_idx):
# ...同前...
def select_best_feature_parallel(data, labels, n_jobs=-1):
n_features = data.shape[1]
gains = Parallel(n_jobs=n_jobs)(
delayed(parallel_information_gain)(data, labels, i)
for i in range(n_features)
)
return np.argmax(gains)
5. 从ID3到C4.5:决策树算法的演进
虽然我们实现了基础的ID3算法,但现代决策树通常使用其改进版本C4.5。主要改进包括:
| 特性 | ID3算法 | C4.5算法 |
|---|---|---|
| 分裂标准 | 信息增益 | 信息增益比 |
| 处理连续特征 | 不支持 | 支持 |
| 缺失值处理 | 不支持 | 支持 |
| 剪枝策略 | 无 | 悲观剪枝 |
| 多叉树 | 是 | 是 |
实现信息增益比的Python代码:
def gain_ratio(data, labels, feature_idx):
info_gain = information_gain(data, labels, feature_idx)
# 计算分裂信息(Split Information)
feature_values = data[:, feature_idx]
_, counts = np.unique(feature_values, return_counts=True)
probabilities = counts / len(feature_values)
split_info = -np.sum(probabilities * np.log2(probabilities))
# 避免除以0
if split_info == 0:
return 0
return info_gain / split_info
决策树算法的魅力在于其直观性和可解释性。通过这次从零实现,我深刻体会到:在机器学习中,真正理解一个算法的最佳方式就是亲手实现它。当看到自己编写的树结构能够正确分类样本时,那种成就感远胜过调用现成库函数。
更多推荐
所有评论(0)