别再死记硬背Sarsa公式了!用Python手搓一个走迷宫AI,5分钟搞懂On-Policy和Off-Policy的区别
用Python实现迷宫AI:5分钟可视化理解Sarsa与Q-learning的本质差异
当你第一次接触强化学习时,是否曾被各种算法术语弄得晕头转向?On-Policy和Off-Policy的区别听起来像天书,而Sarsa和Q-learning的公式对比更是让人望而生畏。今天,我们将打破这种枯燥的学习方式——用不到50行Python代码构建一个会自主探索迷宫的AI,通过实时可视化的路径选择,让你亲眼见证不同策略算法在实际决策中的行为差异。
1. 从零搭建迷宫训练场
在开始算法实现前,我们需要一个合适的训练环境。使用Python的 pygame 库可以快速创建交互式网格世界:
import numpy as np
import pygame
class MazeEnv:
def __init__(self, size=5):
self.size = size
self.grid = np.zeros((size, size))
self.start = (0, 0)
self.goal = (size-1, size-1)
self.obstacles = [(1, 1), (2, 3), (3, 1)] # 障碍物位置
self.state = self.start
def reset(self):
self.state = self.start
return self.state
def step(self, action):
x, y = self.state
if action == 0: x = max(0, x-1) # 上
elif action == 1: x = min(self.size-1, x+1) # 下
elif action == 2: y = max(0, y-1) # 左
elif action == 3: y = min(self.size-1, y+1) # 右
if (x, y) in self.obstacles:
return self.state, -10, False # 撞到障碍物
self.state = (x, y)
done = (self.state == self.goal)
reward = 100 if done else -1 # 每步小惩罚,到达终点大奖
return self.state, reward, done
这个5x5的迷宫中设置了三个障碍物,AI智能体需要从左上角(0,0)移动到右下角(4,4)。环境设计有两个关键点:
- 稀疏奖励 :只有到达终点才有正奖励,其他情况都是负奖励
- 即时惩罚 :撞到障碍物会获得-10的大惩罚,每走一步也有-1的小惩罚
提示:这种奖励结构能有效防止AI在原地打转,鼓励其尽快找到目标
2. Sarsa算法实现与可视化
Sarsa作为典型的On-Policy算法,其核心特点是 学习当前策略下的动作价值 。让我们用Python实现一个表格版的Sarsa:
def sarsa(env, episodes=1000, alpha=0.1, gamma=0.9, epsilon=0.1):
Q = np.zeros((env.size, env.size, 4)) # 状态-动作价值表
for _ in range(episodes):
state = env.reset()
action = epsilon_greedy(Q, state, epsilon)
done = False
while not done:
next_state, reward, done = env.step(action)
next_action = epsilon_greedy(Q, next_state, epsilon)
# Sarsa更新公式
Q[state][action] += alpha * (reward + gamma*Q[next_state][next_action] - Q[state][action])
state, action = next_state, next_action
return Q
def epsilon_greedy(Q, state, epsilon):
if np.random.random() < epsilon:
return np.random.randint(4) # 随机探索
return np.argmax(Q[state]) # 选择当前最优动作
关键实现细节:
- 五元组更新 :(s, a, r, s', a')的完整序列才能更新Q值
- 策略一致性 :选择动作和更新动作使用相同的ε-greedy策略
- 保守探索 :始终考虑下一步实际会执行的动作来更新当前值
运行算法后,我们可以观察到AI在迷宫中典型的探索路径:
路径示例:
■ → ■ → ■ → ■ → ■
■ → X ■ → X ■
■ → ■ → ■ → X ■
■ → X ■ → ■ → ■
■ → ■ → ■ → ■ → ★
注意:Sarsa通常会绕开障碍物走更安全的路线,即使路径稍长
3. Q-learning的Off-Policy特性对比
与Sarsa不同,Q-learning作为Off-Policy算法,其特点是 学习最优策略的价值函数 。实现差异主要体现在更新规则:
def q_learning(env, episodes=1000, alpha=0.1, gamma=0.9, epsilon=0.1):
Q = np.zeros((env.size, env.size, 4))
for _ in range(episodes):
state = env.reset()
done = False
while not done:
action = epsilon_greedy(Q, state, epsilon)
next_state, reward, done = env.step(action)
# Q-learning更新公式
Q[state][action] += alpha * (reward + gamma*np.max(Q[next_state]) - Q[state][action])
state = next_state
return Q
两种算法的核心差异可以用以下表格清晰对比:
| 特性 | Sarsa (On-Policy) | Q-learning (Off-Policy) |
|---|---|---|
| 更新目标 | 当前策略下的Q值 | 最优Q值 |
| 动作选择一致性 | 学习与执行策略相同 | 学习与执行策略可不同 |
| 探索风险 | 更保守,规避危险路径 | 更激进,可能靠近危险 |
| 适用场景 | 需要安全探索的环境 | 追求最优解的环境 |
| 更新公式关键差异 | 使用下一步实际动作a' | 使用下一步最大Q值动作 |
在相同迷宫环境中,Q-learning的典型路径表现更为直接:
路径示例:
■ → ■ → ■ → ■ → ■
■ → X ■ ↓ X ■
■ ↓ X ■ ↓ X ■
■ ↓ X ■ → ■ → ■
■ → ■ → ■ → ■ → ★
4. 策略差异的底层原理剖析
为什么两种算法会产生不同的路径选择?这要从它们的数学本质来分析:
Sarsa的更新公式 :
Q(s,a) ← Q(s,a) + α[r + γQ(s',a') - Q(s,a)]
其中a'是根据当前ε-greedy策略选择的实际下一步动作。这意味着:
- 会考虑探索带来的风险(如靠近障碍物时可能随机撞上)
- 学习的是 执行策略的真实价值 ,包含探索的不确定性
Q-learning的更新公式 :
Q(s,a) ← Q(s,a) + α[r + γmaxₐ'Q(s',a') - Q(s,a)]
它总是假设下一步会采取最优动作:
- 忽略实际探索中可能的风险动作
- 直接学习 最优策略的价值 ,不考虑探索过程
这种差异在悬崖网格问题中表现尤为明显。假设我们的迷宫有一个"悬崖"区域,跌落会有巨大惩罚:
- Sarsa会学习绕远路的安全路径
- Q-learning会学习贴悬崖走的最短路径(但实际探索时可能跌落)
# 悬崖迷宫环境设置示例
cliff_env = MazeEnv(size=4)
cliff_env.obstacles = [(i, 1) for i in range(1, 3)] # 第2列为悬崖
cliff_env.goal = (3, 1) # 目标在悬崖尽头
5. 工程实践中的选择建议
在实际项目中如何选择这两种算法?以下是我的经验总结:
优先使用Sarsa的场景 :
- 安全关键型应用(如机器人控制)
- 探索成本高昂的环境
- 需要稳定策略的在线学习系统
优先使用Q-learning的场景 :
- 追求最优性能的场景
- 可以安全模拟的环境
- 结合经验回放(experience replay)的深度强化学习
一个实用的技巧是 动态调整ε值 :训练初期用较高探索率(如ε=0.3),后期逐渐降低(如ε=0.01)。这在Sarsa中尤为重要,可以平衡探索与利用:
def dynamic_epsilon(episode, total_episodes):
return max(0.01, 0.3 * (1 - episode/total_episodes))
最后分享一个调试技巧:可视化Q值矩阵能直观理解AI的决策逻辑。以下代码可以绘制每个状态下各动作的Q值:
def plot_q_values(Q):
directions = ['↑', '↓', '←', '→']
for i in range(Q.shape[0]):
for j in range(Q.shape[1]):
best_action = np.argmax(Q[i,j])
print(f"{directions[best_action]}:{Q[i,j,best_action]:.1f}", end=' ')
print()
输出示例:
↑:0.0 →:5.1 →:7.3 →:9.8 →:0.0
↓:0.0 X →:0.0 →:6.2 →:0.0
↓:3.1 →:4.5 →:0.0 X →:0.0
↓:0.0 X ↓:2.3 →:5.7 →:8.2
→:0.0 →:0.0 →:0.0 →:0.0 ★
更多推荐

所有评论(0)