Search-R1论文浅析与代码实现
·
R1论文浅析
R1论文通常指Reinforcement Learning with Imagined Goals(RIG)相关研究,核心思想是通过强化学习结合想象目标实现高效策略学习。关键点包括:
- 目标 conditioned 策略:策略网络输入包含状态与目标信息,输出动作以实现目标。
- 潜在空间目标生成:通过变分自编码器(VAE)学习状态的低维表示,在潜在空间中采样新目标。
- 离线数据利用:优先利用离线数据预训练模型,减少在线交互成本。
典型应用场景为机器人操作任务,如抓取、推物体等稀疏奖励环境。论文通过潜在空间的目标采样,显著提升探索效率。
代码实现关键模块
变分自编码器(VAE)构建
import tensorflow as tf
from tensorflow.keras.layers import Dense, Input
class VAE(tf.keras.Model):
def __init__(self, latent_dim):
super().__init__()
self.encoder = tf.keras.Sequential([
Input(shape=(state_dim,)),
Dense(256, activation='relu'),
Dense(128, activation='relu'),
Dense(latent_dim * 2) # 输出均值与对数方差
])
self.decoder = tf.keras.Sequential([
Input(shape=(latent_dim,)),
Dense(128, activation='relu'),
Dense(256, activation='relu'),
Dense(state_dim)
])
def reparameterize(self, mean, logvar):
eps = tf.random.normal(shape=mean.shape)
return eps * tf.exp(logvar * 0.5) + mean
目标 conditioned 策略网络
class PolicyNetwork(tf.keras.Model):
def __init__(self, action_dim):
super().__init__()
self.shared_layers = tf.keras.Sequential([
Dense(256, activation='relu'),
Dense(128, activation='relu')
])
self.mean = Dense(action_dim)
self.log_std = Dense(action_dim)
def call(self, state, goal):
x = tf.concat([state, goal], axis=-1)
x = self.shared_layers(x)
return self.mean(x), self.log_std(x)
训练流程要点
- VAE预训练:使用离线数据集训练VAE,最小化重建损失与KL散度: [ \mathcal{L}{\text{VAE}} = |x - \text{decode}(z)|^2 + \beta D{\text{KL}}(q(z|x) | p(z)) ]
- 策略优化:通过SAC算法更新策略网络,目标函数为: [ \mathcal{J}(\pi) = \mathbb{E}[Q(s,a,g) - \alpha \log \pi(a|s,g)] ]
- 目标重标记:从VAE潜在空间采样新目标$g'$,替换原始目标$g$以提升多样性。
实验调优建议
- 潜在空间维度:通常选择16-64维平衡表达力与训练难度。
- 奖励函数设计:采用目标与当前状态的负欧氏距离: [ r(s,g) = -|\text{encode}(s) - \text{encode}(g)|_2 ]
- 缓冲区管理:分离在线数据与离线数据存储,设置不同采样比例。
更多推荐


所有评论(0)