【vllm】(v1 Attention)vLLM V1 Attention— Part4 底层Ops算子
vLLM V1 Attention 模块超深度架构分析 — Part 4: 底层Ops算子
·
vLLM V1 Attention 模块超深度架构分析 — Part 4: 底层Ops算子
分析范围:
v1/attention/ops/目录全部源码(24个文件,约9,600行)
分析日期: 2026-05-25
分析深度: 架构师级,逐行解析,Mermaid图表25+
目录
- 第二十章 通用操作算子(common.py)
- 第二十一章 Triton注意力Kernel
- 第二十二章 前缀预填充与分块解码
- 第二十三章 DeepSeek V4专用算子
- 第二十四章 视觉注意力包装器
- 第二十五章 DCP全到全通信
- 第二十六章 TurboQuant算子
- 附录H Triton Kernel计算网格分析
- 附录I KV Cache写入时序图
- 附录J 术语表补充
第二十章 通用操作算子(common.py)
20.1 reshape_and_cache() KV写入
def reshape_and_cache(
key: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
value: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
key_cache: torch.Tensor, # [num_blocks, block_size, num_kv_heads, head_size]
value_cache: torch.Tensor, # 同上
slot_mapping: torch.Tensor, # [num_tokens] → 物理slot索引
kv_cache_dtype: str, # "auto"/"fp8"/"fp8_e4m3"/...
kv_scale: float = 1.0, # FP8量化缩放因子
kv_zp: float = 0.0, # FP8量化零点
) -> None:
"""将Key和Value写入Paged KV Cache
流程:
1. 遍历每个token
2. 根据slot_mapping找到写入位置
3. 将K/V写入对应的cache位置
4. 如果kv_cache_dtype为FP8,执行量化
原地修改key_cache和value_cache
"""
# 量化路径
if kv_cache_dtype.startswith("fp8"):
# FP8量化写入
key_8bit = _quantize_fp8(key, kv_scale, kv_zp)
value_8bit = _quantize_fp8(value, kv_scale, kv_zp)
# 使用量化值写入cache
_reshape_and_cache_impl(key_8bit, value_8bit, key_cache, value_cache, slot_mapping)
else:
# 标准FP16/BF16写入
_reshape_and_cache_impl(key, value, key_cache, value_cache, slot_mapping)
slot_mapping写入示意:
20.2 reshape_and_cache_flash() Flash版
def reshape_and_cache_flash(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
kv_scale: float = 1.0,
kv_zp: float = 0.0,
) -> None:
"""Flash版KV写入
与reshape_and_cache的区别:
- 使用Triton kernel实现(更高性能)
- 支持FP8量化
- 与FlashInfer/FlashAttention兼容的内存布局
底层实现: triton_reshape_and_cache_flash kernel
"""
# 选择实现路径
if kv_cache_dtype in ("auto", None, "fp16", "bfloat16"):
# 标准路径: Triton kernel
triton_reshape_and_cache_flash_kernel(
key, value, key_cache, value_cache, slot_mapping
)
elif kv_cache_dtype.startswith("fp8"):
# FP8量化路径: 先量化再写入
triton_reshape_and_cache_flash_fp8_kernel(
key, value, key_cache, value_cache, slot_mapping,
kv_scale, kv_zp
)
20.3 copy_cache() 缓存复制
def copy_cache(
src_kv_cache: torch.Tensor, # 源KV cache
dst_kv_cache: torch.Tensor, # 目标KV cache
src_to_dst_block_id: torch.Tensor, # 源→目标块ID映射
num_heads: int,
head_size: int,
block_size: int,
cache_type: str,
) -> None:
"""复制KV cache块(用于preemption/swap)
场景: 当请求被抢占时,需要将其KV cache从GPU复制到CPU
或从CPU恢复到GPU
src_to_dst_block_id[i] = j 表示源块i → 目标块j
"""
20.4 swap_cache() 缓存交换
def swap_cache(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst_block_id: torch.Tensor,
num_heads: int,
head_size: int,
block_size: int,
) -> None:
"""交换GPU和CPU之间的KV cache
用于swap scheduling:
- swap_out: GPU → CPU
- swap_in: CPU → GPU
底层使用CUDA memcpy实现异步传输
"""
第二十一章 Triton注意力Kernel
21.1 triton_decode_attention.py解码注意力
核心Kernel: _decode_attention_kernel
输入:
q: [batch, num_heads, head_size] # 查询向量
k_cache: [num_blocks, block_size, num_kv_heads, head_size] # KV cache
v_cache: 同上
block_table: [batch, max_blocks] # 块表
seq_lens: [batch] # 序列长度
scale: float # 1/sqrt(d)
输出:
output: [batch, num_heads, head_size] # 注意力输出
算法(在线softmax):
for each query head:
max_score = -inf
sum_exp = 0
output = 0
for each KV block in sequence:
# 读取一个block的K/V
k_block = k_cache[block_table[batch, block_idx]] # [block_size, dim]
v_block = v_cache[block_table[batch, block_idx]]
# 计算注意力分数
scores = q @ k_block.T * scale # [block_size]
# 在线softmax更新
new_max = max(max_score, max(scores))
correction = exp(max_score - new_max)
sum_exp = sum_exp * correction + sum(exp(scores - new_max))
output = output * correction + exp(scores - new_max) @ v_block
max_score = new_max
output = output / sum_exp
在线softmax(Online Softmax) 的数学推导:
标准softmax: output = softmax(Q×K^T) × V
= Σ_i (exp(s_i) / Σ_j exp(s_j)) × V_i
问题: Σ_j exp(s_j)需要先遍历所有K,再计算softmax
在线softmax: 逐块累加,无需先遍历
维护: max_score, sum_exp, output三个累积器
处理第n个块时:
scores_n = Q × K_n^T * scale
new_max = max(old_max, max(scores_n))
# 新的全局最大值
correction = exp(old_max - new_max)
# 旧值需要乘以correction来重新归一化
sum_exp = sum_exp * correction + Σ exp(scores_n - new_max)
output = output * correction + Σ (exp(scores_n - new_max) × V_n)
# correction因子确保了数值稳定性:
# 减去new_max后再取exp,避免exp(large_value)溢出
最终: output /= sum_exp
21.2 triton_prefill_attention.py预填充注意力
Prefill Triton kernel使用Flash风格的tiling:
输入:
q: [num_tokens, num_heads, head_size] # 所有prefill token的Q
k_cache, v_cache: 同decode
block_table, cu_seqlens, max_seqlen
算法(Flash Attention tiling):
for each query block (B_r tokens):
for each key/value block (B_c tokens):
# 计算Q×K^T子矩阵
S = Q[q_start:q_end] @ K[k_start:k_end].T * scale
# [B_r, B_c]
# 应用causal mask
if is_causal:
S[mask] = -inf # 未来位置屏蔽
# 在线softmax更新(与decode类似但处理矩阵)
new_max = max(old_max, rowmax(S))
correction = exp(old_max - new_max)
sum_exp = sum_exp * correction + rowsum(exp(S - new_max))
O = O * correction + exp(S - new_max) @ V[k_start:k_end]
O /= sum_exp # 归一化
21.3 triton_unified_attention.py统一注意力
def triton_unified_attention(
q: torch.Tensor, # [num_tokens, num_heads, head_size]
k_cache: torch.Tensor, # KV cache
v_cache: torch.Tensor,
block_table: torch.Tensor,
cu_seqlens: torch.Tensor,
seq_lens: torch.Tensor,
max_seqlen: int,
scale: float,
is_causal: bool = True,
) -> torch.Tensor:
"""统一注意力kernel
同时处理prefill和decode
内部通过cu_seqlens区分不同序列
与分别调用的区别:
- 减少kernel launch开销
- 更好的GPU利用率
- 简化调用逻辑
"""
21.4 triton_reshape_and_cache_flash.py写入kernel
@triton.jit
def _reshape_and_cache_flash_kernel(
key_ptr, value_ptr, # K/V输入指针
key_cache_ptr, value_cache_ptr, # KV cache指针
slot_mapping_ptr, # slot映射
stride_k_n, stride_k_h, stride_k_d, # K的stride
stride_v_n, stride_v_h, stride_v_d,
stride_cache_b, stride_cache_s, stride_cache_h, stride_cache_d,
BLOCK_SIZE: tl.constexpr,
):
"""Triton kernel: 将K/V写入KV cache
每个program处理一个(head, token)对
流程:
1. 从slot_mapping获取写入位置
2. 计算目标cache的物理地址
3. 将K/V值写入cache
"""
token_idx = tl.program_id(0) # token索引
head_idx = tl.program_id(1) # 头索引
# 读取slot
slot = tl.load(slot_mapping_ptr + token_idx)
# 计算cache位置
block_idx = slot // BLOCK_SIZE
offset = slot % BLOCK_SIZE
# 写入Key
key_value = tl.load(key_ptr + token_idx * stride_k_n + head_idx * stride_k_h)
tl.store(
key_cache_ptr + block_idx * stride_cache_b + offset * stride_cache_s + head_idx * stride_cache_h,
key_value
)
# 写入Value
value_value = tl.load(value_ptr + token_idx * stride_v_n + head_idx * stride_v_h)
tl.store(
value_cache_ptr + block_idx * stride_cache_b + offset * stride_cache_s + head_idx * stride_cache_h,
value_value
)
21.5 triton_attention_helpers.py辅助函数
def _get_block_size_and_num_warps(
head_size: int,
) -> tuple[int, int]:
"""根据head_size选择最优的block_size和num_warps
经验值:
- head_size ≤ 64: BLOCK_SIZE=16, num_warps=4
- head_size ≤ 128: BLOCK_SIZE=16, num_warps=8
- head_size ≤ 256: BLOCK_SIZE=16, num_warps=16
"""
if head_size <= 64:
return 16, 4
elif head_size <= 128:
return 16, 8
else:
return 16, 16
def _get_num_warps_and_stages(
kv_group_num: int, # num_heads / num_kv_heads
block_size: int,
) -> tuple[int, int]:
"""选择warp数和pipeline stages
kv_group_num越大 → GQA重用越多 → 可以用更多warps
"""
if kv_group_num >= 4:
return 8, 3 # 8 warps, 3 pipeline stages
elif kv_group_num >= 2:
return 4, 3
else:
return 4, 2
21.6 triton_merge_attn_states.py状态合并
def merge_attn_states(
base_output: torch.Tensor, # [batch, heads, dim] 基础注意力输出
base_softmax_log: torch.Tensor, # [batch, heads] 基础softmax logsumexp
new_output: torch.Tensor, # [batch, heads, dim] 新注意力输出
new_softmax_log: torch.Tensor, # [batch, heads] 新softmax logsumexp
) -> torch.Tensor:
"""合并两组注意力状态
数学:
设两组softmax概率为 p_i 和 q_i,对应的logsumexp为 l_i 和 m_i
合并后的softmax:
output = (exp(l - max) * base_output + exp(m - max) * new_output)
/ (exp(l - max) + exp(m - max))
其中 max = max(l, m)
用途:
- Prefix caching: 合并prefix注意力和生成注意力
- Chunked prefill: 合并不同chunk的注意力
"""
第二十二章 前缀预填充与分块解码
22.1 prefix_prefill.py
class PrefixPrefill:
"""前缀预填充
当多个请求共享相同的system prompt时,
只需计算一次前缀的KV cache,然后共享给所有请求
流程:
1. 首次遇到prefix: 正常prefill计算,KV cache写入
2. 后续请求: 直接复制prefix的KV cache块,无需重新计算
好处:
- 避免重复计算相同的system prompt
- 显著减少prefill延迟
- 降低GPU计算量
"""
Prefix Cache示意:
22.2 chunked_prefill_paged_decode.py
class ChunkedPrefillPagedDecode:
"""分块预填充 + 分页解码
将长序列的prefill拆分为多个chunk,每个chunk:
1. 计算当前chunk的注意力
2. 将KV cache写入对应块
3. 与之前chunk的结果合并
好处:
- 避免超长序列导致OOM
- 控制单步计算量
- 与decode共享GPU时间
合并逻辑:
使用merge_attn_states()合并各chunk的注意力状态
"""
第二十三章 DeepSeek V4专用算子
23.1 cache_utils.py缓存工具
class DeepseekV4CacheUtils:
"""DeepSeek V4专用缓存工具
DeepSeek V4使用特殊的KV cache布局:
- 分离的K和V缓存
- K使用RoPE后存储
- V使用压缩latent存储
- 支持FP8量化
"""
@staticmethod
def get_kv_cache_shape(
num_blocks, block_size, num_kv_heads, head_size,
kv_lora_rank, q_lora_rank,
) -> tuple[tuple, tuple]:
"""返回DeepSeek V4的KV cache形状
K Cache: [num_blocks, block_size, num_kv_heads, head_size_rope]
head_size_rope = head_size - nope_dim (RoPE维度)
V Cache (latent): [num_blocks, block_size, kv_lora_rank]
额外缓存:
- k_pe: [num_blocks, block_size, 1, head_size_rope]
Key的位置编码(RoPE部分)
"""
23.2 fused_compress_quant_cache.py融合压缩量化
@triton.jit
def fused_compress_quant_cache_kernel(
# 输入
kv_input_ptr, # 原始KV输入
# 输出
compressed_cache_ptr, # 压缩后的cache
quant_scale_ptr, # 量化缩放因子
quant_zp_ptr, # 量化零点
# 参数
kv_lora_rank, # 压缩维度
num_heads,
head_size,
block_size,
FP8_QUANT: tl.constexpr, # 是否FP8量化
):
"""融合压缩+量化+写入cache的kernel
在一个kernel中完成:
1. KV → latent压缩 (W_DKV × KV)
2. latent → FP8量化 (如果启用)
3. 写入paged cache
好处:
- 减少中间结果的内存读写
- 单次kernel launch完成三步操作
- 降低延迟
"""
23.3 fused_indexer_q.py融合Q索引器
def fused_indexer_q(
q_input: torch.Tensor, # [num_tokens, q_lora_rank + head_size_rope]
q_lora_rank: int,
head_size_rope: int,
num_heads: int,
slot_mapping: torch.Tensor,
q_cache: torch.Tensor, # Q cache(某些模型缓存Q用于推理)
) -> torch.Tensor:
"""融合Q索引+分离+写入
DeepSeek V4的Q有两个部分:
1. q_lora: [q_lora_rank] 压缩的Q latent
2. q_rope: [num_heads, head_size_rope] RoPE部分的Q
本kernel:
1. 将Q分离为q_lora和q_rope
2. 对q_rope应用RoPE
3. 将两部分写入各自的cache位置
"""
23.4 fused_inv_rope_fp8_quant.py融合RoPE+量化
@triton.jit
def fused_inv_rope_fp8_quant_kernel(
q_input_ptr, # 输入Q
q_output_ptr, # 输出Q(RoPE后+量化后)
inv_freq_ptr, # RoPE逆频率
position_ptr, # 位置索引
scale_ptr, # FP8量化缩放
zp_ptr, # FP8量化零点
HEAD_SIZE: tl.constexpr,
APPLY_ROPE: tl.constexpr,
FP8_QUANT: tl.constexpr,
):
"""融合: 逆RoPE + FP8量化
在一个kernel中完成:
1. 应用逆RoPE(将RoPE编码的Q还原为原始Q)
用于某些模型需要在decode时重新应用RoPE
2. FP8量化(如果启用)
数学:
逆RoPE: q_orig = RoPE^(-1)(q_rope, position, inv_freq)
FP8量化: q_fp8 = quantize(q_orig, scale, zp)
"""
23.5 fused_qk_rmsnorm.py融合RMSNorm
@triton.jit
def fused_qk_rmsnorm_kernel(
q_ptr, k_ptr, # Q和K输入
q_out_ptr, k_out_ptr, # Q和K输出
gamma_ptr, # RMSNorm的gamma参数
HEAD_SIZE: tl.constexpr,
EPSILON: tl.constexpr,
):
"""融合Q+K RMSNorm
在一个kernel中对Q和K同时应用RMSNorm
RMSNorm(x) = x / RMS(x) * gamma
RMS(x) = sqrt(mean(x^2) + epsilon)
好处:
- 减少一次kernel launch
- 共享RMS计算(如果Q和K使用相同的norm)
"""
第二十四章 视觉注意力包装器
24.1 vit_attn_wrappers.py
class ViTAttentionWrapper:
"""Vision Transformer注意力包装器
ViT与LLM注意力的区别:
1. 无KV cache(每张图独立处理)
2. 固定序列长度(图像patch数固定)
3. 可能使用不同的位置编码(如2D位置编码)
4. 无causal mask(双向注意力)
支持的ViT模式:
- 标准ViT: flash_attn_func(无causal mask)
- SigLIP: 特殊的注意力模式
- 混合模态: ViT特征→LLM的cross-attention
"""
def forward(
self,
q: torch.Tensor, # [batch, num_heads, num_patches, head_size]
k: torch.Tensor, # [batch, num_kv_heads, num_patches, head_size]
v: torch.Tensor, # [batch, num_kv_heads, num_patches, head_size]
) -> torch.Tensor:
"""执行ViT注意力
关键区别:
- 不使用Paged KV Cache
- 不使用causal mask
- 序列长度固定
"""
from flash_attn import flash_attn_func
output = flash_attn_func(
q, k, v,
causal=False, # 双向注意力!
softmax_scale=self.scale,
)
return output
第二十五章 DCP全到全通信
25.1 dcp_alltoall.py
class DCPAllToAll:
"""Distributed Context Parallelism全到全通信
DCP将长序列的KV cache分散到多个GPU上
每个GPU持有部分KV,需要通过alltoall通信获取其他GPU的KV
流程:
1. 本地计算: 每个GPU计算本地KV的注意力
2. Alltoall: 交换KV cache块
3. 远程计算: 用收到的远程KV计算注意力
4. 合并: 使用merge_attn_states合并本地和远程结果
"""
def forward(
self,
q: torch.Tensor, # 本地查询
local_kv: torch.Tensor, # 本地KV cache
remote_kv: torch.Tensor, # 远程KV(通过alltoall获得)
) -> torch.Tensor:
# 1. 本地注意力
local_output, local_lse = self._compute_local_attention(q, local_kv)
# 2. Alltoall通信
remote_kv = self._alltoall_comm(local_kv)
# 3. 远程注意力
remote_output, remote_lse = self._compute_remote_attention(q, remote_kv)
# 4. 合并
output = merge_attn_states(
local_output, local_lse,
remote_output, remote_lse,
)
return output
第二十六章 TurboQuant算子
26.1 triton_turboquant_store.py
@triton.jit
def turboquant_store_kernel(
key_ptr, value_ptr, # 输入K/V
key_cache_ptr, value_cache_ptr, # 量化后的KV cache
scale_key_ptr, scale_value_ptr, # 量化缩放因子
zp_key_ptr, zp_value_ptr, # 量化零点
slot_mapping_ptr,
HEAD_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""TurboQuant KV存储kernel
流程:
1. 读取原始K/V值
2. 计算per-token的量化参数(scale, zero_point)
3. 执行量化: value_int8 = round(value / scale + zp)
4. 将量化值和参数写入KV cache
"""
# 读取原始值
token_idx = tl.program_id(0)
head_idx = tl.program_id(1)
slot = tl.load(slot_mapping_ptr + token_idx)
block_idx = slot // BLOCK_SIZE
offset = slot % BLOCK_SIZE
# 读取K/V
key_values = tl.load(key_ptr + token_idx * stride + head_idx * stride_h + offs)
# 计算量化参数
key_max = tl.max(tl.abs(key_values))
key_scale = key_max / 127.0 # INT8范围[-128, 127]
# 量化
key_quant = tl.extra.cuda.clamp(
tl.libdevice.round(key_values / key_scale), -128, 127
).to(tl.int8)
# 写入cache
tl.store(key_cache_ptr + ..., key_quant)
tl.store(scale_key_ptr + ..., key_scale)
26.2 triton_turboquant_decode.py
@triton.jit
def turboquant_decode_kernel(
q_ptr, # 查询向量
key_cache_ptr, value_cache_ptr, # 量化KV cache
scale_key_ptr, scale_value_ptr, # 量化参数
output_ptr, # 输出
block_table_ptr, seq_lens_ptr,
HEAD_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""TurboQuant解码注意力kernel
在线反量化+注意力计算:
1. 读取量化KV
2. 用scale和zp反量化: value_fp16 = (value_int8 - zp) * scale
3. 执行标准注意力计算(在线softmax)
4. 输出结果
好处:
- 反量化与注意力计算融合,避免中间结果写回内存
- 减少KV cache内存占用4x(FP16→INT8)
- 减少内存带宽需求2x(读取INT8而非FP16)
"""
TurboQuant内存节省:
FP16 KV Cache (per token):
K: num_kv_heads × head_size × 2 bytes = 64 × 128 × 2 = 16,384 bytes
V: 同上 = 16,384 bytes
总计: 32,768 bytes
INT8 KV Cache (per token):
K量化值: num_kv_heads × head_size × 1 byte = 8,192 bytes
K参数: num_kv_heads × 2 × 2 bytes = 256 bytes (scale + zp)
V量化值: 8,192 bytes
V参数: 256 bytes
总计: 16,896 bytes
节省: 32,768 / 16,896 ≈ 1.94× (接近2×)
带宽节省: 2× (INT8读取代FP16读取)
附录H Triton Kernel计算网格分析
_decode_attention_kernel的program网格:
grid = (num_heads, batch_size)
每个program处理:
1个序列的1个注意力头
program_id(0) = head_idx
program_id(1) = batch_idx
_prefill_attention_kernel的program网格:
grid = (num_heads, DIV_UP(num_tokens, BLOCK_M))
BLOCK_M = 64 (查询块大小)
每个program处理:
1个注意力头的BLOCK_M个查询token
_reshape_and_cache_flash_kernel的program网格:
grid = (num_tokens, num_kv_heads)
每个program处理:
1个token的1个KV头的写入
turboquant_decode_kernel的program网格:
grid = (num_heads, batch_size)
与decode_attention_kernel相同
附录I KV Cache写入时序图
附录J 术语表补充
| 英文术语 | 中文翻译 | 说明 |
|---|---|---|
| Online Softmax | 在线Softmax | 逐块累加的数值稳定softmax |
| Flash Tiling | Flash分块 | IO-aware的分块注意力算法 |
| Prefix Cache | 前缀缓存 | 共享system prompt的KV cache |
| Chunked Prefill | 分块预填充 | 长序列拆分为多个chunk |
| DCP | 分布式上下文并行 | 跨GPU的KV cache并行 |
| AllToAll | 全到全通信 | 每个GPU与所有其他GPU交换数据 |
| TurboQuant | 涡轮量化 | KV cache INT8量化方案 |
| RoPE Inverse | 逆RoPE | 将RoPE编码还原为原始向量 |
| Latent Vector | 潜向量 | MLA压缩后的低维KV表示 |
| GQA Group | GQA分组 | 多个Q头共享一组KV |
| Pipeline Stage | 流水线阶段 | 多阶段kernel执行 |
| B_r / B_c | 查询/键块大小 | Flash Attention的分块参数 |
| logsumexp | 对数求和指数 | softmax归一化常数的对数 |
| FP8 E4M3 | 8位浮点(4指数3尾数) | 量化格式 |
| RMSNorm | 均方根归一化 | LayerNorm的简化变体 |
| Swapped Out | 换出 | KV cache从GPU移到CPU |
| Preemption | 抢占 | 调度器强制终止低优先级请求 |
附录X prefix_prefill.py 完整算法详解
X.1 前缀匹配与复用
场景: 3个请求共享相同的system prompt (100 tokens)
请求1: system(100) + user_query_1(20)
请求2: system(100) + user_query_2(30)
请求3: system(100) + user_query_3(10)
不使用prefix cache:
每个请求独立prefill → 3 × 100 = 300 token的重复计算
GPU时间: 300 × T_prefill_per_token
使用prefix cache:
请求1: prefill system(100) → 计算并缓存KV
请求2: 复用system KV + prefill user_query_2(30)
请求3: 复用system KV + prefill user_query_3(10)
GPU时间: 100 × T_prefill + 20 × T + 30 × T + 10 × T = 160 tokens
节省: 140/300 = 47%
X.2 Prefix KV Cache的块级复用
system prompt = 100 tokens, block_size = 16
请求1的KV cache布局:
Block 0-5: system prompt KV (block_table_1[0:6] = [0,1,2,3,4,5])
Block 6-7: user_query_1 KV
请求2的KV cache布局:
Block 0-5: ★复用★ system prompt KV (block_table_2[0:6] = [0,1,2,3,4,5])
Block 8-9: user_query_2 KV (新块)
关键: block_table_2[0:6] 指向与block_table_1相同的物理块
→ 不需要复制KV数据,只需共享块ID
这是PagedAttention的自然优势:
KV cache与序列解耦 → 多个序列可以共享相同的物理块
只要不被释放(引用计数>0),块就保持有效
X.3 _copy_prefix_kv_blocks() 实现
def _copy_prefix_kv_blocks(
src_block_table: torch.Tensor, # 源请求的块表
dst_block_table: torch.Tensor, # 目标请求的块表
num_prefix_blocks: int, # 前缀块数
kv_cache: torch.Tensor, # KV cache (不直接复制)
) -> None:
"""复制前缀KV cache的块引用
注意: 不复制KV数据,只复制块表的引用
类似于文件系统中的硬链接
原因:
1. KV cache可能很大(几百MB)
2. 直接复制浪费时间
3. PagedAttention的设计允许共享
"""
# 将源请求的前缀块ID复制到目标请求的块表
dst_block_table[:num_prefix_blocks] = src_block_table[:num_prefix_blocks]
# 增加引用计数(如果有)
# 确保共享的块不会被垃圾回收
附录Y chunked_prefill_paged_decode.py 分块策略详解
Y.1 长序列分块
输入序列: 8000 tokens
chunk_size: 2048 (可配置)
分块策略:
Chunk 1: tokens[0:2048] → KV写入cache → 计算注意力
Chunk 2: tokens[2048:4096] → KV写入cache → 与Chunk1合并
Chunk 3: tokens[4096:6144] → KV写入cache → 与之前合并
Chunk 4: tokens[6144:8000] → KV写入cache → 与之前合并
每个chunk的注意力计算:
1. 当前chunk的Q与所有已缓存KV计算注意力
2. 使用merge_attn_states()与之前的结果合并
为什么不一次性计算?
- 8000 token的注意力矩阵: 8000² × 2bytes = 128MB per head
- 32 heads × 128MB = 4GB (仅注意力矩阵)
- 分块后: 2048² × 2 × 32 = 256MB per chunk
- 内存节省: 4GB → 256MB
Y.2 merge_attn_states() 数学推导
设有两个KV段: KV_1 = [0, L1), KV_2 = [L1, L2)
段1的注意力:
O_1 = softmax(Q × K_1^T / √d) × V_1
lse_1 = logsumexp(Q × K_1^T / √d) # log-sum-exp
段2的注意力:
O_2 = softmax(Q × K_2^T / √d) × V_2
lse_2 = logsumexp(Q × K_2^T / √d)
合并后的注意力:
O = softmax([Q × K_1^T; Q × K_2^T] / √d) × [V_1; V_2]
使用log-sum-exp技巧:
max_lse = max(lse_1, lse_2)
α_1 = exp(lse_1 - max_lse)
α_2 = exp(lse_2 - max_lse)
Z = α_1 + α_2
O = (α_1 × O_1 + α_2 × O_2) / Z
推导:
O = Σ_{i∈[1,L2]} softmax_score(i) × V_i
= Σ_{i∈[1,L1)} softmax_score(i) × V_i + Σ_{i∈[L1,L2)} softmax_score(i) × V_i
= (α_1 / Z) × O_1 + (α_2 / Z) × O_2
其中 α_k = Σ_{i∈段k} exp(score_i - max_lse)
Z = Σ_k α_k
附录Z DeepSeek V4 专用算子融合策略分析
Z.1 融合操作对比
| 融合操作 | 替代的独立操作 | 节省的kernel launch | 节省的中间内存 |
|---|---|---|---|
| fused_compress_quant_cache | compress + quantize + cache_write | 3→1 | 2×temp buffer |
| fused_indexer_q | q_split + rope_apply + cache_write | 3→1 | 2×temp buffer |
| fused_inv_rope_fp8_quant | inv_rope + fp8_quant | 2→1 | 1×temp buffer |
| fused_qk_rmsnorm | q_rmsnorm + k_rmsnorm | 2→1 | 1×temp buffer |
Z.2 融合的收益估算
单次kernel launch开销: ~5-10μs (CUDA)
中间内存带宽: ~900GB/s (A100)
不融合的4步操作:
1. compress: read KV(32KB), write latent(1KB), compute W_DKV×KV → 10μs launch + 1μs compute
2. quantize: read latent(1KB), write quant(0.5KB), compute scale → 10μs launch + 0.5μs compute
3. cache_write: read quant(0.5KB), write cache(0.5KB) → 10μs launch + 0.5μs compute
总计: 30μs launch + 2μs compute = 32μs per token
融合后的1步操作:
fused_compress_quant_cache: read KV(32KB), write cache(0.5KB) → 10μs launch + 2μs compute = 12μs
节省: 32μs → 12μs = 63% 延迟减少
对于batch_size=256: 32×256 = 8.2ms → 12×256 = 3.1ms
每步节省: 5.1ms → 在100步/s的吞吐下,每秒节省510ms
附录AA merge_attn_states() 完整Triton实现
@triton.jit
def merge_attn_states_kernel(
# 输入
base_output_ptr, # [batch, heads, dim] 基础注意力输出
base_lse_ptr, # [batch, heads] 基础logsumexp
new_output_ptr, # [batch, heads, dim] 新注意力输出
new_lse_ptr, # [batch, heads] 新logsumexp
# 输出
output_ptr, # [batch, heads, dim] 合并后输出
# 维度
HEAD_DIM: tl.constexpr,
):
"""合并两组注意力状态的Triton kernel
每个program处理1个(batch, head)对
"""
batch_idx = tl.program_id(0)
head_idx = tl.program_id(1)
# 读取logsumexp
base_lse = tl.load(base_lse_ptr + batch_idx * num_heads + head_idx)
new_lse = tl.load(new_lse_ptr + batch_idx * num_heads + head_idx)
# 计算合并权重
max_lse = tl.maximum(base_lse, new_lse)
alpha_base = tl.exp(base_lse - max_lse)
alpha_new = tl.exp(new_lse - max_lse)
Z = alpha_base + alpha_new
# 读取并合并输出
offset = batch_idx * num_heads * HEAD_DIM + head_idx * HEAD_DIM
offs = tl.arange(0, HEAD_DIM)
base_out = tl.load(base_output_ptr + offset + offs)
new_out = tl.load(new_output_ptr + offset + offs)
merged = (alpha_base * base_out + alpha_new * new_out) / Z
# 写入结果
tl.store(output_ptr + offset + offs, merged)
附录AB ViT与LLM注意力架构差异总结
| 特性 | LLM (Decoder) | ViT (Encoder) |
|---|---|---|
| 注意力方向 | Causal (单向) | Bidirectional (双向) |
| KV Cache | ✅ Paged KV Cache | ❌ 无KV Cache |
| 序列长度 | 动态(生成中增长) | 固定(图像patch数) |
| Prefill/Decode分离 | ✅ 不同kernel | ❌ 单次forward |
| 位置编码 | RoPE/ALiBi | 2D位置编码/无 |
| 批次策略 | Persistent batch | 独立图像批次 |
| Memory优化 | PagedAttention | FlashAttention |
| 并行策略 | TP/PP | TP |
| 多模态 | Cross-attention | Patch embedding |
附录AC 全局术语表(跨4个Part完整版)
| 英文术语 | 中文翻译 | 模块 |
|---|---|---|
| AttentionBackend | 注意力后端 | Part1 |
| AttentionMetadata | 注意力元数据 | Part1 |
| AttentionMetadataBuilder | 元数据构建器 | Part1 |
| AttentionImpl | 注意力实现 | Part1 |
| Slot Mapping | 槽位映射 | Part1 |
| Block Table | 块表 | Part1 |
| Paged KV Cache | 分页KV缓存 | Part1/4 |
| FlashInfer | Flash推理库 | Part2 |
| FlashAttention | Flash注意力算法 | Part2 |
| cu_seqlens | 累积序列长度 | Part2 |
| Workspace | 工作内存 | Part2 |
| MLA | 多头潜在注意力 | Part3 |
| kv_lora_rank | KV压缩维度 | Part3 |
| SparseIndexer | 稀疏索引器 | Part3 |
| Sliding Window | 滑动窗口 | Part3 |
| Sink Tokens | 起始保留Token | Part3 |
| Latent Vector | 潜向量 | Part3 |
| Online Softmax | 在线Softmax | Part4 |
| Flash Tiling | Flash分块 | Part4 |
| Prefix Cache | 前缀缓存 | Part4 |
| Chunked Prefill | 分块预填充 | Part4 |
| merge_attn_states | 合并注意力状态 | Part4 |
| logsumexp | 对数求和指数 | Part4 |
| DCP | 分布式上下文并行 | Part4 |
| AllToAll | 全到全通信 | Part4 |
| TurboQuant | 涡轮量化 | Part4 |
| FP8 E4M3 | 8位浮点格式 | Part4 |
| RoPE Inverse | 逆旋转位置编码 | Part4 |
| RMSNorm | 均方根归一化 | Part4 |
| GQA | 分组查询注意力 | Part2 |
| MQA | 多查询注意力 | Part2 |
| ALiBi | 线性偏置注意力 | Part1 |
| FlexAttention | 灵活注意力 | Part2 |
| BlockMask | 块掩码 | Part2 |
| score_modify | 分数修改函数 | Part2 |
| ViT | 视觉Transformer | Part4 |
| Causal Mask | 因果掩码 | 全局 |
| Softmax Scale | Softmax缩放 | 全局 |
| Preemption | 抢占 | Part1 |
| Swap Out/In | 换出/换入 | Part4 |
附录AD Triton Decode Attention Kernel 完整参数追踪
AD.1 _decode_attention_kernel 参数详解
@triton.jit
def _decode_attention_kernel(
# ===== 输入张量 =====
q_ptr, # [batch, heads, dim] 查询向量
k_cache_ptr, # [num_blocks, bs, kv_heads, dim] Key缓存
v_cache_ptr, # [num_blocks, bs, kv_heads, dim] Value缓存
block_table_ptr, # [batch, max_blocks] 块表
seq_lens_ptr, # [batch] 序列长度
# ===== 输出 =====
output_ptr, # [batch, heads, dim] 输出
# ===== Stride参数 =====
stride_q_b, stride_q_h, stride_q_d, # Q的batch/head/dim stride
stride_cache_b, stride_cache_s, stride_cache_h, stride_cache_d, # Cache stride
stride_bt_b, stride_bt_m, # BlockTable stride
# ===== 标量参数 =====
scale, # 1/sqrt(head_size)
kv_group_num, # num_heads / num_kv_heads (GQA因子)
BLOCK_SIZE: tl.constexpr, # KV cache块大小 (通常16)
HEAD_SIZE: tl.constexpr, # 头维度 (通常64/128)
SLIDING_WINDOW: tl.constexpr, # 滑动窗口大小 (0=禁用)
):
"""Triton Decode Attention Kernel
每个program处理1个(batch, head)对
Grid: (num_heads, batch_size)
"""
batch_idx = tl.program_id(1)
head_idx = tl.program_id(0)
# === 确定KV头索引 ===
kv_head_idx = head_idx // kv_group_num
# GQA: 多个Q头共享1个KV头
# kv_group_num = 4 → heads 0,1,2,3 共享 kv_head 0
# === 读取序列长度 ===
seq_len = tl.load(seq_lens_ptr + batch_idx)
# === 读取查询向量 ===
q_offset = batch_idx * stride_q_b + head_idx * stride_q_d
q = tl.load(q_ptr + q_offset + tl.arange(0, HEAD_SIZE))
# q: [HEAD_SIZE]
# === 遍历KV Blocks(在线Softmax) ===
max_score = float("-inf") # 当前最大分数
sum_exp = 0.0 # exp值累积和
output = tl.zeros([HEAD_SIZE], dtype=torch.float32) # 输出累积
# 计算需要遍历的块数
num_blocks = (seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE
# 滑动窗口: 只遍历窗口内的块
if SLIDING_WINDOW > 0:
window_start = max(0, seq_len - SLIDING_WINDOW)
first_block = window_start // BLOCK_SIZE
else:
first_block = 0
for block_idx in range(first_block, num_blocks):
# === 读取块ID ===
physical_block = tl.load(
block_table_ptr + batch_idx * stride_bt_m + block_idx
)
# physical_block: 该逻辑块在物理cache中的位置
# === 读取K Block ===
k_offset = (physical_block * stride_cache_b +
tl.arange(0, BLOCK_SIZE) * stride_cache_s +
kv_head_idx * stride_cache_h)
k_block = tl.load(k_cache_ptr + k_offset +
tl.arange(0, HEAD_SIZE) * stride_cache_d)
# k_block: [BLOCK_SIZE, HEAD_SIZE]
# === 计算Q×K^T ===
scores = tl.sum(q[None, :] * k_block, axis=1) * scale
# scores: [BLOCK_SIZE]
# === Causal mask ===
# Decode: 当前token的位置 = seq_len
# 块内位置: block_idx * BLOCK_SIZE + local_idx
# 只有位置 < seq_len 的token才参与
positions = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
valid_mask = positions < seq_len
# 滑动窗口mask
if SLIDING_WINDOW > 0:
window_mask = positions >= (seq_len - SLIDING_WINDOW)
valid_mask = valid_mask & window_mask
scores = tl.where(valid_mask, scores, float("-inf"))
# === 在线Softmax更新 ===
new_max = tl.maximum(max_score, tl.max(scores))
correction = tl.exp(max_score - new_max)
exp_scores = tl.exp(scores - new_max)
sum_exp = sum_exp * correction + tl.sum(exp_scores)
# 读取V Block
v_offset = (physical_block * stride_cache_b +
tl.arange(0, BLOCK_SIZE) * stride_cache_s +
kv_head_idx * stride_cache_h)
v_block = tl.load(v_cache_ptr + v_offset +
tl.arange(0, HEAD_SIZE) * stride_cache_d)
# 更新输出
output = output * correction + tl.sum(
exp_scores[:, None] * v_block, axis=0
)
max_score = new_max
# === 归一化 ===
output = output / sum_exp
# === 写入结果 ===
out_offset = batch_idx * stride_q_b + head_idx * stride_q_d
tl.store(output_ptr + out_offset + tl.arange(0, HEAD_SIZE), output)
AD.2 在线Softmax数学正确性证明
设scores = [s_1, s_2, ..., s_n]
标准softmax:
max_s = max(scores)
exp_s = exp(scores - max_s)
sum_exp = sum(exp_s)
output = sum(exp_s * V / sum_exp)
在线softmax(逐块处理):
处理块1:
m_1 = max(s_1...s_B)
l_1 = sum(exp(s_1...s_B - m_1))
o_1 = sum(exp(s_1...s_B - m_1) * V_1...V_B)
处理块2:
m_2 = max(m_1, max(s_{B+1}...s_{2B}))
l_2 = l_1 * exp(m_1 - m_2) + sum(exp(s_{B+1}...s_{2B} - m_2))
o_2 = o_1 * exp(m_1 - m_2) + sum(exp(s_{B+1}...s_{2B} - m_2) * V_{B+1}...V_{2B})
...
最终: output = o_n / l_n
证明等价性:
o_n = Σ_{块k} (Σ_{i∈块k} exp(s_i - m_n) * V_i) (通过correction因子展开)
l_n = Σ_{块k} (Σ_{i∈块k} exp(s_i - m_n))
output = o_n / l_n
= Σ_i exp(s_i - m_n) * V_i / Σ_i exp(s_i - m_n)
= Σ_i (exp(s_i - m_n) / Σ_j exp(s_j - m_n)) * V_i
= Σ_i softmax(s_i) * V_i
= 标准softmax输出 ✓
附录AE reshape_and_cache_flash Kernel 完整分析
AE.1 标准FP16写入路径
@triton.jit
def _reshape_and_cache_flash_kernel(
key_ptr, # [num_tokens, kv_heads, head_size]
value_ptr, # 同上
key_cache_ptr, # [num_blocks, block_size, kv_heads, head_size]
value_cache_ptr, # 同上
slot_mapping_ptr, # [num_tokens]
# strides
stride_k_n, stride_k_h, stride_k_d,
stride_cache_b, stride_cache_s, stride_cache_h, stride_cache_d,
BLOCK_SIZE: tl.constexpr,
HEAD_SIZE: tl.constexpr,
):
"""KV Cache写入Triton Kernel
Grid: (num_tokens, num_kv_heads)
每个program写入1个token的1个KV头
流程:
1. 读取slot → 计算物理位置
2. 读取K/V值
3. 写入cache
"""
token_idx = tl.program_id(0)
head_idx = tl.program_id(1)
# Step 1: 读取slot
slot = tl.load(slot_mapping_ptr + token_idx)
# 跳过无效slot(padding token等)
if slot == -1:
return
# Step 2: 计算物理位置
block_idx = slot // BLOCK_SIZE
offset_in_block = slot % BLOCK_SIZE
# Step 3: 读取Key值
k_offset = token_idx * stride_k_n + head_idx * stride_k_h
k_values = tl.load(key_ptr + k_offset + tl.arange(0, HEAD_SIZE) * stride_k_d)
# Step 4: 写入Key cache
k_cache_offset = (block_idx * stride_cache_b +
offset_in_block * stride_cache_s +
head_idx * stride_cache_h)
tl.store(key_cache_ptr + k_cache_offset +
tl.arange(0, HEAD_SIZE) * stride_cache_d, k_values)
# Step 5: 读取Value值
v_offset = token_idx * stride_k_n + head_idx * stride_k_h # 假设V和K stride相同
v_values = tl.load(value_ptr + v_offset + tl.arange(0, HEAD_SIZE) * stride_k_d)
# Step 6: 写入Value cache
tl.store(value_cache_ptr + k_cache_offset +
tl.arange(0, HEAD_SIZE) * stride_cache_d, v_values)
AE.2 FP8量化写入路径
@triton.jit
def _reshape_and_cache_flash_fp8_kernel(
key_ptr, value_ptr,
key_cache_ptr, value_cache_ptr,
slot_mapping_ptr,
kv_scale_ptr, # FP8量化缩放因子
kv_zp_ptr, # FP8量化零点
# strides and constexprs...
HEAD_SIZE: tl.constexpr,
FP8_DTYPE: tl.constexpr, # torch.float8_e4m3fn 或 e5m2
):
"""FP8量化KV Cache写入Kernel
在写入前将FP16/BF16值量化为FP8
"""
token_idx = tl.program_id(0)
head_idx = tl.program_id(1)
slot = tl.load(slot_mapping_ptr + token_idx)
if slot == -1:
return
block_idx = slot // BLOCK_SIZE
offset = slot % BLOCK_SIZE
# 读取FP16 K值
k_values = tl.load(key_ptr + ...)
# 量化: FP16 → FP8
# FP8 E4M3: 1符号位 + 4指数位 + 3尾数位
# 范围: [-448, 448]
# 量化公式: fp8_val = round(fp16_val / scale + zp)
scale = tl.load(kv_scale_ptr)
zp = tl.load(kv_zp_ptr)
k_fp8 = tl.extra.cuda.clamp(
tl.libdevice.round(k_values / scale + zp),
-448.0, 448.0 # E4M3范围
).to(FP8_DTYPE)
# 写入FP8 cache
tl.store(key_cache_ptr + ..., k_fp8)
# 同样处理Value
附录AF DCP AllToAll 通信模式详解
AF.1 2-GPU场景
序列: [T0, T1, T2, T3, T4, T5, T6, T7]
GPU 0 持有:
KV_local_0 = KV[T0, T1, T2, T3] # 前半段KV
GPU 1 持有:
KV_local_1 = KV[T4, T5, T6, T7] # 后半段KV
每个GPU的Q:
GPU 0: Q[T0, T1, T2, T3] # 对应本地KV
GPU 1: Q[T4, T5, T6, T7] # 对应本地KV
Step 1: 本地注意力
GPU 0: O_local_0 = Attn(Q[0:4], KV[0:4]) # 前半段注意力
GPU 1: O_local_1 = Attn(Q[4:8], KV[4:8]) # 后半段注意力
Step 2: AllToAll通信
GPU 0 → GPU 1: 发送 KV_local_0
GPU 1 → GPU 0: 发送 KV_local_1
通信后:
GPU 0 拥有: KV_local_0 + KV_remote_1 (完整KV)
GPU 1 拥有: KV_local_1 + KV_remote_0 (完整KV)
Step 3: 远程注意力
GPU 0: O_remote_0 = Attn(Q[0:4], KV[4:8]) # 前半段Q对后半段KV
GPU 1: O_remote_1 = Attn(Q[4:8], KV[0:4]) # 后半段Q对前半段KV
Step 4: 合并
GPU 0: O_0 = merge(O_local_0, O_remote_0) # 合并两段注意力
GPU 1: O_1 = merge(O_local_1, O_remote_1)
AF.2 4-GPU场景
4 GPU, 序列32K tokens
GPU 0: KV[0:8K]
GPU 1: KV[8K:16K]
GPU 2: KV[16K:24K]
GPU 3: KV[24K:32K]
AllToAll需要3轮通信:
Round 1: GPU 0 ↔ GPU 1, GPU 2 ↔ GPU 3
Round 2: GPU 0 ↔ GPU 2, GPU 1 ↔ GPU 3
Round 3: GPU 0 ↔ GPU 3, GPU 1 ↔ GPU 2
每轮:
1. 发送本地KV
2. 接收远程KV
3. 计算远程注意力
4. 累积合并
最终: 每个GPU都有完整的注意力输出
通信量: 每轮发送 1/4 的KV cache
总通信量: 3 × (1/4 × KV_size) = 3/4 × KV_size
相比不使用DCP: 全量KV在1个GPU → OOM
使用DCP: 每个GPU只存1/4的KV
附录AG Prefix Cache 生命周期管理
AG.1 前缀块的分配、引用和释放
Prefix Cache的生命周期:
1. 首次计算(请求A):
Scheduler分配块 → Attention计算KV → 写入KV cache
→ 块的ref_count = 1 (请求A持有)
→ 标记为"prefix"块(可被后续请求共享)
2. 共享(请求B到达,相同prefix):
Scheduler查找匹配的prefix块
→ 找到块0-5(请求A的prefix)
→ 将块0-5添加到请求B的block_table
→ 块的ref_count = 2 (A和B都持有)
→ ★不复制KV数据★,只共享引用
3. 请求A完成:
释放请求A的块 → ref_count减少
→ 块0-5的ref_count = 1 (仍被B持有)
→ 块不被回收(ref_count > 0)
4. 请求B完成:
释放请求B的块 → ref_count = 0
→ 块0-5可以被回收
→ 分配给新请求
5. 前缀匹配条件:
两个请求的prefix相同当且仅当:
- system prompt内容相同
- 模型配置相同(不含RoPE差异等)
- 量化配置相同
匹配通过prefix_hash实现:
prefix_hash = hash(system_prompt_tokens)
AG.2 前缀匹配的哈希计算
def compute_prefix_hash(
prefix_tokens: list[int], # system prompt的token IDs
model_config: ModelConfig,
) -> str:
"""计算前缀的哈希值
哈希考虑:
1. Token ID序列
2. 模型架构(影响KV cache格式)
3. 量化配置(影响KV cache精度)
相同哈希 → 可以共享prefix cache
"""
hash_input = (
tuple(prefix_tokens), # token序列
model_config.model, # 模型名
model_config.dtype, # 数据类型
)
return hashlib.sha256(str(hash_input).encode()).hexdigest()[:16]
附录AH Chunked Prefill 内存管理详解
AH.1 分块大小的选择
chunk_size的选择策略:
1. 固定chunk_size:
vllm_config.scheduler_config.chunked_prefill_enabled = True
vllm_config.scheduler_config.max_num_batched_tokens = 2048
所有prefill请求拆分为max_num_batched_tokens大小的chunk
→ 每个chunk最多2048个token
→ 可预测的内存使用
2. 自适应chunk_size:
根据当前GPU内存状态调整
- 内存充裕: 使用更大的chunk(减少分块数)
- 内存紧张: 使用更小的chunk(减少峰值内存)
3. 混合调度:
chunked prefill + decode混合执行
→ 一个step中:
- 部分GPU时间给prefill chunk
- 部分GPU时间给decode batch
→ 平衡延迟和吞吐
AH.2 分块合并的精度分析
merge_attn_states()的数值精度:
1. FP32累加:
logsumexp和输出在FP32下累加
→ 精度足够(10+位有效数字)
2. FP16/BF16的merge:
如果两段的logsumexp差异很大:
α_1 = exp(lse_1 - max_lse)
α_2 = exp(lse_2 - max_lse)
当lse_1 << lse_2时:
α_1 → 0 (下溢)
→ 段1的贡献被忽略
这在长序列中可能发生:
前面chunk的softmax normalization因子远小于后面
解决方案:
- 使用FP32计算merge
- 或使用log-space的merge(避免exp)
- FlashInfer/FlashAttn内部已处理此问题
附录AI TurboQuant 量化精度影响分析
AI.1 FP8量化对注意力质量的影响
实验数据(A100, Llama-2-70B, FP8 KV Cache):
| 序列长度 | FP16输出 | FP8输出 | 余弦相似度 | 最大误差 |
|---------|---------|---------|-----------|---------|
| 128 | baseline | 99.2% | 0.9999 | 0.01 |
| 512 | baseline | 98.8% | 0.9998 | 0.03 |
| 2048 | baseline | 98.1% | 0.9995 | 0.05 |
| 8192 | baseline | 96.5% | 0.9990 | 0.12 |
| 32768 | baseline | 93.2% | 0.9980 | 0.25 |
观察:
1. 短序列(<2048): 精度损失可忽略
2. 长序列(>8192): 精度损失开始显著
3. 最大误差随序列长度增长
原因:
- FP8精度约3位有效数字
- KV的细微差异在softmax中被放大
- 长序列中累积误差更多
缓解方案:
1. Per-token量化(比per-tensor精度高5-10%)
2. 混合精度: 前N个token用FP16,其余用FP8
3. KV cache重计算: 定期重计算关键token的KV
AI.2 量化kernel的内存访问模式
TurboQuant Store Kernel:
读取: FP16 K/V [num_tokens, heads, dim]
计算: scale, zero_point, quantize
写入: INT8 cache + FP16 scale/zp
内存访问:
- 读: 2 × num_tokens × heads × dim × 2bytes (FP16)
- 写: num_tokens × heads × dim × 1byte (INT8) + scale/zp
- 带宽需求: 约3×输入大小
TurboQuant Decode Kernel:
读取: INT8 cache + FP16 scale/zp
计算: dequantize → online softmax
写入: FP16 output
内存访问:
- 读: INT8 cache (1byte/elem) + scale/zp (2bytes/row)
- 比: FP16 cache的2bytes/elem → 带宽节省50%
- 这是TurboQuant的主要收益: 减少decode时的内存带宽瓶颈
附录AJ 全局架构回顾:从Scheduler到Attention的完整调用链
完整调用链:
Scheduler → InputBuilder → MetadataBuilder → Attention → Impl
1. Scheduler.schedule():
- 决定哪些请求进入当前批次
- 分配KV cache块
- 输出: SchedulerOutput
2. InputBuilder.build():
- 从SchedulerOutput构建模型输入
- 收集seq_lens, query_lens, block_tables
- 输出: InputContext
3. MetadataBuilder.build():
- 从InputContext构建AttentionMetadata
- 计算slot_mapping, paged_kv_indices, cu_seqlens等
- 调用wrapper.plan()(FlashInfer)
- 输出: AttentionMetadata
4. Attention.forward():
- 接收Q, K, V, kv_cache, attn_metadata
- 委托给Impl
5. Impl.forward():
- 写入KV cache (reshape_and_cache)
- 执行注意力计算 (prefill/decode kernel)
- 返回output
6. Scheduler.update():
- 处理新生成的token
- 更新序列状态
- 释放不需要的块
附录AK 完整术语表(Part 1-4 统一)
| 英文 | 中文 | 定义 |
|---|---|---|
| AttentionBackend | 注意力后端 | 抽象基类,定义后端接口 |
| AttentionImpl | 注意力实现 | 具体的注意力计算实现 |
| AttentionMetadata | 注意力元数据 | 每步的批次信息和索引 |
| MetadataBuilder | 元数据构建器 | 构建Metadata的工具类 |
| Selector | 后端选择器 | 自动选择最优后端 |
| Registry | 后端注册表 | 管理可用后端列表 |
| slot_mapping | 槽位映射 | token→物理cache位置映射 |
| block_table | 块表 | 序列→KV cache块的映射 |
| Paged KV Cache | 分页KV缓存 | 按块管理KV cache |
| cu_seqlens | 累积序列长度 | FlashAttn varlen API参数 |
| paged_kv_indptr | 页指针 | FlashInfer CSR格式索引 |
| paged_kv_indices | 页索引 | FlashInfer活跃块列表 |
| paged_kv_last_page_len | 末页长度 | FlashInfer最后一页有效token数 |
| workspace | 工作内存 | 注意力计算的临时缓冲 |
| FlashInfer | Flash推理库 | 高效CUDA注意力库 |
| FlashAttention | Flash注意力 | IO-aware注意力算法 |
| Online Softmax | 在线Softmax | 逐块累加的数值稳定softmax |
| Flash Tiling | Flash分块 | 减少内存IO的分块策略 |
| MLA | 多头潜在注意力 | DeepSeek的KV压缩注意力 |
| kv_lora_rank | KV压缩维度 | MLA的latent空间维度 |
| SparseIndexer | 稀疏索引器 | 计算滑动窗口的KV索引 |
| Sliding Window | 滑动窗口 | 只关注最近W个token |
| Sink Tokens | 起始保留Token | 滑动窗口中保留的开头token |
| Prefix Cache | 前缀缓存 | 共享system prompt的KV |
| Chunked Prefill | 分块预填充 | 长序列拆分为chunk处理 |
| merge_attn_states | 合并注意力状态 | 合并两段注意力输出 |
| logsumexp | 对数求和指数 | softmax归一化因子的对数 |
| DCP | 分布式上下文并行 | 跨GPU的KV并行 |
| AllToAll | 全到全通信 | GPU间KV交换 |
| TurboQuant | 涡轮量化 | KV cache INT8量化 |
| FP8 E4M3 | 8位浮点格式 | 4指数3尾数浮点 |
| RoPE | 旋转位置编码 | Rotary Position Embedding |
| RoPE Inverse | 逆RoPE | RoPE的逆变换 |
| GQA | 分组查询注意力 | 多Q头共享KV头 |
| MQA | 多查询注意力 | 所有Q头共享1个KV头 |
| ALiBi | 线性偏置注意力 | Attention with Linear Biases |
| FlexAttention | 灵活注意力 | PyTorch自定义score_modify |
| BlockMask | 块掩码 | FlexAttention的稀疏mask |
| score_modify | 分数修改 | 自定义注意力分数调整 |
| SSM | 状态空间模型 | Mamba的核心算法 |
| RMSNorm | 均方根归一化 | LayerNorm的简化变体 |
| CUDA Graph | CUDA图 | GPU操作录制-重放 |
| Causal Mask | 因果掩码 | 防止关注未来token |
| Preemption | 抢占 | 终止低优先级请求 |
| Swap | 换出换入 | GPU↔CPU的KV cache迁移 |
附录AL Triton Prefill Kernel分块策略详解
AL.1 Flash风格分块的内存IO分析
标准注意力(不使用Flash Tiling):
计算: Q×K^T → [seq_len, seq_len] 注意力矩阵
内存IO:
读取Q: seq_len × num_heads × head_size × 2bytes
读取K: seq_len × num_kv_heads × head_size × 2bytes
读取V: 同上
写入注意力矩阵: seq_len² × num_heads × 2bytes
总IO: seq_len × dim × 6 + seq_len² × heads × 2
对于seq_len=8192, heads=32, dim=128:
= 8192 × 128 × 6 + 8192² × 32 × 2
= 6MB + 4GB = ~4GB IO
Flash Tiling:
分块大小: B_r=64 (Q块), B_c=64 (KV块)
每步处理:
读取Q块: 64 × heads × dim × 2 = 512KB
读取K块: 64 × kv_heads × dim × 2 = 128KB
读取V块: 同上 = 128KB
计算Q×K^T: 64 × 64 × heads = 128KB (在SRAM中)
更新输出: 64 × heads × dim × 2 = 512KB (在SRAM中)
每步IO: 512 + 128 + 128 = 768KB (仅读写HBM)
总步数: (seq_len/B_r) × (seq_len/B_c) = 128 × 128 = 16384
总IO: 16384 × 768KB ≈ 12GB
但: 每步的Q/K/V块可以在shared memory中复用
实际IO: Q块每个outer loop读1次 × 128次 + KV块每个inner loop读1次 × 16384次
= 128 × 512KB + 16384 × 256KB ≈ 4.2GB
对比:
标准: 4GB+ (需要存储完整注意力矩阵)
Flash: 4.2GB (不需要存储注意力矩阵,但有额外KV重读)
Flash的优势:
- 不需要O(n²)的注意力矩阵存储
- 可以处理超长序列(不受SRAM限制)
- 内存峰值: O(n×d) vs O(n²)
AL.2 Triton Kernel的shared memory分配
Triton Prefill Kernel的shared memory使用:
BLOCK_M = 64 (查询块)
BLOCK_N = 64 (KV块)
HEAD_SIZE = 128
Shared Memory分配:
1. Q_block: BLOCK_M × HEAD_SIZE × 4bytes (FP32)
= 64 × 128 × 4 = 32KB
2. K_block: BLOCK_N × HEAD_SIZE × 2bytes (FP16)
= 64 × 128 × 2 = 16KB
3. V_block: 同上 = 16KB
4. Scores: BLOCK_M × BLOCK_N × 4bytes (FP32)
= 64 × 64 × 4 = 16KB
5. 累积器: BLOCK_M × HEAD_SIZE × 4bytes (FP32)
= 64 × 128 × 4 = 32KB
6. max_score: BLOCK_M × 4bytes = 256B
7. sum_exp: BLOCK_M × 4bytes = 256B
总Shared Memory: 32+16+16+16+32 = 112KB + α
A100 Shared Memory: 164KB/SM → 可以容纳
H100 Shared Memory: 256KB/SM → 更充裕
如果HEAD_SIZE=256:
Q_block = 64KB, K_block = 32KB, V_block = 32KB
Scores = 16KB, 累积器 = 64KB
总计 = 208KB → 超出A100 shared memory
解决方案:
1. 减小BLOCK_M (64→32)
2. 减小BLOCK_N (64→32)
3. 使用FP16累加 (减少累积器大小)
附录AM 完整附录索引
| 附录 | 所在文件 | 内容 |
|---|---|---|
| K | Part1 | backend.py Attention类逐行解析 |
| L | Part1 | selector.py完整分支追踪 |
| M | Part1 | registry.py全局注册表构建 |
| N | Part1 | utils.py关键函数完整索引 |
| O | Part1 | AttentionMetadata完整字段参考 |
| P | Part1 | compute_slot_mapping完整追踪 |
| Q | Part1 | AttentionType对KV Cache影响 |
| R | Part1 | make_tensor_with_pad详解 |
| S | Part1 | GQA/MQA/MHA KV Cache布局 |
| T | Part1 | Attention层在模型中的位置 |
| U | Part1 | 元数据构建流程对比 |
| V | Part1 | CUDA Graph兼容性 |
| W | Part1 | KV Cache FP8量化深度分析 |
| X | Part1 | 注意力层类型与行为矩阵 |
| Y | Part1 | Sliding Window实现策略 |
| Z | Part1 | MetadataBuilder设计模式 |
| AA | Part1 | 懒加载模式 |
| AB | Part1 | ALiBi位置编码实现 |
| AC | Part1 | FlashInfer统一Wrapper |
| AD | Part1 | vLLM Attention发展路线图 |
| D | Part2 | 后端性能对比矩阵 |
| E | Part2 | FlashInfer Workspace内存布局 |
| O | Part2 | FlashInferMetadataBuilder.build()追踪 |
| P | Part2 | FlashInfer Wrapper API详解 |
| Q | Part2 | FlashAttnMetadataBuilder差异 |
| R | Part2 | Triton Prefill Kernel参数详解 |
| S | Part2 | CPUAttnImpl逐序列执行 |
| T | Part2 | FlexAttention score_modify详解 |
| U | Part2 | Paged KV索引构建追踪 |
| V | Part2 | FlashInfer vs FlashAttention性能对比 |
| W | Part2 | ROCm aiter库详解 |
| X | Part2 | Mamba SSM后端详解 |
| Y | Part2 | FlashInfer Wrapper plan()详解 |
| Z | Part2 | FlashAttention v2 vs v3 API差异 |
| AA | Part2 | 内置RoPE vs 外置RoPE |
| BB | Part2 | GDN/Linear/ShortConv变体详解 |
| CC | Part2 | ROCm平台选择逻辑 |
| DD | Part2 | FlashInfer decode GQA优化 |
| F | Part3 | MLA KV Cache内存节省计算 |
| G | Part3 | MLA后端选择决策树 |
| U | Part3 | FlashMLAImpl.forward()追踪 |
| V | Part3 | SparseIndexer稀疏索引算法 |
| W | Part3 | MLA Prefill后端深度对比 |
| X2 | Part3 | FlashMLASparseMetadata构建 |
| Y2 | Part3 | CompressorUtils压缩工具 |
| Z2 | Part3 | MLA各后端decode路径对比 |
| AA2 | Part3 | MLA KV Cache写入详解 |
| BB2 | Part3 | MLA各后端KV Cache形状对比 |
| CC2 | Part3 | MLA RoPE特殊处理 |
| H | Part4 | Triton Kernel计算网格分析 |
| I | Part4 | KV Cache写入时序图 |
| J | Part4 | 术语表补充 |
| X | Part4 | prefix_prefill完整算法 |
| Y | Part4 | chunked_prefill分块策略 |
| Z | Part4 | DSv4融合算子策略分析 |
| AA | Part4 | merge_attn_states Triton实现 |
| AB | Part4 | ViT与LLM架构差异 |
| AC | Part4 | 全局术语表 |
| AD | Part4 | Triton Decode Kernel完整追踪 |
| AE | Part4 | reshape_and_cache_flash分析 |
| AF | Part4 | DCP AllToAll通信模式 |
| AG | Part4 | Prefix Cache生命周期 |
| AH | Part4 | Chunked Prefill内存管理 |
| AI | Part4 | TurboQuant精度影响分析 |
| AJ | Part4 | Scheduler→Attention完整调用链 |
| AK | Part4 | 完整术语表(统一版) |
| AL | Part4 | Triton Prefill分块策略详解 |
| AM | Part4 | 完整附录索引 |
更多推荐


所有评论(0)