告别巨型Q表!用PyTorch手把手实现价值函数逼近(VFA),搞定CartPole游戏

当你在Gymnasium的CartPole环境中第一次尝试Q-Learning时,是否曾被那个不断膨胀的Q表格吓到?状态空间稍微复杂些,内存占用就会指数级增长。这就是传统表格型强化学习方法的致命伤——维度灾难(Curse of Dimensionality)。今天我们将用PyTorch实现价值函数逼近(Value Function Approximation),用神经网络这个万能函数逼近器来替代笨重的Q表格。

1. 为什么需要价值函数逼近

在经典CartPole问题中,小车的状态由四个连续变量构成:

  • 小车位置(x)
  • 小车速度(v)
  • 杆角度(θ)
  • 杆角速度(ω)

如果用离散化方法处理,假设每个维度分成20个区间,动作空间有2个动作(左/右),那么Q表大小将是:

20^4 * 2 = 320,000 个条目

这种存储方式存在三个致命缺陷:

  1. 内存爆炸:状态维度增加时,存储需求呈指数增长
  2. 泛化性差:相似状态无法共享经验
  3. 效率低下:查表操作在连续空间变得极其低效

函数逼近的核心思想是用参数化函数$Q(s,a;w)$代替Q表,其中w是可训练参数。PyTorch实现的优势在于:

  • 自动微分简化梯度计算
  • GPU加速提升训练速度
  • 灵活的神经网络架构设计

提示:VFA不仅适用于离散动作空间,稍加修改就能扩展到连续控制问题

2. 环境准备与特征工程

首先建立我们的实验环境:

import gymnasium as gym
import torch
import torch.nn as nn

env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]  # 4
action_dim = env.action_space.n  # 2

对于线性逼近器,特征设计至关重要。我们采用多项式特征增强表现力:

def polynomial_features(state, degree=2):
    """将4维状态转换为多项式特征"""
    x, v, theta, omega = state
    features = [
        1, x, v, theta, omega,
        x*v, x*theta, x*omega,
        v*theta, v*omega,
        theta*omega,
        x**2, v**2, theta**2, omega**2
    ]
    return torch.FloatTensor(features)

这种特征工程比原始状态更适合线性模型捕捉非线性关系。不同特征处理方式对比:

特征类型 维度 优点 缺点
原始状态 4 简单直接 无法捕捉非线性
多项式特征 15 增强非线性能力 维度增长快
神经网络 自定义 自动学习特征 需要更多数据

3. 构建PyTorch逼近器

我们实现两种典型的函数逼近器:

3.1 线性逼近器

class LinearVFA(nn.Module):
    def __init__(self, feature_dim, action_dim):
        super().__init__()
        self.weights = nn.Linear(feature_dim, action_dim, bias=False)
        
    def forward(self, features):
        return self.weights(features)

3.2 神经网络逼近器

class NeuralVFA(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_size=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, action_dim)
        )
        
    def forward(self, state):
        return self.net(state)

关键参数初始化技巧:

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0.01)

model.apply(init_weights)  # 应用Xavier初始化

4. 训练流程实现

完整的训练循环包含这些关键步骤:

  1. 经验收集:使用ε-greedy策略与环境交互
def select_action(state, epsilon):
    if random.random() < epsilon:
        return env.action_space.sample()
    else:
        with torch.no_grad():
            q_values = model(state)
            return q_values.argmax().item()
  1. TD目标计算:实现Sarsa风格的更新
current_q = model(current_state)[action]
next_q = model(next_state)[next_action]  # Sarsa风格
target = reward + gamma * next_q * (1 - done)
loss = F.mse_loss(current_q, target)
  1. 参数更新:PyTorch标准优化流程
optimizer.zero_grad()
loss.backward()
optimizer.step()

完整的训练参数配置:

参数 推荐值 作用
γ (gamma) 0.99 折扣因子
ε初始值 1.0 探索率
ε衰减率 0.995 线性衰减
学习率 0.001 Adam优化器
批次大小 32 每次更新样本数
隐藏层大小 64 神经网络宽度

5. 效果对比与调优

经过2000轮训练后,两种方法的性能对比:

指标 线性VFA 神经网络VFA
收敛步数 ~800 ~400
最高得分 200 500
训练速度 快(1x) 慢(3x)
稳定性 需要调参

常见问题解决方案:

  • 震荡不收敛:减小学习率,增加批次大小
  • 得分卡在200:调整γ接近1,增强长期回报考虑
  • 探索不足:采用ε衰减策略,如:
    epsilon = max(0.01, epsilon * 0.995)  # 指数衰减
    

进阶技巧:

  • 使用经验回放打破数据相关性
  • 实现Double DQN减少过高估计
  • 添加优先级采样提升重要经验利用率
# 示例优先级回放缓冲区
class PriorityBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.priorities = np.zeros(capacity)
        
    def add(self, experience, priority):
        if len(self.buffer) < self.capacity:
            self.buffer.append(experience)
        else:
            idx = np.argmin(self.priorities)
            self.buffer[idx] = experience
        self.priorities[len(self.buffer)-1] = priority

在实际测试中,当杆子快要倒下时,神经网络VFA能捕捉到更细微的状态变化。比如当θ>0.2且ω>0.5时,模型会强烈建议向相反方向移动,而线性模型对这种非线性关系的反应要迟钝许多。

Logo

免费领 100 小时云算力,进群参与显卡、AI PC 幸运抽奖

更多推荐