vLLM V1 Attention 模块超深度架构分析 — Part 3: MLA后端体系

分析范围: v1/attention/backends/mla/ 目录全部源码(21个文件,约6,200行)


目录


第十四章 MLA架构原理与总体设计

14.1 Multi-head Latent Attention原理

DeepSeek-V2/V3引入的MLA(Multi-head Latent Attention) 是一种KV压缩注意力机制:

核心思想:将多头的Key和Value投影到一个低维latent空间,大幅减少KV Cache的内存占用。

标准MHA:
  K = W_K × H     → [seq_len, num_kv_heads, head_size]
  V = W_V × H     → [seq_len, num_kv_heads, head_size]
  KV Cache: 2 × num_kv_heads × head_size per token

MLA:
  C = W_DKV × H   → [seq_len, kv_lora_rank]  # 压缩的latent
  K = W_UK × C    → [seq_len, num_kv_heads, head_size]  # 解压的K
  V = W_UV × C    → [seq_len, num_kv_heads, head_size]  # 解压的V
  KV Cache: kv_lora_rank per token  (远小于 2 × num_kv_heads × head_size)

数学推导

标准: Cache大小 = 2 × n_kv × d × seq_len × batch × 2bytes
MLA:  Cache大小 = d_lora × seq_len × batch × 2bytes

压缩比: 2 × n_kv × d / d_lora
例如: n_kv=64, d=128, d_lora=512 → 压缩比 = 2×64×128/512 = 32x

14.2 MLA与标准MHA的对比

MLA

标准MHA

Hidden State H

K = W_K × H
[seq, n_kv, d]

V = W_V × H
[seq, n_kv, d]

KV Cache
2 × n_kv × d

Hidden State H

C = W_DKV × H
[seq, d_lora] ★压缩★

KV Cache
d_lora ★极小★

K = W_UK × C
[seq, n_kv, d]

V = W_UV × C
[seq, n_kv, d]

Q × K^T → softmax → × V

Q × K^T → softmax → × V

14.3 MLA后端总体架构

MLA核心组件

FlashMLA变体

MLA后端选择

Yes

Yes

No

Yes

No

No

Yes

No

Yes

MLA后端选择

CUDA平台?

FlashMLA可用?

FlashMLABackend

FlashInfer?

FlashInferMLABackend

FlashAttnMLABackend

ROCm?

ROCmAITerMLABackend

XPU?

XPUMLASparseBackend

FlashMLABackend
(Dense)

FlashMLASparseBackend
(Sliding Window)

Indexer
稀疏索引计算

Compressor
KV压缩/解压

Prefill Selector
预填充后端选择


第十五章 FlashMLABackend核心实现

15.1 FlashMLABackend类结构

class FlashMLABackend(AttentionBackend):
    """FlashMLA后端
    
    使用FlashMLA库实现DeepSeek MLA
    特点:
    - KV Cache存储压缩的latent向量
    - Prefill时解压K/V
    - Decode时使用latent直接计算
    - 支持CUDA Graph
    """
    
    @staticmethod
    def get_name() -> str:
        return "FLASHMLA"
    
    @staticmethod
    def get_impl_cls() -> type:
        return FlashMLAImpl
    
    @classmethod
    def get_metadata_cls(cls) -> type:
        return FlashMLAMetadata
    
    @classmethod
    def get_builder_cls(cls) -> type:
        return FlashMLAMetadataBuilder
    
    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,  # MLA中不使用,但接口要求
        head_size: int,     # MLA中不使用
    ) -> tuple[int, ...]:
        # MLA KV Cache只存储压缩的latent向量
        # 形状: [num_blocks, block_size, kv_lora_rank]
        # 不再是 [2, num_blocks, block_size, num_kv_heads, head_size]
        # 这是因为MLA将K和V压缩到同一个latent空间
        kv_lora_rank = ...  # 从模型配置获取
        return (num_blocks, block_size, kv_lora_rank)
    
    @staticmethod
    def get_supported_head_sizes() -> list[int]:
        # MLA的head_size由kv_lora_rank决定
        # 不再是传统的[32, 64, 128, 256]
        return [512]  # DeepSeek-V2/V3的kv_lora_rank

15.2 FlashMLAMetadata元数据

@dataclass
class FlashMLAMetadata(AttentionMetadata):
    """FlashMLA专用元数据"""
    
    # ---- MLA特有的压缩索引 ----
    # Prefill时需要将latent解压为K和V
    # 但不需要为每个token单独存储K和V
    
    # ---- Q的RoPE处理 ----
    # MLA中Q仍使用RoPE,但K不使用(K在latent空间中)
    # 因此需要单独处理Q的RoPE
    q_rope_inv: torch.Tensor | None = None     # Q的RoPE逆频率
    q_rope_cos: torch.Tensor | None = None     # Q的RoPE余弦
    
    # ---- Prefill信息 ----
    # MLA prefill使用选定的prefill后端
    prefill_backend_name: str | None = None    # "flash_attn" / "flashinfer" / "trtllm"
    
    # ---- 解压参数 ----
    # 从latent解压K和V的投影矩阵
    w_uk: torch.Tensor | None = None  # [num_kv_heads * head_size, kv_lora_rank]
    w_uv: torch.Tensor | None = None  # [num_kv_heads * head_size, kv_lora_rank]

15.3 FlashMLAImpl实现

class FlashMLAImpl(AttentionImpl):
    """FlashMLA注意力执行层"""
    
    def forward(self, query, key, value, kv_cache, attn_metadata):
        num_tokens = query.shape[0]
        
        # MLA的forward流程与标准MHA不同:
        # 1. Q经过RoPE处理(标准处理)
        # 2. K/V不需要单独存储,只存储latent C
        # 3. Prefill时: 解压C→K,V → 标准注意力计算
        # 4. Decode时: 使用latent直接计算(flash_mla专用kernel)
        
        # 1. 处理Query
        query = query.view(num_tokens, self.num_heads, self.head_size)
        # Q的RoPE在外层已处理
        
        # 2. 写入latent到KV cache
        # key实际上已经是压缩的latent C(经过W_DKV投影)
        latent = key  # [num_tokens, kv_lora_rank]
        self._write_latent_cache(latent, kv_cache, attn_metadata.slot_mapping)
        
        # 3. 执行注意力
        if attn_metadata.num_prefills > 0:
            output = self._run_prefill(query, kv_cache, attn_metadata)
        else:
            output = self._run_decode(query, kv_cache, attn_metadata)
        
        return output.view(num_tokens, -1)
    
    def _run_prefill(self, query, kv_cache, attn_metadata):
        """MLA Prefill: 解压latent → 标准注意力"""
        
        # 从KV cache读取latent
        latent_cache = kv_cache  # [num_blocks, block_size, kv_lora_rank]
        
        # 解压: C → K, V
        # K = W_UK × C, V = W_UV × C
        key = torch.matmul(latent_for_prefill, self.w_uk.T)
        value = torch.matmul(latent_for_prefill, self.w_uv.T)
        
        # 使用选定的prefill后端执行标准注意力
        # (flash_attn / flashinfer / trtllm_ragged)
        output = self.prefill_impl.forward(
            query, key, value, ...
        )
        return output
    
    def _run_decode(self, query, kv_cache, attn_metadata):
        """MLA Decode: 使用FlashMLA专用kernel"""
        
        # FlashMLA decode kernel直接在latent空间计算注意力
        # 不需要解压K/V,减少计算量
        from flash_mla import flash_mla_with_kvcache
        
        output = flash_mla_with_kvcache(
            q=query,                          # [batch, 1, num_heads, head_size]
            k_cache=kv_cache,                 # latent cache
            v_cache=kv_cache,                 # 同一个latent cache
            cache_seqlens=attn_metadata.seq_lens_tensor,
            block_table=attn_metadata.block_tables,
            head_dim_v=self.kv_lora_rank,     # V的维度是kv_lora_rank
            head_dim_k=self.head_size,        # K的维度是head_size(RoPE后)
            ...
        )
        return output

15.4 KV Cache压缩存储

MLA KV Cache (per token)

标准KV Cache (per token)

K: n_kv × d
= 64 × 128 = 8192 floats
= 32KB (fp16)

V: n_kv × d
= 64 × 128 = 8192 floats
= 32KB (fp16)

总计: 64KB/token

C: d_lora
= 512 floats
= 1KB (fp16)

总计: 1KB/token

压缩比: 64x


第十六章 FlashMLASparseBackend稀疏注意力

16.1 稀疏滑动窗口动机

DeepSeek-V3使用稀疏注意力:每个token只关注最近的W个token(滑动窗口)+ 少数sink tokens。

class FlashMLASparseBackend(FlashMLABackend):
    """FlashMLA稀疏注意力后端
    
    特点:
    - 滑动窗口注意力
    - 使用SparseIndexer计算稀疏索引
    - 减少不必要的KV读取
    """
    
    @staticmethod
    def get_name() -> str:
        return "FLASHMLA_SPARSE"
    
    @classmethod
    def get_metadata_cls(cls) -> type:
        return FlashMLASparseMetadata
    
    @classmethod
    def get_builder_cls(cls) -> type:
        return FlashMLASparseMetadataBuilder

16.2 SparseIndexer索引器

class SparseIndexer:
    """稀疏注意力索引计算器
    
    计算每个decode token需要关注哪些KV位置
    
    滑动窗口规则:
    - 每个token关注最近的window_size个token
    - 加上开头的sink_size个token
    
    索引输出:
    - kv_indices: [batch, max_num_kv] 需要读取的KV索引
    - kv_indptr: [batch+1] CSR指针
    - num_kv: [batch] 每序列需要的KV数量
    """
    
    def __init__(
        self,
        window_size: int,       # 滑动窗口大小
        sink_size: int = 0,     # 开头sink token数
        block_size: int = 16,   # KV cache块大小
    ):
        self.window_size = window_size
        self.sink_size = sink_size
        self.block_size = block_size
    
    def compute_index(
        self,
        seq_lens: torch.Tensor,     # [batch] 序列长度
        block_tables: torch.Tensor, # [batch, max_blocks] 块表
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """计算稀疏索引"""
        
        batch_size = seq_lens.shape[0]
        
        # 对每个序列,计算需要关注的KV位置
        kv_indices_list = []
        kv_indptr = [0]
        
        for i in range(batch_size):
            seq_len = seq_lens[i].item()
            
            # 滑动窗口: 最近window_size个token
            window_start = max(0, seq_len - self.window_size)
            
            # Sink tokens: 开头的sink_size个token
            sink_end = min(self.sink_size, seq_len)
            
            # 合并: sink + window(去重)
            if window_start < sink_end:
                # 窗口和sink重叠 → 直接取[0, seq_len]
                indices = list(range(seq_len))
            else:
                # 不重叠 → sink + gap + window
                indices = list(range(sink_end)) + list(range(window_start, seq_len))
            
            # 映射到物理KV cache位置
            for pos in indices:
                block_id = block_tables[i, pos // self.block_size]
                offset = pos % self.block_size
                kv_indices_list.append(block_id * self.block_size + offset)
            
            kv_indptr.append(len(kv_indices_list))
        
        # 转换为张量
        kv_indices = torch.tensor(kv_indices_list, dtype=torch.int32)
        kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
        num_kv = torch.diff(kv_indptr)
        
        return kv_indices, kv_indptr, num_kv

稀疏索引示意

seq_len = 100, window_size = 32, sink_size = 4

关注的KV位置:
  Sink:    [0, 1, 2, 3]                → 4个token
  Gap:     [4, 5, ..., 67]             → 不关注(64个token跳过)
  Window:  [68, 69, ..., 99]           → 32个token
  总计:    36个token(而非100个)

内存节省: 64%  (36/100)
计算节省: 64%  (Q×K^T从100次减为36次)

16.3 稀疏注意力执行流程

Yes

No: Prefill

FlashMLASparseImpl.forward()

Decode模式?

计算稀疏索引
SparseIndexer.compute_index()

只读取稀疏位置的KV
kv_indices[kv_indptr[i]:kv_indptr[i+1]]

flash_mla_with_sparse_kv()
稀疏注意力计算

读取完整KV
(prefill总是dense)

标准prefill注意力

返回output


第十七章 MLA Indexer深度解析

17.1 设计目的

indexer.py(776行)是MLA后端中最复杂的组件之一,负责:

  1. Prefill索引:确定每个prefill token的KV cache写入位置
  2. Decode索引:确定每个decode token需要读取的KV位置
  3. 稀疏索引:在滑动窗口模式下计算稀疏KV索引
  4. CUDA Graph兼容:索引计算需要支持CUDA Graph录制

17.2 索引计算流程

class MLAIndexer:
    """MLA索引计算器
    
    核心方法:
    - compute_prefill_index(): 计算prefill的KV cache索引
    - compute_decode_index(): 计算decode的KV cache索引  
    - compute_sparse_decode_index(): 计算稀疏decode的KV cache索引
    """
    
    def compute_decode_index(
        self,
        seq_lens: torch.Tensor,         # [batch] 当前序列长度
        block_tables: torch.Tensor,      # [batch, max_blocks] 块表
        kv_cache_dtype: str,
    ) -> MLADecodeIndexResult:
        """计算decode索引
        
        返回:
          slot_mapping: [batch] 每序列新token的写入slot
          block_table:  [batch, max_blocks] 块表
          seq_lens:     [batch] 序列长度
        """
        batch_size = seq_lens.shape[0]
        
        # 每个decode token的写入位置 = 序列末尾
        slot_mapping = torch.empty(batch_size, dtype=torch.int64)
        
        for i in range(batch_size):
            seq_len = seq_lens[i].item()
            # 新token的位置 = seq_len(0-indexed)
            pos = seq_len  # 因为当前token还未写入
            
            block_id = block_tables[i, pos // self.block_size]
            offset = pos % self.block_size
            slot_mapping[i] = block_id * self.block_size + offset
        
        return MLADecodeIndexResult(
            slot_mapping=slot_mapping,
            block_tables=block_tables,
            seq_lens=seq_lens,
        )

17.3 索引数据结构

@dataclass
class MLAPrefillIndexResult:
    """Prefill索引结果"""
    slot_mapping: torch.Tensor         # [num_prefill_tokens] 写入slot
    cu_seqlens: torch.Tensor           # [num_prefills + 1] 累积长度
    max_seq_len: int                   # 最长prefill序列

@dataclass  
class MLADecodeIndexResult:
    """Decode索引结果"""
    slot_mapping: torch.Tensor         # [num_decode] 写入slot
    block_tables: torch.Tensor         # [batch, max_blocks]
    seq_lens: torch.Tensor             # [batch]
    # 稀疏模式额外字段:
    sparse_kv_indices: torch.Tensor | None = None  # 稀疏KV索引
    sparse_kv_indptr: torch.Tensor | None = None   # 稀疏KV指针
    sparse_num_kv: torch.Tensor | None = None       # 每序列KV数

第十八章 MLA Prefill子系统

18.1 Prefill后端选择

MLA在Prefill阶段需要解压latent为K/V,然后执行标准注意力。prefill可以使用不同的后端:

Prefill后端 特点 适用场景
FlashAttention 高效、稳定 默认选择
FlashInfer 支持更多功能 需要变长batch时
TRT-LLM Ragged NVIDIA优化 H100等新硬件

18.2 PrefillRegistry注册表

class MLAPrefillRegistry:
    """MLA Prefill后端注册表"""
    
    _registry: dict[str, type[MLAPrefillBackend]] = {}
    
    @classmethod
    def register(cls, name: str):
        """装饰器: 注册prefill后端"""
        def decorator(backend_cls):
            cls._registry[name] = backend_cls
            return backend_cls
        return decorator
    
    @classmethod
    def get_backend(cls, name: str) -> type[MLAPrefillBackend]:
        return cls._registry[name]

# 注册后端
@MLAPrefillRegistry.register("flash_attn")
class FlashAttnMLAPrefill(MLAPrefillBackend): ...

@MLAPrefillRegistry.register("flashinfer")  
class FlashInferMLAPrefill(MLAPrefillBackend): ...

@MLAPrefillRegistry.register("trtllm_ragged")
class TRTLLMRaggedMLAPrefill(MLAPrefillBackend): ...

18.3 PrefillSelector选择器

class MLAPrefillSelector:
    """MLA Prefill后端选择器
    
    根据硬件和库可用性选择最优prefill后端
    """
    
    def select(self, vllm_config) -> str:
        # 优先级: flash_attn > flashinfer > trtllm_ragged
        if is_flash_attn_available():
            return "flash_attn"
        if is_flashinfer_available():
            return "flashinfer"
        if is_trtllm_available():
            return "trtllm_ragged"
        raise RuntimeError("No MLA prefill backend available")

第十九章 其他MLA后端

19.1 FlashInferMLABackend

class FlashInferMLABackend(AttentionBackend):
    """FlashInfer MLA后端
    
    使用FlashInfer库实现MLA
    适用于FlashMLA不可用的场景
    """
    
    @staticmethod
    def get_name() -> str:
        return "FLASHINFER_MLA"
    
    # 与FlashMLABackend的主要区别:
    # - 使用FlashInfer的Paged KV Cache API
    # - 需要在decode时解压latent为K/V
    # - 性能略低于FlashMLA(因为需要解压)

19.2 FlashAttnMLABackend

class FlashAttnMLABackend(AttentionBackend):
    """FlashAttention MLA后端
    
    使用FlashAttention库实现MLA
    适用于FlashMLA和FlashInfer都不可用的场景
    """
    
    @staticmethod
    def get_name() -> str:
        return "FLASHATTN_MLA"

19.3 CUTLASSMLABackend

class CUTLASSMLABackend(AttentionBackend):
    """CUTLASS MLA后端
    
    使用NVIDIA CUTLASS库实现MLA
    专注于decode阶段的优化
    """
    
    @staticmethod
    def get_name() -> str:
        return "CUTLASS_MLA"

19.4 TritonMLABackend

class TritonMLABackend(AttentionBackend):
    """Triton MLA后端
    
    使用自定义Triton kernel实现MLA
    适用于无专用MLA库的场景
    """
    
    @staticmethod
    def get_name() -> str:
        return "TRITON_MLA"

19.5 ROCm/XPU MLA后端

# ROCm平台
class ROCmAITerMLABackend(AttentionBackend):
    """ROCm aiter MLA后端"""
    @staticmethod
    def get_name() -> str:
        return "ROCM_AITER_MLA"

class ROCmAITerMLASparseBackend(AttentionBackend):
    """ROCm aiter MLA稀疏后端"""
    @staticmethod
    def get_name() -> str:
        return "ROCM_AITER_MLA_SPARSE"

# XPU平台
class XPUMLASparseBackend(AttentionBackend):
    """XPU MLA稀疏后端"""
    @staticmethod
    def get_name() -> str:
        return "XPU_MLA_SPARSE"

附录F MLA KV Cache内存节省计算

典型DeepSeek-V3配置:
  num_kv_heads = 64
  head_size = 128
  kv_lora_rank = 512

标准KV Cache (per token, fp16):
  K: 64 × 128 × 2 = 16,384 bytes
  V: 64 × 128 × 2 = 16,384 bytes
  总计: 32,768 bytes = 32KB

MLA KV Cache (per token, fp16):
  C: 512 × 2 = 1,024 bytes = 1KB

压缩比: 32KB / 1KB = 32×

对于4K上下文、256并发:
  标准: 32KB × 4096 × 256 = 32GB
  MLA:  1KB × 4096 × 256 = 1GB
  
  节省: 31GB GPU内存!

附录G MLA后端选择决策树

CUDA

Yes

Yes

No

Yes

No

No: Dense

Yes

No

Yes

No

Yes

No

Yes

No

ROCm

Yes

No

XPU

MLA后端选择

硬件平台?

需要稀疏注意力
(滑动窗口)?

FlashMLA_Sparse
可用?

FlashMLASparseBackend
★推荐★

FlashInfer_MLA_Sparse?

FlashInferMLASparseBackend

FlashAttnMLA + 稀疏逻辑

FlashMLA可用?

FlashMLABackend
★推荐★

FlashInfer?

FlashInferMLABackend

FlashAttention?

FlashAttnMLABackend

CUTLASS?

CUTLASSMLABackend

TritonMLABackend
★最终回退★

aiter MLA?

ROCmAITerMLABackend

ROCm Triton MLA

XPUMLASparseBackend


附录U FlashMLAImpl.forward() 完整流程追踪

U.1 Decode路径详解

FlashMLAImpl._run_decode(query, kv_cache, attn_metadata):

Step 1: 准备查询
  query = query.view(num_decode, num_heads, head_size)
  # [batch, heads, dim]

Step 2: 准备KV cache
  # MLA的KV cache只存储latent向量
  # 形状: [num_blocks, block_size, kv_lora_rank]
  # 不需要分别读取K和V

Step 3: 调用flash_mla_with_kvcache
  output = flash_mla_with_kvcache(
    q=query,                          # [batch, 1, heads, dim]
    k_cache=kv_cache,                 # latent cache
    cache_seqlens=attn_metadata.seq_lens,
    block_table=attn_metadata.block_tables,
    head_dim_v=kv_lora_rank,          # V维度=压缩维度
    head_dim_k=head_size,             # K维度=原始维度
    softmax_scale=scale,
  )

Step 4: 处理输出
  output = output.squeeze(1)  # [batch, heads, dim]
  output = output.view(num_decode, -1)  # [batch, heads * dim]

U.2 Prefill路径详解

FlashMLAImpl._run_prefill(query, kv_cache, attn_metadata):

Step 1: 从latent cache解压K和V
  # 读取prefill序列的所有latent
  latent = read_from_cache(kv_cache, attn_metadata)
  # latent: [num_prefill_tokens, kv_lora_rank]
  
  # 解压K: K = latent × W_UK^T
  key = torch.matmul(latent, w_uk)
  # key: [num_prefill_tokens, num_kv_heads * head_size]
  
  # 解压V: V = latent × W_UV^T
  value = torch.matmul(latent, w_uv)
  # value: [num_prefill_tokens, num_kv_heads * head_size]

Step 2: 选择prefill后端
  # 根据MLAPrefillSelector的选择
  # 可能是: flash_attn, flashinfer, 或 trtllm_ragged
  
  if prefill_backend == "flash_attn":
    output = flash_attn_varlen_func(
      q=query, k=key, v=value,
      cu_seqlens_q=cu_seqlens,
      cu_seqlens_k=cu_seqlens,
      max_seqlen_q=max_seq_len,
      max_seqlen_k=max_seq_len,
      softmax_scale=scale,
      causal=True,
    )
  elif prefill_backend == "flashinfer":
    output = flashinfer_prefill(query, key, value, ...)
  elif prefill_backend == "trtllm_ragged":
    output = trtllm_ragged_prefill(query, key, value, ...)

U.3 Q的RoPE处理(MLA特殊)

MLA中的RoPE处理与标准MHA不同:

标准MHA:
  Q = apply_rope(Q, position)
  K = apply_rope(K, position)
  → Q和K都使用RoPE

MLA:
  Q = apply_rope(Q, position)         # Q使用RoPE
  K = NO_ROPE(K)                      # K不使用RoPE!
  C = compress(K, V) → latent cache   # 压缩存储
  
  为什么K不用RoPE?
  因为RoPE会破坏K的低秩结构
  MLA依赖K的低秩结构实现压缩
  如果K经过RoPE,就不能用W_DKV压缩到latent空间
  
  解决方案: 将K分为两部分
  K = K_nope + K_rope
  K_nope: 不使用RoPE的部分 → 可以压缩
  K_rope: 使用RoPE的部分 → 单独存储(很小)
  
  在decode时:
  从latent解压出K_nope
  从单独的cache读取K_rope
  合并为完整的K: K = K_nope + K_rope

附录V SparseIndexer 稀疏索引算法深度分析

V.1 滑动窗口+Sink的索引计算

参数:
  window_size = 4096   # 滑动窗口大小
  sink_size = 4        # 开头保留的sink token数
  block_size = 16      # KV cache块大小

场景: seq_len = 10000

Step 1: 确定关注范围
  window_start = max(0, 10000 - 4096) = 5904
  sink_end = min(4, 10000) = 4

  关注范围: [0, 3] ∪ [5904, 9999]
  跳过范围: [4, 5903]  → 5900个token不需读取

Step 2: 计算KV块覆盖
  关注的token范围 → 需要读取的KV块
  
  Sink部分: [0, 3]
    → 块0: [0, 15] (只读取前4个token)
  
  Window部分: [5904, 9999]
    → 块369: [5904, 5919]
    → 块370: [5920, 5935]
    → ...
    → 块624: [9984, 9999]
  
  总块数: 1 (sink) + 256 (window) = 257 块
  不读取的块: 625 - 257 = 368 块 (59%跳过)

Step 3: 构建稀疏索引
  paged_kv_indices = [0, 369, 370, ..., 624]  # 257个块ID
  paged_kv_indptr = [0, 257]                   # 1个请求
  paged_kv_last_page_len = [10000 % 16] = [0]  # 最后一页满
  # 特殊处理: 当last_page_len=0时,使用block_size

Step 4: 执行稀疏注意力
  只对257个块执行注意力计算
  跳过368个块 → 减少59%的KV读取和计算

V.2 多序列批次的索引

batch_size = 3
seq_lens = [10000, 5000, 1000]
window_size = 4096, sink_size = 4

请求0: seq_len=10000
  关注: [0,3] ∪ [5904,9999] → 257 blocks
  
请求1: seq_len=5000
  window_start = max(0, 5000-4096) = 904
  关注: [0,3] ∪ [904,4999] → 1 + 256 = 257 blocks
  
请求2: seq_len=1000
  window_start = max(0, 1000-4096) = 0
  关注: [0,999] → 全部(因为seq_len < window_size)
  blocks: ceil(1000/16) = 63 blocks

paged_kv_indices = [
  0, 369, 370, ..., 624,     # 请求0: 257 blocks
  1000, 1369, ..., 1624,     # 请求1: 257 blocks  
  2000, 2001, ..., 2062,     # 请求2: 63 blocks
]
# 总计: 577 blocks

paged_kv_indptr = [0, 257, 514, 577]
# 请求0: indices[0:257]
# 请求1: indices[257:514]
# 请求2: indices[514:577]

附录W MLA Prefill后端深度对比

W.1 FlashAttnMLAPrefill

特点:
  - 使用flash_attn_varlen_func
  - 最稳定的实现
  - 支持所有head_size
  - 不需要额外库
  
限制:
  - 需要先解压latent→K,V
  - 解压计算量: num_tokens × kv_lora_rank × (num_kv_heads × head_size)
  - 内存峰值: 需要存储完整K和V
  
适用: 小到中等序列长度

W.2 FlashInferMLAPrefill

特点:
  - 使用FlashInfer的prefill API
  - 支持Paged KV Cache
  - 更好的batch利用率
  
限制:
  - 需要FlashInfer库
  - 某些head_size可能不支持
  
适用: 混合batch(prefill+decode同时)

W.3 TRTLLMRaggedMLAPrefill

特点:
  - 使用TensorRT-LLM的ragged tensor API
  - NVIDIA H100优化
  - 最高性能(H100上)
  
限制:
  - 需要TRT-LLM库
  - 仅支持NVIDIA GPU
  - 配置复杂
  
适用: H100/A100等NVIDIA高端GPU

W.4 性能对比

Prefill后端 A100性能 H100性能 内存峰值 兼容性
FlashAttn ★★★★ ★★★★ 高(需解压) ★★★★★
FlashInfer ★★★★ ★★★★ ★★★
TRT-LLM ★★★ ★★★★★ ★★

附录X2 FlashMLASparseMetadata 构建流程

X2.1 稀疏索引与dense索引的构建差异

FlashMLASparseMetadataBuilder.build():

1. 构建dense索引(与FlashMLA相同)
   - slot_mapping: token→物理slot映射
   - block_tables: 序列→块映射
   - paged_kv_indices/indptr/last_page_len

2. 额外构建稀疏索引
   - sparse_kv_indices: 需要读取的稀疏KV位置
   - sparse_kv_indptr: CSR指针
   - sparse_num_kv: 每序列需要的KV数量
   
   使用SparseIndexer计算:
     indexer = SparseIndexer(
       window_size=sliding_window,
       sink_size=sink_size,
       block_size=block_size,
     )
     sparse_indices = indexer.compute_index(
       seq_lens=decode_seq_lens,
       block_tables=block_tables,
     )

3. 构建complete metadata
   return FlashMLASparseMetadata(
     # 标准字段
     ...,
     # 稀疏字段
     sparse_kv_indices=sparse_indices.kv_indices,
     sparse_kv_indptr=sparse_indices.kv_indptr,
     sparse_num_kv=sparse_indices.num_kv,
   )

X2.2 稀疏decode的注意力计算

FlashMLASparseImpl._run_decode():

1. 读取稀疏KV
   # 不是读取所有KV块
   # 只读取sparse_kv_indices指定的块
   
   for i in range(batch_size):
     start = sparse_kv_indptr[i]
     end = sparse_kv_indptr[i+1]
     # 请求i的KV块 = kv_cache[sparse_kv_indices[start:end]]
   
2. 执行稀疏注意力
   # 使用flash_mla的稀疏模式
   output = flash_mla_with_sparse_kv(
     q=query,
     kv_cache=kv_cache,
     sparse_indices=sparse_kv_indices,
     sparse_indptr=sparse_kv_indptr,
     num_kv=sparse_num_kv,
     ...
   )
   
   # 内部只读取指定的KV块
   # 跳过不需要的块
   # 减少内存带宽和计算量

附录Y2 CompressorUtils KV压缩工具详解

Y2.1 KV压缩流程

compressor_utils.py 提供:

1. compress_kv_to_latent():
   输入: K [num_tokens, num_kv_heads, head_size]
         V [num_tokens, num_kv_heads, head_size]
   操作: C = W_DKV × concat(K, V)
   输出: C [num_tokens, kv_lora_rank]
   
2. decompress_latent_to_kv():
   输入: C [num_tokens, kv_lora_rank]
   操作: K = W_UK × C, V = W_UV × C
   输出: K [num_tokens, num_kv_heads, head_size]
         V [num_tokens, num_kv_heads, head_size]

3. compress_and_cache():
   融合操作: 压缩 → 写入cache
   避免中间latent张量的内存分配

权重矩阵:
  W_DKV: [kv_lora_rank, 2 * num_kv_heads * head_size]  # 下投影
  W_UK:  [num_kv_heads * head_size, kv_lora_rank]       # K上投影
  W_UV:  [num_kv_heads * head_size, kv_lora_rank]       # V上投影
  
  压缩: C = W_DKV × [K; V]  (concat后投影)
  解压: K = W_UK × C, V = W_UV × C (分别投影)

Y2.2 压缩的数学等价性

标准MHA的注意力计算:
  scores = Q × K^T / √d
  attn = softmax(scores)
  output = attn × V

MLA的注意力计算:
  K = W_UK × C, V = W_UV × C
  scores = Q × (W_UK × C)^T / √d
         = Q × C^T × W_UK^T / √d     (结合律)
  
  注意: Q × W_UK 可以预先计算(Q的投影已经包含了这部分)
  因此 MLA 的注意力计算可以在压缩空间中完成
  不需要显式解压 K 和 V
  
  这就是 FlashMLA decode kernel 高效的原因:
  直接在 latent 空间计算注意力,无需解压

附录Z2 MLA各后端decode路径对比

FlashMLABackend decode:
  → flash_mla_with_kvcache()
  → 直接在latent空间计算注意力
  → 不解压K/V
  → 最快

FlashMLASparseBackend decode:
  → flash_mla_with_sparse_kv()
  → 只读取部分latent(滑动窗口)
  → 不解压K/V
  → 内存带宽最优

FlashInferMLABackend decode:
  → 需要解压C→K,V
  → 使用FlashInfer的decode API
  → 额外解压开销

FlashAttnMLABackend decode:
  → 需要解压C→K,V
  → 使用flash_attn_with_kvcache
  → 额外解压开销

CUTLASSMLABackend decode:
  → 使用CUTLASS kernel
  → 可能不需要完整解压
  → 中等性能

TritonMLABackend decode:
  → 自定义Triton kernel
  → 可能需要解压
  → 回退方案

附录AA2 FlashMLA KV Cache写入详解

AA2.1 latent写入 vs 标准KV写入

标准MHA的KV写入:
  key = W_K × hidden    # [num_tokens, num_kv_heads * head_size]
  value = W_V × hidden  # [num_tokens, num_kv_heads * head_size]
  
  reshape_and_cache(key, value, kv_cache, slot_mapping)
  # kv_cache[0][slot] = key  → K缓存
  # kv_cache[1][slot] = value → V缓存

MLA的latent写入:
  compressed = W_DKV × hidden  # [num_tokens, kv_lora_rank]
  
  # 只写入latent,不写入K和V!
  write_latent_cache(compressed, kv_cache, slot_mapping)
  # kv_cache[slot] = compressed → 压缩缓存
  
  # K和V在需要时从latent解压:
  # K = W_UK × compressed  → [num_tokens, num_kv_heads * head_size]
  # V = W_UV × compressed  → [num_tokens, num_kv_heads * head_size]

MLA的额外K_rope存储:
  k_rope = W_KR × hidden  # [num_tokens, num_rope_heads * rope_dim]
  # k_rope需要单独存储(因为K的RoPE部分不能压缩)
  
  完整的MLA cache:
  - latent_cache: [num_blocks, block_size, kv_lora_rank]  → 主缓存
  - k_rope_cache: [num_blocks, block_size, num_rope_heads, rope_dim] → RoPE部分

AA2.2 MLA decode的KV读取

MLA Decode时,需要从cache读取:
1. latent: 压缩的KV表示 → 直接用于latent注意力(FlashMLA kernel)
2. k_rope: K的RoPE部分 → 与latent解压的K_nope拼接

完整的K重建:
  K_nope = W_UK × latent    → [seq_len, num_kv_heads, nope_dim]
  K_rope = k_rope_cache      → [seq_len, num_rope_heads, rope_dim]
  K = concat(K_nope, K_rope) → [seq_len, num_kv_heads, head_size]

V重建:
  V = W_UV × latent → [seq_len, num_kv_heads, head_size]

注意: FlashMLA kernel不需要显式重建K和V
  它直接在latent空间计算注意力
  这是FlashMLA比其他MLA后端快的关键原因

附录BB2 MLA各后端KV Cache形状对比

=== 标准MHA ===
kv_cache_shape = (2, num_blocks, block_size, num_kv_heads, head_size)
示例: (2, 1024, 16, 64, 128) → 2×1024×16×64×128×2 = 512MB

=== MLA (FlashMLA) ===
latent_cache_shape = (num_blocks, block_size, kv_lora_rank)
示例: (1024, 16, 512) → 1024×16×512×2 = 16MB
k_rope_cache_shape = (num_blocks, block_size, num_rope_heads, rope_dim)
示例: (1024, 16, 1, 64) → 1024×16×1×64×2 = 2MB
总计: 18MB (vs 512MB → 28×节省)

=== MLA (FlashInfer) ===
# FlashInfer MLA存储完整K和V(因为需要解压)
kv_cache_shape = (2, num_blocks, block_size, num_kv_heads, head_size)
# 与标准MHA相同!没有压缩优势
# 只有在decode时才知道压缩
# → FlashInfer MLA不是最优选择

=== MLA (Triton) ===
# 与FlashMLA相同,使用latent cache
latent_cache_shape = (num_blocks, block_size, kv_lora_rank)
# 但decode时需要解压 → 性能不如FlashMLA

附录CC2 MLA RoPE的特殊处理

CC2.1 DeepSeek-V2/V3的RoPE分离

标准RoPE:
  K_full = apply_rope(K, position, freqs)
  → 整个K向量都经过RoPE变换
  → 旋转部分与原始部分混合
  → 无法压缩(RoPE破坏低秩结构)

DeepSeek MLA的RoPE分离:
  K = K_nope + K_rope
  
  K_nope: [num_kv_heads, nope_dim] → 不使用RoPE → 可以压缩
  K_rope: [1, rope_dim] → 使用RoPE → 单独存储
  
  分离方式:
  hidden → W_K_nope → K_nope → 压缩到latent
  hidden → W_K_rope → K_rope → 直接存储(很小)
  
  注意: K_rope的head数通常为1(1组RoPE编码)
  → K_rope的cache极小: [num_blocks, block_size, 1, rope_dim]

数学等价性:
  标准K的RoPE:
    K_rope_full = RoPE(K, pos, freq)
    = [K[:, :nope_dim] + K[:, nope_dim:] ⊗ RoPE(pos, freq)]
  
  MLA的分离:
    K = K_nope + concat(0, K_rope ⊗ RoPE(pos, freq))
    K_nope部分不做RoPE
    K_rope部分单独做RoPE
  
  等价条件: nope_dim + rope_dim = head_size
  K_nope = K[:, :nope_dim]
  K_rope = K[:, nope_dim:]

CC2.2 MLA RoPE在decode中的处理

Decode时的K重建流程:

1. 从latent cache读取latent:
   latent = latent_cache[block, offset]  # [kv_lora_rank]

2. 解压K_nope:
   K_nope = W_UK × latent  # [num_kv_heads, nope_dim]
   
3. 从k_rope cache读取K_rope:
   k_rope = k_rope_cache[block, offset]  # [1, rope_dim]
   # 注意: k_rope_cache在写入时已经应用了RoPE
   
4. 拼接:
   K = concat(K_nope, K_rope_broadcast)  # [num_kv_heads, head_size]
   # K_rope需要广播到所有KV头
   
5. FlashMLA kernel的优化:
   直接在latent空间计算Q×K_nope部分
   单独处理Q×K_rope部分
   两部分合并 → 等价于完整的Q×K计算

附录DD2 MLA后端与标准后端的代码复用关系

MLA后端的代码复用策略:

FlashMLABackend:
  - 继承: AttentionBackend (抽象接口)
  - 复用: slot_mapping计算 (from utils.py)
  - 复用: reshape_and_cache (latent写入)
  - 独有: flash_mla_with_kvcache (专用decode kernel)
  - 独有: latent→KV解压逻辑 (prefill)
  
FlashMLASparseBackend:
  - 继承: FlashMLABackend
  - 复用: 所有FlashMLABackend的逻辑
  - 独有: SparseIndexer (稀疏索引计算)
  - 独有: flash_mla_with_sparse_kv (稀疏decode kernel)

FlashInferMLABackend:
  - 继承: AttentionBackend
  - 复用: FlashInfer的Paged KV Cache API
  - 复用: FlashInfer的prefill/decode wrapper
  - 独有: latent解压→标准KV的转换层
  - 注意: 不使用FlashInfer内置MLA(如果有的话)

FlashAttnMLABackend:
  - 继承: AttentionBackend
  - 复用: flash_attn_varlen_func (prefill)
  - 复用: flash_attn_with_kvcache (decode)
  - 独有: latent解压→标准KV的转换层

CUTLASSMLABackend:
  - 继承: AttentionBackend
  - 复用: CUTLASS库的MLA kernel
  - 独有: 特定的cache管理逻辑

TritonMLABackend:
  - 继承: AttentionBackend
  - 复用: Triton decode/prefill kernel
  - 独有: latent解压的Triton实现

附录EE2 MLA模型检测逻辑

selector.py中的_is_mla_model()检测:

def _is_mla_model(vllm_config) -> bool:
    """检测模型是否使用MLA架构"""
    model_cls = vllm_config.model_config.model_cls
    
    # 方法1: 类名检测
    mla_model_names = [
        "DeepseekV2ForCausalLM",
        "DeepseekV3ForCausalLM", 
        "DeepseekV4ForCausalLM",
        "DeepSeekV2ForCausalLM",
        "DeepSeekV3ForCausalLM",
    ]
    if any(name in model_cls.__name__ for name in mla_model_names):
        return True
    
    # 方法2: 配置参数检测
    hf_config = vllm_config.model_config.hf_config
    if hasattr(hf_config, 'kv_lora_rank') and hf_config.kv_lora_rank > 0:
        # kv_lora_rank > 0 表示使用MLA
        return True
    
    return False

_is_sparse_mla_model()检测:
    # 检查是否使用稀疏注意力(滑动窗口)
    if hasattr(hf_config, 'sliding_window') and hf_config.sliding_window > 0:
        return True
    if hasattr(hf_config, 'attending_scope'):
        return True
    return False
Logo

免费领 100 小时云算力,进群参与显卡、AI PC 幸运抽奖

更多推荐