TWISTER模块

论文《LEARNING TRANSFORMER-BASED WORLD MODELS WITH CONTRASTIVE PREDICTIVE CODING》
论文地址: https://openreview.net/forum?id=YK9G4Htdew
深度交流Q裙:1051849847
全网同名 【大嘴带你水论文】 B站定时发布详细讲解视频
详细代码见文章最后

1、作用

TWISTER (Transformer-based World model wIth contraSTivE Representations) 是一个基于Transformer的世界模型,它利用动作条件下的对比预测编码(Contrastive Predictive Coding)来学习高级别的时间特征表示,从而显著提升智能体在复杂环境中的性能。该模型在Atari 100k基准测试中取得了162%的人类归一化平均分,创下了在不使用前瞻搜索的方法中的新纪录。

图1. TWISTER结构图

2、机制

1、Transformer状态空间模型 (TSSM)
TWISTER的核心是一个Transformer状态空间模型(TSSM),它使用掩码自注意力机制来预测未来的随机状态。与基于RNN的模型相比,TSSM能够更有效地处理长距离依赖关系,从而更准确地模拟世界。

2、动作条件下的对比预测编码 (AC-CPC)
为了学习更深层次的环境理解,TWISTER引入了动作条件下的对比预测编码(AC-CPC)。该机制通过最大化模型状态与未来随机状态之间的互信息来学习时间特征表示。这种方法迫使模型不仅预测下一个状态,还要理解更长时间范围内的动态变化。

3、Actor-Critic学习
与Dreamer系列算法类似,TWISTER也在学习到的世界模型之上训练一个Actor网络和一个Critic网络。这两个网络在由世界模型生成的想象轨迹上进行训练,以学习最大化预期未来奖励总和的策略。

3、代码

1. 计算机视觉领域
轻量化视觉识别:QSD-Transformer特别适合移动端和边缘设备
实时目标检测:脉冲神经网络的稀疏计算大幅降低功耗
视频序列分析:TWISTER的时序建模能力解决视频理解问题

2. 强化学习领域
模型基RL:TWISTER直接提升世界模型的学习效率
游戏AI与机器人控制:在资源受限环境下实现智能决策
自动驾驶:基于视觉的实时决策系统

3. 边缘计算与IoT
智能传感器网络:超低功耗的本地AI处理
工业4.0:实时质量检测和预测性维护
医疗设备:便携式诊断和监护设备

4. 神经形态计算
脑启发芯片:QSD-Transformer天然适配神经形态硬件
极低功耗AI:突破传统计算的能效瓶颈

5. 跨模态学习
多模态融合:视觉-语言-传感器数据的联合建模
时序数据分析:金融、气象、生物信号等时间序列预测

import torch
import torch.nn as nn
from torch.distributions import Normal, kl_divergence

class TSSM(nn.Module):
    """Transformer状态空间模型 (TSSM)"""
    def __init__(self, num_actions, stoch_size=32, hidden_size=512, num_blocks=4, **kwargs):
        super().__init__()
        self.num_actions = num_actions
        self.stoch_size = stoch_size
        self.hidden_size = hidden_size
        self.action_mixer = nn.Sequential(
            nn.Linear(stoch_size + num_actions, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.GELU()
        )
        self.transformer = TransformerNetwork(hidden_size, num_blocks=num_blocks, **kwargs)
        self.dynamics_predictor = nn.Sequential(
            nn.Linear(hidden_size, stoch_size),
        )

    def get_stoch(self, sample):
        return sample.reshape(*sample.shape[:-2], self.stoch_size, 32).softmax(-1).reshape(*sample.shape[:-2], -1)

    def initial(self, batch_size):
        return torch.zeros(batch_size, self.stoch_size, device=self.parameters().__next__().device)

    def observe(self, embed, action, is_first, state=None):
        swap = lambda x: x.permute(1, 0, *range(2, len(x.shape)))
        embed, action, is_first = swap(embed), swap(action), swap(is_first)
        post, prior = self.observe_step(embed, action, is_first, state)
        post = {k: swap(v) for k, v in post.items()}
        prior = {k: swap(v) for k, v in prior.items()}
        return post, prior

    def imagine(self, action, state=None):
        swap = lambda x: x.permute(1, 0, *range(2, len(x.shape)))
        action = swap(action)
        state = self.imagine_step(action, state)
        state = {k: swap(v) for k, v in state.items()}
        return state

    def get_feat(self, state):
        return state['stoch']

    def get_dist(self, state, dtype=None):
        return Normal(state['mean'], state['std'])

    def slice_hidden(self, hidden, h_slice):
        return hidden[:, :, h_slice]

    def get_hidden_len(self, hidden):
        return hidden.shape[-1]

    def forward_img(self, h, z):
        x = self.action_mixer(torch.cat([z, torch.zeros_like(action)], dim=-1))
        h, x = self.transformer(h, x)
        return h, x

    def forward_obs(self, h, z, action):
        x = self.action_mixer(torch.cat([z, action], dim=-1))
        h, x = self.transformer(h, x)
        return h, x

    def forward(self, h, z, action):
        x = self.action_mixer(torch.cat([z, action], dim=-1))
        h, x = self.transformer(h, x)
        logits = self.dynamics_predictor(x)
        return h, logits

    def set_inference_mode(self):
        """为了与示例脚本对齐,添加一个虚拟的推理模式切换方法。"""
        # 此模型的前向传播在训练和评估模式下行为一致
        pass

if __name__ == '__main__':
    # 模型参数
    num_actions = 4
    stoch_size = 32
    hidden_size = 512
    num_blocks = 4
    batch_size = 2
    seq_length = 10

    # 创建模型
    model = TSSM(num_actions=num_actions, stoch_size=stoch_size, hidden_size=hidden_size, num_blocks=num_blocks)

    # 创建输入
    h = torch.randn(batch_size, seq_length, hidden_size)
    z = torch.randn(batch_size, seq_length, stoch_size)
    action = torch.randn(batch_size, seq_length, num_actions)

    # 训练模式
    _, logits_train = model(h, z, action)

    # 推理模式
    model.set_inference_mode()
    _, logits_inference = model(h, z, action)

    # 打印结果
    print('输入 (h) 尺寸:', h.size())
    print('输入 (z) 尺寸:', z.size())
    print('输入 (action) 尺寸:', action.size())
    print('训练输出 (logits) 尺寸:', logits_train.size())
    print('推理输出 (logits) 尺寸:', logits_inference.size())
    print(f"参数数量: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")

详细代码 gitcode地址:https://gitcode.com/2301_80107842/research

Logo

惟楚有才,于斯为盛。欢迎来到长沙!!! 茶颜悦色、臭豆腐、CSDN和你一个都不能少~

更多推荐