【RL Latest Tech】分层强化学习:Option-Critic架构算法
分层强化学习(Hierarchical Reinforcement Learning, HRL)通过将复杂问题分解为更小的子问题,显著提高了强化学习算法在解决高维状态空间和长期目标任务中的效率。Option-Critic架构是分层强化学习中一种非常有影响力的方法,专门用于自动发现和优化子策略(称为“Option”)。它是在经典的Options框架基础上提出的,用来处理分层决策问题,特别是可以在没有
📢本篇文章是博主强化学习RL领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对相关等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅解。文章分类在👉强化学习专栏:
【强化学习】(22)---《分层强化学习:Option-Critic架构算法》
分层强化学习:Option-Critic架构算法
目录
分层强化学习(Hierarchical Reinforcement Learning, HRL)通过将复杂问题分解为更小的子问题,显著提高了强化学习算法在解决高维状态空间和长期目标任务中的效率。Option-Critic架构是分层强化学习中一种非常有影响力的方法,专门用于自动发现和优化子策略(称为“Option”)。它是在经典的Options框架基础上提出的,用来处理分层决策问题,特别是可以在没有明确的子目标定义的情况下自动学习子策略。
1. 基本概念
在Option-Critic架构中,最核心的思想是使用 “选项” 来建模高级行为策略。每个选项代表一段策略或行为,负责特定的子任务。具体地说,选项包括三个部分:
- 初始条件(Initiation set):智能体可以选择该选项的状态集合。
- 内部策略(Intra-option policy):智能体在选项激活时的行动策略,即在特定状态下采取的行动。
- 终止条件(Termination condition):定义选项何时终止或结束。
在这种架构下,智能体不再直接学习每个状态下的单一动作,而是学习 何时选择选项以及如何在选项内行动。这使得学习更加抽象化和高效化。
2. Option-Critic框架的核心要素
2.1 选项(Option)
- 选项 是智能体可以执行的一个序列动作,这些动作组成一个子策略。选项不仅包含具体的操作步骤,还包含 何时开始选项 以及 何时终止选项。
- 在每一个时间步,智能体可以选择:执行一个基础动作或选择一个更高级别的选项。在选项内,智能体执行动作直到达到终止条件,然后选择新的选项。
2.2 Intra-Option Q-Learning
Option-Critic框架引入了 Intra-Option Q-Learning 算法,用于更新选项的内部策略。在传统的强化学习中,Q-learning用于评估在特定状态下选择某个动作的价值。而在Option-Critic中,Intra-Option Q-Learning则用于评估 在选项内 如何选择动作。具体步骤如下:
- 对于每个选项,有一个内部策略,它定义了在选项激活时的动作选择方式。
- Q函数代表在状态选择选项的期望回报。
- 当智能体在执行选项时,内部策略根据当前状态选择具体动作,并且根据Intra-Option Q-Learning规则更新Q值。
2.3 选项的终止条件
选项的终止条件决定了何时退出当前选项并返回上层策略。例如,一个选项可能在达到某个目标状态时终止,或者在经过一定的时间步数后自动终止。终止条件的优化同样是Option-Critic框架中的一个关键部分。
3. Option-Critic算法工作流程
Option-Critic算法通过以下步骤完成学习过程:
- 选项初始化:初始化选项的初始条件、内部策略和终止条件。
- 执行过程:
- 智能体基于当前状态和Q值函数选择某个选项。
- 执行选项的内部策略,在环境中采取具体行动,直到终止条件触发。
- Q值更新:使用Intra-Option Q-Learning更新每个选项的Q值函数。
- 策略更新:同时更新选项的内部策略和终止条件,使得智能体能够更好地进行决策。
- 重复以上过程,直至收敛。
[Python] Option-Critic实现
🔥若是下面代码复现困难或者有问题,欢迎评论区留言;需要以整个项目形式的代码,请在评论区留下您的邮箱📌,以便于及时分享给您(私信难以及时回复)。
Option-Critic架构在CartPole环境中的Python代码,使用了OpenAI Gym和PyTorch进行实现。
步骤:
- 设置 CartPole 环境。
- 定义选项的策略网络。
- 定义用于估计每个选项值的Q 网络。
- 实现选项内策略梯度和终止梯度。
- 创建经验回放缓冲区,用于存储和采样经验。
- 使用时间差分学习和策略梯度训练智能体。
- 测试和可视化智能体的表现。
算法训练代码
"""《Option-Critic实现》
时间:2024.10.01
环境:CartPole
作者:不去幼儿园
"""
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random
# 超参数调整
GAMMA = 0.98 # 稍微减少折扣因子,减少远期奖励影响
LEARNING_RATE = 0.0005 # 降低学习率,提高训练的稳定性
BATCH_SIZE = 128 # 增加批次大小以稳定训练
MEMORY_SIZE = 20000 # 增大经验回放缓冲区
EPSILON_DECAY = 0.99 # 减慢 epsilon 的衰减速度
MIN_EPSILON = 0.05 # 增大最小 epsilon,保持一定探索
NUM_OPTIONS = 2
NUM_EPISODES = 1000 # 增加训练回合数
# 修改后的网络结构
class CriticNetwork(nn.Module):
def __init__(self, state_dim, num_options):
super(CriticNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, 256) # 增加隐藏层神经元数量
self.fc2 = nn.Linear(256, num_options)
def forward(self, state):
x = torch.relu(self.fc1(state))
return self.fc2(x)
class IntraOptionPolicyNetwork(nn.Module):
def __init__(self, state_dim, num_options, action_dim):
super(IntraOptionPolicyNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, 256) # 增加神经元数量
self.fc2 = nn.Linear(256, num_options * action_dim)
self.num_options = num_options
self.action_dim = action_dim
def forward(self, state, option):
x = torch.relu(self.fc1(state))
policy_logits = self.fc2(x)
option_policy = policy_logits.view(-1, self.num_options, self.action_dim)
return option_policy[:, option, :]
class TerminationNetwork(nn.Module):
def __init__(self, state_dim, num_options):
super(TerminationNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, 256) # 增加神经元数量
self.fc2 = nn.Linear(256, num_options)
def forward(self, state):
x = torch.relu(self.fc1(state))
termination_probs = torch.sigmoid(self.fc2(x))
return termination_probs
# 经验回放缓冲区
class ReplayBuffer:
def __init__(self, capacity):
self.buffer = deque(maxlen=capacity)
def push(self, state, option, action, reward, next_state, done):
self.buffer.append((state, option, action, reward, next_state, done))
def sample(self, batch_size):
states, options, actions, rewards, next_states, dones = zip(*random.sample(self.buffer, batch_size))
return np.stack(states), options, actions, rewards, np.stack(next_states), dones
def size(self):
return len(self.buffer)
# 选择动作
def select_action(policy_net, state, option, epsilon):
if random.random() < epsilon:
return random.choice([0, 1]) # CartPole 动作空间为 2(0 或 1)
else:
state = torch.FloatTensor(state).unsqueeze(0)
action_probs = torch.softmax(policy_net(state, option), dim=-1)
return torch.argmax(action_probs).item()
# Option-Critic 智能体
class OptionCriticAgent:
def __init__(self, state_dim, action_dim, num_options):
self.policy_net = IntraOptionPolicyNetwork(state_dim, num_options, action_dim)
self.q_net = CriticNetwork(state_dim, num_options)
self.termination_net = TerminationNetwork(state_dim, num_options)
self.optimizer_policy = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
self.optimizer_q = optim.Adam(self.q_net.parameters(), lr=LEARNING_RATE)
self.optimizer_term = optim.Adam(self.termination_net.parameters(), lr=LEARNING_RATE)
self.epsilon = 1.0
self.num_options = num_options
self.memory = ReplayBuffer(MEMORY_SIZE)
def train(self, batch_size):
if self.memory.size() < batch_size:
return
states, options, actions, rewards, next_states, dones = self.memory.sample(batch_size)
states = torch.FloatTensor(states)
next_states = torch.FloatTensor(next_states)
rewards = torch.FloatTensor(rewards)
options = torch.LongTensor(options)
actions = torch.LongTensor(actions)
dones = torch.FloatTensor(dones)
# 更新 Q 函数
q_values = self.q_net(states)
next_q_values = self.q_net(next_states).detach()
target_q_values = rewards + GAMMA * next_q_values.max(1)[0] * (1 - dones)
loss_q = nn.functional.mse_loss(q_values.gather(1, options.unsqueeze(1)).squeeze(), target_q_values)
self.optimizer_q.zero_grad()
loss_q.backward()
self.optimizer_q.step()
# 更新选项内策略
for option in range(self.num_options):
policy_logits = self.policy_net(states, option)
action_probs = torch.softmax(policy_logits, dim=-1)
log_action_probs = torch.log(action_probs)
policy_loss = -log_action_probs.gather(1, actions.unsqueeze(1)).mean()
self.optimizer_policy.zero_grad()
policy_loss.backward()
self.optimizer_policy.step()
# 更新终止概率
terminations = self.termination_net(states)
termination_loss = nn.functional.binary_cross_entropy(terminations.gather(1, options.unsqueeze(1)).squeeze(),
dones)
self.optimizer_term.zero_grad()
termination_loss.backward()
self.optimizer_term.step()
def remember(self, state, option, action, reward, next_state, done):
self.memory.push(state, option, action, reward, next_state, done)
# 训练智能体
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = OptionCriticAgent(state_dim, action_dim, NUM_OPTIONS)
for episode in range(NUM_EPISODES):
state, _ = env.reset()
option = random.choice(range(NUM_OPTIONS)) # 随机选择一个选项
done = False
episode_reward = 0
while not done:
action = select_action(agent.policy_net, state, option, agent.epsilon)
next_state, reward, done, _, __ = env.step(action)
agent.remember(state, option, action, reward, next_state, done)
agent.train(BATCH_SIZE)
# 选项终止时选择新选项
if random.random() < agent.termination_net(torch.FloatTensor(state))[option].item():
option = random.choice(range(NUM_OPTIONS))
state = next_state
episode_reward += reward
agent.epsilon = max(MIN_EPSILON, agent.epsilon * EPSILON_DECAY)
print(f"Episode {episode + 1}: Total Reward: {episode_reward}")
env.close()
# Option-Critic 智能体
class OptionCriticAgent:
def __init__(self, state_dim, action_dim, num_options):
self.policy_net = IntraOptionPolicyNetwork(state_dim, num_options, action_dim)
self.q_net = CriticNetwork(state_dim, num_options)
self.termination_net = TerminationNetwork(state_dim, num_options)
self.optimizer_policy = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
self.optimizer_q = optim.Adam(self.q_net.parameters(), lr=LEARNING_RATE)
self.optimizer_term = optim.Adam(self.termination_net.parameters(), lr=LEARNING_RATE)
self.epsilon = 1.0
self.num_options = num_options
self.memory = ReplayBuffer(MEMORY_SIZE)
def train(self, batch_size):
if self.memory.size() < batch_size:
return
states, options, actions, rewards, next_states, dones = self.memory.sample(batch_size)
states = torch.FloatTensor(states)
next_states = torch.FloatTensor(next_states)
rewards = torch.FloatTensor(rewards)
options = torch.LongTensor(options)
actions = torch.LongTensor(actions)
dones = torch.FloatTensor(dones)
# 更新 Q 函数
q_values = self.q_net(states)
next_q_values = self.q_net(next_states).detach()
target_q_values = rewards + GAMMA * next_q_values.max(1)[0] * (1 - dones)
loss_q = nn.functional.mse_loss(q_values.gather(1, options.unsqueeze(1)).squeeze(), target_q_values)
self.optimizer_q.zero_grad()
loss_q.backward()
self.optimizer_q.step()
# 更新选项内策略
for option in range(self.num_options):
policy_logits = self.policy_net(states, option)
action_probs = torch.softmax(policy_logits, dim=-1)
log_action_probs = torch.log(action_probs)
policy_loss = -log_action_probs.gather(1, actions.unsqueeze(1)).mean()
self.optimizer_policy.zero_grad()
policy_loss.backward()
self.optimizer_policy.step()
# 更新终止概率
terminations = self.termination_net(states)
termination_loss = nn.functional.binary_cross_entropy(terminations.gather(1, options.unsqueeze(1)).squeeze(),
dones)
self.optimizer_term.zero_grad()
termination_loss.backward()
self.optimizer_term.step()
def remember(self, state, option, action, reward, next_state, done):
self.memory.push(state, option, action, reward, next_state, done)
# 在 CartPole 环境中训练智能体
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = OptionCriticAgent(state_dim, action_dim, NUM_OPTIONS)
for episode in range(NUM_EPISODES):
state, _ = env.reset()
option = random.choice(range(NUM_OPTIONS)) # 初始化为随机选项
done = False
episode_reward = 0
while not done:
action = select_action(agent.policy_net, state, option, agent.epsilon)
next_state, reward, done, _, __ = env.step(action)
agent.remember(state, option, action, reward, next_state, done)
agent.train(BATCH_SIZE)
if random.random() < agent.termination_net(torch.FloatTensor(state))[option].item():
option = random.choice(range(NUM_OPTIONS)) # 终止当前选项并选择新选项
state = next_state
episode_reward += reward
agent.epsilon = max(MIN_EPSILON, agent.epsilon * EPSILON_DECAY)
print(f"Episode {episode + 1}: Total Reward: {episode_reward}")
env.close()
测试和可视化代码
要在 CartPole 环境中测试 Option-Critic 模型并显示动画,需要利用 gym
库中的 render()
方法。以下代码演示了如何在训练完模型后进行测试,并实时显示动画。
import gym
import torch
# 测试 Option-Critic 模型并显示动画
def test_option_critic(agent, env, num_episodes=5):
for episode in range(num_episodes):
state, _ = env.reset()
option = random.choice(range(agent.num_options)) # 随机选择一个选项
done = False
episode_reward = 0
env.render() # 初始化渲染环境
while not done:
env.render() # 渲染环境,显示动画
action = select_action(agent.policy_net, state, option, epsilon=0.0) # 使用已学策略选择动作
next_state, reward, done, _, __ = env.step(action)
# 检查选项是否应终止,并在终止时重新选择新选项
if random.random() < agent.termination_net(torch.FloatTensor(state))[option].item():
option = random.choice(range(agent.num_options))
state = next_state
episode_reward += reward
print(f"测试 Episode {episode + 1}: Total Reward: {episode_reward}")
env.close()
# 创建 CartPole 环境并调用测试函数
env = gym.make('CartPole-v1', render_mode='human')
test_option_critic(agent, env)
[Notice] 注意事项
确保您使用的环境和库版本支持 render_mode
参数。如果使用的 Gym 版本较旧,可能不支持此选项。在这种情况下,建议更新 gym
库。
上述采用的是gym == 0.26.2
pip install --upgrade gym
注意 :env.step(action)的返回参数数量,多则删除__
由于博文主要为了介绍相关算法的原理和应用的方法,缺乏对于实际效果的关注,算法可能在上述环境中的效果不佳,一是算法不适配上述环境,二是算法未调参和优化,三是等。上述代码用于了解和学习算法足够了,但若是想直接将上面代码应用于实际项目中,还需要进行修改。
4. Option-Critic的优势
Option-Critic架构的优点主要体现在以下几方面:
- 层次化决策:通过分层的结构,智能体能够进行更高级别的决策,不必在每一步都从头开始学习。
- 策略的可重用性:通过选项,智能体可以学习可重复使用的策略片段,用于不同的任务场景。
- 提升效率:在复杂任务中,分层策略减少了动作空间的复杂度,使得学习过程更加高效。
5. 相关的挑战
尽管Option-Critic架构在理论上具备优势,但在实际应用中仍面临一些挑战:
- 选项的设计与优化:选择合适的选项数量和复杂度对于模型性能有很大影响。
- 探索与利用的平衡:在不同的选项之间,如何平衡探索新选项与利用已有的选项是一个难题。
- 终止条件的优化:如何合理地学习选项的终止条件,避免选项过早或过晚终止,同样是一个挑战。
6. 总结
Option-Critic架构通过引入 选项 这一中间层次,将复杂问题分解为多个子任务,并通过学习这些子任务的策略与终止条件来实现有效的分层强化学习。它是一种非常有前景的强化学习框架,适用于处理复杂、长期依赖的任务。
参考文献
完整项目代码:
【RL Latest Tech】分层强化学习:Option-Critic架构算法项目代码实现
Option-Critic架构通过整合选项策略和Q学习算法,提供了一种优雅的解决方案,使得智能体能够高效地在复杂环境中进行学习和决策。
文章若有不当和不正确之处,还望理解与指出。由于部分文字、图片等来源于互联网,无法核实真实出处,如涉及相关争议,请联系博主删除。如有错误、疑问和侵权,欢迎评论留言联系作者,或者关注VX公众号:Rain21321,联系作者。✨
更多推荐
所有评论(0)