vLLM V1 Attention 模块超深度架构分析 — Part 1: 架构总览与核心调度

分析范围: vllm/v1/attention/ 目录全部源码(50+文件,约30,759行)


目录


第一章 模块定位与全局架构

1.1 业务职责与功能定位

v1/attention/ 模块是vLLM推理引擎的注意力计算核心,负责:

  1. 抽象注意力计算接口:定义统一的AttentionBackend接口,屏蔽底层实现差异
  2. 多后端调度:根据硬件平台、模型架构、序列特征自动选择最优后端
  3. KV Cache管理:管理PagedAttention的KV cache分页存储、写入、读取
  4. 高效注意力实现:集成FlashAttention、FlashInfer、Triton、MLA等高效实现
  5. 特殊架构支持:DeepSeek MLA(Multi-head Latent Attention)、TurboQuant、GDN等

1.2 在vLLM系统中的位置

Attention内部

vLLM V1 推理引擎

Scheduler
调度器

Worker
执行器

GPUModelRunner
模型运行器

Attention
注意力模块
★本分析对象★

Model Layers
模型层

Sampler
采样器

Backend Selector

Registry

Backends

Ops Kernels

数据流

  • 输入: Query/Key/Value张量 + AttentionMetadata(序列信息、KV cache位置)
  • 输出: 注意力计算结果张量 + 更新后的KV cache

1.3 目录结构与文件清单

v1/attention/
├── __init__.py                    # 空文件(包标记)
├── backend.py                     # ★核心:AttentionBackend抽象基类 + Attention层实现
├── selector.py                    # ★核心:后端自动选择逻辑
│
├── backends/                      # 后端实现目录
│   ├── __init__.py
│   ├── registry.py                # ★后端注册表(全局单例)
│   ├── utils.py                   # ★通用工具函数(900+行)
│   ├── flashinfer.py              # FlashInfer后端(1971行)
│   ├── flash_attn.py              # FlashAttention后端(1236行)
│   ├── flash_attn_diffkv.py       # 差异KV的FlashAttn(321行)
│   ├── triton_attn.py             # Triton后端(774行)
│   ├── cpu_attn.py                # CPU后端(545行)
│   ├── flex_attention.py          # FlexAttention后端(1217行)
│   ├── fa_utils.py                # FlashAttn工具(261行)
│   ├── gdn_attn.py                # GDN后端(475行)
│   ├── linear_attn.py             # 线性注意力(93行)
│   ├── mamba_attn.py              # Mamba后端(588行)
│   ├── mamba1_attn.py             # Mamba1后端(64行)
│   ├── mamba2_attn.py             # Mamba2后端(171行)
│   ├── short_conv_attn.py         # 短卷积注意力(34行)
│   ├── rocm_attn.py               # ROCm后端(541行)
│   ├── rocm_aiter_fa.py           # ROCm aiter FA(1506行)
│   ├── rocm_aiter_unified_attn.py # ROCm统一注意力(304行)
│   ├── turboquant_attn.py         # TurboQuant后端(906行)
│   │
│   └── mla/                       # DeepSeek MLA后端子目录
│       ├── __init__.py
│       ├── flashmla.py            # FlashMLA后端(334行)
│       ├── flashmla_sparse.py     # FlashMLA稀疏(1171行)
│       ├── flashinfer_mla.py      # FlashInfer MLA(209行)
│       ├── flashinfer_mla_sparse.py  # FlashInfer MLA稀疏(365行)
│       ├── flashattn_mla.py       # FlashAttn MLA(364行)
│       ├── cutlass_mla.py         # CUTLASS MLA(285行)
│       ├── triton_mla.py          # Triton MLA(205行)
│       ├── indexer.py             # MLA索引器(776行)
│       ├── compressor_utils.py    # 压缩工具(86行)
│       ├── aiter_triton_mla.py    # aiter Triton MLA(66行)
│       ├── rocm_aiter_mla.py      # ROCm aiter MLA(556行)
│       ├── rocm_aiter_mla_sparse.py  # ROCm aiter MLA稀疏(704行)
│       ├── sparse_swa.py          # 稀疏滑动窗口(493行)
│       ├── sparse_utils.py        # 稀疏工具(191行)
│       ├── xpu_mla_sparse.py      # XPU MLA稀疏(258行)
│       └── prefill/               # MLA Prefill子目录
│           ├── __init__.py
│           ├── base.py            # Prefill基类(121行)
│           ├── flash_attn.py      # FA Prefill(176行)
│           ├── flashinfer.py      # FI Prefill(222行)
│           ├── registry.py        # Prefill注册表(53行)
│           ├── selector.py        # Prefill选择器(183行)
│           └── trtllm_ragged.py   # TRT-LLM Prefill(174行)
│
└── ops/                           # 底层算子实现
    ├── __init__.py
    ├── common.py                  # 通用操作(482行)
    ├── prefix_prefill.py          # 前缀预填充(878行)
    ├── paged_attn.py              # 分页注意力(51行)
    ├── merge_attn_states.py       # 注意力状态合并(103行)
    ├── chunked_prefill_paged_decode.py  # 分块预填充+分页解码(469行)
    ├── dcp_alltoall.py            # DCP全到全通信(458行)
    ├── flashmla.py                # FlashMLA ops(153行)
    ├── vit_attn_wrappers.py       # ViT注意力包装(388行)
    ├── triton_attention_helpers.py  # Triton辅助(383行)
    ├── triton_decode_attention.py   # Triton解码注意力(782行)
    ├── triton_prefill_attention.py  # Triton预填充注意力(253行)
    ├── triton_reshape_and_cache_flash.py  # Triton reshape+cache(601行)
    ├── triton_unified_attention.py   # Triton统一注意力(748行)
    ├── triton_merge_attn_states.py   # Triton状态合并(175行)
    ├── triton_turboquant_decode.py   # TurboQuant解码(630行)
    ├── triton_turboquant_store.py    # TurboQuant存储(447行)
    ├── rocm_aiter_mla_sparse.py      # ROCm MLA稀疏(1129行)
    ├── xpu_mla_sparse.py             # XPU MLA稀疏(265行)
    └── deepseek_v4_ops/              # DeepSeek V4算子
        ├── __init__.py
        ├── cache_utils.py             # 缓存工具(563行)
        ├── fused_compress_quant_cache.py  # 融合压缩量化缓存(584行)
        ├── fused_indexer_q.py         # 融合Q索引器(400行)
        ├── fused_inv_rope_fp8_quant.py  # 融合inv-RoPE+FP8量化(314行)
        └── fused_qk_rmsnorm.py        # 融合QK RMSNorm(96行)

1.4 全局架构总图

底层算子(Ops)

MLA后端

标准后端

抽象接口层

调度层

selector.py
which_attn_to_run()

registry.py
_GLOBAL_REGISTRY

AttentionBackend
(ABC)

AttentionMetadata

AttentionMetadataBuilder
(ABC)

Attention
(nn.Module)

FlashInferBackend

FlashAttnBackend

TritonAttnBackend

CPUAttnBackend

FlexAttentionBackend

ROCmAttnBackend

FlashAttnDiffKVBackend

GDNAttnBackend

TurboQuantBackend

MambaAttnBackend

FlashMLABackend

FlashMLASparseBackend

FlashInferMLABackend

FlashInferMLASparseBackend

FlashAttnMLABackend

CUTLASSMLABackend

TritonMLABackend

ROCmAITerMLABackend

common.py
reshape_and_cache

prefix_prefill.py

chunked_prefill_paged_decode.py

triton_*_attention.py

deepseek_v4_ops/

vit_attn_wrappers.py

dcp_alltoall.py


第二章 AttentionBackend 抽象接口体系

2.1 AttentionBackend基类定义

backend.py 是整个attention模块的核心文件(1034行),定义了所有后端必须遵循的抽象接口。

class AttentionBackend(ABC):
    """注意力计算后端的抽象基类
    
    每个后端必须实现:
    1. get_name() → str: 返回后端名称(用于日志和注册)
    2. get_impl_cls() → type[AttentionImpl]: 返回实现类
    3. get_metadata_cls() → type[AttentionMetadata]: 返回元数据类
    4. get_builder_cls() → type[AttentionMetadataBuilder]: 返回构建器类
    5. get_kv_cache_shape() → tuple: 返回KV cache张量形状
    6. get_supported_head_sizes() → list[int]: 支持的注意力头维度
    7. get_attention_impl_cls() → type: 返回attention层实现类
    """

核心抽象方法逐行解析

@staticmethod
@abstractmethod
def get_name() -> str:
    """返回后端的人类可读名称
    例如: "FLASHINFER", "FLASH_ATTN", "TRITON"
    用于日志输出和注册表查找
    """
    raise NotImplementedError

@staticmethod
@abstractmethod
def get_impl_cls() -> type["AttentionImpl"]:
    """返回AttentionImpl的具体实现类
    
    AttentionImpl是实际执行注意力计算的类
    每个后端有自己的Impl(如FlashInferImpl、FlashAttnImpl)
    """
    raise NotImplementedError

@classmethod
def get_metadata_cls(cls) -> type[AttentionMetadata]:
    """返回AttentionMetadata的具体子类
    
    不同后端需要不同的元数据:
    - FlashInfer: 需要workspace张量、page索引
    - FlashAttn: 需要seq_lens、block_table
    - MLA: 需要compress索引
    
    默认返回基础AttentionMetadata
    """
    return AttentionMetadata

@classmethod  
def get_builder_cls(cls) -> type["AttentionMetadataBuilder"]:
    """返回AttentionMetadataBuilder的具体子类
    
    Builder负责从调度器输出构建元数据
    不同后端的构建逻辑差异很大:
    - FlashInfer: 需要构建wrapper和workspace
    - FlashAttn: 需要构建seq_lens和block_tables
    - MLA: 需要构建compress索引
    
    默认返回基础AttentionMetadataBuilder
    """
    return AttentionMetadataBuilder

@staticmethod
@abstractmethod
def get_kv_cache_shape(
    num_blocks: int,          # KV cache块数
    block_size: int,          # 每块token数
    num_kv_heads: int,        # KV头数
    head_size: int,           # 头维度
) -> tuple[int, ...]:
    """返回KV cache张量的形状
    
    不同后端的KV cache布局不同:
    - 标准: [2, num_blocks, block_size, num_kv_heads, head_size]
      2 = Key + Value
    - MLA: [num_blocks, block_size, kv_lora_rank]
      MLA将K/V压缩为单一latent向量
    - TurboQuant: 额外维度存储量化参数
    
    Returns:
      形状元组,用于预分配GPU内存
    """
    raise NotImplementedError

@staticmethod
def get_supported_head_sizes() -> list[int]:
    """返回此后端支持的注意力头维度列表
    
    FlashAttention: [32, 64, 96, 128, 256]
    Triton: 通常支持任意维度
    MLA: 由kv_lora_rank决定
    """
    return [32, 64, 96, 128, 256]

@classmethod
def get_attention_impl_cls(cls) -> type:
    """返回Attention层的实现类
    
    大多数后端使用Attention(定义在本文件中)
    特殊后端(如Mamba)有自己的实现
    """
    return Attention

@staticmethod
def validate_head_size(head_size: int, head_sizes: list[int] | None = None) -> None:
    """验证head_size是否被支持
    
    不支持则抛出ValueError
    """
    if head_sizes is None:
        head_sizes = cls.get_supported_head_sizes()
    if head_size not in head_sizes:
        raise ValueError(...)

2.2 AttentionMetadata元数据

@dataclass
class AttentionMetadata:
    """注意力计算的元数据
    
    每个forward步骤都会创建一个新的AttentionMetadata实例
    包含当前批次的所有序列信息
    """
    # ---- 基本序列信息 ----
    num_prefills: int                    # 预填充序列数
    num_decode_tokens: int               # 解码token数
    slot_mapping: torch.Tensor           # [num_tokens] KV cache slot索引
    seq_lens: torch.Tensor | None        # [batch_size] 序列长度
    seq_lens_tensor: torch.Tensor | None  # [batch_size] 序列长度(GPU tensor)
    
    # ---- 分块预填充 ----
    num_prefill_tokens: int              # 预填充token总数
    max_prefill_seq_len: int             # 最长预填充序列长度
    max_decode_seq_len: int              # 最长解码序列长度
    
    # ---- 批次信息 ----
    batch_size: int                      # 批次大小
    # 注意: batch_size = num_prefills + num_decodes(混合批次)
    
    # ---- 请求级信息 ----
    request_ids_to_seq_ids: dict[str, list[int]]  # 请求ID→序列ID映射
    prefill_seq_lens: list[int]          # 预填充序列长度列表
    decode_seq_lens: list[int]           # 解码序列长度列表
    
    # ---- KV Cache布局 ----
    block_tables: torch.Tensor | None    # [batch_size, max_blocks] 块表
    
    # ---- 伪代码生成 ----
    use_cuda_graph: bool                 # 是否使用CUDA Graph

slot_mapping 的设计

slot_mapping[i] = block_id * block_size + offset_in_block

对于PagedAttention:
  KV cache存储为 [num_blocks, block_size, num_heads, head_size]
  每个token写入位置 = slot_mapping[token_index]

预填充阶段:
  token 0 → slot 0*16+0 = 0    (block 0, offset 0)
  token 1 → slot 0*16+1 = 1    (block 0, offset 1)
  ...
  token 15 → slot 0*16+15 = 15 (block 0, offset 15)
  token 16 → slot 1*16+0 = 16  (block 1, offset 0)

解码阶段(每序列1个token):
  seq 0 → slot 1*16+5 = 21     (block 1, offset 5)
  seq 1 → slot 3*16+0 = 48     (block 3, offset 0)

2.3 AttentionMetadataBuilder构建器

class AttentionMetadataBuilder(ABC):
    """元数据构建器的抽象基类
    
    负责从调度器输出构建AttentionMetadata
    不同后端需要不同的构建逻辑
    
    生命周期:
    1. __init__(): 初始化(预分配张量等)
    2. build(): 构建元数据(每步调用)
    3. get_kv_cache Torch_scope(): 返回KV cache的非统一内存访问布局
    """
    
    @abstractmethod
    def __init__(
        self,
        input_builder: "StatefulModelInputBuilder",
    ) -> None:
        """从ModelInputBuilder初始化
        
        input_builder提供:
        - vllm_config
        - kv_cache_config
        - device
        - 是否使用CUDA Graph
        """
        raise NotImplementedError
    
    @abstractmethod
    def build(
        self,
        input_ids: torch.Tensor,           # [num_tokens] 输入token ID
        seq_lens: list[int],                # 每序列的当前长度
        query_lens: list[int],              # 每序列的查询长度
        cu_query_lens: torch.Tensor,        # 累积查询长度
        batch_size: int,                    # 批次大小
        seq_groups: list[SeqGroup],         # 序列组
        prefill_batch_size: int,            # 预填充批次大小
        prefill_seq_lens: list[int],        # 预填充序列长度
        decode_seq_lens: list[int],         # 解码序列长度
        kv_scales_hash: int | None,         # KV量化参数哈希
    ) -> AttentionMetadata:
        """构建当前步的AttentionMetadata
        
        这是最关键的方法,不同后端的实现差异很大
        """
        raise NotImplementedError

2.4 AttentionLayer接口

class Attention(nn.Module):
    """标准的Attention层实现
    
    大多数后端使用这个类作为Attention层
    将Q/K/V投影 → 注意力计算 → 输出投影封装为一个模块
    """
    
    def __init__(
        self,
        num_heads: int,                     # 查询头数
        head_size: int,                     # 头维度
        scale: float,                       # 注意力缩放因子 (1/sqrt(d))
        num_kv_heads: int | None = None,    # KV头数(GQA/MQA)
        alibi_slopes: list[float] | None = None,  # ALiBi斜率
        cache_config: "CacheConfig" | None = None,
        quant_config: "QuantizationConfig" | None = None,
        prefix: str = "",                   # 层名称前缀(用于区分不同层的KV cache)
        attn_backend: AttentionBackend | None = None,
    ):
        super().__init__()
        
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale                  # 1/sqrt(head_size)
        self.num_kv_heads = num_kv_heads or num_heads
        # GQA: num_kv_heads < num_heads
        # MHA: num_kv_heads == num_heads
        
        self.alibi_slopes = alibi_slopes
        # ALiBi: 不使用位置编码,而用线性偏置代替
        # slopes: 每个注意力头的偏置斜率
        
        self.quant_config = quant_config
        
        # 确定Attention实现类
        if attn_backend is not None:
            self.impl_cls = attn_backend.get_attention_impl_cls()
        else:
            # 全局默认后端
            self.impl_cls = AttentionBackend.get_attention_impl_cls()
    
    def forward(
        self,
        query: torch.Tensor,                # [num_tokens, num_heads * head_size]
        key: torch.Tensor,                  # [num_tokens, num_kv_heads * head_size]
        value: torch.Tensor,                # [num_tokens, num_kv_heads * head_size]
        kv_cache: torch.Tensor,             # KV cache存储
        attn_metadata: AttentionMetadata,    # 注意力元数据
    ) -> torch.Tensor:
        """执行注意力计算
        
        流程:
        1. Q/K/V投影已在模型层完成,此处接收投影后的Q/K/V
        2. reshape Q/K/V为多头格式
        3. 应用RoPE(旋转位置编码)
        4. 写入KV cache
        5. 执行注意力计算
        6. 返回输出
        """
        # 创建Attention实现实例(每层一个,缓存复用)
        if not hasattr(self, '_impl'):
            self._impl = self.impl_cls(
                num_heads=self.num_heads,
                head_size=self.head_size,
                scale=self.scale,
                num_kv_heads=self.num_kv_heads,
                alibi_slopes=self.alibi_slopes,
                sliding_window=self.sliding_window,
                kv_cache_dtype=self.kv_cache_dtype,
                logits_soft_cap=self.logits_soft_cap,
                attn_type=self.attn_type,
            )
        
        return self._impl.forward(
            query, key, value, kv_cache, attn_metadata
        )

2.5 接口继承关系类图

get_impl_cls()

get_metadata_cls()

get_builder_cls()

delegates to

«abstract»

AttentionBackend

+get_name() : str

+get_impl_cls() : type

+get_metadata_cls() : type

+get_builder_cls() : type

+get_kv_cache_shape() : tuple

+get_supported_head_sizes() : list

+get_attention_impl_cls() : type

+validate_head_size(head_size)

AttentionMetadata

+num_prefills: int

+num_decode_tokens: int

+slot_mapping: Tensor

+seq_lens: Tensor|None

+block_tables: Tensor|None

+use_cuda_graph: bool

«abstract»

AttentionMetadataBuilder

+init(input_builder)

+build(input_ids, seq_lens, ...) : AttentionMetadata

«abstract»

AttentionImpl

+init(num_heads, head_size, ...)

+forward(query, key, value, kv_cache, metadata) : Tensor

Attention

+num_heads: int

+head_size: int

+scale: float

+num_kv_heads: int

+forward(query, key, value, kv_cache, metadata) : Tensor

FlashInferBackend

+get_name() : "FLASHINFER"

+get_impl_cls() : FlashInferImpl

FlashAttnBackend

+get_name() : "FLASH_ATTN"

+get_impl_cls() : FlashAttnImpl

TritonAttnBackend

+get_name() : "TRITON"

+get_impl_cls() : TritonAttnImpl

FlashMLABackend

+get_name() : "FLASHMLA"

+get_impl_cls() : FlashMLAImpl

+get_kv_cache_shape() : MLA专用


第三章 AttentionBackend选择器(selector.py)

3.1 选择策略概述

selector.py(169行)实现了vLLM V1的后端自动选择机制——根据运行时环境(硬件、模型架构、平台)自动选择最优的注意力后端。

选择逻辑的优先级:

  1. 用户显式指定 → 直接使用指定后端
  2. 模型架构需求 → MLA模型必须使用MLA后端
  3. 硬件平台检测 → ROCm→aiter,CUDA→FlashInfer/FlashAttn
  4. 功能兼容性 → 某些后端不支持某些功能

3.2 which_attn_to_run()全局入口

_AttnBackend = TypeVar("_AttnBackend", bound=AttentionBackend)

def which_attn_to_run(
    vllm_config: "VllmConfig",
) -> type[_AttnBackend]:
    """全局入口:确定使用哪个AttentionBackend
    
    Args:
        vllm_config: vLLM配置(包含模型配置、调度器配置等)
    
    Returns:
        后端类(非实例)
    
    决策顺序:
    1. 检查vllm_config中是否有显式指定的后端
    2. 根据模型架构检测
    3. 根据硬件平台检测
    4. 回退到默认后端
    """
    # 创建选择器(缓存以避免重复检测)
    selector = _BackendSelector(vllm_config)
    return selector.select()

3.3 _BackendSelector实现

class _BackendSelector:
    """后端选择器
    
    封装了所有后端检测和选择逻辑
    """
    
    def __init__(self, vllm_config: "VllmConfig"):
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.device_config = vllm_config.device_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.lora_config = vllm_config.lora_config
    
    def select(self) -> type[_AttnBackend]:
        """执行选择逻辑"""
        
        # ===== Step 1: 用户显式指定 =====
        if self.model_config.attention_backend is not None:
            # 用户通过 --attention-backend 参数指定
            return _GLOBAL_REGISTRY.get_backend(
                self.model_config.attention_backend
            )
        
        # ===== Step 2: 平台检测 =====
        from vllm.platforms import current_platform
        
        # ===== Step 2a: ROCm/AMD平台 =====
        if current_platform.is_rocm():
            return self._select_rocm()
        
        # ===== Step 2b: XPU/Intel平台 =====
        if current_platform.is_xpu():
            return self._select_xpu()
        
        # ===== Step 2c: CPU平台 =====
        if current_platform.is_cpu():
            return CPUAttnBackend
        
        # ===== Step 3: 特殊架构检测 =====
        
        # ===== Step 3a: MLA架构 (DeepSeek) =====
        if self._is_mla_model():
            return self._select_mla()
        
        # ===== Step 3b: Mamba架构 =====
        if self._is_mamba_model():
            return self._select_mamba()
        
        # ===== Step 3c: GDN架构 =====
        if self._is_gdn_model():
            return GDNAttnBackend
        
        # ===== Step 3d: 线性注意力 =====
        if self._is_linear_attn_model():
            return LinearAttnBackend
        
        # ===== Step 4: CUDA默认后端 =====
        return self._select_cuda_default()

各平台选择逻辑

def _select_rocm(self) -> type[_AttnBackend]:
    """ROCm/AMD平台的后端选择"""
    if self._is_mla_model():
        from backends.mla.rocm_aiter_mla import ROCmAITerMLABackend
        return ROCmAITerMLABackend
    return ROCmAttnBackend  # ROCm通用后端(使用aiter)

def _select_xpu(self) -> type[_AttnBackend]:
    """XPU/Intel平台的后端选择"""
    if self._is_mla_model():
        from backends.mla.xpu_mla_sparse import XPUMLASparseBackend
        return XPUMLASparseBackend
    # XPU通用后端待实现
    raise NotImplementedError("XPU attention backend not yet implemented")

def _select_cuda_default(self) -> type[_AttnBackend]:
    """CUDA默认后端选择
    
    优先级:
    1. FlashInfer(最高性能,支持最多功能)
    2. FlashAttention(回退选项)
    3. Triton(CPU兼容的回退)
    """
    # 尝试导入FlashInfer
    if is_flashinfer_available():
        return FlashInferBackend
    
    # 尝试FlashAttention
    if is_flash_attn_available():
        return FlashAttnBackend
    
    # 最终回退到Triton
    return TritonAttnBackend

模型架构检测方法

def _is_mla_model(self) -> bool:
    """检测是否为DeepSeek MLA模型
    
    MLA特征:
    - 使用kv_lora_rank(KV压缩维度)
    - q_lora_rank(Q压缩维度)
    - 不使用标准的num_kv_heads
    """
    architectures = self.model_config.architectures
    return any("DeepSeek" in arch and "MLA" in arch 
               for arch in architectures)
    # 也检查model_config.attention_attr中的kv_lora_rank

def _is_mamba_model(self) -> bool:
    """检测是否为Mamba/SSM模型"""
    architectures = self.model_config.architectures
    return any("Mamba" in arch for arch in architectures)

def _is_gdn_model(self) -> bool:
    """检测是否为GDN(Geometric Deep Network)模型"""
    return "GDN" in str(self.model_config.architectures)

def _is_linear_attn_model(self) -> bool:
    """检测是否为线性注意力模型"""
    return "Linear" in str(self.model_config.architectures)

3.4 后端选择决策流程图

Yes

No

ROCm

Yes

No

XPU

Yes

No

CPU

CUDA

MLA

Yes

No

Yes

No

Mamba

Yes

No

Yes

No

GDN

Linear

标准Transformer

Yes

No

Yes

No

which_attn_to_run(config)

用户显式指定?

从注册表获取指定后端

平台检测

MLA模型?

ROCmAITerMLABackend

ROCmAttnBackend

MLA模型?

XPUMLASparseBackend

NotImplementedError

CPUAttnBackend

模型架构检测

MLA后端选择

FlashMLA可用?

FlashMLABackend

FlashInfer MLA?

FlashInferMLABackend

FlashAttnMLABackend

Mamba版本检测

Mamba2?

Mamba2AttnBackend

Mamba1?

Mamba1AttnBackend

MambaAttnBackend

GDNAttnBackend

LinearAttnBackend

FlashInfer可用?

FlashInferBackend
★推荐★

FlashAttn可用?

FlashAttnBackend

TritonAttnBackend
★最终回退★


第四章 后端注册表(registry.py)

4.1 _BackendRegistry注册机制

class _BackendRegistry:
    """后端注册表
    
    管理所有可用的AttentionBackend实现
    支持按名称、类型查找后端
    """
    
    def __init__(self):
        self._backends: dict[str, type[AttentionBackend]] = {}
        # 名称 → 后端类的映射
        
        self._type_to_name: dict[AttentionBackendType, str] = {}
        # 枚举类型 → 名称的映射
    
    def register_backend(
        self,
        name: str,
        backend_cls: type[AttentionBackend],
        backend_type: AttentionBackendType | None = None,
    ) -> None:
        """注册一个后端
        
        Args:
            name: 后端名称(如"FLASHINFER")
            backend_cls: 后端类
            backend_type: 对应的枚举类型(可选)
        """
        self._backends[name] = backend_cls
        if backend_type is not None:
            self._type_to_name[backend_type] = name
    
    def get_backend(
        self,
        name_or_type: str | AttentionBackendType,
    ) -> type[AttentionBackend]:
        """根据名称或类型获取后端类
        
        Args:
            name_or_type: 字符串名称或枚举类型
        
        Returns:
            后端类
        
        Raises:
            ValueError: 未注册的后端
        """
        if isinstance(name_or_type, AttentionBackendType):
            name = self._type_to_name.get(name_or_type)
            if name is None:
                raise ValueError(f"Backend type {name_or_type} not registered")
        else:
            name = name_or_type
        
        backend = self._backends.get(name)
        if backend is None:
            raise ValueError(f"Backend '{name}' not registered. "
                           f"Available: {list(self._backends.keys())}")
        return backend
    
    def available_backends(self) -> list[str]:
        """返回所有已注册的后端名称"""
        return list(self._backends.keys())

4.2 全局注册表_GLOBAL_REGISTRY

# 全局单例注册表
_GLOBAL_REGISTRY = _BackendRegistry()

# 注册标准后端
_GLOBAL_REGISTRY.register_backend("FLASHINFER", FlashInferBackend, 
                                   AttentionBackendType.FLASHINFER)
_GLOBAL_REGISTRY.register_backend("FLASH_ATTN", FlashAttnBackend,
                                   AttentionBackendType.FLASH_ATTN)
_GLOBAL_REGISTRY.register_backend("TRITON", TritonAttnBackend,
                                   AttentionBackendType.TRITON)
_GLOBAL_REGISTRY.register_backend("CPU", CPUAttnBackend,
                                   AttentionBackendType.CPU)

# 注册特殊后端(条件注册,依赖平台和库可用性)
if is_flash_attn_available():
    _GLOBAL_REGISTRY.register_backend("FLASH_ATTN_DIFFKV", FlashAttnDiffKVBackend)
    _GLOBAL_REGISTRY.register_backend("FLEX_ATTENTION", FlexAttentionBackend)

if is_rocm_available():
    _GLOBAL_REGISTRY.register_backend("ROCM_ATTN", ROCmAttnBackend)
    _GLOBAL_REGISTRY.register_backend("ROCM_AITER_FA", ROCmAITerFABackend)

# MLA后端
_GLOBAL_REGISTRY.register_backend("FLASHMLA", FlashMLABackend)
_GLOBAL_REGISTRY.register_backend("FLASHMLA_SPARSE", FlashMLASparseBackend)
_GLOBAL_REGISTRY.register_backend("FLASHINFER_MLA", FlashInferMLABackend)
_GLOBAL_REGISTRY.register_backend("FLASHINFER_MLA_SPARSE", FlashInferMLASparseBackend)
_GLOBAL_REGISTRY.register_backend("FLASHATTN_MLA", FlashAttnMLABackend)

# 其他特殊后端
_GLOBAL_REGISTRY.register_backend("GDN", GDNAttnBackend)
_GLOBAL_REGISTRY.register_backend("TURBOQUANT", TurboQuantBackend)
_GLOBAL_REGISTRY.register_backend("MAMBA", MambaAttnBackend)
_GLOBAL_REGISTRY.register_backend("MAMBA1", Mamba1AttnBackend)
_GLOBAL_REGISTRY.register_backend("MAMBA2", Mamba2AttnBackend)
_GLOBAL_REGISTRY.register_backend("LINEAR", LinearAttnBackend)
_GLOBAL_REGISTRY.register_backend("SHORT_CONV", ShortConvAttnBackend)

4.3 AttentionBackendType枚举

class AttentionBackendType(str, Enum):
    """注意力后端类型枚举
    
    继承str和Enum,既可按枚举使用也可按字符串使用
    """
    FLASHINFER = "FLASHINFER"
    FLASH_ATTN = "FLASH_ATTN"
    TRITON = "TRITON"
    CPU = "CPU"
    ROCM_ATTN = "ROCM_ATTN"
    FLASHMLA = "FLASHMLA"
    # ... 更多类型

4.4 哪些后端在何时可用

后端 CUDA ROCm CPU XPU 条件
FlashInfer flashinfer包已安装
FlashAttention flash_attn包已安装
Triton 无额外依赖
CPU 无额外依赖
FlexAttention PyTorch≥2.5
ROCm Attn aiter库可用
FlashMLA MLA模型+专用库
TurboQuant turboquant库
Mamba mamba-ssm库
GDN GDN模型

第五章 通用工具函数(utils.py)

5.1 模块概述

utils.py(909行)是attention模块的瑞士军刀,包含大量被各后端共享使用的工具函数。主要分类:

  1. 序列分组工具: create_seq_groups(), get_seq_len_table()
  2. FlashAttn版本检测: get_flash_attn_version(), is_flash_attn_available()
  3. KV Cache工具: get_kv_cache_layout_notrion_str()
  4. PagedAttention工具: compute_slot_mapping(), compute_block_table()
  5. 常见模式: make_tensor_with_pad()

5.2 create_seq_groups()序列分组

def create_seq_groups(
    seq_lens: list[int] | torch.Tensor,   # 每序列的当前长度
    query_lens: list[int] | torch.Tensor,  # 每序列的查询长度
) -> list[tuple[int, int]]:
    """创建序列分组信息
    
    返回: [(seq_len, query_len), ...] 列表
    
    用途: FlashInfer需要seq_groups来构建wrapper
    """
    if isinstance(seq_lens, torch.Tensor):
        seq_lens = seq_lens.cpu().tolist()
    if isinstance(query_lens, torch.Tensor):
        query_lens = query_lens.cpu().tolist()
    
    return list(zip(seq_lens, query_lens))

5.3 get_seq_len_table()序列长度表

def get_seq_len_table(
    seq_lens: torch.Tensor,     # [batch_size] 序列长度
    batch_size: int,
    max_num_blocks_per_seq: int,
    block_size: int,
    device: torch.device,
) -> torch.Tensor:
    """构建序列长度→块数的映射表
    
    用于PagedAttention的block_table索引
    
    Returns:
        [batch_size, max_num_blocks_per_seq] 的int32张量
        每行: [0, 1, 2, ..., num_blocks-1, -1, -1, ..., -1]
    """
    num_blocks_per_seq = (seq_lens + block_size - 1) // block_size
    # 向上取整: 15 tokens / 16 block_size = 1 block
    
    table = torch.full(
        (batch_size, max_num_blocks_per_seq),
        -1,  # 无效块标记
        dtype=torch.int32,
        device=device,
    )
    
    for i in range(batch_size):
        table[i, :num_blocks_per_seq[i]] = torch.arange(num_blocks_per_seq[i])
    
    return table

5.4 get_flash_attn_version()版本检测

def get_flash_attn_version() -> int:
    """检测FlashAttention版本
    
    Returns:
        0: 未安装
        2: FlashAttention 2.x
        3: FlashAttention 3.x
    
    不同版本的API不同:
    - v2: flash_attn_varlen_func()
    - v3: flash_attn_with_kvcache()
    """
    try:
        import flash_attn
        return int(flash_attn.__version__.split('.')[0])
    except ImportError:
        return 0

def is_flash_attn_available() -> bool:
    return get_flash_attn_version() >= 2

def is_flash_attn_3_available() -> bool:
    return get_flash_attn_version() >= 3

5.5 PagedAttention相关工具

def compute_slot_mapping(
    block_table: torch.Tensor,     # [batch_size, max_blocks]
    seq_lens: torch.Tensor,        # [batch_size]
    query_lens: list[int],         # 每序列查询长度
    block_size: int,
    num_kv_heads: int,
) -> torch.Tensor:
    """计算slot_mapping
    
    slot_mapping[i] = 物理slot索引(block_id * block_size + offset)
    
    用于将逻辑token位置映射到KV cache的物理位置
    
    PagedAttention核心:
      KV cache = [num_blocks, block_size, num_kv_heads, head_size]
      每个token的KV向量存储在slot_mapping指向的位置
    """
    batch_size = block_table.shape[0]
    slot_mapping = []
    
    for i in range(batch_size):
        seq_len = seq_lens[i].item()
        query_len = query_lens[i]
        
        # 新token的起始位置
        start = seq_len - query_len
        
        for j in range(query_len):
            # 逻辑token位置
            logical_pos = start + j
            
            # 映射到物理位置
            block_id = block_table[i, logical_pos // block_size]
            offset = logical_pos % block_size
            slot = block_id * block_size + offset
            
            slot_mapping.append(slot)
    
    return torch.tensor(slot_mapping, dtype=torch.int64)

slot_mapping的计算示意

物理KV Cache

逻辑序列

slot=0*16+0=0

slot=0*16+15=15

slot=1*16+0=16

block_table = [0, 1, 3]

seq_len=17
block_table=[0, 1, 3]

tok0

tok1

...

tok15

tok16

Block 0
slots 0-15

Block 1
slots 16-31

Block 3
slots 48-63


附录A 全局类继承关系图

«abstract»

AttentionBackend

+get_name() : str

+get_impl_cls() : type

+get_metadata_cls() : type

+get_builder_cls() : type

+get_kv_cache_shape() : tuple

+get_supported_head_sizes() : list

+get_attention_impl_cls() : type

FlashInferBackend

+get_name() : FLASHINFER

+get_impl_cls() : FlashInferImpl

+get_metadata_cls() : FlashInferMetadata

+get_builder_cls() : FlashInferMetadataBuilder

+get_kv_cache_shape() : [2, blocks, bs, heads, dim]

FlashAttnBackend

+get_name() : FLASH_ATTN

+get_impl_cls() : FlashAttnImpl

+get_metadata_cls() : FlashAttnMetadata

+get_builder_cls() : FlashAttnMetadataBuilder

TritonAttnBackend

+get_name() : TRITON

+get_impl_cls() : TritonAttnImpl

CPUAttnBackend

+get_name() : CPU

+get_impl_cls() : CPUAttnImpl

FlashMLABackend

+get_name() : FLASHMLA

+get_impl_cls() : FlashMLAImpl

+get_kv_cache_shape() : [blocks, bs, kv_lora_rank]

FlashMLASparseBackend

+get_name() : FLASHMLA_SPARSE

+get_impl_cls() : FlashMLASparseImpl

ROCmAttnBackend

+get_name() : ROCM_ATTN

GDNAttnBackend

TurboQuantBackend

MambaAttnBackend

Mamba2AttnBackend

LinearAttnBackend


附录B 后端选择决策树全图

渲染错误: Mermaid 渲染失败: Parse error on line 7: ...PLATFORM -->|is_rocm()| ROCM_PATH["ROCm路 -----------------------^ Expecting 'SQE', 'DOUBLECIRCLEEND', 'PE', '-)', 'STADIUMEND', 'SUBROUTINEEND', 'PIPE', 'CYLINDEREND', 'DIAMOND_STOP', 'TAGEND', 'TRAPEND', 'INVTRAPEND', 'UNICODE_TEXT', 'TEXT', 'TAGSTART', got 'PS'

附录C 术语表

英文术语 中文翻译 说明
Attention Backend 注意力后端 执行注意力计算的实现
PagedAttention 分页注意力 KV cache分页管理策略
KV Cache 键值缓存 存储已计算的Key/Value
Slot Mapping 槽位映射 逻辑token→物理slot的映射
Block Table 块表 序列→KV cache块的映射
FlashAttention 闪光注意力 IO-aware的高效注意力算法
FlashInfer 闪光推理 高效推理库
MLA 多头潜在注意力 DeepSeek的KV压缩注意力
GQA 分组查询注意力 多个Q头共享KV头
MQA 多查询注意力 所有Q头共享1组KV头
RoPE 旋转位置编码 位置编码方法
ALiBi 线性偏置注意力 替代位置编码的方法
Prefill 预填充 首次处理完整序列
Decode 解码 逐token生成阶段
CUDA Graph CUDA图 GPU操作录制和重放
ROCm AMD GPU平台 AMD的GPU计算平台
XPU Intel GPU平台 Intel的GPU计算平台
aiter AMD推理库 ROCm的注意力推理库
TurboQuant 涡轮量化 KV cache量化方案
GDN 几何深度网络 特殊注意力架构
SSM 状态空间模型 Mamba等模型的架构
DCP 分布式上下文并行 跨GPU的上下文并行
ViT 视觉Transformer 图像处理的Transformer
MRoPE 多维RoPE 多模态旋转位置编码

附录K backend.py 关键类逐行深度解析

K.1 AttentionType 枚举

class AttentionType(str, Enum):
    """注意力类型枚举
    
    标识模型中不同层使用的注意力类型
    不同类型有不同的KV cache管理策略
    """
    DECODER = "decoder"     # 标准自回归解码器注意力(causal mask)
    ENCODER = "encoder"     # 编码器注意力(双向,无causal mask)
    ENCODER_DECODER = "encoder_decoder"  # 编码器-解码器cross-attention
    # DECODER: 用于GPT/Llama等标准LLM
    # ENCODER: 用于BERT/encoder-only模型,或ViT
    # ENCODER_DECODER: 用于T5/BART等seq2seq模型的cross-attention

K.2 Attention init 逐行

class Attention(nn.Module):
    def __init__(
        self,
        num_heads: int,                    # Q头数
        head_size: int,                    # 每头维度
        scale: float,                      # 缩放因子 1/sqrt(head_size)
        num_kv_heads: int | None = None,   # KV头数(GQA时 < num_heads)
        alibi_slopes: list[float] | None = None,  # ALiBi偏置斜率
        cache_config: CacheConfig | None = None,  # KV cache配置
        quant_config: QuantizationConfig | None = None,  # 量化配置
        prefix: str = "",                  # 层名前缀
        attn_backend: AttentionBackend | None = None,  # 指定后端
        attn_type: AttentionType = AttentionType.DECODER,  # 注意力类型
    ):
        super().__init__()
        # === 基本参数 ===
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
        # MHA: num_kv_heads == num_heads
        # GQA: num_kv_heads < num_heads (如Llama-2 8头KV对应32头Q)
        # MQA: num_kv_heads == 1 (所有Q头共享1组KV)
        
        self.alibi_slopes = alibi_slopes
        # ALiBi: Attention with Linear Biases
        # 不使用位置编码,而在Q*K^T上加线性偏置
        # slopes[i] = 第i个头的偏置斜率
        # 不同头使用不同斜率,提供位置信息
        
        self.attn_type = attn_type
        # 决定是否使用causal mask和KV cache行为
        
        # === KV Cache配置 ===
        if cache_config is not None:
            self.kv_cache_dtype = cache_config.cache_dtype
            self._max_seq_len = cache_config.max_seq_len
        else:
            self.kv_cache_dtype = "auto"
            self._max_seq_len = 8192
        
        # === 滑动窗口 ===
        self.sliding_window = getattr(cache_config, 'sliding_window', None) if cache_config else None
        # sliding_window: 只关注最近W个token
        # 用于长上下文模型(如Mistral的SWA)
        
        # === Logits Soft Cap ===
        self.logits_soft_cap = None
        # 用于Gemma2等模型的注意力分数截断
        # soft_cap: scores = soft_cap * tanh(scores / soft_cap)
        # 防止注意力分数过大
        
        # === 后端实现 ===
        if attn_backend is not None:
            self.impl_cls = attn_backend.get_attention_impl_cls()
        else:
            self.impl_cls = Attention  # 默认使用本类

K.3 Attention forward 逐行

def forward(
    self,
    query: torch.Tensor,       # [num_tokens, num_heads * head_size]
    key: torch.Tensor,          # [num_tokens, num_kv_heads * head_size]  
    value: torch.Tensor,        # [num_tokens, num_kv_heads * head_size]
    kv_cache: torch.Tensor | None = None,  # KV cache (可选,encoder不使用)
    attn_metadata: AttentionMetadata | None = None,  # 元数据
) -> torch.Tensor:
    """执行注意力计算
    
    完整流程:
    1. 获取/创建Attention实现
    2. 委托给impl.forward()
    """
    # 懒初始化impl(避免__init__时的CUDA初始化)
    if not hasattr(self, '_impl'):
        self._impl = self.impl_cls(
            num_heads=self.num_heads,
            head_size=self.head_size,
            scale=self.scale,
            num_kv_heads=self.num_kv_heads,
            alibi_slopes=self.alibi_slopes,
            sliding_window=self.sliding_window,
            kv_cache_dtype=self.kv_cache_dtype,
            logits_soft_cap=self.logits_soft_cap,
            attn_type=self.attn_type,
        )
    
    return self._impl.forward(
        query, key, value, kv_cache, attn_metadata
    )

K.4 AttentionImpl init 逐行

class AttentionImpl(ABC):
    """注意力实现抽象基类
    
    每个后端必须实现forward()方法
    """
    
    @abstractmethod
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
        kv_cache_dtype: str,
        logits_soft_cap: float | None,
        attn_type: AttentionType,
    ):
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_kv_heads
        self.sliding_window = sliding_window
        self.kv_cache_dtype = kv_cache_dtype
        self.logits_soft_cap = logits_soft_cap
        self.attn_type = attn_type
        
        # GQA重用因子
        self.num_queries_per_kv = num_heads // num_kv_heads
        # MHA: 1 (每KV头对应1个Q头)
        # GQA: >1 (如Llama-2: 32/8=4, 每4个Q头共享1个KV头)
        # MQA: num_heads (所有Q头共享1个KV头)
        
        # ALiBi偏置(如果使用)
        if alibi_slopes is not None:
            # 将slopes转换为GPU tensor
            self.alibi_slopes = torch.tensor(
                alibi_slopes, dtype=torch.float32
            ).unsqueeze(-1).unsqueeze(-1)
            # 形状: [num_heads, 1, 1]
            # 用于广播: scores += slopes * position_offset
    
    @abstractmethod
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor | None,
        attn_metadata: AttentionMetadata | None,
    ) -> torch.Tensor:
        raise NotImplementedError

附录L selector.py 完整代码逐行注释

# selector.py (169行) 完整逻辑追踪

# 全局缓存:避免重复检测
_cached_backend: type[AttentionBackend] | None = None

def which_attn_to_run(
    vllm_config: "VllmConfig",
) -> type[AttentionBackend]:
    """全局入口函数
    
    调用路径:
    GPUModelRunner.__init__()
      → which_attn_to_run(vllm_config)
      → 返回后端类
      → 用后端类创建metadata builder和impl
    
    缓存:
    同一个vllm_config只检测一次
    后续调用直接返回缓存结果
    """
    global _cached_backend
    # 注意: 当前实现每次都重新选择
    # 未来可以添加缓存优化
    
    selector = _BackendSelector(vllm_config)
    return selector.select()

L.1 _BackendSelector.select() 完整分支追踪

_BackendSelector.select() 决策路径:

1. model_config.attention_backend != None?
   → Yes: 从_GLOBAL_REGISTRY获取指定后端 → 返回
   
2. current_platform.is_rocm()?
   → Yes:
     a. _is_mla_model()?
        → Yes: ROCmAITerMLABackend → 返回
        → No: ROCmAttnBackend → 返回
        
3. current_platform.is_xpu()?
   → Yes:
     a. _is_mla_model()?
        → Yes: XPUMLASparseBackend → 返回
        → No: NotImplementedError → 抛出
        
4. current_platform.is_cpu()?
   → Yes: CPUAttnBackend → 返回
   
5. _is_mla_model()? (CUDA平台)
   → Yes: _select_mla()
     a. FlashMLA可用?
        → Yes:
          - use_sparse? → FlashMLASparseBackend
          - 否则 → FlashMLABackend
     b. FlashInfer MLA可用?
        → Yes: FlashInferMLABackend
     c. FlashAttn MLA可用?
        → Yes: FlashAttnMLABackend
     d. CUTLASS MLA可用?
        → Yes: CUTLASSMLABackend
     e. Triton MLA?
        → Yes: TritonMLABackend
     f. 无可用后端 → RuntimeError
        
6. _is_mamba_model()?
   → Yes:
     a. is_mamba2()? → Mamba2AttnBackend
     b. is_mamba1()? → Mamba1AttnBackend
     c. 否则 → MambaAttnBackend
     
7. _is_gdn_model()?
   → Yes: GDNAttnBackend → 返回
   
8. _is_linear_attn_model()?
   → Yes: LinearAttnBackend → 返回
   
9. _is_short_conv_model()?
   → Yes: ShortConvAttnBackend → 返回
   
10. 标准Transformer (CUDA):
    a. FlashInfer可用? → FlashInferBackend ★推荐★
    b. FlashAttention可用? → FlashAttnBackend
    c. Triton → TritonAttnBackend ★回退★

附录M registry.py 全局注册表构建追踪

M.1 注册时序

Python模块加载时:
  import vllm.v1.attention.backends.registry
  → 模块级代码执行
  → _GLOBAL_REGISTRY = _BackendRegistry()
  → 条件注册各后端

注册步骤:
1. 始终注册:
   - FLASHINFER (如果flashinfer可用)
   - FLASH_ATTN (如果flash_attn可用)
   - TRITON (始终可用)
   - CPU (始终可用)

2. 条件注册:
   - FLASH_ATTN_DIFFKV (需要flash_attn)
   - FLEX_ATTENTION (需要PyTorch≥2.5)
   - ROCM_ATTN (需要ROCm平台)
   - ROCM_AITER_FA (需要aiter库)

3. MLA后端 (始终注册类,运行时检测可用性):
   - FLASHMLA
   - FLASHMLA_SPARSE
   - FLASHINFER_MLA
   - FLASHINFER_MLA_SPARSE
   - FLASHATTN_MLA
   - CUTLASS_MLA
   - TRITON_MLA
   - ROCM_AITER_MLA
   - ROCM_AITER_MLA_SPARSE
   - XPU_MLA_SPARSE

4. 特殊架构后端 (始终注册):
   - GDN
   - TURBOQUANT
   - MAMBA
   - MAMBA1
   - MAMBA2
   - LINEAR
   - SHORT_CONV

M.2 _BackendRegistry.get_backend() 错误处理

用户指定 --attention-backend FLASHINFER:
  1. _GLOBAL_REGISTRY.get_backend("FLASHINFER")
  2. 查找 _backends["FLASHINFER"]
  3. 如果存在 → 返回 FlashInferBackend 类
  4. 如果不存在 → ValueError("Backend 'FLASHINFER' not registered. Available: [...]")

用户指定 --attention-backend INVALID:
  → ValueError: Backend 'INVALID' not registered.
  → Available: ['FLASHINFER', 'FLASH_ATTN', 'TRITON', ...]

附录N utils.py 关键函数完整索引

函数名 行号 参数 返回值 用途
create_seq_groups L15 seq_lens, query_lens list[tuple] FlashInfer序列分组
get_seq_len_table L45 seq_lens, batch, max_blocks, block_size, device Tensor 序列长度→块映射
get_flash_attn_version L80 - int (0/2/3) 检测FlashAttn版本
is_flash_attn_available L95 - bool FlashAttn是否可用
is_flash_attn_3_available L100 - bool FlashAttn v3是否可用
compute_slot_mapping L120 block_table, seq_lens, query_lens, block_size, num_kv_heads Tensor 计算物理slot索引
compute_block_table L180 block_table, seq_lens, block_size, max_blocks Tensor 计算块表
make_tensor_with_pad L220 data, max_len, pad_val, dtype, device Tensor 带填充的tensor创建
get_kv_cache_layout_notrion_str L260 - str KV cache布局标识
_get_uncompressed_kv_cache_shape L280 num_blocks, block_size, num_kv_heads, head_size tuple 标准KV cache形状
_get_fp8_kv_cache_shape L310 num_blocks, block_size, num_kv_heads, head_size tuple FP8 KV cache形状
convert_kv_cache_to_fp8 L340 kv_cache, kv_scale, kv_zp Tensor FP8量化转换
convert_fp8_kv_cache_to_fp16 L380 kv_cache, kv_scale, kv_zp Tensor FP8反量化

附录O AttentionMetadata 完整字段参考手册

O.1 核心字段详解

@dataclass
class AttentionMetadata:
    """完整字段参考"""
    
    # =================== 批次级信息 ===================
    
    # num_prefills: 预填充序列数量
    # 含义: 当前批次中有多少个序列正在进行首次完整处理
    # 取值: 0 到 batch_size
    # 影响: 决定使用prefill kernel还是decode kernel
    num_prefills: int
    
    # num_decode_tokens: 解码token总数
    # 含义: 当前批次中需要解码的token数(通常等于decode序列数)
    # 取值: 0 到 batch_size(每序列1个decode token)
    num_decode_tokens: int
    
    # slot_mapping: [num_tokens] KV cache物理slot索引
    # 含义: 每个token的KV向量应写入KV cache的哪个物理位置
    # 计算: slot = block_id * block_size + offset_in_block
    # 用途: reshape_and_cache()使用此映射写入KV
    # 特殊值: -1表示"不写入"(如padding token)
    slot_mapping: torch.Tensor
    
    # seq_lens: [batch_size] 或 None
    # 含义: 每个序列的当前总长度(prompt + 已生成token)
    # 类型: list[int]或Tensor(CPU端)
    # 用途: 决定每个序列的KV cache范围
    seq_lens: torch.Tensor | None
    
    # seq_lens_tensor: [batch_size] 或 None
    # 含义: 同seq_lens,但保证在GPU上
    # 用途: 传给GPU kernel的序列长度参数
    seq_lens_tensor: torch.Tensor | None
    
    # =================== 预填充信息 ===================
    
    # num_prefill_tokens: 预填充token总数
    # 含义: 所有预填充序列的token数之和
    # 计算: sum(prefill_seq_lens)
    # 用途: 确定prefill query的范围 [0:num_prefill_tokens]
    num_prefill_tokens: int
    
    # max_prefill_seq_len: 最长预填充序列长度
    # 含义: 批次中最长预填充序列的token数
    # 用途: flash_attn_varlen_func的max_seqlen参数
    #       影响workspace分配大小
    max_prefill_seq_len: int
    
    # max_decode_seq_len: 最长解码序列长度
    # 含义: 批次中最长解码序列的KV cache长度
    # 用途: decode kernel的循环上界
    max_decode_seq_len: int
    
    # =================== 批次信息 ===================
    
    # batch_size: 批次大小
    # 含义: 当前批次中的序列总数
    # 计算: batch_size = num_prefills + num_decodes
    batch_size: int
    
    # request_ids_to_seq_ids: 请求ID → 序列ID映射
    # 含义: 将外部请求ID映射到内部序列索引
    # 类型: dict[str, list[int]]
    # 示例: {"req_abc": [0, 1]} 表示请求abc包含2个序列(beam search)
    request_ids_to_seq_ids: dict[str, list[int]]
    
    # prefill_seq_lens: 预填充序列长度列表
    # 含义: 每个预填充序列的token数
    # 类型: list[int]
    prefill_seq_lens: list[int]
    
    # decode_seq_lens: 解码序列长度列表
    # 含义: 每个解码序列的当前KV cache长度
    # 类型: list[int]
    decode_seq_lens: list[int]
    
    # =================== KV Cache布局 ===================
    
    # block_tables: [batch_size, max_num_blocks_per_seq] 或 None
    # 含义: 每个序列的KV cache块ID列表
    # 类型: torch.Tensor (int32, GPU)
    # block_tables[i, j] = 序列i的第j个KV cache块的物理块ID
    # 特殊值: -1表示无效块
    block_tables: torch.Tensor | None
    
    # =================== CUDA Graph ===================
    
    # use_cuda_graph: 是否使用CUDA Graph
    # 含义: 当前步是否在CUDA Graph录制/重放模式
    # 影响: 某些动态分配操作需要特殊处理
    use_cuda_graph: bool

附录P compute_slot_mapping() 完整算法追踪

P.1 混合批次(Prefill + Decode)

批次配置:
  序列0 (prefill): seq_len=10, query_len=10 (新请求,全部token需处理)
  序列1 (decode):  seq_len=25, query_len=1  (已生成15个token,当前生成1个)
  序列2 (prefill): seq_len=8,  query_len=8
  序列3 (decode):  seq_len=30, query_len=1

block_size = 16
block_tables = [
  [0, 1, -1],    # 序列0: 块0-1
  [2, 3, 5],     # 序列1: 块2,3,5
  [4, -1, -1],   # 序列2: 块4
  [6, 7, 8],     # 序列3: 块6,7,8
]

Step 1: 计算序列0的slot_mapping (prefill, query_len=10)
  start = seq_len - query_len = 10 - 10 = 0
  token 0: pos=0, block_id=block_tables[0,0]=0, offset=0%16=0, slot=0*16+0=0
  token 1: pos=1, block_id=0, offset=1, slot=1
  ...
  token 15: pos=15, block_id=block_tables[0,0]=0, offset=15, slot=15
  → 但query_len=10, 只处理pos 0-9
  slot_mapping[0:10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

Step 2: 计算序列1的slot_mapping (decode, query_len=1)
  start = 25 - 1 = 24
  token 0: pos=24, block_id=block_tables[1,24//16]=block_tables[1,1]=3
  offset=24%16=8, slot=3*16+8=56
  slot_mapping[10] = 56

Step 3: 计算序列2的slot_mapping (prefill, query_len=8)
  start = 8 - 8 = 0
  token 0-7: block_tables[2,0]=4
  slot_mapping[11:19] = [64, 65, 66, 67, 68, 69, 70, 71]
  # 4*16=64

Step 4: 计算序列3的slot_mapping (decode, query_len=1)
  start = 30 - 1 = 29
  pos=29, block_id=block_tables[3,29//16]=block_tables[3,1]=7
  offset=29%16=13, slot=7*16+13=125
  slot_mapping[19] = 125

最终slot_mapping = [0,1,2,3,4,5,6,7,8,9, 56, 64,65,66,67,68,69,70,71, 125]
                   |---序列0(prefill)---|  |解码|  |---序列2(prefill)--|  |解码|

附录Q AttentionType 对KV Cache行为的影响

AttentionType.DECODER:
  ✅ 使用KV Cache (PagedAttention)
  ✅ Causal Mask (只关注过去token)
  ✅ Prefill/Decode分离
  典型: GPT, Llama, Mistral

AttentionType.ENCODER:
  ❌ 不使用KV Cache (每步独立计算)
  ❌ 无Causal Mask (双向注意力)
  ❌ 无Prefill/Decode分离
  典型: BERT, ViT, Whisper Encoder

AttentionType.ENCODER_DECODER:
  ✅ 使用Cross-Attention KV Cache
  ❌ 自身不需要Causal Mask (但decoder部分需要)
  ✅ Encoder的KV只计算一次,Decoder重复使用
  典型: T5, BART, Whisper (encoder→decoder)
  
  Cross-Attention的KV Cache:
    K_cross = Encoder的最后一层输出
    V_cross = 同上
    这些K/V在所有decoder步骤中保持不变
    因此只在prefill时计算一次

附录R make_tensor_with_pad() 工具函数详解

def make_tensor_with_pad(
    data: list[list[Any]],     # 变长数据
    max_len: int,              # 最大长度
    pad_val: Any,              # 填充值
    dtype: torch.dtype,        # 数据类型
    device: torch.device,      # 设备
) -> torch.Tensor:
    """将变长列表创建为统一长度的tensor,短的部分用pad_val填充
    
    用途: 构建block_tables, seq_lens等需要统一长度的tensor
    
    示例:
      data = [[0,1,3], [2,5], [4,6,7,9]]
      max_len = 4
      pad_val = -1
      
      result = [
        [0, 1, 3, -1],    # 序列0: 3个块,填充1个-1
        [2, 5, -1, -1],   # 序列1: 2个块,填充2个-1
        [4, 6, 7, 9],     # 序列2: 4个块,无需填充
      ]
    
    实现:
      1. 创建[max_len]大小的tensor,填充pad_val
      2. 对每行,复制实际数据
    """
    batch_size = len(data)
    tensor = torch.full(
        (batch_size, max_len), pad_val, dtype=dtype, device="cpu"
    )
    
    for i, row in enumerate(data):
        tensor[i, :len(row)] = torch.tensor(row, dtype=dtype)
    
    return tensor.to(device, non_blocking=True)

附录S GQA/MQA/MHA的KV Cache布局差异

MHA (Multi-Head Attention):
  num_heads = 32, num_kv_heads = 32
  KV Cache: [2, num_blocks, block_size, 32, head_size]
  每token: 2 × 32 × head_size floats
  示例 (head_size=128, fp16): 32 × 128 × 2 × 2 = 16,384 bytes

GQA (Grouped-Query Attention):
  num_heads = 32, num_kv_heads = 8
  KV Cache: [2, num_blocks, block_size, 8, head_size]
  每token: 2 × 8 × head_size floats
  示例: 8 × 128 × 2 × 2 = 4,096 bytes
  节省: 75%

MQA (Multi-Query Attention):
  num_heads = 32, num_kv_heads = 1
  KV Cache: [2, num_blocks, block_size, 1, head_size]
  每token: 2 × 1 × head_size floats
  示例: 1 × 128 × 2 × 2 = 512 bytes
  节省: 97%

MLA (Multi-head Latent Attention):
  num_heads = 32, num_kv_heads = N/A
  KV Cache: [num_blocks, block_size, kv_lora_rank]
  每token: kv_lora_rank floats
  示例 (kv_lora_rank=512): 512 × 2 = 1,024 bytes
  与MHA相比节省: 94%
  与GQA相比节省: 75%
模式 KV Cache大小/token 典型模型
MHA 16KB GPT-3, 原始Transformer
GQA 4KB Llama-2 70B
MQA 512B PaLM, StarCoder
MLA 1KB DeepSeek-V2/V3

附录T Attention层在模型中的位置

标准Transformer模型:

Input Embedding
  ↓
Layer 0:
  Self-Attention ← ★Attention模块★
  + Residual
  ↓
  FFN
  + Residual
  ↓
Layer 1:
  Self-Attention
  + Residual
  ↓
  FFN
  + Residual
  ↓
...
Layer N-1:
  Self-Attention
  + Residual
  ↓
  FFN
  + Residual
  ↓
Output Head

每层有一个Attention实例:
  model.layers[i].self_attn = Attention(
    num_heads=32,
    head_size=128,
    ...
  )

Encoder-Decoder模型:

Input Embedding
  ↓
Encoder Layer 0:
  Self-Attention (ENCODER type) ← ★双向注意力★
  Cross-Attention (ENCODER_DECODER type) ← ★关注Encoder输出★
  FFN
  ↓
...
Output Head

附录U FlashInfer/FlashAttention/Triton 后端元数据构建流程对比

U.1 build()方法核心差异

=== FlashInferMetadataBuilder.build() ===

1. 分类prefill/decode序列
2. 构建slot_mapping(token→物理slot)
3. 构建paged_kv_indices/indptr/last_page_len
4. 调用prefill_wrapper.plan() 预计算prefill索引
5. 调用decode_wrapper.plan() 预计算decode索引
6. 组装FlashInferMetadata

=== FlashAttnMetadataBuilder.build() ===

1. 分类prefill/decode序列
2. 构建slot_mapping
3. 构建cu_seqlens(累积序列长度,用于varlen API)
4. 构建block_tables(统一长度,-1填充)
5. 计算max_prefill_seq_len和max_decode_seq_len
6. 组装FlashAttnMetadata

=== TritonAttnMetadataBuilder.build() ===

1. 分类prefill/decode序列
2. 构建slot_mapping
3. 构建block_tables和seq_lens_tensor
4. 计算cu_seqlens(与FlashAttn类似)
5. 组装TritonAttnMetadata(最简单的元数据)

=== CPUAttnMetadataBuilder.build() ===

1. 分类prefill/decode序列
2. 构建slot_mapping
3. 构建block_tables(CPU端tensor)
4. 组装CPUAttnMetadata(无GPU特有字段)

U.2 元数据字段对比矩阵

字段 FlashInfer FlashAttention Triton CPU FlexAttention
slot_mapping
seq_lens
block_tables
paged_kv_indices
paged_kv_indptr
paged_kv_last_page_len
cu_seqlens
workspace
prefill_wrapper
decode_wrapper
block_mask
score_modify

附录V CUDA Graph与注意力后端兼容性

V.1 CUDA Graph工作原理

CUDA Graph是一个GPU操作的录制-重放机制:

1. 录制阶段:
   - 执行一次完整的forward
   - 记录所有CUDA kernel调用和内存操作
   - 生成CUDA Graph

2. 重放阶段:
   - 直接重放录制的Graph
   - 跳过CPU端调度开销(~10μs/kernel)
   - 对decode特别有效(每步kernel很少)

CUDA Graph的要求:
- 输入/输出tensor的地址不变
- kernel参数不变(或可预测变化)
- 不能有动态分支(如变长序列)

兼容性:
  ✅ FlashInfer Decode: 支持CUDA Graph
  ✅ FlashMLA Decode: 支持CUDA Graph
  ✅ FlashAttention v3: 支持CUDA Graph
  ❌ FlashAttention v2: 有限支持
  ❌ Triton Prefill: 不支持(动态序列长度)
  ❌ CPU: 不适用

V.2 CUDA Graph模式下的workspace预分配

正常模式:
  workspace在每次forward时动态分配
  大小取决于当前batch_size和seq_len

CUDA Graph模式:
  workspace必须预分配到最大可能大小
  因为重放时不能重新分配
  
  max_workspace_size = max_batch_size × max_seq_len × num_heads × head_size × 4
  
  例如: 256 × 8192 × 32 × 128 × 4 = 4GB
  
  这意味着CUDA Graph模式需要更多GPU内存
  但换来的是调度开销的消除

附录W KV Cache FP8量化深度分析

W.1 FP8格式详解

FP8 E4M3FN:
  符号位: 1 bit
  指数位: 4 bits (bias=7)
  尾数位: 3 bits
  范围: [-448, 448]
  精度: ~3位有效数字
  特殊值: 无NaN,有±inf

FP8 E5M2FN:
  符号位: 1 bit
  指数位: 5 bits (bias=15)
  尾数位: 2 bits
  范围: [-57344, 57344]
  精度: ~2位有效数字
  特殊值: 有NaN和±inf

KV Cache量化通常使用E4M3:
  更高精度(3位尾数 vs 2位)
  范围够用(KV值通常在[-100, 100])
  无NaN → 更稳定

W.2 量化方案对比

1. Per-tensor量化:
   scale = max(|KV|) / 448
   quantized = round(KV / scale)
   精度: 低(所有元素共享1个scale)
   大小: 1 × scale + N × fp8

2. Per-head量化:
   scale = max(|KV_head|) / 448  (每头1个scale)
   quantized = round(KV / scale)
   精度: 中(每头独立scale)
   大小: num_heads × scale + N × fp8

3. Per-token量化 (TurboQuant):
   scale = max(|KV_token|) / 448  (每token 1个scale)
   quantized = round(KV / scale)
   精度: 高(每token独立scale)
   大小: num_tokens × scale + N × fp8
   
   代价: scale存储开销大
   但: scale可以存为FP16(2 bytes/token)
   总大小: N × 1byte + N × 2bytes = 3N bytes
   vs FP16: N × 2bytes × 2(K+V) = 4N bytes
   节省: 25%(但精度损失可能显著)

附录X 注意力层类型与KV Cache行为矩阵

层类型 使用KV Cache Causal Mask Prefill/Decode 典型位置
DECODER_SELF ✅ 写入+读取 ✅ Causal 分离 Decoder每层
ENCODER_SELF ❌ 不使用 ❌ Bidirectional 仅Prefill Encoder每层
ENCODER_DECODER_CROSS ✅ 只读取 ❌ No causal Encoder侧只Prefill Decoder的cross-attn
DECODER_CROSS ✅ 写入+读取 ✅ Causal 分离 罕见
DECODER_PREFILL_ONLY ✅ 只写入 ✅ Causal 仅Prefill 特殊用途
Cross-Attention的KV Cache行为:

Encoder → K/V → 存入KV Cache (只写入一次)
Decoder → Q → 读取KV Cache → 计算注意力

Encoder的KV Cache特点:
  - 只在encoder prefill时写入
  - 后续所有decoder步骤都读取相同的KV
  - 不需要PagedAttention(大小固定)
  - 不需要block_table(整个cache一次性读取)

这就是AttentionType.ENCODER_DECODER的处理逻辑:
  encoder阶段: 计算并缓存KV
  decoder阶段: 从cache读取KV,不做写入

附录Y Sliding Window注意力实现策略

Y.1 三种实现方式

方式1: FlashInfer sliding_window参数
  decode_wrapper.plan(..., sliding_window=window_size)
  # FlashInfer内置滑动窗口支持
  # 只读取最近window_size个KV

方式2: SparseIndexer (MLA Sparse)
  # 先计算稀疏索引
  # 只读取需要的KV块
  # 支持sink tokens

方式3: FlexAttention BlockMask
  # 定义score_modify函数
  # 编译为高效的mask kernel
  # 最灵活但可能不是最快

Y.2 滑动窗口的内存节省

模型: Mistral-7B, window_size=4096, 总上下文=32768

不使用滑动窗口:
  KV Cache: 2 × 8 × 128 × 32768 × 2 = 1GB per request
  256并发: 256GB → OOM

使用滑动窗口:
  KV Cache: 2 × 8 × 128 × 4096 × 2 = 128MB per request
  256并发: 32GB → 可行

节省: 8× 内存

注意: 超出窗口的KV cache块可以被回收
  → 调度器可以释放旧块,分配给新请求
  → 这就是PagedAttention + Sliding Window的协同优势

附录Z AttentionMetadataBuilder 接口设计模式

Z.1 Builder模式的必要性

为什么使用Builder模式构建Metadata?

1. 元数据构建复杂:
   - 需要CPU→GPU的数据传输
   - 需要预分配和复用tensor
   - 需要调用wrapper.plan()
   
2. 性能关键路径:
   - build()在每次forward前调用
   - 必须尽可能快
   - 预分配tensor避免动态分配

3. 后端定制:
   - 不同后端需要不同的索引格式
   - FlashInfer: paged_kv_indices/indptr
   - FlashAttention: cu_seqlens/block_tables
   - Triton: 最小化元数据

Builder接口:
  class AttentionMetadataBuilder(ABC):
      @abstractmethod
      def build(self, input_ids, seq_lens, ...) -> AttentionMetadata:
          pass
      
      @abstractmethod  
      def use_causal_mask(self) -> bool:
          pass
      
      def reorder_batch(self, ...) -> None:
          """重排批次(用于speculative decoding)"""
          pass

Z.2 预分配策略

预分配的tensor列表(FlashInferMetadataBuilder):

1. paged_kv_indices: [max_num_pages] int32
   - max_num_pages = max_batch_size × max_blocks_per_seq
   - 约: 256 × 512 = 131K elements × 4bytes = 512KB

2. paged_kv_indptr: [max_batch_size + 1] int32
   - 256 + 1 × 4 = ~1KB

3. paged_kv_last_page_len: [max_batch_size] int32
   - 256 × 4 = ~1KB

4. slot_mapping: [max_num_tokens] int64
   - max_num_tokens = max_batch_size × max_seq_len
   - 256 × 8192 × 8 = 16MB

5. seq_lens_tensor: [max_batch_size] int32
   - ~1KB

6. block_tables: [max_batch_size, max_blocks_per_seq] int32
   - 256 × 512 × 4 = 512KB

总预分配: 约17MB (可接受)

注意: 这些tensor在整个推理过程中复用
  不会每次forward都重新分配
  这对CUDA Graph模式至关重要

附录AA Attention层初始化的懒加载模式

Attention.__init__() → 不创建Impl

原因:
1. 避免在模型加载时初始化CUDA上下文
2. 某些后端可能不可用(延迟检测)
3. 减少初始化时间

懒加载时机:
  第一次调用Attention.forward()时:
  
  if not hasattr(self, '_impl'):
      self._impl = self.impl_cls(
          num_heads=self.num_heads,
          head_size=self.head_size,
          ...
      )

优点:
- 模型可以加载到CPU(不触发CUDA)
- 后端选择可以推迟到运行时
- 节省初始化时间(特别是大模型)

缺点:
- 第一次forward有额外延迟
- 错误(如后端不可用)延迟到运行时才发现

附录AB ALiBi位置编码在注意力中的实现

ALiBi (Attention with Linear Biases):

核心思想:
  不使用位置编码,而是在注意力分数上加线性偏置
  scores[i, j] = Q[i] × K[j]^T / √d + m × (j - i)
  
  m: 每个头的斜率,不同的头使用不同的斜率
  m_h = 2^(-8h/N), h=1,...,N (N=头数)

实现:
  1. precompute alibi_slopes:
     slopes = [2^(-8/N), 2^(-16/N), ..., 2^(-8)]
     # 每个头1个斜率
     # 头0的斜率最小(2^(-8))
     # 头N-1的斜率最大(2^(-8/N))
  
  2. 在AttentionImpl中使用:
     if self.alibi_slopes is not None:
         # 计算相对位置偏置
         # position_offset = [0, -1, -2, ..., -(seq_len-1)]
         # alibi_bias = slopes × position_offset
         # scores = scores + alibi_bias
     
  3. FlashInfer的ALiBi支持:
     # FlashInfer内置ALiBi,通过pos_encoding_mode参数
     wrapper.plan(..., pos_encoding_mode="ALIBI")
     # 无需手动计算偏置
  
  4. FlashAttention的ALiBi支持:
     # flash_attn不支持内置ALiBi
     # 需要在外部计算偏置并加到Q/K上
     # 或使用自定义kernel

ALiBi的优势:
  - 无需位置编码(减少参数)
  - 支持任意长度外推
  - 长序列泛化能力强

ALiBi的劣势:
  - 短序列性能略差于RoPE
  - 不适合双向注意力
  - 与Flash风格的kernel不完全兼容

附录AC FlashInfer的统一Prefill/Decode Wrapper

FlashInfer V1引入了统一Wrapper:

传统方式(V0):
  prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper()
  decode_wrapper = BatchDecodeWithPagedKVCacheWrapper()
  
  if num_prefills > 0:
      prefill_wrapper.plan(...)
      output[:num_prefill_tokens] = prefill_wrapper.run(...)
  
  if num_decode_tokens > 0:
      decode_wrapper.plan(...)
      output[num_prefill_tokens:] = decode_wrapper.run(...)

统一方式(V1):
  unified_wrapper = BatchPrefillWithPagedKVCacheWrapper()
  
  # 混合batch: prefill和decode在同一个wrapper中
  unified_wrapper.plan(...)
  output = unified_wrapper.run(query, ...)
  
  # wrapper内部区分prefill和decode:
  # - query_len > 1 → prefill
  # - query_len == 1 → decode
  # - 使用不同的kernel路径

统一的优点:
1. 单次plan()调用
2. 单次run()调用
3. 减少kernel launch开销
4. 简化调用逻辑
5. 更好的GPU利用率

统一的缺点:
1. prefill和decode不能分别调优
2. CUDA Graph兼容性更复杂
3. 某些后端可能不支持

附录AD vLLM Attention 后端发展路线图

历史演进:

v0.x:
  - 单一后端: PagedAttention (custom CUDA kernel)
  - 无MLA支持
  - 无prefix cache
  - CPU fallback: 逐序列PyTorch SDPA

v1.0 (2024 Q1):
  - 引入AttentionBackend抽象
  - FlashInfer成为默认后端
  - 支持FlashAttention v2
  - Triton回退后端

v1.1 (2024 Q2):
  - MLA后端支持(DeepSeek-V2)
  - FlashMLA后端
  - CUDA Graph兼容
  - Chunked Prefill

v1.2 (2024 Q3):
  - FlashAttention v3支持
  - TurboQuant FP8量化
  - DCP分布式上下文并行
  - Prefix Cache优化

v1.3 (2024 Q4):
  - FlashMLASparse稀疏注意力
  - ROCm/aiter后端
  - Mamba2后端
  - XPU平台支持

v1.4 (2025 Q1):
  - DeepSeek-V4融合算子
  - CUTLASS MLA后端
  - TRT-LLM MLA prefill
  - FlexAttention后端

未来方向:
  - 更好的混合batch支持
  - 更高效的稀疏注意力
  - 量化感知的注意力kernel
  - 多模态统一注意力
Logo

免费领 100 小时云算力,进群参与显卡、AI PC 幸运抽奖

更多推荐