从零构建ID3决策树:用Python实现经典分类算法

决策树是机器学习中最直观的算法之一,它模拟人类做决策的过程,通过一系列规则对数据进行分类。ID3算法作为决策树家族的早期成员,以其简洁的理论基础和清晰的构建逻辑,成为入门机器学习的绝佳选择。本文将抛开数学公式的抽象表达,带你用纯代码理解信息增益、节点分裂等核心概念,最终实现一个可处理真实数据集的分类器。

1. 环境准备与数据理解

在开始编写决策树之前,我们需要明确两个关键点:开发环境和数据格式。Python的科学计算栈为我们提供了必要工具:

import numpy as np
from collections import Counter

决策树处理的数据通常是二维表格形式。以经典的鸢尾花数据集为例,每行代表一个样本,前几列是特征(如花瓣长度、花萼宽度),最后一列是类别标签。在代码中,我们用NumPy数组表示:

# 示例数据结构
features = np.array([
    [5.1, 3.5, 1.4],  # 样本1特征
    [4.9, 3.0, 1.4],  # 样本2特征
    # ...更多样本
])
labels = np.array([0, 0, 1, 1])  # 对应类别

注意:ID3算法要求离散型特征。若使用连续值特征,需要先进行离散化处理(如等宽分箱)。

2. 信息论基础实现

ID3算法的核心是信息增益,这需要我们先实现信息熵的计算。信息熵度量了数据的混乱程度:

def entropy(labels):
    """计算标签的信息熵"""
    counts = Counter(labels)
    probs = [count / len(labels) for count in counts.values()]
    return -sum(p * np.log2(p) for p in probs if p > 0)

理解这个函数的关键点:

  • Counter 统计每个类别出现的次数
  • 列表推导式计算各类别概率
  • 最后求和时忽略零概率项(因为lim p→0 p log p = 0)

接下来实现条件熵,它表示在已知某个特征条件下标签的不确定性:

def conditional_entropy(features, labels, feature_idx):
    """计算指定特征的条件熵"""
    feature_values = features[:, feature_idx]
    total = len(labels)
    cond_entropy = 0.0
    
    for value in set(feature_values):
        mask = feature_values == value
        sub_labels = labels[mask]
        weight = len(sub_labels) / total
        cond_entropy += weight * entropy(sub_labels)
    
    return cond_entropy

信息增益就是熵与条件熵的差值:

def information_gain(features, labels, feature_idx):
    """计算指定特征的信息增益"""
    return entropy(labels) - conditional_entropy(features, labels, feature_idx)

3. 决策树构建过程

有了信息增益的计算能力,我们就可以开始构建决策树了。树的每个节点需要存储以下信息:

  • 如果是叶节点:类别标签
  • 如果是内部节点:划分特征及其分支
def find_best_split(features, labels):
    """找到信息增益最大的特征"""
    gains = [information_gain(features, labels, i) 
             for i in range(features.shape[1])]
    return np.argmax(gains)

递归构建决策树的核心逻辑:

def build_tree(features, labels, depth=0, max_depth=10):
    # 终止条件1:所有样本属于同一类别
    if len(set(labels)) == 1:
        return labels[0]
    
    # 终止条件2:没有特征可用或达到最大深度
    if features.shape[1] == 0 or depth >= max_depth:
        return Counter(labels).most_common(1)[0][0]
    
    # 选择最佳分裂特征
    best_feature = find_best_split(features, labels)
    tree = {'feature': best_feature, 'branches': {}}
    
    # 按特征值划分数据集
    feature_values = features[:, best_feature]
    for value in set(feature_values):
        mask = feature_values == value
        sub_features = np.delete(features[mask], best_feature, axis=1)
        sub_labels = labels[mask]
        
        # 递归构建子树
        tree['branches'][value] = build_tree(
            sub_features, sub_labels, depth+1, max_depth)
    
    return tree

提示:实际应用中应添加预剪枝逻辑,如设置最小样本数、信息增益阈值等,防止过拟合。

4. 决策树的预测与应用

构建好的决策树是一个嵌套字典,预测时需要从根节点开始遍历:

def predict(tree, sample):
    """使用决策树预测单个样本"""
    if not isinstance(tree, dict):
        return tree  # 到达叶节点
    
    feature_value = sample[tree['feature']]
    if feature_value not in tree['branches']:
        return None  # 处理未见过的特征值
    
    return predict(tree['branches'][feature_value], sample)

测试整个流程:

# 示例:西瓜数据集
features = np.array([
    ['青绿', '蜷缩', '浊响'],
    ['乌黑', '蜷缩', '沉闷'],
    # ...更多样本
])
labels = np.array(['好瓜', '好瓜', '坏瓜', '坏瓜'])

tree = build_tree(features, labels)
test_sample = ['青绿', '稍蜷', '浊响']
print(predict(tree, test_sample))  # 输出预测类别

5. 算法优化与实用技巧

基础ID3实现有几个可以改进的地方:

  1. 连续值处理
def discretize_continuous(feature_col, n_bins=5):
    """将连续特征离散化为n_bins个区间"""
    bins = np.linspace(min(feature_col), max(feature_col), n_bins+1)
    return np.digitize(feature_col, bins[:-1])
  1. 缺失值处理策略
  • 填充该特征的最常见值
  • 按照当前节点样本的比例分配
  1. 可视化决策树 (需要graphviz库):
from graphviz import Digraph

def visualize_tree(tree, dot=None, parent=None, edge_label=None):
    if dot is None:
        dot = Digraph()
    
    node_id = str(id(tree))
    if isinstance(tree, dict):
        dot.node(node_id, f"Feature {tree['feature']}")
        for value, branch in tree['branches'].items():
            visualize_tree(branch, dot, node_id, str(value))
    else:
        dot.node(node_id, f"Class: {tree}")
    
    if parent is not None:
        dot.edge(parent, node_id, label=edge_label)
    
    return dot

6. 从ID3到C4.5的演进

虽然我们实现了ID3,但了解其改进版C4.5的特性也很重要:

特性 ID3 C4.5
分裂标准 信息增益 信息增益比
连续值处理 不支持 自动离散化
缺失值处理 不支持 支持
剪枝方式 悲观剪枝
多叉树

实现信息增益比:

def split_info(features, feature_idx):
    """计算特征的固有信息(用于信息增益比)"""
    feature_values = features[:, feature_idx]
    counts = Counter(feature_values)
    probs = [count / len(feature_values) for count in counts.values()]
    return -sum(p * np.log2(p) for p in probs if p > 0)

def gain_ratio(features, labels, feature_idx):
    """计算信息增益比"""
    gain = information_gain(features, labels, feature_idx)
    si = split_info(features, feature_idx)
    return gain / si if si != 0 else 0

在实际项目中,我通常会在数据预处理阶段做好特征工程,包括处理缺失值、离散化连续特征等。对于小型数据集,决策树的训练速度很快,但要注意控制树的深度防止过拟合。当特征数量很多时,可以先用随机森林确定特征重要性,再用重要特征构建单个决策树提高可解释性。

更多推荐