【vllm】(v1 Attention)vLLM V1 Attention— Part3 MLA后端体系
vLLM V1 Attention 模块超深度架构分析 — Part 3: MLA后端体系
·
vLLM V1 Attention 模块超深度架构分析 — Part 3: MLA后端体系
分析范围:
v1/attention/backends/mla/目录全部源码(21个文件,约6,200行)
目录
- 第十四章 MLA架构原理与总体设计
- 第十五章 FlashMLABackend核心实现
- 第十六章 FlashMLASparseBackend稀疏注意力
- 第十七章 MLA Indexer深度解析
- 第十八章 MLA Prefill子系统
- 第十九章 其他MLA后端
- 附录F MLA KV Cache内存节省计算
- 附录G MLA后端选择决策树
第十四章 MLA架构原理与总体设计
14.1 Multi-head Latent Attention原理
DeepSeek-V2/V3引入的MLA(Multi-head Latent Attention) 是一种KV压缩注意力机制:
核心思想:将多头的Key和Value投影到一个低维latent空间,大幅减少KV Cache的内存占用。
标准MHA:
K = W_K × H → [seq_len, num_kv_heads, head_size]
V = W_V × H → [seq_len, num_kv_heads, head_size]
KV Cache: 2 × num_kv_heads × head_size per token
MLA:
C = W_DKV × H → [seq_len, kv_lora_rank] # 压缩的latent
K = W_UK × C → [seq_len, num_kv_heads, head_size] # 解压的K
V = W_UV × C → [seq_len, num_kv_heads, head_size] # 解压的V
KV Cache: kv_lora_rank per token (远小于 2 × num_kv_heads × head_size)
数学推导:
标准: Cache大小 = 2 × n_kv × d × seq_len × batch × 2bytes
MLA: Cache大小 = d_lora × seq_len × batch × 2bytes
压缩比: 2 × n_kv × d / d_lora
例如: n_kv=64, d=128, d_lora=512 → 压缩比 = 2×64×128/512 = 32x
14.2 MLA与标准MHA的对比
14.3 MLA后端总体架构
第十五章 FlashMLABackend核心实现
15.1 FlashMLABackend类结构
class FlashMLABackend(AttentionBackend):
"""FlashMLA后端
使用FlashMLA库实现DeepSeek MLA
特点:
- KV Cache存储压缩的latent向量
- Prefill时解压K/V
- Decode时使用latent直接计算
- 支持CUDA Graph
"""
@staticmethod
def get_name() -> str:
return "FLASHMLA"
@staticmethod
def get_impl_cls() -> type:
return FlashMLAImpl
@classmethod
def get_metadata_cls(cls) -> type:
return FlashMLAMetadata
@classmethod
def get_builder_cls(cls) -> type:
return FlashMLAMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int, # MLA中不使用,但接口要求
head_size: int, # MLA中不使用
) -> tuple[int, ...]:
# MLA KV Cache只存储压缩的latent向量
# 形状: [num_blocks, block_size, kv_lora_rank]
# 不再是 [2, num_blocks, block_size, num_kv_heads, head_size]
# 这是因为MLA将K和V压缩到同一个latent空间
kv_lora_rank = ... # 从模型配置获取
return (num_blocks, block_size, kv_lora_rank)
@staticmethod
def get_supported_head_sizes() -> list[int]:
# MLA的head_size由kv_lora_rank决定
# 不再是传统的[32, 64, 128, 256]
return [512] # DeepSeek-V2/V3的kv_lora_rank
15.2 FlashMLAMetadata元数据
@dataclass
class FlashMLAMetadata(AttentionMetadata):
"""FlashMLA专用元数据"""
# ---- MLA特有的压缩索引 ----
# Prefill时需要将latent解压为K和V
# 但不需要为每个token单独存储K和V
# ---- Q的RoPE处理 ----
# MLA中Q仍使用RoPE,但K不使用(K在latent空间中)
# 因此需要单独处理Q的RoPE
q_rope_inv: torch.Tensor | None = None # Q的RoPE逆频率
q_rope_cos: torch.Tensor | None = None # Q的RoPE余弦
# ---- Prefill信息 ----
# MLA prefill使用选定的prefill后端
prefill_backend_name: str | None = None # "flash_attn" / "flashinfer" / "trtllm"
# ---- 解压参数 ----
# 从latent解压K和V的投影矩阵
w_uk: torch.Tensor | None = None # [num_kv_heads * head_size, kv_lora_rank]
w_uv: torch.Tensor | None = None # [num_kv_heads * head_size, kv_lora_rank]
15.3 FlashMLAImpl实现
class FlashMLAImpl(AttentionImpl):
"""FlashMLA注意力执行层"""
def forward(self, query, key, value, kv_cache, attn_metadata):
num_tokens = query.shape[0]
# MLA的forward流程与标准MHA不同:
# 1. Q经过RoPE处理(标准处理)
# 2. K/V不需要单独存储,只存储latent C
# 3. Prefill时: 解压C→K,V → 标准注意力计算
# 4. Decode时: 使用latent直接计算(flash_mla专用kernel)
# 1. 处理Query
query = query.view(num_tokens, self.num_heads, self.head_size)
# Q的RoPE在外层已处理
# 2. 写入latent到KV cache
# key实际上已经是压缩的latent C(经过W_DKV投影)
latent = key # [num_tokens, kv_lora_rank]
self._write_latent_cache(latent, kv_cache, attn_metadata.slot_mapping)
# 3. 执行注意力
if attn_metadata.num_prefills > 0:
output = self._run_prefill(query, kv_cache, attn_metadata)
else:
output = self._run_decode(query, kv_cache, attn_metadata)
return output.view(num_tokens, -1)
def _run_prefill(self, query, kv_cache, attn_metadata):
"""MLA Prefill: 解压latent → 标准注意力"""
# 从KV cache读取latent
latent_cache = kv_cache # [num_blocks, block_size, kv_lora_rank]
# 解压: C → K, V
# K = W_UK × C, V = W_UV × C
key = torch.matmul(latent_for_prefill, self.w_uk.T)
value = torch.matmul(latent_for_prefill, self.w_uv.T)
# 使用选定的prefill后端执行标准注意力
# (flash_attn / flashinfer / trtllm_ragged)
output = self.prefill_impl.forward(
query, key, value, ...
)
return output
def _run_decode(self, query, kv_cache, attn_metadata):
"""MLA Decode: 使用FlashMLA专用kernel"""
# FlashMLA decode kernel直接在latent空间计算注意力
# 不需要解压K/V,减少计算量
from flash_mla import flash_mla_with_kvcache
output = flash_mla_with_kvcache(
q=query, # [batch, 1, num_heads, head_size]
k_cache=kv_cache, # latent cache
v_cache=kv_cache, # 同一个latent cache
cache_seqlens=attn_metadata.seq_lens_tensor,
block_table=attn_metadata.block_tables,
head_dim_v=self.kv_lora_rank, # V的维度是kv_lora_rank
head_dim_k=self.head_size, # K的维度是head_size(RoPE后)
...
)
return output
15.4 KV Cache压缩存储
第十六章 FlashMLASparseBackend稀疏注意力
16.1 稀疏滑动窗口动机
DeepSeek-V3使用稀疏注意力:每个token只关注最近的W个token(滑动窗口)+ 少数sink tokens。
class FlashMLASparseBackend(FlashMLABackend):
"""FlashMLA稀疏注意力后端
特点:
- 滑动窗口注意力
- 使用SparseIndexer计算稀疏索引
- 减少不必要的KV读取
"""
@staticmethod
def get_name() -> str:
return "FLASHMLA_SPARSE"
@classmethod
def get_metadata_cls(cls) -> type:
return FlashMLASparseMetadata
@classmethod
def get_builder_cls(cls) -> type:
return FlashMLASparseMetadataBuilder
16.2 SparseIndexer索引器
class SparseIndexer:
"""稀疏注意力索引计算器
计算每个decode token需要关注哪些KV位置
滑动窗口规则:
- 每个token关注最近的window_size个token
- 加上开头的sink_size个token
索引输出:
- kv_indices: [batch, max_num_kv] 需要读取的KV索引
- kv_indptr: [batch+1] CSR指针
- num_kv: [batch] 每序列需要的KV数量
"""
def __init__(
self,
window_size: int, # 滑动窗口大小
sink_size: int = 0, # 开头sink token数
block_size: int = 16, # KV cache块大小
):
self.window_size = window_size
self.sink_size = sink_size
self.block_size = block_size
def compute_index(
self,
seq_lens: torch.Tensor, # [batch] 序列长度
block_tables: torch.Tensor, # [batch, max_blocks] 块表
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""计算稀疏索引"""
batch_size = seq_lens.shape[0]
# 对每个序列,计算需要关注的KV位置
kv_indices_list = []
kv_indptr = [0]
for i in range(batch_size):
seq_len = seq_lens[i].item()
# 滑动窗口: 最近window_size个token
window_start = max(0, seq_len - self.window_size)
# Sink tokens: 开头的sink_size个token
sink_end = min(self.sink_size, seq_len)
# 合并: sink + window(去重)
if window_start < sink_end:
# 窗口和sink重叠 → 直接取[0, seq_len]
indices = list(range(seq_len))
else:
# 不重叠 → sink + gap + window
indices = list(range(sink_end)) + list(range(window_start, seq_len))
# 映射到物理KV cache位置
for pos in indices:
block_id = block_tables[i, pos // self.block_size]
offset = pos % self.block_size
kv_indices_list.append(block_id * self.block_size + offset)
kv_indptr.append(len(kv_indices_list))
# 转换为张量
kv_indices = torch.tensor(kv_indices_list, dtype=torch.int32)
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
num_kv = torch.diff(kv_indptr)
return kv_indices, kv_indptr, num_kv
稀疏索引示意:
seq_len = 100, window_size = 32, sink_size = 4
关注的KV位置:
Sink: [0, 1, 2, 3] → 4个token
Gap: [4, 5, ..., 67] → 不关注(64个token跳过)
Window: [68, 69, ..., 99] → 32个token
总计: 36个token(而非100个)
内存节省: 64% (36/100)
计算节省: 64% (Q×K^T从100次减为36次)
16.3 稀疏注意力执行流程
第十七章 MLA Indexer深度解析
17.1 设计目的
indexer.py(776行)是MLA后端中最复杂的组件之一,负责:
- Prefill索引:确定每个prefill token的KV cache写入位置
- Decode索引:确定每个decode token需要读取的KV位置
- 稀疏索引:在滑动窗口模式下计算稀疏KV索引
- CUDA Graph兼容:索引计算需要支持CUDA Graph录制
17.2 索引计算流程
class MLAIndexer:
"""MLA索引计算器
核心方法:
- compute_prefill_index(): 计算prefill的KV cache索引
- compute_decode_index(): 计算decode的KV cache索引
- compute_sparse_decode_index(): 计算稀疏decode的KV cache索引
"""
def compute_decode_index(
self,
seq_lens: torch.Tensor, # [batch] 当前序列长度
block_tables: torch.Tensor, # [batch, max_blocks] 块表
kv_cache_dtype: str,
) -> MLADecodeIndexResult:
"""计算decode索引
返回:
slot_mapping: [batch] 每序列新token的写入slot
block_table: [batch, max_blocks] 块表
seq_lens: [batch] 序列长度
"""
batch_size = seq_lens.shape[0]
# 每个decode token的写入位置 = 序列末尾
slot_mapping = torch.empty(batch_size, dtype=torch.int64)
for i in range(batch_size):
seq_len = seq_lens[i].item()
# 新token的位置 = seq_len(0-indexed)
pos = seq_len # 因为当前token还未写入
block_id = block_tables[i, pos // self.block_size]
offset = pos % self.block_size
slot_mapping[i] = block_id * self.block_size + offset
return MLADecodeIndexResult(
slot_mapping=slot_mapping,
block_tables=block_tables,
seq_lens=seq_lens,
)
17.3 索引数据结构
@dataclass
class MLAPrefillIndexResult:
"""Prefill索引结果"""
slot_mapping: torch.Tensor # [num_prefill_tokens] 写入slot
cu_seqlens: torch.Tensor # [num_prefills + 1] 累积长度
max_seq_len: int # 最长prefill序列
@dataclass
class MLADecodeIndexResult:
"""Decode索引结果"""
slot_mapping: torch.Tensor # [num_decode] 写入slot
block_tables: torch.Tensor # [batch, max_blocks]
seq_lens: torch.Tensor # [batch]
# 稀疏模式额外字段:
sparse_kv_indices: torch.Tensor | None = None # 稀疏KV索引
sparse_kv_indptr: torch.Tensor | None = None # 稀疏KV指针
sparse_num_kv: torch.Tensor | None = None # 每序列KV数
第十八章 MLA Prefill子系统
18.1 Prefill后端选择
MLA在Prefill阶段需要解压latent为K/V,然后执行标准注意力。prefill可以使用不同的后端:
| Prefill后端 | 特点 | 适用场景 |
|---|---|---|
| FlashAttention | 高效、稳定 | 默认选择 |
| FlashInfer | 支持更多功能 | 需要变长batch时 |
| TRT-LLM Ragged | NVIDIA优化 | H100等新硬件 |
18.2 PrefillRegistry注册表
class MLAPrefillRegistry:
"""MLA Prefill后端注册表"""
_registry: dict[str, type[MLAPrefillBackend]] = {}
@classmethod
def register(cls, name: str):
"""装饰器: 注册prefill后端"""
def decorator(backend_cls):
cls._registry[name] = backend_cls
return backend_cls
return decorator
@classmethod
def get_backend(cls, name: str) -> type[MLAPrefillBackend]:
return cls._registry[name]
# 注册后端
@MLAPrefillRegistry.register("flash_attn")
class FlashAttnMLAPrefill(MLAPrefillBackend): ...
@MLAPrefillRegistry.register("flashinfer")
class FlashInferMLAPrefill(MLAPrefillBackend): ...
@MLAPrefillRegistry.register("trtllm_ragged")
class TRTLLMRaggedMLAPrefill(MLAPrefillBackend): ...
18.3 PrefillSelector选择器
class MLAPrefillSelector:
"""MLA Prefill后端选择器
根据硬件和库可用性选择最优prefill后端
"""
def select(self, vllm_config) -> str:
# 优先级: flash_attn > flashinfer > trtllm_ragged
if is_flash_attn_available():
return "flash_attn"
if is_flashinfer_available():
return "flashinfer"
if is_trtllm_available():
return "trtllm_ragged"
raise RuntimeError("No MLA prefill backend available")
第十九章 其他MLA后端
19.1 FlashInferMLABackend
class FlashInferMLABackend(AttentionBackend):
"""FlashInfer MLA后端
使用FlashInfer库实现MLA
适用于FlashMLA不可用的场景
"""
@staticmethod
def get_name() -> str:
return "FLASHINFER_MLA"
# 与FlashMLABackend的主要区别:
# - 使用FlashInfer的Paged KV Cache API
# - 需要在decode时解压latent为K/V
# - 性能略低于FlashMLA(因为需要解压)
19.2 FlashAttnMLABackend
class FlashAttnMLABackend(AttentionBackend):
"""FlashAttention MLA后端
使用FlashAttention库实现MLA
适用于FlashMLA和FlashInfer都不可用的场景
"""
@staticmethod
def get_name() -> str:
return "FLASHATTN_MLA"
19.3 CUTLASSMLABackend
class CUTLASSMLABackend(AttentionBackend):
"""CUTLASS MLA后端
使用NVIDIA CUTLASS库实现MLA
专注于decode阶段的优化
"""
@staticmethod
def get_name() -> str:
return "CUTLASS_MLA"
19.4 TritonMLABackend
class TritonMLABackend(AttentionBackend):
"""Triton MLA后端
使用自定义Triton kernel实现MLA
适用于无专用MLA库的场景
"""
@staticmethod
def get_name() -> str:
return "TRITON_MLA"
19.5 ROCm/XPU MLA后端
# ROCm平台
class ROCmAITerMLABackend(AttentionBackend):
"""ROCm aiter MLA后端"""
@staticmethod
def get_name() -> str:
return "ROCM_AITER_MLA"
class ROCmAITerMLASparseBackend(AttentionBackend):
"""ROCm aiter MLA稀疏后端"""
@staticmethod
def get_name() -> str:
return "ROCM_AITER_MLA_SPARSE"
# XPU平台
class XPUMLASparseBackend(AttentionBackend):
"""XPU MLA稀疏后端"""
@staticmethod
def get_name() -> str:
return "XPU_MLA_SPARSE"
附录F MLA KV Cache内存节省计算
典型DeepSeek-V3配置:
num_kv_heads = 64
head_size = 128
kv_lora_rank = 512
标准KV Cache (per token, fp16):
K: 64 × 128 × 2 = 16,384 bytes
V: 64 × 128 × 2 = 16,384 bytes
总计: 32,768 bytes = 32KB
MLA KV Cache (per token, fp16):
C: 512 × 2 = 1,024 bytes = 1KB
压缩比: 32KB / 1KB = 32×
对于4K上下文、256并发:
标准: 32KB × 4096 × 256 = 32GB
MLA: 1KB × 4096 × 256 = 1GB
节省: 31GB GPU内存!
附录G MLA后端选择决策树
附录U FlashMLAImpl.forward() 完整流程追踪
U.1 Decode路径详解
FlashMLAImpl._run_decode(query, kv_cache, attn_metadata):
Step 1: 准备查询
query = query.view(num_decode, num_heads, head_size)
# [batch, heads, dim]
Step 2: 准备KV cache
# MLA的KV cache只存储latent向量
# 形状: [num_blocks, block_size, kv_lora_rank]
# 不需要分别读取K和V
Step 3: 调用flash_mla_with_kvcache
output = flash_mla_with_kvcache(
q=query, # [batch, 1, heads, dim]
k_cache=kv_cache, # latent cache
cache_seqlens=attn_metadata.seq_lens,
block_table=attn_metadata.block_tables,
head_dim_v=kv_lora_rank, # V维度=压缩维度
head_dim_k=head_size, # K维度=原始维度
softmax_scale=scale,
)
Step 4: 处理输出
output = output.squeeze(1) # [batch, heads, dim]
output = output.view(num_decode, -1) # [batch, heads * dim]
U.2 Prefill路径详解
FlashMLAImpl._run_prefill(query, kv_cache, attn_metadata):
Step 1: 从latent cache解压K和V
# 读取prefill序列的所有latent
latent = read_from_cache(kv_cache, attn_metadata)
# latent: [num_prefill_tokens, kv_lora_rank]
# 解压K: K = latent × W_UK^T
key = torch.matmul(latent, w_uk)
# key: [num_prefill_tokens, num_kv_heads * head_size]
# 解压V: V = latent × W_UV^T
value = torch.matmul(latent, w_uv)
# value: [num_prefill_tokens, num_kv_heads * head_size]
Step 2: 选择prefill后端
# 根据MLAPrefillSelector的选择
# 可能是: flash_attn, flashinfer, 或 trtllm_ragged
if prefill_backend == "flash_attn":
output = flash_attn_varlen_func(
q=query, k=key, v=value,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seq_len,
max_seqlen_k=max_seq_len,
softmax_scale=scale,
causal=True,
)
elif prefill_backend == "flashinfer":
output = flashinfer_prefill(query, key, value, ...)
elif prefill_backend == "trtllm_ragged":
output = trtllm_ragged_prefill(query, key, value, ...)
U.3 Q的RoPE处理(MLA特殊)
MLA中的RoPE处理与标准MHA不同:
标准MHA:
Q = apply_rope(Q, position)
K = apply_rope(K, position)
→ Q和K都使用RoPE
MLA:
Q = apply_rope(Q, position) # Q使用RoPE
K = NO_ROPE(K) # K不使用RoPE!
C = compress(K, V) → latent cache # 压缩存储
为什么K不用RoPE?
因为RoPE会破坏K的低秩结构
MLA依赖K的低秩结构实现压缩
如果K经过RoPE,就不能用W_DKV压缩到latent空间
解决方案: 将K分为两部分
K = K_nope + K_rope
K_nope: 不使用RoPE的部分 → 可以压缩
K_rope: 使用RoPE的部分 → 单独存储(很小)
在decode时:
从latent解压出K_nope
从单独的cache读取K_rope
合并为完整的K: K = K_nope + K_rope
附录V SparseIndexer 稀疏索引算法深度分析
V.1 滑动窗口+Sink的索引计算
参数:
window_size = 4096 # 滑动窗口大小
sink_size = 4 # 开头保留的sink token数
block_size = 16 # KV cache块大小
场景: seq_len = 10000
Step 1: 确定关注范围
window_start = max(0, 10000 - 4096) = 5904
sink_end = min(4, 10000) = 4
关注范围: [0, 3] ∪ [5904, 9999]
跳过范围: [4, 5903] → 5900个token不需读取
Step 2: 计算KV块覆盖
关注的token范围 → 需要读取的KV块
Sink部分: [0, 3]
→ 块0: [0, 15] (只读取前4个token)
Window部分: [5904, 9999]
→ 块369: [5904, 5919]
→ 块370: [5920, 5935]
→ ...
→ 块624: [9984, 9999]
总块数: 1 (sink) + 256 (window) = 257 块
不读取的块: 625 - 257 = 368 块 (59%跳过)
Step 3: 构建稀疏索引
paged_kv_indices = [0, 369, 370, ..., 624] # 257个块ID
paged_kv_indptr = [0, 257] # 1个请求
paged_kv_last_page_len = [10000 % 16] = [0] # 最后一页满
# 特殊处理: 当last_page_len=0时,使用block_size
Step 4: 执行稀疏注意力
只对257个块执行注意力计算
跳过368个块 → 减少59%的KV读取和计算
V.2 多序列批次的索引
batch_size = 3
seq_lens = [10000, 5000, 1000]
window_size = 4096, sink_size = 4
请求0: seq_len=10000
关注: [0,3] ∪ [5904,9999] → 257 blocks
请求1: seq_len=5000
window_start = max(0, 5000-4096) = 904
关注: [0,3] ∪ [904,4999] → 1 + 256 = 257 blocks
请求2: seq_len=1000
window_start = max(0, 1000-4096) = 0
关注: [0,999] → 全部(因为seq_len < window_size)
blocks: ceil(1000/16) = 63 blocks
paged_kv_indices = [
0, 369, 370, ..., 624, # 请求0: 257 blocks
1000, 1369, ..., 1624, # 请求1: 257 blocks
2000, 2001, ..., 2062, # 请求2: 63 blocks
]
# 总计: 577 blocks
paged_kv_indptr = [0, 257, 514, 577]
# 请求0: indices[0:257]
# 请求1: indices[257:514]
# 请求2: indices[514:577]
附录W MLA Prefill后端深度对比
W.1 FlashAttnMLAPrefill
特点:
- 使用flash_attn_varlen_func
- 最稳定的实现
- 支持所有head_size
- 不需要额外库
限制:
- 需要先解压latent→K,V
- 解压计算量: num_tokens × kv_lora_rank × (num_kv_heads × head_size)
- 内存峰值: 需要存储完整K和V
适用: 小到中等序列长度
W.2 FlashInferMLAPrefill
特点:
- 使用FlashInfer的prefill API
- 支持Paged KV Cache
- 更好的batch利用率
限制:
- 需要FlashInfer库
- 某些head_size可能不支持
适用: 混合batch(prefill+decode同时)
W.3 TRTLLMRaggedMLAPrefill
特点:
- 使用TensorRT-LLM的ragged tensor API
- NVIDIA H100优化
- 最高性能(H100上)
限制:
- 需要TRT-LLM库
- 仅支持NVIDIA GPU
- 配置复杂
适用: H100/A100等NVIDIA高端GPU
W.4 性能对比
| Prefill后端 | A100性能 | H100性能 | 内存峰值 | 兼容性 |
|---|---|---|---|---|
| FlashAttn | ★★★★ | ★★★★ | 高(需解压) | ★★★★★ |
| FlashInfer | ★★★★ | ★★★★ | 中 | ★★★ |
| TRT-LLM | ★★★ | ★★★★★ | 低 | ★★ |
附录X2 FlashMLASparseMetadata 构建流程
X2.1 稀疏索引与dense索引的构建差异
FlashMLASparseMetadataBuilder.build():
1. 构建dense索引(与FlashMLA相同)
- slot_mapping: token→物理slot映射
- block_tables: 序列→块映射
- paged_kv_indices/indptr/last_page_len
2. 额外构建稀疏索引
- sparse_kv_indices: 需要读取的稀疏KV位置
- sparse_kv_indptr: CSR指针
- sparse_num_kv: 每序列需要的KV数量
使用SparseIndexer计算:
indexer = SparseIndexer(
window_size=sliding_window,
sink_size=sink_size,
block_size=block_size,
)
sparse_indices = indexer.compute_index(
seq_lens=decode_seq_lens,
block_tables=block_tables,
)
3. 构建complete metadata
return FlashMLASparseMetadata(
# 标准字段
...,
# 稀疏字段
sparse_kv_indices=sparse_indices.kv_indices,
sparse_kv_indptr=sparse_indices.kv_indptr,
sparse_num_kv=sparse_indices.num_kv,
)
X2.2 稀疏decode的注意力计算
FlashMLASparseImpl._run_decode():
1. 读取稀疏KV
# 不是读取所有KV块
# 只读取sparse_kv_indices指定的块
for i in range(batch_size):
start = sparse_kv_indptr[i]
end = sparse_kv_indptr[i+1]
# 请求i的KV块 = kv_cache[sparse_kv_indices[start:end]]
2. 执行稀疏注意力
# 使用flash_mla的稀疏模式
output = flash_mla_with_sparse_kv(
q=query,
kv_cache=kv_cache,
sparse_indices=sparse_kv_indices,
sparse_indptr=sparse_kv_indptr,
num_kv=sparse_num_kv,
...
)
# 内部只读取指定的KV块
# 跳过不需要的块
# 减少内存带宽和计算量
附录Y2 CompressorUtils KV压缩工具详解
Y2.1 KV压缩流程
compressor_utils.py 提供:
1. compress_kv_to_latent():
输入: K [num_tokens, num_kv_heads, head_size]
V [num_tokens, num_kv_heads, head_size]
操作: C = W_DKV × concat(K, V)
输出: C [num_tokens, kv_lora_rank]
2. decompress_latent_to_kv():
输入: C [num_tokens, kv_lora_rank]
操作: K = W_UK × C, V = W_UV × C
输出: K [num_tokens, num_kv_heads, head_size]
V [num_tokens, num_kv_heads, head_size]
3. compress_and_cache():
融合操作: 压缩 → 写入cache
避免中间latent张量的内存分配
权重矩阵:
W_DKV: [kv_lora_rank, 2 * num_kv_heads * head_size] # 下投影
W_UK: [num_kv_heads * head_size, kv_lora_rank] # K上投影
W_UV: [num_kv_heads * head_size, kv_lora_rank] # V上投影
压缩: C = W_DKV × [K; V] (concat后投影)
解压: K = W_UK × C, V = W_UV × C (分别投影)
Y2.2 压缩的数学等价性
标准MHA的注意力计算:
scores = Q × K^T / √d
attn = softmax(scores)
output = attn × V
MLA的注意力计算:
K = W_UK × C, V = W_UV × C
scores = Q × (W_UK × C)^T / √d
= Q × C^T × W_UK^T / √d (结合律)
注意: Q × W_UK 可以预先计算(Q的投影已经包含了这部分)
因此 MLA 的注意力计算可以在压缩空间中完成
不需要显式解压 K 和 V
这就是 FlashMLA decode kernel 高效的原因:
直接在 latent 空间计算注意力,无需解压
附录Z2 MLA各后端decode路径对比
FlashMLABackend decode:
→ flash_mla_with_kvcache()
→ 直接在latent空间计算注意力
→ 不解压K/V
→ 最快
FlashMLASparseBackend decode:
→ flash_mla_with_sparse_kv()
→ 只读取部分latent(滑动窗口)
→ 不解压K/V
→ 内存带宽最优
FlashInferMLABackend decode:
→ 需要解压C→K,V
→ 使用FlashInfer的decode API
→ 额外解压开销
FlashAttnMLABackend decode:
→ 需要解压C→K,V
→ 使用flash_attn_with_kvcache
→ 额外解压开销
CUTLASSMLABackend decode:
→ 使用CUTLASS kernel
→ 可能不需要完整解压
→ 中等性能
TritonMLABackend decode:
→ 自定义Triton kernel
→ 可能需要解压
→ 回退方案
附录AA2 FlashMLA KV Cache写入详解
AA2.1 latent写入 vs 标准KV写入
标准MHA的KV写入:
key = W_K × hidden # [num_tokens, num_kv_heads * head_size]
value = W_V × hidden # [num_tokens, num_kv_heads * head_size]
reshape_and_cache(key, value, kv_cache, slot_mapping)
# kv_cache[0][slot] = key → K缓存
# kv_cache[1][slot] = value → V缓存
MLA的latent写入:
compressed = W_DKV × hidden # [num_tokens, kv_lora_rank]
# 只写入latent,不写入K和V!
write_latent_cache(compressed, kv_cache, slot_mapping)
# kv_cache[slot] = compressed → 压缩缓存
# K和V在需要时从latent解压:
# K = W_UK × compressed → [num_tokens, num_kv_heads * head_size]
# V = W_UV × compressed → [num_tokens, num_kv_heads * head_size]
MLA的额外K_rope存储:
k_rope = W_KR × hidden # [num_tokens, num_rope_heads * rope_dim]
# k_rope需要单独存储(因为K的RoPE部分不能压缩)
完整的MLA cache:
- latent_cache: [num_blocks, block_size, kv_lora_rank] → 主缓存
- k_rope_cache: [num_blocks, block_size, num_rope_heads, rope_dim] → RoPE部分
AA2.2 MLA decode的KV读取
MLA Decode时,需要从cache读取:
1. latent: 压缩的KV表示 → 直接用于latent注意力(FlashMLA kernel)
2. k_rope: K的RoPE部分 → 与latent解压的K_nope拼接
完整的K重建:
K_nope = W_UK × latent → [seq_len, num_kv_heads, nope_dim]
K_rope = k_rope_cache → [seq_len, num_rope_heads, rope_dim]
K = concat(K_nope, K_rope) → [seq_len, num_kv_heads, head_size]
V重建:
V = W_UV × latent → [seq_len, num_kv_heads, head_size]
注意: FlashMLA kernel不需要显式重建K和V
它直接在latent空间计算注意力
这是FlashMLA比其他MLA后端快的关键原因
附录BB2 MLA各后端KV Cache形状对比
=== 标准MHA ===
kv_cache_shape = (2, num_blocks, block_size, num_kv_heads, head_size)
示例: (2, 1024, 16, 64, 128) → 2×1024×16×64×128×2 = 512MB
=== MLA (FlashMLA) ===
latent_cache_shape = (num_blocks, block_size, kv_lora_rank)
示例: (1024, 16, 512) → 1024×16×512×2 = 16MB
k_rope_cache_shape = (num_blocks, block_size, num_rope_heads, rope_dim)
示例: (1024, 16, 1, 64) → 1024×16×1×64×2 = 2MB
总计: 18MB (vs 512MB → 28×节省)
=== MLA (FlashInfer) ===
# FlashInfer MLA存储完整K和V(因为需要解压)
kv_cache_shape = (2, num_blocks, block_size, num_kv_heads, head_size)
# 与标准MHA相同!没有压缩优势
# 只有在decode时才知道压缩
# → FlashInfer MLA不是最优选择
=== MLA (Triton) ===
# 与FlashMLA相同,使用latent cache
latent_cache_shape = (num_blocks, block_size, kv_lora_rank)
# 但decode时需要解压 → 性能不如FlashMLA
附录CC2 MLA RoPE的特殊处理
CC2.1 DeepSeek-V2/V3的RoPE分离
标准RoPE:
K_full = apply_rope(K, position, freqs)
→ 整个K向量都经过RoPE变换
→ 旋转部分与原始部分混合
→ 无法压缩(RoPE破坏低秩结构)
DeepSeek MLA的RoPE分离:
K = K_nope + K_rope
K_nope: [num_kv_heads, nope_dim] → 不使用RoPE → 可以压缩
K_rope: [1, rope_dim] → 使用RoPE → 单独存储
分离方式:
hidden → W_K_nope → K_nope → 压缩到latent
hidden → W_K_rope → K_rope → 直接存储(很小)
注意: K_rope的head数通常为1(1组RoPE编码)
→ K_rope的cache极小: [num_blocks, block_size, 1, rope_dim]
数学等价性:
标准K的RoPE:
K_rope_full = RoPE(K, pos, freq)
= [K[:, :nope_dim] + K[:, nope_dim:] ⊗ RoPE(pos, freq)]
MLA的分离:
K = K_nope + concat(0, K_rope ⊗ RoPE(pos, freq))
K_nope部分不做RoPE
K_rope部分单独做RoPE
等价条件: nope_dim + rope_dim = head_size
K_nope = K[:, :nope_dim]
K_rope = K[:, nope_dim:]
CC2.2 MLA RoPE在decode中的处理
Decode时的K重建流程:
1. 从latent cache读取latent:
latent = latent_cache[block, offset] # [kv_lora_rank]
2. 解压K_nope:
K_nope = W_UK × latent # [num_kv_heads, nope_dim]
3. 从k_rope cache读取K_rope:
k_rope = k_rope_cache[block, offset] # [1, rope_dim]
# 注意: k_rope_cache在写入时已经应用了RoPE
4. 拼接:
K = concat(K_nope, K_rope_broadcast) # [num_kv_heads, head_size]
# K_rope需要广播到所有KV头
5. FlashMLA kernel的优化:
直接在latent空间计算Q×K_nope部分
单独处理Q×K_rope部分
两部分合并 → 等价于完整的Q×K计算
附录DD2 MLA后端与标准后端的代码复用关系
MLA后端的代码复用策略:
FlashMLABackend:
- 继承: AttentionBackend (抽象接口)
- 复用: slot_mapping计算 (from utils.py)
- 复用: reshape_and_cache (latent写入)
- 独有: flash_mla_with_kvcache (专用decode kernel)
- 独有: latent→KV解压逻辑 (prefill)
FlashMLASparseBackend:
- 继承: FlashMLABackend
- 复用: 所有FlashMLABackend的逻辑
- 独有: SparseIndexer (稀疏索引计算)
- 独有: flash_mla_with_sparse_kv (稀疏decode kernel)
FlashInferMLABackend:
- 继承: AttentionBackend
- 复用: FlashInfer的Paged KV Cache API
- 复用: FlashInfer的prefill/decode wrapper
- 独有: latent解压→标准KV的转换层
- 注意: 不使用FlashInfer内置MLA(如果有的话)
FlashAttnMLABackend:
- 继承: AttentionBackend
- 复用: flash_attn_varlen_func (prefill)
- 复用: flash_attn_with_kvcache (decode)
- 独有: latent解压→标准KV的转换层
CUTLASSMLABackend:
- 继承: AttentionBackend
- 复用: CUTLASS库的MLA kernel
- 独有: 特定的cache管理逻辑
TritonMLABackend:
- 继承: AttentionBackend
- 复用: Triton decode/prefill kernel
- 独有: latent解压的Triton实现
附录EE2 MLA模型检测逻辑
selector.py中的_is_mla_model()检测:
def _is_mla_model(vllm_config) -> bool:
"""检测模型是否使用MLA架构"""
model_cls = vllm_config.model_config.model_cls
# 方法1: 类名检测
mla_model_names = [
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"DeepseekV4ForCausalLM",
"DeepSeekV2ForCausalLM",
"DeepSeekV3ForCausalLM",
]
if any(name in model_cls.__name__ for name in mla_model_names):
return True
# 方法2: 配置参数检测
hf_config = vllm_config.model_config.hf_config
if hasattr(hf_config, 'kv_lora_rank') and hf_config.kv_lora_rank > 0:
# kv_lora_rank > 0 表示使用MLA
return True
return False
_is_sparse_mla_model()检测:
# 检查是否使用稀疏注意力(滑动窗口)
if hasattr(hf_config, 'sliding_window') and hf_config.sliding_window > 0:
return True
if hasattr(hf_config, 'attending_scope'):
return True
return False
更多推荐


所有评论(0)