从信息熵到分类预测:用Python手写ID3决策树的实战指南

决策树算法作为机器学习中最直观的模型之一,其核心思想是通过一系列规则对数据进行分类或回归。不同于神经网络这样的"黑箱"模型,决策树的结构清晰可见,每一步决策过程都像人类思考问题时的逻辑链条。本文将带您从零开始实现ID3决策树算法,不仅理解其背后的数学原理,更能通过代码实践将其转化为可运行的分类工具。

1. 决策树基础与信息论原理

决策树算法的核心在于如何选择最优的特征进行节点分裂。ID3算法使用信息增益作为特征选择标准,而理解信息增益的前提是掌握信息熵的概念。

信息熵(Information Entropy)由克劳德·香农提出,用来衡量系统的不确定性。在分类问题中,熵可以表示数据集的"混乱程度"。当数据集中所有样本都属于同一类别时,熵为0;当类别分布均匀时,熵达到最大值。

信息熵的数学定义为:

import numpy as np

def calc_info_entropy(labels):
    """计算信息熵"""
    unique_labels, counts = np.unique(labels, return_counts=True)
    probabilities = counts / len(labels)
    entropy = -np.sum(probabilities * np.log2(probabilities))
    return entropy

信息增益 则是特征选择的关键指标,表示使用某特征划分数据集后熵的减少量。信息增益越大,说明该特征对分类的贡献越大。计算信息增益需要先计算条件熵:

def calc_conditional_entropy(features, labels, feature_idx, value):
    """计算条件熵"""
    mask = features[:, feature_idx] == value
    sub_labels = labels[mask]
    p = len(sub_labels) / len(labels)
    return p * calc_info_entropy(sub_labels)

2. 核心算法实现:从信息增益到树构建

有了信息熵和条件熵的计算方法,我们可以实现信息增益的计算函数:

def calc_info_gain(features, labels, feature_idx):
    """计算信息增益"""
    base_entropy = calc_info_entropy(labels)
    unique_values = np.unique(features[:, feature_idx])
    cond_entropy = sum(
        calc_conditional_entropy(features, labels, feature_idx, value)
        for value in unique_values
    )
    return base_entropy - cond_entropy

基于信息增益,我们可以选择最优特征进行节点分裂:

def choose_best_feature(features, labels):
    """选择信息增益最大的特征"""
    n_features = features.shape[1]
    best_gain = -1
    best_feature = None
    
    for feature_idx in range(n_features):
        gain = calc_info_gain(features, labels, feature_idx)
        if gain > best_gain:
            best_gain = gain
            best_feature = feature_idx
            
    return best_feature

3. 递归构建决策树

决策树的构建是一个递归过程,核心思路是:

  1. 选择当前最优特征
  2. 根据特征值划分数据集
  3. 对每个子集递归构建子树
  4. 设置终止条件防止无限递归

实现代码如下:

def create_decision_tree(features, labels, feature_names=None):
    """递归构建决策树"""
    
    # 终止条件1:所有样本属于同一类别
    if len(np.unique(labels)) == 1:
        return labels[0]
    
    # 终止条件2:没有特征可分或所有样本特征相同
    if features.shape[1] == 0 or len(np.unique(features, axis=0)) == 1:
        return np.argmax(np.bincount(labels))
    
    # 选择最优特征
    best_feature_idx = choose_best_feature(features, labels)
    if feature_names is not None:
        best_feature_name = feature_names[best_feature_idx]
    else:
        best_feature_name = str(best_feature_idx)
    
    # 初始化树结构
    tree = {best_feature_name: {}}
    
    # 获取该特征的所有可能值
    unique_values = np.unique(features[:, best_feature_idx])
    
    # 递归构建子树
    for value in unique_values:
        mask = features[:, best_feature_idx] == value
        sub_features = np.delete(features[mask], best_feature_idx, axis=1)
        sub_labels = labels[mask]
        
        if feature_names is not None:
            sub_feature_names = np.delete(feature_names, best_feature_idx)
        else:
            sub_feature_names = None
            
        tree[best_feature_name][value] = create_decision_tree(
            sub_features, sub_labels, sub_feature_names
        )
    
    return tree

4. 决策树的预测与可视化

构建好决策树后,我们需要实现预测功能:

def predict(tree, sample, feature_names=None):
    """使用决策树进行预测"""
    if not isinstance(tree, dict):
        return tree
    
    feature_name = next(iter(tree))
    if feature_names is not None:
        feature_idx = np.where(feature_names == feature_name)[0][0]
    else:
        feature_idx = int(feature_name)
    
    feature_value = sample[feature_idx]
    subtree = tree[feature_name].get(feature_value, None)
    
    if subtree is None:
        return None
    
    return predict(subtree, sample, feature_names)

为了更直观地理解决策树的结构,我们可以使用文本方式打印树:

def print_tree(tree, indent=""):
    """打印决策树结构"""
    if not isinstance(tree, dict):
        print(indent + "预测结果: " + str(tree))
        return
    
    feature_name = next(iter(tree))
    print(indent + feature_name + ":")
    
    for value, subtree in tree[feature_name].items():
        print(indent + "  " + str(value) + " ->")
        print_tree(subtree, indent + "    ")

5. 实战案例:鸢尾花分类

让我们用一个实际案例来测试我们的决策树实现。使用经典的鸢尾花数据集:

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

# 加载数据
iris = load_iris()
X = iris.data
y = iris.target
feature_names = iris.feature_names

# 划分训练测试集
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# 构建决策树
tree = create_decision_tree(X_train, y_train, feature_names)

# 打印树结构
print_tree(tree)

# 评估模型
correct = 0
for sample, label in zip(X_test, y_test):
    pred = predict(tree, sample, feature_names)
    if pred == label:
        correct += 1

accuracy = correct / len(y_test)
print(f"测试准确率: {accuracy:.2f}")

6. 常见问题与优化策略

在实际实现决策树时,有几个关键点需要注意:

  1. 连续值处理 :ID3算法原本只支持离散特征,对于连续值需要先进行离散化处理
  2. 过拟合问题 :决策树容易过拟合,可以通过剪枝策略来优化
  3. 缺失值处理 :需要考虑如何处理特征缺失的样本
  4. 特征重要性 :可以通过计算特征在决策树中被使用的频率来评估特征重要性

对于连续特征的处理,我们可以添加一个离散化步骤:

def discretize_continuous_feature(feature_values, n_bins=3):
    """将连续特征离散化为n_bins个区间"""
    percentiles = np.linspace(0, 100, n_bins + 1)[1:-1]
    thresholds = np.percentile(feature_values, percentiles)
    discretized = np.digitize(feature_values, thresholds)
    return discretized

决策树的剪枝策略可以分为预剪枝和后剪枝两种:

  • 预剪枝 :在树构建过程中提前停止分裂
  • 后剪枝 :先构建完整树,然后自底向上剪枝

一个简单的预剪枝实现可以基于信息增益阈值:

def create_tree_with_pruning(features, labels, min_gain=0.01):
    """带预剪枝的决策树构建"""
    if len(np.unique(labels)) == 1:
        return labels[0]
    
    best_feature_idx = choose_best_feature(features, labels)
    gain = calc_info_gain(features, labels, best_feature_idx)
    
    # 如果信息增益小于阈值,停止分裂
    if gain < min_gain:
        return np.argmax(np.bincount(labels))
    
    tree = {best_feature_idx: {}}
    unique_values = np.unique(features[:, best_feature_idx])
    
    for value in unique_values:
        mask = features[:, best_feature_idx] == value
        sub_features = np.delete(features[mask], best_feature_idx, axis=1)
        sub_labels = labels[mask]
        
        tree[best_feature_idx][value] = create_tree_with_pruning(
            sub_features, sub_labels, min_gain
        )
    
    return tree

7. 决策树的优势与局限性

决策树算法具有以下优势:

  • 直观易懂 :决策过程可视化,容易解释
  • 数据准备简单 :不需要特征缩放,能处理混合类型数据
  • 非线性关系 :可以捕捉特征间的非线性关系

但同时也有其局限性:

  • 容易过拟合 :特别是当树很深时
  • 不稳定性 :数据的小变化可能导致完全不同的树结构
  • 局部最优 :基于贪心算法,不一定得到全局最优树

在实际项目中,决策树常常作为基础组件用于构建更强大的集成模型,如随机森林和梯度提升决策树(GBDT)。

更多推荐