【vllm】(v1 Attention)vLLM V1 Attention— Part2 标准Attention后端实现
·
vLLM V1 Attention 模块超深度架构分析 — Part 2: 标准Attention后端实现
分析范围:
v1/attention/backends/目录的标准后端(FlashInfer/FlashAttn/Triton/CPU/Flex/ROCm/TurboQuant)
目录
- 第六章 FlashInfer后端深度解析
- 第七章 FlashAttention后端深度解析
- 第八章 Triton后端深度解析
- 第九章 CPU后端
- 第十章 FlexAttention后端
- 第十一章 ROCm后端
- 第十二章 TurboQuant后端
- 第十三章 特殊架构后端
- 附录D 后端性能对比矩阵
- 附录E FlashInfer Workspace内存布局
第六章 FlashInfer后端深度解析
6.1 FlashInfer概述与架构
FlashInfer是vLLM V1推荐的默认CUDA注意力后端,提供:
- 统一的Prefill/Decode API
- 高效的Batched Prefill(支持变长序列)
- Paged KV Cache管理
- CUDA Graph兼容
- Sliding Window支持
- 多种RoPE实现
6.2 FlashInferBackend类结构
class FlashInferBackend(AttentionBackend):
"""FlashInfer后端
特点:
- 使用FlashInfer库的Paged KV Cache API
- 支持CUDA Graph(workspace预分配)
- Prefill和Decode使用不同的wrapper
"""
@staticmethod
def get_name() -> str:
return "FLASHINFER"
@staticmethod
def get_impl_cls() -> type:
return FlashInferImpl
@classmethod
def get_metadata_cls(cls) -> type:
return FlashInferMetadata
@classmethod
def get_builder_cls(cls) -> type:
return FlashInferMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> tuple[int, ...]:
# FlashInfer使用标准的2-tensor布局:
# Key: [num_blocks, block_size, num_kv_heads, head_size]
# Value: [num_blocks, block_size, num_kv_heads, head_size]
# 合并为: [2, num_blocks, block_size, num_kv_heads, head_size]
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_supported_head_sizes() -> list[int]:
return [32, 64, 96, 128, 256]
6.3 FlashInferMetadata元数据
@dataclass
class FlashInferMetadata(AttentionMetadata):
"""FlashInfer专用元数据
在基础AttentionMetadata之上增加FlashInfer特有的字段
"""
# ---- FlashInfer Workspace ----
# FlashInfer需要workspace内存来存储中间计算结果
# workspace大小取决于最大batch_size和序列长度
workspace: torch.Tensor | None = None
# ---- Prefill/Decode区分 ----
# FlashInfer对prefill和decode使用不同的API
prefill_wrapper: "BatchPrefillWithPagedKVCacheWrapper | None" = None
decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper | None" = None
# ---- 序列分组 ----
# FlashInfer使用seq_groups来描述批次中的序列
prefill_seq_groups: list[tuple[int, int]] | None = None
decode_seq_groups: list[tuple[int, int]] | None = None
# ---- Paged KV Cache索引 ----
# FlashInfer使用paged_kv_indices和paged_kv_indptr
# 类似CSR格式的稀疏索引
paged_kv_indices: torch.Tensor | None = None # [num_pages] 页索引
paged_kv_indptr: torch.Tensor | None = None # [batch+1] 页指针
paged_kv_last_page_len: torch.Tensor | None = None # [batch] 最后一页长度
# ---- RoPE参数 ----
# FlashInfer内置了RoPE,通过Q/K的position信息传入
rotary_dim: int | None = None
rotary_inv_q: torch.Tensor | None = None # RoPE逆频率(Q)
rotary_inv_k: torch.Tensor | None = None # RoPE逆频率(K)
rotary_cos_q: torch.Tensor | None = None # RoPE余弦(Q)
rotary_cos_k: torch.Tensor | None = None # RoPE余弦(K)
Paged KV Cache索引结构:
batch_size = 3
block_table = [[0, 1, 3], [2, 5, -1], [4, 6, 7]] # 每序列的块列表
seq_lens = [40, 25, 50]
paged_kv_indptr = [0, 3, 5, 8]
# 请求0: pages[0:3] = [0, 1, 3]
# 请求1: pages[3:5] = [2, 5]
# 请求2: pages[5:8] = [4, 6, 7]
paged_kv_indices = [0, 1, 3, 2, 5, 4, 6, 7]
# 展平的所有活跃页索引
paged_kv_last_page_len = [8, 9, 2]
# 请求0: 40 tokens / 16 block_size → 2.5 → 3 pages, last_page_len=40%16=8
# 请求1: 25 / 16 → 2 pages, last_page_len=25%16=9
# 请求2: 50 / 16 → 4 pages, last_page_len=50%16=2
6.4 FlashInferMetadataBuilder构建器
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
"""FlashInfer元数据构建器
负责每步构建FlashInferMetadata,包括:
1. 创建/更新prefill和decode wrapper
2. 构建paged_kv_indices/indptr/last_page_len
3. 分配workspace
"""
def __init__(self, input_builder):
self.input_builder = input_builder
self.vllm_config = input_builder.vllm_config
self.device = input_builder.device
# 预分配FlashInfer wrapper
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
workspace=None, # 首次使用时分配
)
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
workspace=None,
)
# 预分配索引张量
max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
max_num_blocks = self.vllm_config.cache_config.num_gpu_blocks
max_num_pages = max_batch_size * max_num_blocks
self.paged_kv_indices = torch.zeros(
max_num_pages, dtype=torch.int32, device=self.device
)
self.paged_kv_indptr = torch.zeros(
max_batch_size + 1, dtype=torch.int32, device=self.device
)
self.paged_kv_last_page_len = torch.zeros(
max_batch_size, dtype=torch.int32, device=self.device
)
build()方法核心流程:
def build(self, input_ids, seq_lens, query_lens, ...) -> FlashInferMetadata:
"""构建当前步的FlashInferMetadata"""
# 1. 分类prefill和decode序列
prefill_indices = [i for i, ql in enumerate(query_lens) if ql > 1]
decode_indices = [i for i, ql in enumerate(query_lens) if ql == 1]
# 2. 构建prefill paged_kv索引
if prefill_indices:
self._build_paged_kv_index(
prefill_indices, seq_lens, query_lens,
is_prefill=True
)
self.prefill_wrapper.plan(
paged_kv_indptr=self.paged_kv_indptr[:len(prefill_indices)+1],
paged_kv_indices=self.paged_kv_indices[:total_pages],
paged_kv_last_page_len=self.paged_kv_last_page_len[:len(prefill_indices)],
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
...
)
# 3. 构建decode paged_kv索引
if decode_indices:
self._build_paged_kv_index(
decode_indices, seq_lens, query_lens,
is_prefill=False
)
self.decode_wrapper.plan(
paged_kv_indptr=self.paged_kv_indptr[:len(decode_indices)+1],
paged_kv_indices=self.paged_kv_indices[:total_pages],
paged_kv_last_page_len=self.paged_kv_last_page_len[:len(decode_indices)],
...
)
# 4. 构建slot_mapping
slot_mapping = compute_slot_mapping(
block_tables, seq_lens, query_lens, block_size
)
# 5. 组装metadata
return FlashInferMetadata(
num_prefills=len(prefill_indices),
num_decode_tokens=len(decode_indices),
slot_mapping=slot_mapping,
seq_lens=seq_lens,
prefill_wrapper=self.prefill_wrapper,
decode_wrapper=self.decode_wrapper,
paged_kv_indices=self.paged_kv_indices,
paged_kv_indptr=self.paged_kv_indptr,
paged_kv_last_page_len=self.paged_kv_last_page_len,
...
)
6.5 FlashInferImpl注意力实现
class FlashInferImpl(AttentionImpl):
"""FlashInfer注意力执行层
实际调用FlashInfer库执行注意力计算
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
):
self.num_heads = num_heads
self.head_size = head_size
self.scale = scale
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window
self.kv_cache_dtype = kv_cache_dtype
def forward(
self,
query: torch.Tensor, # [num_tokens, num_heads * head_size]
key: torch.Tensor, # [num_tokens, num_kv_heads * head_size]
value: torch.Tensor, # [num_tokens, num_kv_heads * head_size]
kv_cache: torch.Tensor, # [2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: FlashInferMetadata,
) -> torch.Tensor:
"""执行FlashInfer注意力计算"""
num_tokens = query.shape[0]
# 1. Reshape Q/K/V为多头格式
query = query.view(num_tokens, self.num_heads, self.head_size)
key = key.view(num_tokens, self.num_kv_heads, self.head_size)
value = value.view(num_tokens, self.num_kv_heads, self.head_size)
# 2. 写入KV cache
# 使用slot_mapping将K/V写入KV cache的对应位置
key_cache = kv_cache[0] # [num_blocks, block_size, num_kv_heads, head_size]
value_cache = kv_cache[1]
# reshape_and_cache: 将K/V写入paged KV cache
# 这是FlashInfer的特有操作,与标准PagedAttention的写入不同
ops.reshape_and_cache_flash(
key, value, key_cache, value_cache,
attn_metadata.slot_mapping, self.kv_cache_dtype,
kv_scale=..., kv_zp=...,
)
# 3. 执行注意力计算
if attn_metadata.num_prefills > 0 and attn_metadata.num_decode_tokens > 0:
# 混合批次:prefill和decode同时存在
# FlashInfer V1支持混合批次
output = self._run_mixed_batch(query, attn_metadata)
elif attn_metadata.num_prefills > 0:
# 纯prefill批次
output = self._run_prefill(query, attn_metadata)
else:
# 纯decode批次
output = self._run_decode(query, attn_metadata)
return output.view(num_tokens, self.num_heads * self.head_size)
三种执行模式:
def _run_prefill(self, query, attn_metadata):
"""执行prefill注意力
使用BatchPrefillWithPagedKVCacheWrapper
支持变长序列的批量prefill
"""
prefill_query = query[:attn_metadata.num_prefill_tokens]
output = attn_metadata.prefill_wrapper.run(
prefill_query,
paged_kv_cache=kv_cache, # 包含K和V的缓存
rotary_inv_q=attn_metadata.rotary_inv_q,
rotary_inv_k=attn_metadata.rotary_inv_k,
sm_scale=self.scale,
)
return output
def _run_decode(self, query, attn_metadata):
"""执行decode注意力
使用BatchDecodeWithPagedKVCacheWrapper
每序列1个查询token
"""
decode_query = query[attn_metadata.num_prefill_tokens:]
output = attn_metadata.decode_wrapper.run(
decode_query,
paged_kv_cache=kv_cache,
sm_scale=self.scale,
)
return output
def _run_mixed_batch(self, query, attn_metadata):
"""执行混合批次(prefill+decode)"""
# FlashInfer V1支持prefill和decode在同一个wrapper中执行
# 使用统一的BatchPrefillWithPagedKVCacheWrapper
output = attn_metadata.prefill_wrapper.run(
query,
paged_kv_cache=kv_cache,
sm_scale=self.scale,
)
return output
6.6 Workspace内存管理
FlashInfer需要workspace内存来存储中间计算结果(如softmax的exp值累积)。workspace的大小取决于批次大小和序列长度。
def _prepare_workspace(
self,
batch_size: int,
max_seq_len: int,
) -> torch.Tensor:
"""准备FlashInfer workspace
Workspace用途:
- softmax的累积器(exp值求和、max值跟踪)
- 注意力矩阵的部分结果
- 临时缓冲区
大小估算:
- 每个decode token: 约 num_heads * head_size * 4 bytes
- 每个prefill token: 约额外需要 seq_len * head_size * 4 bytes
- 总大小约: batch_size * max_seq_len * num_heads * head_size * 4
"""
workspace_size = self._estimate_workspace_size(batch_size, max_seq_len)
if self.workspace is None or self.workspace.numel() < workspace_size:
self.workspace = torch.empty(
workspace_size, dtype=torch.uint8, device=self.device
)
return self.workspace
6.7 Prefill/Decode双模式切换
第七章 FlashAttention后端深度解析
7.1 FlashAttention版本差异
| 特性 | FlashAttention v2 | FlashAttention v3 |
|---|---|---|
| Prefill API | flash_attn_varlen_func |
flash_attn_with_kvcache |
| Decode API | flash_attn_with_kvcache |
flash_attn_with_kvcache |
| KV Cache管理 | 外部管理 | 内置支持 |
| CUDA Graph | 需要手动管理 | 原生支持 |
| 性能 | 高 | 更高(H100优化) |
7.2 FlashAttnBackend类结构
class FlashAttnBackend(AttentionBackend):
"""FlashAttention后端
使用FlashAttention 2/3的API
与FlashInfer的主要区别:
- 不使用wrapper模式
- 直接调用flash_attn函数
- 需要手动构建seq_lens和block_tables
"""
@staticmethod
def get_name() -> str:
return "FLASH_ATTN"
@staticmethod
def get_impl_cls() -> type:
return FlashAttnImpl
@classmethod
def get_metadata_cls(cls) -> type:
return FlashAttnMetadata
@classmethod
def get_builder_cls(cls) -> type:
return FlashAttnMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks, block_size, num_kv_heads, head_size
) -> tuple[int, ...]:
# FlashAttention也使用标准的2-tensor布局
return (2, num_blocks, block_size, num_kv_heads, head_size)
7.3 FlashAttnMetadata元数据
@dataclass
class FlashAttnMetadata(AttentionMetadata):
"""FlashAttention专用元数据
与FlashInfer的区别:
- 不使用paged_kv_indices/indptr
- 使用seq_lens_tensor和block_tables直接传给flash_attn
"""
# FlashAttention需要的seq_lens(GPU tensor)
seq_lens_tensor: torch.Tensor | None = None # [batch_size]
# Prefill序列的累积长度(用于varlen API)
cu_seqlens: torch.Tensor | None = None # [batch_size + 1]
# Prefill/Decode的最大序列长度
max_prefill_seq_len: int = 0
max_decode_seq_len: int = 0
# Block tables(GPU tensor)
block_tables: torch.Tensor | None = None # [batch_size, max_blocks]
# FlashAttention v3特有
# v3使用kvcache指针而非block_table
kvcache_start_idx: int = 0
kvcache_end_idx: int = 0
cu_seqlens的设计:
prefill_batch:
序列0: 10 tokens → cu_seqlens[0] = 0, cu_seqlens[1] = 10
序列1: 20 tokens → cu_seqlens[1] = 10, cu_seqlens[2] = 30
序列2: 5 tokens → cu_seqlens[2] = 30, cu_seqlens[3] = 35
flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen=max_prefill_seq_len,
)
7.4 FlashAttnImpl实现
class FlashAttnImpl(AttentionImpl):
"""FlashAttention注意力执行层"""
def forward(self, query, key, value, kv_cache, attn_metadata):
num_tokens = query.shape[0]
# Reshape
query = query.view(num_tokens, self.num_heads, self.head_size)
key = key.view(num_tokens, self.num_kv_heads, self.head_size)
value = value.view(num_tokens, self.num_kv_heads, self.head_size)
# 写入KV cache
self._write_kv_cache(key, value, kv_cache, attn_metadata.slot_mapping)
# 执行注意力
if attn_metadata.num_prefills > 0:
output = self._run_prefill(query, kv_cache, attn_metadata)
if attn_metadata.num_decode_tokens > 0:
decode_output = self._run_decode(query, kv_cache, attn_metadata)
if attn_metadata.num_prefills > 0:
# 混合批次:拼接prefill和decode输出
output = torch.cat([output, decode_output], dim=0)
else:
output = decode_output
return output.view(num_tokens, -1)
def _run_prefill(self, query, kv_cache, attn_metadata):
"""使用flash_attn_varlen_func执行prefill"""
from flash_attn import flash_attn_varlen_func
# 提取prefill部分的query
prefill_query = query[:attn_metadata.num_prefill_tokens]
# 获取KV(从KV cache中读取)
key_cache = kv_cache[0] # [num_blocks, block_size, ...]
value_cache = kv_cache[1]
# 使用varlen API
output = flash_attn_varlen_func(
q=prefill_query,
k=key_cache,
v=value_cache,
cu_seqlens_q=attn_metadata.cu_seqlens,
cu_seqlens_k=attn_metadata.cu_seqlens,
max_seqlen_q=attn_metadata.max_prefill_seq_len,
max_seqlen_k=attn_metadata.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
block_table=attn_metadata.block_tables,
)
return output
def _run_decode(self, query, kv_cache, attn_metadata):
"""使用flash_attn_with_kvcache执行decode"""
from flash_attn import flash_attn_with_kvcache
decode_query = query[attn_metadata.num_prefill_tokens:]
output = flash_attn_with_kvcache(
q=decode_query.unsqueeze(1), # [batch, 1, heads, dim]
k_cache=kv_cache[0],
v_cache=kv_cache[1],
cache_seqlens=attn_metadata.seq_lens_tensor,
block_table=attn_metadata.block_tables,
softmax_scale=self.scale,
causal=True,
)
return output.squeeze(1) # [batch, heads, dim]
7.5 FlashAttnDiffKV后端
class FlashAttnDiffKVBackend(FlashAttnBackend):
"""差异KV的FlashAttention后端
用于KV head维度不同的模型
例如: 某些层的Key和Value维度不同
与标准FlashAttn的区别:
- K和V的head_size可以不同
- 需要分别处理K和V的reshape
"""
@staticmethod
def get_name() -> str:
return "FLASH_ATTN_DIFFKV"
第八章 Triton后端深度解析
8.1 TritonAttnBackend类结构
class TritonAttnBackend(AttentionBackend):
"""Triton后端
使用自定义Triton kernel实现注意力
当FlashInfer和FlashAttention都不可用时的回退选项
"""
@staticmethod
def get_name() -> str:
return "TRITON"
@staticmethod
def get_impl_cls() -> type:
return TritonAttnImpl
@staticmethod
def get_kv_cache_shape(num_blocks, block_size, num_kv_heads, head_size):
return (2, num_blocks, block_size, num_kv_heads, head_size)
8.2 TritonAttnImpl实现
class TritonAttnImpl(AttentionImpl):
"""Triton注意力执行层
使用Triton kernel实现:
- prefill: triton_prefill_attention
- decode: triton_decode_attention
- reshape_and_cache: triton_reshape_and_cache_flash
"""
def forward(self, query, key, value, kv_cache, attn_metadata):
num_tokens = query.shape[0]
# Reshape
query = query.view(num_tokens, self.num_heads, self.head_size)
key = key.view(num_tokens, self.num_kv_heads, self.head_size)
value = value.view(num_tokens, self.num_kv_heads, self.head_size)
# 写入KV cache
ops.triton_reshape_and_cache_flash(
key, value, kv_cache[0], kv_cache[1],
attn_metadata.slot_mapping,
)
# 执行注意力
if attn_metadata.num_prefills > 0:
output = self._triton_prefill(query, kv_cache, attn_metadata)
else:
output = self._triton_decode(query, kv_cache, attn_metadata)
return output.view(num_tokens, -1)
8.3 TritonPrefillImpl与TritonDecodeImpl
第九章 CPU后端
9.1 CPUAttnBackend
class CPUAttnBackend(AttentionBackend):
"""CPU后端
在CPU设备上运行注意力计算
使用PyTorch原生实现
"""
@staticmethod
def get_name() -> str:
return "CPU"
@staticmethod
def get_impl_cls() -> type:
return CPUAttnImpl
@staticmethod
def get_kv_cache_shape(num_blocks, block_size, num_kv_heads, head_size):
# CPU使用标准的PagedAttention布局
return (2, num_blocks, block_size, num_kv_heads, head_size)
9.2 CPU注意力实现
class CPUAttnImpl(AttentionImpl):
"""CPU注意力执行层
使用PyTorch原生scaled_dot_product_attention
不使用任何外部库
"""
def forward(self, query, key, value, kv_cache, attn_metadata):
# CPU实现使用标准的PyTorch SDPA
# 注意: CPU不支持FlashAttention风格的tiling
# 1. 写入KV cache
self._write_kv_cache_cpu(key, value, kv_cache, attn_metadata)
# 2. 逐序列执行注意力
outputs = []
start = 0
for i, seq_len in enumerate(attn_metadata.seq_lens):
end = start + seq_len
# 提取当前序列的Q/K/V
q = query[start:end] # [seq_len, heads, dim]
k = key_cache_for_seq # [seq_len, kv_heads, dim]
v = value_cache_for_seq
# 计算注意力
attn_output = torch.nn.functional.scaled_dot_product_attention(
q.transpose(0, 1), # [heads, seq_len, dim]
k.transpose(0, 1),
v.transpose(0, 1),
is_causal=True,
scale=self.scale,
)
outputs.append(attn_output.transpose(0, 1))
start = end
return torch.cat(outputs, dim=0)
第十章 FlexAttention后端
10.1 FlexAttention概述
FlexAttention是PyTorch 2.5+引入的灵活注意力API,允许用户自定义score_modify函数,支持:
- Sliding Window
- Document Masking
- ALiBi
- 自定义注意力模式
class FlexAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "FLEX_ATTENTION"
@staticmethod
def get_impl_cls() -> type:
return FlexAttentionImpl
10.2 FlexAttentionMetadata
@dataclass
class FlexAttentionMetadata(AttentionMetadata):
"""FlexAttention专用元数据
主要增加score_modify相关的字段
"""
# Block mask for FlexAttention
# 定义哪些(query, key)位置需要计算注意力
block_mask: "BlockMask | None" = None
# Score modify function
# 自定义注意力分数修改函数
score_modify: "Callable | None" = None
第十一章 ROCm后端
11.1 ROCmAttnBackend
class ROCmAttnBackend(AttentionBackend):
"""ROCm/AMD GPU后端
使用AMD aiter库实现高效注意力
"""
@staticmethod
def get_name() -> str:
return "ROCM_ATTN"
11.2 ROCmAITerFABackend
class ROCmAITerFABackend(AttentionBackend):
"""ROCm aiter FlashAttention后端
使用aiter库的flash_attention_forward实现
专门为AMD GPU优化
"""
@staticmethod
def get_name() -> str:
return "ROCM_AITER_FA"
第十二章 TurboQuant后端
12.1 TurboQuant概述
TurboQuant是一种KV Cache量化方案,将KV值量化为低精度(如FP8/INT8),减少内存占用和带宽需求。
class TurboQuantBackend(AttentionBackend):
"""TurboQuant后端
特点:
- KV Cache使用FP8/INT8量化
- 解码时反量化
- 存储时量化
- 使用Triton kernel实现量化和反量化
"""
@staticmethod
def get_name() -> str:
return "TURBOQUANT"
@staticmethod
def get_kv_cache_shape(num_blocks, block_size, num_kv_heads, head_size):
# 量化KV cache的形状不同
# 额外维度存储量化参数(scale, zero_point)
return (
2, # K + V
num_blocks,
block_size,
num_kv_heads,
head_size,
# + 量化参数维度
)
12.2 TurboQuantMetadata
@dataclass
class TurboQuantMetadata(AttentionMetadata):
"""TurboQuant专用元数据"""
# 量化参数
kv_scale: torch.Tensor | None = None # KV缩放因子
kv_zp: torch.Tensor | None = None # KV零点
# 反量化临时缓冲区
dequant_key: torch.Tensor | None = None
dequant_value: torch.Tensor | None = None
第十三章 特殊架构后端
13.1 GDNAttnBackend
class GDNAttnBackend(AttentionBackend):
"""GDN(Geometric Deep Network)注意力后端
特点:
- 使用几何注意力机制
- 不使用标准的Q*K^T注意力
- 通过自定义AttentionImpl实现
"""
@staticmethod
def get_name() -> str:
return "GDN"
13.2 MambaAttnBackend
class MambaAttnBackend(AttentionBackend):
"""Mamba SSM后端
特点:
- 不使用标准注意力机制
- 使用状态空间模型(SSM)
- KV cache存储SSM状态而非K/V
- 注意力层被替换为SSM层
"""
@staticmethod
def get_attention_impl_cls() -> type:
# Mamba不使用标准Attention层
return MambaImpl # 自定义实现
@staticmethod
def get_kv_cache_shape(num_blocks, block_size, num_kv_heads, head_size):
# SSM状态的cache形状不同
# 保存A, B, C, D参数和中间状态
return (num_blocks, block_size, ...) # SSM特定形状
Mamba1AttnBackend vs Mamba2AttnBackend:
| 特性 | Mamba1 | Mamba2 |
|---|---|---|
| SSM算法 | 选择性SSM | 并行SSD |
| 实现 | mamba_ssm库 | mamba_ssm v2 |
| Cache | 离散步长+状态 | 连续参数+状态 |
| 并行化 | 有限 | 原生并行 |
13.3 LinearAttnBackend
class LinearAttnBackend(AttentionBackend):
"""线性注意力后端
特点:
- O(n)复杂度而非O(n²)
- 不计算完整的注意力矩阵
- 使用kernel trick近似
- 适用于长序列
"""
@staticmethod
def get_name() -> str:
return "LINEAR"
附录D 后端性能对比矩阵
| 后端 | Prefill性能 | Decode性能 | 内存效率 | 功能完备性 | CUDA Graph | 推荐场景 |
|---|---|---|---|---|---|---|
| FlashInfer | ★★★★★ | ★★★★★ | ★★★★ | ★★★★★ | ✅ | 通用CUDA |
| FlashAttention | ★★★★ | ★★★★ | ★★★★ | ★★★ | ✅(v3) | 无FlashInfer时 |
| Triton | ★★★ | ★★★ | ★★★ | ★★★ | ❌ | 回退/调试 |
| CPU | ★ | ★ | ★★ | ★★ | ❌ | CPU推理 |
| FlexAttention | ★★★ | ★★ | ★★ | ★★★★ | ❌ | 自定义mask |
| ROCm aiter | ★★★★ | ★★★★ | ★★★★ | ★★★ | ✅ | AMD GPU |
| TurboQuant | ★★★★ | ★★★★ | ★★★★★ | ★★ | ✅ | 量化KV |
| FlashMLA | ★★★★★ | ★★★★★ | ★★★★★ | ★★★★ | ✅ | DeepSeek |
| Mamba | N/A | ★★★★★ | ★★★★★ | ★★ | ✅ | SSM模型 |
附录E FlashInfer Workspace内存布局
附录O FlashInferMetadataBuilder.build() 完整流程追踪
O.1 构建步骤时序
FlashInferMetadataBuilder.build() 完整流程:
Step 1: 分类序列
prefill_indices = [i for i, ql in enumerate(query_lens) if ql > 1]
decode_indices = [i for i, ql in enumerate(query_lens) if ql == 1]
# query_lens > 1 → prefill (一次性处理多个token)
# query_lens == 1 → decode (每步1个token)
Step 2: 构建slot_mapping
slot_mapping = torch.empty(num_tokens, dtype=torch.int64)
for i, (seq_len, query_len) in enumerate(zip(seq_lens, query_lens)):
start = seq_len - query_len
for j in range(query_len):
logical_pos = start + j
block_id = block_tables[i, logical_pos // block_size]
offset = logical_pos % block_size
slot_mapping[token_offset + j] = block_id * block_size + offset
token_offset += query_len
Step 3: 构建Prefill paged_kv索引
if prefill_indices:
total_pages = 0
paged_kv_indptr[0] = 0
for i in prefill_indices:
seq_len = seq_lens[i]
num_pages = (seq_len + block_size - 1) // block_size
# 将block_table中的块ID复制到paged_kv_indices
for j in range(num_pages):
paged_kv_indices[total_pages + j] = block_tables[i, j]
total_pages += num_pages
paged_kv_indptr[prefill_idx + 1] = total_pages
# last_page_len: 最后一页中有效token数
paged_kv_last_page_len[prefill_idx] = seq_len % block_size or block_size
# 调用wrapper.plan()预计算
prefill_wrapper.plan(
paged_kv_indptr=paged_kv_indptr[:len(prefill_indices)+1],
paged_kv_indices=paged_kv_indices[:total_pages],
paged_kv_last_page_len=paged_kv_last_page_len[:len(prefill_indices)],
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
page_size=block_size,
causal=True,
)
Step 4: 构建Decode paged_kv索引(类似prefill但query_len=1)
Step 5: 组装metadata
return FlashInferMetadata(
num_prefills=len(prefill_indices),
num_decode_tokens=len(decode_indices),
slot_mapping=slot_mapping,
seq_lens=seq_lens_tensor,
block_tables=block_tables_tensor,
prefill_wrapper=prefill_wrapper,
decode_wrapper=decode_wrapper,
paged_kv_indices=paged_kv_indices,
paged_kv_indptr=paged_kv_indptr,
paged_kv_last_page_len=paged_kv_last_page_len,
...
)
附录P FlashInfer Wrapper API详解
P.1 BatchPrefillWithPagedKVCacheWrapper
# FlashInfer Prefill Wrapper API
class BatchPrefillWithPagedKVCacheWrapper:
"""FlashInfer Prefill包装器
生命周期:
1. 创建: wrapper = BatchPrefillWithPagedKVCacheWrapper(workspace)
2. 计划: wrapper.plan(indptr, indices, last_page_len, ...)
3. 执行: output = wrapper.run(query, paged_kv_cache, ...)
4. 重复步骤2-3
plan()做什么:
- 预计算注意力所需的索引和元数据
- 分配workspace内存
- 构建CUDA kernel的launch参数
- 这些计算只需要做一次,decode时直接使用
run()做什么:
- 执行实际的注意力计算
- 调用FlashInfer的CUDA kernel
"""
P.2 BatchDecodeWithPagedKVCacheWrapper
class BatchDecodeWithPagedKVCacheWrapper:
"""FlashInfer Decode包装器
与Prefill wrapper的区别:
- 每个序列只有1个查询token
- 使用不同的kernel(decode-specific优化)
- workspace大小不同
- 支持CUDA Graph(prefill通常不支持)
"""
附录Q FlashAttnMetadataBuilder.build() 关键差异
FlashAttnMetadataBuilder vs FlashInferMetadataBuilder:
1. 不使用paged_kv_indices/indptr
→ 使用seq_lens_tensor和block_tables直接传给flash_attn
2. 需要计算cu_seqlens
cu_seqlens[0] = 0
for i in range(batch_size):
cu_seqlens[i+1] = cu_seqlens[i] + query_lens[i]
# cu_seqlens = [0, 10, 30, 35] 表示:
# 序列0: token 0-9
# 序列1: token 10-29
# 序列2: token 30-34
3. block_tables需要padding到统一长度
FlashAttention要求block_tables的每行长度相同
→ 用-1填充无效位置
4. max_prefill_seq_len必须预先计算
→ 用于flash_attn_varlen_func的max_seqlen参数
附录R TritonAttnImpl Prefill Kernel参数详解
triton_prefill_attention() 参数:
输入:
q: [num_tokens, num_heads, head_size]
k_cache: [num_blocks, block_size, num_kv_heads, head_size]
v_cache: 同上
cu_seqlens_q: [batch_size + 1] 查询的累积长度
cu_seqlens_k: [batch_size + 1] Key的累积长度
max_seqlen_q: 最长查询序列
max_seqlen_k: 最长Key序列
block_table: [batch_size, max_blocks] 块表
scale: 1/sqrt(head_size)
causal: bool 是否causal mask
Kernel网格:
grid = (num_heads, DIV_UP(num_tokens, BLOCK_M))
BLOCK_M = 64 (查询块大小)
BLOCK_N = 64 (Key/Value块大小)
每个program处理:
- 1个注意力头
- BLOCK_M个连续查询token
Tiling策略 (Flash Attention):
外循环: 遍历KV块 (BLOCK_N个token/块)
内循环: 遍历查询块 (BLOCK_M个token/块)
对每个(KV块, 查询块)对:
1. 从cache读取K_block和V_block
2. 计算S = Q_block × K_block^T × scale
3. 应用causal mask (如果i < j → mask)
4. 在线softmax更新 (维护max_score, sum_exp, O)
最终: O /= sum_exp
附录S CPUAttnImpl 逐序列执行详解
CPU注意力计算流程:
for i in range(batch_size):
seq_len = seq_lens[i]
query_len = query_lens[i]
# 获取当前序列的KV cache
# 从Paged KV cache中读取该序列的所有KV
k_seq = torch.empty(seq_len, num_kv_heads, head_size)
v_seq = torch.empty(seq_len, num_kv_heads, head_size)
for j in range(seq_len):
block_id = block_tables[i, j // block_size]
offset = j % block_size
k_seq[j] = key_cache[block_id, offset] # [num_kv_heads, head_size]
v_seq[j] = value_cache[block_id, offset]
# 标准SDPA计算
q_seq = query[start:start+query_len] # [query_len, num_heads, head_size]
# GQA扩展: 将K/V复制到所有Q头
if num_queries_per_kv > 1:
k_expanded = k_seq.repeat_interleave(num_queries_per_kv, dim=1)
v_expanded = v_seq.repeat_interleave(num_queries_per_kv, dim=1)
attn_output = F.scaled_dot_product_attention(
q_seq.transpose(0, 1), # [num_heads, query_len, head_size]
k_expanded.transpose(0, 1), # [num_heads, seq_len, head_size]
v_expanded.transpose(0, 1),
is_causal=(attn_type == DECODER),
scale=scale,
)
outputs.append(attn_output.transpose(0, 1))
start += query_len
# 拼接所有序列的输出
output = torch.cat(outputs, dim=0) # [num_tokens, num_heads, head_size]
附录T FlexAttentionImpl Score Modify详解
FlexAttention的核心是score_modify函数:
标准注意力:
scores = Q × K^T × scale
probs = softmax(scores)
FlexAttention:
scores = Q × K^T
scores = score_modify(scores, positions) # ★自定义修改★
probs = softmax(scores × scale)
score_modify示例:
1. ALiBi:
def alibi_modify(scores, q_pos, k_pos):
return scores + slopes * (k_pos - q_pos)
# slopes: 每头不同的偏置斜率
# (k_pos - q_pos): 相对位置
2. Sliding Window:
def sliding_window_modify(scores, q_pos, k_pos):
mask = (k_pos < q_pos - window_size) | (k_pos > q_pos)
scores[mask] = -inf
return scores
# 只关注[q_pos - window_size, q_pos]范围内的token
3. Document Masking:
def document_modify(scores, doc_ids):
# 同一文档内可以互相关注
# 不同文档之间不能关注
mask = (doc_ids[q] != doc_ids[k])
scores[mask] = -inf
return scores
FlexAttention将这些score_modify编译为高效的CUDA kernel
通过torch.compile + BlockMask实现
附录U FlashInfer Paged KV Cache索引构建完整追踪
U.1 从block_table到FlashInfer索引的转换
输入: block_tables (PagedAttention格式)
block_tables = [
[0, 1, 3, -1], # 序列0: 3个块 (0,1,3)
[2, 5, -1, -1], # 序列1: 2个块 (2,5)
[4, 6, 7, 8], # 序列2: 4个块 (4,6,7,8)
]
seq_lens = [40, 25, 50]
block_size = 16
输出: FlashInfer格式
paged_kv_indptr = [0, 3, 5, 9]
# 序列0: blocks[0:3]
# 序列1: blocks[3:5]
# 序列2: blocks[5:9]
paged_kv_indices = [0, 1, 3, 2, 5, 4, 6, 7, 8]
# 序列0: [0, 1, 3]
# 序列1: [2, 5]
# 序列2: [4, 6, 7, 8]
paged_kv_last_page_len = [8, 9, 2]
# 序列0: 40%16=8 → 最后一页8个有效token
# 序列1: 25%16=9 → 最后一页9个有效token
# 序列2: 50%16=2 → 最后一页2个有效token
U.2 构建代码逐行
def _build_paged_kv_index(
self,
indices: list[int], # 需要构建索引的序列索引
seq_lens: list[int], # 所有序列长度
query_lens: list[int], # 查询长度
is_prefill: bool, # 是否prefill
) -> None:
"""从block_tables构建FlashInfer的paged_kv索引"""
total_pages = 0
self.paged_kv_indptr[0] = 0
for idx, seq_idx in enumerate(indices):
seq_len = seq_lens[seq_idx]
# 计算该序列的页数
num_pages = (seq_len + self.block_size - 1) // self.block_size
# 从block_tables复制块ID到paged_kv_indices
for j in range(num_pages):
self.paged_kv_indices[total_pages + j] = self.block_tables[seq_idx, j]
total_pages += num_pages
# 更新indptr
self.paged_kv_indptr[idx + 1] = total_pages
# 计算last_page_len
remainder = seq_len % self.block_size
self.paged_kv_last_page_len[idx] = remainder if remainder > 0 else self.block_size
# 特殊处理: 余数为0意味着最后一页恰好满,使用block_size
# 裁剪到实际大小
self.paged_kv_indptr = self.paged_kv_indptr[:len(indices) + 1]
self.paged_kv_indices = self.paged_kv_indices[:total_pages]
self.paged_kv_last_page_len = self.paged_kv_last_page_len[:len(indices)]
附录V FlashInfer vs FlashAttention 性能特征深度对比
V.1 Decode性能
FlashInfer Decode:
- 专门的decode kernel (per-token optimized)
- 每序列1个Q token → K/V扫描整个cache
- 支持Paged KV Cache(零散块读取)
- Workspace预分配(CUDA Graph友好)
- 吞吐: ~1000 tok/s per A100 (batch=256)
FlashAttention Decode:
- flash_attn_with_kvcache
- 同样per-token模式
- 支持Paged KV Cache
- 需要block_table参数
- 吞吐: ~900 tok/s per A100 (batch=256)
# 略低于FlashInfer,因为FlashInfer有更多decode-specific优化
V.2 Prefill性能
FlashInfer Prefill:
- BatchPrefillWithPagedKVCacheWrapper
- 支持变长batch(不同序列长度)
- 支持CUDA Graph(需要固定batch_size)
- 吞吐: ~15000 tok/s per A100 (prefill_only)
FlashAttention Prefill:
- flash_attn_varlen_func
- 支持变长batch
- CUDA Graph支持有限
- 吞吐: ~14000 tok/s per A100 (prefill_only)
V.3 混合批次性能
FlashInfer:
- 原生支持混合batch(prefill+decode同时执行)
- 使用统一wrapper
- 吞吐: 比分离执行高10-20%
FlashAttention:
- 混合batch需要分别调用prefill和decode
- 然后拼接输出
- 吞吐: 比分离执行略低(拼接开销)
附录W ROCm后端aiter库详解
W.1 aiter库概述
aiter是AMD ROCm平台的高效推理库
核心组件:
1. flash_attention_forward: AMD GPU优化的FlashAttention
2. paged_attention: AMD GPU的分页注意力
3. mla_attention: AMD GPU的MLA注意力
4. fused_ops: 融合操作(RoPE, RMSNorm等)
与CUDA后端的对应关系:
CUDA FlashInfer → ROCm aiter FA
CUDA FlashAttention → ROCm aiter FA
CUDA FlashMLA → ROCm aiter MLA
W.2 ROCmAITerFABackend实现
class ROCmAITerFABackend(AttentionBackend):
"""ROCm aiter FlashAttention后端"""
@staticmethod
def get_name() -> str:
return "ROCM_AITER_FA"
@staticmethod
def get_impl_cls() -> type:
return ROCmAITerFAImpl
class ROCmAITerFAImpl(AttentionImpl):
def forward(self, query, key, value, kv_cache, attn_metadata):
# 使用aiter的flash_attention_forward
from aiter import flash_attention_forward
output = flash_attention_forward(
query, key, value,
cu_seqlens=attn_metadata.cu_seqlens,
max_seqlen=attn_metadata.max_prefill_seq_len,
causal=True,
softmax_scale=self.scale,
)
return output
附录X Mamba SSM后端详解
X.1 Mamba状态空间模型
Mamba不使用标准注意力,而是使用SSM:
标准注意力:
y = softmax(Q × K^T / √d) × V
复杂度: O(n²) 序列长度
Mamba SSM:
h_t = A × h_{t-1} + B × x_t # 状态递推
y_t = C × h_t # 输出
复杂度: O(n) 序列长度
KV Cache替代:
Mamba的"cache"是SSM的状态h_t
h_t的维度远小于KV cache
不需要PagedAttention式的分页管理
X.2 MambaAttnBackend实现
class MambaAttnBackend(AttentionBackend):
"""Mamba SSM后端"""
@staticmethod
def get_attention_impl_cls() -> type:
# Mamba不使用标准Attention层
return MambaImpl
@staticmethod
def get_kv_cache_shape(num_blocks, block_size, num_kv_heads, head_size):
# SSM状态的cache形状
# 存储A, B, C, D, dt等SSM参数和状态
return (num_blocks, block_size, state_dim)
@staticmethod
def get_supported_head_sizes() -> list[int]:
# SSM没有head_size概念
return []
X.3 Mamba1 vs Mamba2
| 特性 | Mamba1 | Mamba2 (SSD) |
|---|---|---|
| 算法 | 选择性SSM | 结构化状态空间对偶 |
| 并行化 | 有限 (scan) | 原生并行 (matrix) |
| 吞吐 | 中 | 高 (~2-5×) |
| 实现 | mamba_ssm | mamba_ssm v2 |
| 状态大小 | d_model × d_state | d_model × d_state |
| Flash风格 | 无 | 支持类似Flash的tiling |
附录Y FlashInfer Wrapper plan()详解
Y.1 BatchPrefillWithPagedKVCacheWrapper.plan()
def plan(
self,
paged_kv_indptr: torch.Tensor, # [batch+1] 页指针
paged_kv_indices: torch.Tensor, # [total_pages] 页索引
paged_kv_last_page_len: torch.Tensor, # [batch] 最后一页长度
num_qo_heads: int, # Q头数
num_kv_heads: int, # KV头数
head_dim: int, # 头维度
page_size: int, # 页大小(block_size)
causal: bool = True, # 是否causal
pos_encoding_mode: str = "NONE", # 位置编码模式
qo_indptr: torch.Tensor | None = None, # Q的CSR指针
custom_mask: torch.Tensor | None = None, # 自定义mask
) -> None:
"""预计算prefill所需的索引和元数据
plan()做什么:
1. 构建CUDA kernel的launch grid参数
grid = (num_heads, num_qo_tiles, batch_size)
num_qo_tiles = sum(ceil(query_len / tile_size))
2. 计算每个序列的KV范围
对于每个序列i:
kv_start = paged_kv_indptr[i] * page_size
kv_end = (paged_kv_indptr[i+1]-1) * page_size + last_page_len[i]
3. 分配workspace
workspace_size = batch_size * num_heads * head_dim * sizeof(float)
+ 累积器空间
4. 将索引复制到GPU
paged_kv_indptr_gpu = paged_kv_indptr.to(device)
paged_kv_indices_gpu = paged_kv_indices.to(device)
为什么需要plan():
- FlashInfer的kernel是预编译的
- plan()提供了kernel需要的所有运行时参数
- 避免在run()时重复计算
- 支持CUDA Graph(plan一次,run多次)
"""
Y.2 BatchDecodeWithPagedKVCacheWrapper.plan()
def plan(
self,
paged_kv_indptr: torch.Tensor,
paged_kv_indices: torch.Tensor,
paged_kv_last_page_len: torch.Tensor,
num_qo_heads: int,
num_kv_heads: int,
head_dim: int,
page_size: int,
pos_encoding_mode: str = "NONE",
sliding_window: int | None = None, # Decode特有: 滑动窗口
) -> None:
"""预计算decode所需的索引和元数据
与prefill.plan()的区别:
1. 每个序列只有1个查询token
→ 不需要qo_indptr
→ grid更简单
2. 支持sliding_window
→ 只读取最近window_size个KV
→ 减少内存带宽
3. workspace大小不同
decode workspace更小(每token只需累加1个输出)
"""
附录Z FlashAttention v2 vs v3 API差异详解
Z.1 Prefill API
FlashAttention v2:
flash_attn_varlen_func(
q, k, v, # Q/K/V
cu_seqlens_q, cu_seqlens_k, # 累积长度
max_seqlen_q, max_seqlen_k, # 最大长度
softmax_scale,
causal=True,
)
→ 返回: output [num_tokens, heads, dim]
特点:
- 变长batch API
- K/V从输入参数传入(不是从cache读取)
- 需要手动将KV cache中的K/V读取出来
FlashAttention v3:
flash_attn_varlen_func(
q, k, v, # 同v2
cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k,
softmax_scale,
causal=True,
block_table=block_tables, # ★新增★: 直接从Paged KV Cache读取
)
→ 返回: output
特点:
- 支持Paged KV Cache
- 不需要手动读取KV
- 更高效的内存访问
Z.2 Decode API
FlashAttention v2:
flash_attn_with_kvcache(
q, # [batch, 1, heads, dim]
k_cache, v_cache, # KV cache
cache_seqlens, # [batch] 缓存序列长度
block_table, # [batch, max_blocks] 块表
softmax_scale,
causal=True,
)
FlashAttention v3:
flash_attn_with_kvcache(
q,
k_cache, v_cache,
cache_seqlens,
block_table,
softmax_scale,
causal=True,
# ★新增参数★
cache_batch_idx=None, # 批次索引(用于非连续cache)
q_batch_idx=None, # Q的批次索引
q_start_idx=0, # Q的起始索引
cache_leftpad=None, # 左填充(特定格式)
)
v3改进:
- 支持非连续的Q和KV cache布局
- 支持q_start_idx(混合batch中Q的偏移)
- 更好的CUDA Graph兼容性
附录AA FlashAttention 内置RoPE vs 外置RoPE
FlashInfer的RoPE处理方式:
FlashInfer内置了RoPE,通过以下参数传入:
rotary_inv_q: RoPE的逆频率(Q)
rotary_inv_k: RoPE的逆频率(K)
在kernel内部完成RoPE计算:
q_rope = apply_rope(q, positions, inv_freq_q)
k_rope = apply_rope(k, positions, inv_freq_k)
优点: 减少一次kernel launch(RoPE+Attn融合)
缺点: 只支持标准RoPE,不支持自定义变体
FlashAttention的RoPE处理方式:
在调用flash_attn之前,在外部计算RoPE:
q = apply_rope(q, positions, freqs) # 独立kernel
k = apply_rope(k, positions, freqs)
output = flash_attn(q, k, v, ...)
优点: 支持任意RoPE变体(YaRN, LongRoPE等)
缺点: 额外2次kernel launch
Triton的RoPE处理方式:
同FlashAttention,外置RoPE
可以将RoPE融合到Q/K投影kernel中
附录BB GDN/Linear/ShortConv 注意力变体详解
BB.1 GDN (Geometric Deep Network) 注意力
# GDN使用几何注意力机制
# 不计算 Q*K^T,而是使用距离度量
# 核心公式:
# attention = softmax(-distance(Q, K) / temperature) × V
# distance可以是: L2距离, 余弦距离, 或学习到的度量
class GDNImpl(AttentionImpl):
def forward(self, query, key, value, kv_cache, attn_metadata):
# 不使用标准Q*K^T
# 使用几何距离
distance = compute_geometric_distance(query, key)
scores = -distance / self.temperature
attn = softmax(scores)
output = attn @ value
return output
BB.2 Linear Attention
# 线性注意力: O(n)复杂度
# 核心思想: 使用kernel trick避免显式计算注意力矩阵
#
# 标准注意力: O(n²)
# O = softmax(Q×K^T/√d) × V
#
# 线性注意力: O(n×d)
# O = φ(Q) × (φ(K)^T × V)
# 其中φ是特征映射函数
# 由于矩阵乘法结合律: 先计算 φ(K)^T × V [d×d],再乘以 φ(Q) [n×d]
# 复杂度: n×d² (vs n²×d)
# 常见的φ函数:
# 1. ELU+1: φ(x) = elu(x) + 1
# 2. Softmax近似: φ(x) = exp(x) / √d (Performer)
# 3. 随机特征: φ(x) = random_features(x) (FAVOR+)
class LinearAttnImpl(AttentionImpl):
def forward(self, query, key, value, kv_cache, attn_metadata):
phi_q = self.feature_map(query) # [n, d]
phi_k = self.feature_map(key) # [n, d]
# 累积KV状态: S = Σ φ(k_i)^T × v_i
# 这是O(d²)的状态,可以增量更新
S = phi_k.T @ value # [d, d_v]
# 输出: O = φ(Q) × S / (φ(Q) × z)
# z = Σ φ(k_i) 的归一化因子
z = phi_q @ phi_k.sum(dim=0, keepdim=True).T
output = phi_q @ S / z
return output
BB.3 ShortConv (短卷积替代注意力)
# 短卷积: 用1D卷积替代注意力
# 感受野: 固定窗口大小(kernel_size)
# 复杂度: O(n × kernel_size)
#
# 适用场景: 局部依赖(不需要全局上下文)
# 常见于: 长序列模型的早期层
class ShortConvImpl(AttentionImpl):
def forward(self, query, key, value, kv_cache, attn_metadata):
# 使用1D因果卷积
# kernel_size = 3 或 5
output = self.causal_conv1d(value, self.weight, self.bias)
return output
@staticmethod
def get_kv_cache_shape(*args):
# 短卷积不需要KV Cache
return (0,) # 空cache
附录CC ROCm平台注意力后端完整选择逻辑
ROCm平台的注意力后端选择:
1. 检测ROCm版本和GPU型号
rocm_version = get_rocm_version()
gpu_arch = get_gpu_arch() # gfx90a, gfx942, etc.
2. 检测aiter库可用性
aiter_available = importlib.util.find_spec("aiter") is not None
3. 后端选择逻辑:
if aiter_available:
if is_mla_model():
→ ROCmAITerMLABackend 或 ROCmAITerMLASparseBackend
else:
→ ROCmAITerFABackend # aiter的FlashAttention实现
else:
→ ROCmAttnBackend # 基于composable_kernel的实现
4. aiter库的性能特点:
- 专门为AMD GPU优化
- 支持MI200/MI300系列
- 使用WMMA (Wave Matrix Multiply-Accumulate)
- 与CUDA FlashInfer的API对齐
5. 不支持的后端:
❌ FlashInfer (仅CUDA)
❌ FlashAttention v3 (仅CUDA)
❌ Triton (ROCm支持有限)
❌ CUTLASS (仅CUDA)
附录DD FlashInfer decode kernel GQA优化详解
DD.1 GQA在decode中的内存访问模式
GQA decode的关键: 多个Q头共享1组KV头
假设: num_heads=32, num_kv_heads=8, kv_group_num=4
标准实现:
for head in range(32):
kv_head = head // 4 # 0,0,0,0,1,1,1,1,...,7,7,7,7
k = k_cache[kv_head] # 读取KV头
scores = q[head] @ k.T
output[head] = softmax(scores) @ v_cache[kv_head]
问题: KV被重复读取4次(每4个Q头读取同一组KV)
优化实现 (FlashInfer):
# 将Q按KV头分组
for kv_head in range(8):
k = k_cache[kv_head] # 只读1次!
v = v_cache[kv_head] # 只读1次!
for q_head in range(kv_head*4, (kv_head+1)*4):
scores = q[q_head] @ k.T
output[q_head] = softmax(scores) @ v
优化: KV读取从32次减少到8次
内存带宽节省: 75%
DD.2 FlashInfer的KV cache共享机制
FlashInfer decode kernel的KV读取策略:
1. 所有Q头先共享读取KV:
# 第一步: 将KV从cache加载到shared memory
for kv_head in range(num_kv_heads):
k_shared[kv_head] = load_from_cache(kv_head)
v_shared[kv_head] = load_from_cache(kv_head)
# 第二步: 每个Q头使用shared memory中的KV
for q_head in range(num_heads):
kv_head = q_head // kv_group_num
scores = q[q_head] @ k_shared[kv_head].T
output[q_head] = softmax(scores) @ v_shared[kv_head]
2. shared memory利用:
KV在shared memory中 → 所有Q头共享
→ 全局内存读取量: num_kv_heads × head_size × 2
→ 而非: num_heads × head_size × 2
GQA节省: num_kv_heads/num_heads = 8/32 = 4×
更多推荐


所有评论(0)