vLLM V1 Attention 模块超深度架构分析 — Part 4: 底层Ops算子

分析范围: v1/attention/ops/ 目录全部源码(24个文件,约9,600行)
分析日期: 2026-05-25
分析深度: 架构师级,逐行解析,Mermaid图表25+


目录


第二十章 通用操作算子(common.py)

20.1 reshape_and_cache() KV写入

def reshape_and_cache(
    key: torch.Tensor,      # [num_tokens, num_kv_heads, head_size]
    value: torch.Tensor,    # [num_tokens, num_kv_heads, head_size]
    key_cache: torch.Tensor,  # [num_blocks, block_size, num_kv_heads, head_size]
    value_cache: torch.Tensor,  # 同上
    slot_mapping: torch.Tensor,  # [num_tokens] → 物理slot索引
    kv_cache_dtype: str,    # "auto"/"fp8"/"fp8_e4m3"/...
    kv_scale: float = 1.0,  # FP8量化缩放因子
    kv_zp: float = 0.0,     # FP8量化零点
) -> None:
    """将Key和Value写入Paged KV Cache
    
    流程:
    1. 遍历每个token
    2. 根据slot_mapping找到写入位置
    3. 将K/V写入对应的cache位置
    4. 如果kv_cache_dtype为FP8,执行量化
    
    原地修改key_cache和value_cache
    """
    # 量化路径
    if kv_cache_dtype.startswith("fp8"):
        # FP8量化写入
        key_8bit = _quantize_fp8(key, kv_scale, kv_zp)
        value_8bit = _quantize_fp8(value, kv_scale, kv_zp)
        # 使用量化值写入cache
        _reshape_and_cache_impl(key_8bit, value_8bit, key_cache, value_cache, slot_mapping)
    else:
        # 标准FP16/BF16写入
        _reshape_and_cache_impl(key, value, key_cache, value_cache, slot_mapping)

slot_mapping写入示意

KV Cache (物理存储)

Input K/V

slot=5 → B0[5]

slot=6 → B0[6]

slot=7 → B0[7]

slot=32 → B2[0]

slot_mapping

[5, 6, 7, 32]

K[0]

K[1]

K[2]

K[3]

Block 0: [0,1,2,...,15]

Block 1: [16,17,...,31]

Block 2: [32,33,...,47]

20.2 reshape_and_cache_flash() Flash版

def reshape_and_cache_flash(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
    kv_scale: float = 1.0,
    kv_zp: float = 0.0,
) -> None:
    """Flash版KV写入
    
    与reshape_and_cache的区别:
    - 使用Triton kernel实现(更高性能)
    - 支持FP8量化
    - 与FlashInfer/FlashAttention兼容的内存布局
    
    底层实现: triton_reshape_and_cache_flash kernel
    """
    # 选择实现路径
    if kv_cache_dtype in ("auto", None, "fp16", "bfloat16"):
        # 标准路径: Triton kernel
        triton_reshape_and_cache_flash_kernel(
            key, value, key_cache, value_cache, slot_mapping
        )
    elif kv_cache_dtype.startswith("fp8"):
        # FP8量化路径: 先量化再写入
        triton_reshape_and_cache_flash_fp8_kernel(
            key, value, key_cache, value_cache, slot_mapping,
            kv_scale, kv_zp
        )

20.3 copy_cache() 缓存复制

def copy_cache(
    src_kv_cache: torch.Tensor,     # 源KV cache
    dst_kv_cache: torch.Tensor,     # 目标KV cache
    src_to_dst_block_id: torch.Tensor,  # 源→目标块ID映射
    num_heads: int,
    head_size: int,
    block_size: int,
    cache_type: str,
) -> None:
    """复制KV cache块(用于preemption/swap)
    
    场景: 当请求被抢占时,需要将其KV cache从GPU复制到CPU
    或从CPU恢复到GPU
    
    src_to_dst_block_id[i] = j 表示源块i → 目标块j
    """

20.4 swap_cache() 缓存交换

def swap_cache(
    src_kv_cache: torch.Tensor,
    dst_kv_cache: torch.Tensor,
    src_to_dst_block_id: torch.Tensor,
    num_heads: int,
    head_size: int,
    block_size: int,
) -> None:
    """交换GPU和CPU之间的KV cache
    
    用于swap scheduling:
    - swap_out: GPU → CPU
    - swap_in: CPU → GPU
    
    底层使用CUDA memcpy实现异步传输
    """

第二十一章 Triton注意力Kernel

21.1 triton_decode_attention.py解码注意力

核心Kernel: _decode_attention_kernel

输入:
  q: [batch, num_heads, head_size]       # 查询向量
  k_cache: [num_blocks, block_size, num_kv_heads, head_size]  # KV cache
  v_cache: 同上
  block_table: [batch, max_blocks]       # 块表
  seq_lens: [batch]                      # 序列长度
  scale: float                           # 1/sqrt(d)

输出:
  output: [batch, num_heads, head_size]  # 注意力输出

算法(在线softmax):
  for each query head:
    max_score = -inf
    sum_exp = 0
    output = 0
    
    for each KV block in sequence:
      # 读取一个block的K/V
      k_block = k_cache[block_table[batch, block_idx]]  # [block_size, dim]
      v_block = v_cache[block_table[batch, block_idx]]
      
      # 计算注意力分数
      scores = q @ k_block.T * scale  # [block_size]
      
      # 在线softmax更新
      new_max = max(max_score, max(scores))
      correction = exp(max_score - new_max)
      sum_exp = sum_exp * correction + sum(exp(scores - new_max))
      output = output * correction + exp(scores - new_max) @ v_block
      max_score = new_max
    
    output = output / sum_exp

在线softmax(Online Softmax) 的数学推导:

标准softmax: output = softmax(Q×K^T) × V
  = Σ_i (exp(s_i) / Σ_j exp(s_j)) × V_i

问题: Σ_j exp(s_j)需要先遍历所有K,再计算softmax

在线softmax: 逐块累加,无需先遍历
  维护: max_score, sum_exp, output三个累积器

处理第n个块时:
  scores_n = Q × K_n^T * scale
  
  new_max = max(old_max, max(scores_n))
  # 新的全局最大值
  
  correction = exp(old_max - new_max)
  # 旧值需要乘以correction来重新归一化
  
  sum_exp = sum_exp * correction + Σ exp(scores_n - new_max)
  output = output * correction + Σ (exp(scores_n - new_max) × V_n)
  
  # correction因子确保了数值稳定性:
  # 减去new_max后再取exp,避免exp(large_value)溢出

最终: output /= sum_exp

Yes

No

输入: Q, K_cache, V_cache, block_table, seq_lens

遍历KV Blocks

读取一个Block的K/V
K_block = K_cache[block_table[i,j]]

计算分数
scores = Q × K_block^T × scale

在线Softmax更新
new_max, sum_exp, output

还有更多Block?

归一化
output /= sum_exp

返回output

21.2 triton_prefill_attention.py预填充注意力

Prefill Triton kernel使用Flash风格的tiling

输入:
  q: [num_tokens, num_heads, head_size]  # 所有prefill token的Q
  k_cache, v_cache: 同decode
  block_table, cu_seqlens, max_seqlen

算法(Flash Attention tiling):
  for each query block (B_r tokens):
    for each key/value block (B_c tokens):
      # 计算Q×K^T子矩阵
      S = Q[q_start:q_end] @ K[k_start:k_end].T * scale
      # [B_r, B_c]
      
      # 应用causal mask
      if is_causal:
        S[mask] = -inf  # 未来位置屏蔽
      
      # 在线softmax更新(与decode类似但处理矩阵)
      new_max = max(old_max, rowmax(S))
      correction = exp(old_max - new_max)
      sum_exp = sum_exp * correction + rowsum(exp(S - new_max))
      O = O * correction + exp(S - new_max) @ V[k_start:k_end]
    
    O /= sum_exp  # 归一化

21.3 triton_unified_attention.py统一注意力

def triton_unified_attention(
    q: torch.Tensor,          # [num_tokens, num_heads, head_size]
    k_cache: torch.Tensor,    # KV cache
    v_cache: torch.Tensor,
    block_table: torch.Tensor,
    cu_seqlens: torch.Tensor,
    seq_lens: torch.Tensor,
    max_seqlen: int,
    scale: float,
    is_causal: bool = True,
) -> torch.Tensor:
    """统一注意力kernel
    
    同时处理prefill和decode
    内部通过cu_seqlens区分不同序列
    
    与分别调用的区别:
    - 减少kernel launch开销
    - 更好的GPU利用率
    - 简化调用逻辑
    """

21.4 triton_reshape_and_cache_flash.py写入kernel

@triton.jit
def _reshape_and_cache_flash_kernel(
    key_ptr, value_ptr,           # K/V输入指针
    key_cache_ptr, value_cache_ptr,  # KV cache指针
    slot_mapping_ptr,             # slot映射
    stride_k_n, stride_k_h, stride_k_d,  # K的stride
    stride_v_n, stride_v_h, stride_v_d,
    stride_cache_b, stride_cache_s, stride_cache_h, stride_cache_d,
    BLOCK_SIZE: tl.constexpr,
):
    """Triton kernel: 将K/V写入KV cache
    
    每个program处理一个(head, token)对
    
    流程:
    1. 从slot_mapping获取写入位置
    2. 计算目标cache的物理地址
    3. 将K/V值写入cache
    """
    token_idx = tl.program_id(0)  # token索引
    head_idx = tl.program_id(1)   # 头索引
    
    # 读取slot
    slot = tl.load(slot_mapping_ptr + token_idx)
    
    # 计算cache位置
    block_idx = slot // BLOCK_SIZE
    offset = slot % BLOCK_SIZE
    
    # 写入Key
    key_value = tl.load(key_ptr + token_idx * stride_k_n + head_idx * stride_k_h)
    tl.store(
        key_cache_ptr + block_idx * stride_cache_b + offset * stride_cache_s + head_idx * stride_cache_h,
        key_value
    )
    
    # 写入Value
    value_value = tl.load(value_ptr + token_idx * stride_v_n + head_idx * stride_v_h)
    tl.store(
        value_cache_ptr + block_idx * stride_cache_b + offset * stride_cache_s + head_idx * stride_cache_h,
        value_value
    )

21.5 triton_attention_helpers.py辅助函数

def _get_block_size_and_num_warps(
    head_size: int,
) -> tuple[int, int]:
    """根据head_size选择最优的block_size和num_warps
    
    经验值:
    - head_size ≤ 64: BLOCK_SIZE=16, num_warps=4
    - head_size ≤ 128: BLOCK_SIZE=16, num_warps=8
    - head_size ≤ 256: BLOCK_SIZE=16, num_warps=16
    """
    if head_size <= 64:
        return 16, 4
    elif head_size <= 128:
        return 16, 8
    else:
        return 16, 16

def _get_num_warps_and_stages(
    kv_group_num: int,   # num_heads / num_kv_heads
    block_size: int,
) -> tuple[int, int]:
    """选择warp数和pipeline stages
    
    kv_group_num越大 → GQA重用越多 → 可以用更多warps
    """
    if kv_group_num >= 4:
        return 8, 3  # 8 warps, 3 pipeline stages
    elif kv_group_num >= 2:
        return 4, 3
    else:
        return 4, 2

21.6 triton_merge_attn_states.py状态合并

def merge_attn_states(
    base_output: torch.Tensor,   # [batch, heads, dim] 基础注意力输出
    base_softmax_log: torch.Tensor,  # [batch, heads] 基础softmax logsumexp
    new_output: torch.Tensor,    # [batch, heads, dim] 新注意力输出
    new_softmax_log: torch.Tensor,   # [batch, heads] 新softmax logsumexp
) -> torch.Tensor:
    """合并两组注意力状态
    
    数学:
    设两组softmax概率为 p_i 和 q_i,对应的logsumexp为 l_i 和 m_i
    
    合并后的softmax:
    output = (exp(l - max) * base_output + exp(m - max) * new_output) 
           / (exp(l - max) + exp(m - max))
    其中 max = max(l, m)
    
    用途:
    - Prefix caching: 合并prefix注意力和生成注意力
    - Chunked prefill: 合并不同chunk的注意力
    """

第二十二章 前缀预填充与分块解码

22.1 prefix_prefill.py

class PrefixPrefill:
    """前缀预填充
    
    当多个请求共享相同的system prompt时,
    只需计算一次前缀的KV cache,然后共享给所有请求
    
    流程:
    1. 首次遇到prefix: 正常prefill计算,KV cache写入
    2. 后续请求: 直接复制prefix的KV cache块,无需重新计算
    
    好处:
    - 避免重复计算相同的system prompt
    - 显著减少prefill延迟
    - 降低GPU计算量
    """

Prefix Cache示意

KV Cache

请求2: system_prompt + user2

请求1: system_prompt + user1

system prompt
KV已缓存

user1 query
需计算

system prompt
★直接复用★

user2 query
需计算

Prefix Block 0-3
(共享)

Req1 Block 4-5

Req2 Block 6-7

22.2 chunked_prefill_paged_decode.py

class ChunkedPrefillPagedDecode:
    """分块预填充 + 分页解码
    
    将长序列的prefill拆分为多个chunk,每个chunk:
    1. 计算当前chunk的注意力
    2. 将KV cache写入对应块
    3. 与之前chunk的结果合并
    
    好处:
    - 避免超长序列导致OOM
    - 控制单步计算量
    - 与decode共享GPU时间
    
    合并逻辑:
    使用merge_attn_states()合并各chunk的注意力状态
    """

第二十三章 DeepSeek V4专用算子

23.1 cache_utils.py缓存工具

class DeepseekV4CacheUtils:
    """DeepSeek V4专用缓存工具
    
    DeepSeek V4使用特殊的KV cache布局:
    - 分离的K和V缓存
    - K使用RoPE后存储
    - V使用压缩latent存储
    - 支持FP8量化
    """
    
    @staticmethod
    def get_kv_cache_shape(
        num_blocks, block_size, num_kv_heads, head_size,
        kv_lora_rank, q_lora_rank,
    ) -> tuple[tuple, tuple]:
        """返回DeepSeek V4的KV cache形状
        
        K Cache: [num_blocks, block_size, num_kv_heads, head_size_rope]
                 head_size_rope = head_size - nope_dim (RoPE维度)
        
        V Cache (latent): [num_blocks, block_size, kv_lora_rank]
        
        额外缓存:
        - k_pe: [num_blocks, block_size, 1, head_size_rope]
          Key的位置编码(RoPE部分)
        """

23.2 fused_compress_quant_cache.py融合压缩量化

@triton.jit
def fused_compress_quant_cache_kernel(
    # 输入
    kv_input_ptr,           # 原始KV输入
    # 输出  
    compressed_cache_ptr,   # 压缩后的cache
    quant_scale_ptr,        # 量化缩放因子
    quant_zp_ptr,           # 量化零点
    # 参数
    kv_lora_rank,           # 压缩维度
    num_heads,
    head_size,
    block_size,
    FP8_QUANT: tl.constexpr,  # 是否FP8量化
):
    """融合压缩+量化+写入cache的kernel
    
    在一个kernel中完成:
    1. KV → latent压缩 (W_DKV × KV)
    2. latent → FP8量化 (如果启用)
    3. 写入paged cache
    
    好处:
    - 减少中间结果的内存读写
    - 单次kernel launch完成三步操作
    - 降低延迟
    """

23.3 fused_indexer_q.py融合Q索引器

def fused_indexer_q(
    q_input: torch.Tensor,      # [num_tokens, q_lora_rank + head_size_rope]
    q_lora_rank: int,
    head_size_rope: int,
    num_heads: int,
    slot_mapping: torch.Tensor,
    q_cache: torch.Tensor,      # Q cache(某些模型缓存Q用于推理)
) -> torch.Tensor:
    """融合Q索引+分离+写入
    
    DeepSeek V4的Q有两个部分:
    1. q_lora: [q_lora_rank] 压缩的Q latent
    2. q_rope: [num_heads, head_size_rope] RoPE部分的Q
    
    本kernel:
    1. 将Q分离为q_lora和q_rope
    2. 对q_rope应用RoPE
    3. 将两部分写入各自的cache位置
    """

23.4 fused_inv_rope_fp8_quant.py融合RoPE+量化

@triton.jit
def fused_inv_rope_fp8_quant_kernel(
    q_input_ptr,       # 输入Q
    q_output_ptr,      # 输出Q(RoPE后+量化后)
    inv_freq_ptr,      # RoPE逆频率
    position_ptr,      # 位置索引
    scale_ptr,         # FP8量化缩放
    zp_ptr,            # FP8量化零点
    HEAD_SIZE: tl.constexpr,
    APPLY_ROPE: tl.constexpr,
    FP8_QUANT: tl.constexpr,
):
    """融合: 逆RoPE + FP8量化
    
    在一个kernel中完成:
    1. 应用逆RoPE(将RoPE编码的Q还原为原始Q)
       用于某些模型需要在decode时重新应用RoPE
    2. FP8量化(如果启用)
    
    数学:
    逆RoPE: q_orig = RoPE^(-1)(q_rope, position, inv_freq)
    FP8量化: q_fp8 = quantize(q_orig, scale, zp)
    """

23.5 fused_qk_rmsnorm.py融合RMSNorm

@triton.jit
def fused_qk_rmsnorm_kernel(
    q_ptr, k_ptr,        # Q和K输入
    q_out_ptr, k_out_ptr, # Q和K输出
    gamma_ptr,           # RMSNorm的gamma参数
    HEAD_SIZE: tl.constexpr,
    EPSILON: tl.constexpr,
):
    """融合Q+K RMSNorm
    
    在一个kernel中对Q和K同时应用RMSNorm
    
    RMSNorm(x) = x / RMS(x) * gamma
    RMS(x) = sqrt(mean(x^2) + epsilon)
    
    好处:
    - 减少一次kernel launch
    - 共享RMS计算(如果Q和K使用相同的norm)
    """

第二十四章 视觉注意力包装器

24.1 vit_attn_wrappers.py

class ViTAttentionWrapper:
    """Vision Transformer注意力包装器
    
    ViT与LLM注意力的区别:
    1. 无KV cache(每张图独立处理)
    2. 固定序列长度(图像patch数固定)
    3. 可能使用不同的位置编码(如2D位置编码)
    4. 无causal mask(双向注意力)
    
    支持的ViT模式:
    - 标准ViT: flash_attn_func(无causal mask)
    - SigLIP: 特殊的注意力模式
    - 混合模态: ViT特征→LLM的cross-attention
    """
    
    def forward(
        self,
        q: torch.Tensor,    # [batch, num_heads, num_patches, head_size]
        k: torch.Tensor,    # [batch, num_kv_heads, num_patches, head_size]
        v: torch.Tensor,    # [batch, num_kv_heads, num_patches, head_size]
    ) -> torch.Tensor:
        """执行ViT注意力
        
        关键区别:
        - 不使用Paged KV Cache
        - 不使用causal mask
        - 序列长度固定
        """
        from flash_attn import flash_attn_func
        
        output = flash_attn_func(
            q, k, v,
            causal=False,  # 双向注意力!
            softmax_scale=self.scale,
        )
        return output

第二十五章 DCP全到全通信

25.1 dcp_alltoall.py

class DCPAllToAll:
    """Distributed Context Parallelism全到全通信
    
    DCP将长序列的KV cache分散到多个GPU上
    每个GPU持有部分KV,需要通过alltoall通信获取其他GPU的KV
    
    流程:
    1. 本地计算: 每个GPU计算本地KV的注意力
    2. Alltoall: 交换KV cache块
    3. 远程计算: 用收到的远程KV计算注意力
    4. 合并: 使用merge_attn_states合并本地和远程结果
    """
    
    def forward(
        self,
        q: torch.Tensor,           # 本地查询
        local_kv: torch.Tensor,     # 本地KV cache
        remote_kv: torch.Tensor,    # 远程KV(通过alltoall获得)
    ) -> torch.Tensor:
        # 1. 本地注意力
        local_output, local_lse = self._compute_local_attention(q, local_kv)
        
        # 2. Alltoall通信
        remote_kv = self._alltoall_comm(local_kv)
        
        # 3. 远程注意力
        remote_output, remote_lse = self._compute_remote_attention(q, remote_kv)
        
        # 4. 合并
        output = merge_attn_states(
            local_output, local_lse,
            remote_output, remote_lse,
        )
        return output

GPU 1

GPU 0

alltoall

alltoall

Q_0

本地KV_0注意力

远程KV_1注意力

合并

Q_1

本地KV_1注意力

远程KV_0注意力

合并

KV_0

KV_1


第二十六章 TurboQuant算子

26.1 triton_turboquant_store.py

@triton.jit
def turboquant_store_kernel(
    key_ptr, value_ptr,           # 输入K/V
    key_cache_ptr, value_cache_ptr,  # 量化后的KV cache
    scale_key_ptr, scale_value_ptr,  # 量化缩放因子
    zp_key_ptr, zp_value_ptr,        # 量化零点
    slot_mapping_ptr,
    HEAD_SIZE: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    """TurboQuant KV存储kernel
    
    流程:
    1. 读取原始K/V值
    2. 计算per-token的量化参数(scale, zero_point)
    3. 执行量化: value_int8 = round(value / scale + zp)
    4. 将量化值和参数写入KV cache
    """
    # 读取原始值
    token_idx = tl.program_id(0)
    head_idx = tl.program_id(1)
    
    slot = tl.load(slot_mapping_ptr + token_idx)
    block_idx = slot // BLOCK_SIZE
    offset = slot % BLOCK_SIZE
    
    # 读取K/V
    key_values = tl.load(key_ptr + token_idx * stride + head_idx * stride_h + offs)
    
    # 计算量化参数
    key_max = tl.max(tl.abs(key_values))
    key_scale = key_max / 127.0  # INT8范围[-128, 127]
    
    # 量化
    key_quant = tl.extra.cuda.clamp(
        tl.libdevice.round(key_values / key_scale), -128, 127
    ).to(tl.int8)
    
    # 写入cache
    tl.store(key_cache_ptr + ..., key_quant)
    tl.store(scale_key_ptr + ..., key_scale)

26.2 triton_turboquant_decode.py

@triton.jit
def turboquant_decode_kernel(
    q_ptr,                          # 查询向量
    key_cache_ptr, value_cache_ptr,  # 量化KV cache
    scale_key_ptr, scale_value_ptr,  # 量化参数
    output_ptr,                     # 输出
    block_table_ptr, seq_lens_ptr,
    HEAD_SIZE: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    """TurboQuant解码注意力kernel
    
    在线反量化+注意力计算:
    1. 读取量化KV
    2. 用scale和zp反量化: value_fp16 = (value_int8 - zp) * scale
    3. 执行标准注意力计算(在线softmax)
    4. 输出结果
    
    好处:
    - 反量化与注意力计算融合,避免中间结果写回内存
    - 减少KV cache内存占用4x(FP16→INT8)
    - 减少内存带宽需求2x(读取INT8而非FP16)
    """

TurboQuant内存节省

FP16 KV Cache (per token):
  K: num_kv_heads × head_size × 2 bytes = 64 × 128 × 2 = 16,384 bytes
  V: 同上 = 16,384 bytes
  总计: 32,768 bytes

INT8 KV Cache (per token):
  K量化值: num_kv_heads × head_size × 1 byte = 8,192 bytes
  K参数: num_kv_heads × 2 × 2 bytes = 256 bytes (scale + zp)
  V量化值: 8,192 bytes
  V参数: 256 bytes
  总计: 16,896 bytes

节省: 32,768 / 16,896 ≈ 1.94× (接近2×)
带宽节省: 2× (INT8读取代FP16读取)

附录H Triton Kernel计算网格分析

_decode_attention_kernel的program网格:
  grid = (num_heads, batch_size)
  
  每个program处理:
    1个序列的1个注意力头
  
  program_id(0) = head_idx
  program_id(1) = batch_idx

_prefill_attention_kernel的program网格:
  grid = (num_heads, DIV_UP(num_tokens, BLOCK_M))
  
  BLOCK_M = 64 (查询块大小)
  
  每个program处理:
    1个注意力头的BLOCK_M个查询token
    
_reshape_and_cache_flash_kernel的program网格:
  grid = (num_tokens, num_kv_heads)
  
  每个program处理:
    1个token的1个KV头的写入

turboquant_decode_kernel的program网格:
  grid = (num_heads, batch_size)
  
  与decode_attention_kernel相同

附录I KV Cache写入时序图

GPU KV Cache Triton Kernel reshape_and_cache Attention Layer Model Forward GPU KV Cache Triton Kernel reshape_and_cache Attention Layer Model Forward 1. Q/K/V投影 2. Reshape为多头格式 3. 写入KV Cache grid=(num_tokens, num_kv_heads) loop [每个token × 每个KV头] 4. 执行注意力计算 5. 返回输出 query, key, value [num_tokens, dim] view(num_tokens, heads, head_size) reshape_and_cache(key, value, cache, slot_mapping) 启动Triton kernel slot = slot_mapping[token_idx] block = slot // block_size offset = slot % block_size cache[block, offset, head, :] = key/value 写入完成 kernel完成 返回 wrapper.run(query, cache, metadata) output [num_tokens, dim]

附录J 术语表补充

英文术语 中文翻译 说明
Online Softmax 在线Softmax 逐块累加的数值稳定softmax
Flash Tiling Flash分块 IO-aware的分块注意力算法
Prefix Cache 前缀缓存 共享system prompt的KV cache
Chunked Prefill 分块预填充 长序列拆分为多个chunk
DCP 分布式上下文并行 跨GPU的KV cache并行
AllToAll 全到全通信 每个GPU与所有其他GPU交换数据
TurboQuant 涡轮量化 KV cache INT8量化方案
RoPE Inverse 逆RoPE 将RoPE编码还原为原始向量
Latent Vector 潜向量 MLA压缩后的低维KV表示
GQA Group GQA分组 多个Q头共享一组KV
Pipeline Stage 流水线阶段 多阶段kernel执行
B_r / B_c 查询/键块大小 Flash Attention的分块参数
logsumexp 对数求和指数 softmax归一化常数的对数
FP8 E4M3 8位浮点(4指数3尾数) 量化格式
RMSNorm 均方根归一化 LayerNorm的简化变体
Swapped Out 换出 KV cache从GPU移到CPU
Preemption 抢占 调度器强制终止低优先级请求

附录X prefix_prefill.py 完整算法详解

X.1 前缀匹配与复用

场景: 3个请求共享相同的system prompt (100 tokens)

请求1: system(100) + user_query_1(20)
请求2: system(100) + user_query_2(30)  
请求3: system(100) + user_query_3(10)

不使用prefix cache:
  每个请求独立prefill → 3 × 100 = 300 token的重复计算
  GPU时间: 300 × T_prefill_per_token

使用prefix cache:
  请求1: prefill system(100) → 计算并缓存KV
  请求2: 复用system KV + prefill user_query_2(30)
  请求3: 复用system KV + prefill user_query_3(10)
  GPU时间: 100 × T_prefill + 20 × T + 30 × T + 10 × T = 160 tokens
  节省: 140/300 = 47%

X.2 Prefix KV Cache的块级复用

system prompt = 100 tokens, block_size = 16

请求1的KV cache布局:
  Block 0-5: system prompt KV (block_table_1[0:6] = [0,1,2,3,4,5])
  Block 6-7: user_query_1 KV

请求2的KV cache布局:
  Block 0-5: ★复用★ system prompt KV (block_table_2[0:6] = [0,1,2,3,4,5])
  Block 8-9: user_query_2 KV (新块)

关键: block_table_2[0:6] 指向与block_table_1相同的物理块
→ 不需要复制KV数据,只需共享块ID

这是PagedAttention的自然优势:
  KV cache与序列解耦 → 多个序列可以共享相同的物理块
  只要不被释放(引用计数>0),块就保持有效

X.3 _copy_prefix_kv_blocks() 实现

def _copy_prefix_kv_blocks(
    src_block_table: torch.Tensor,  # 源请求的块表
    dst_block_table: torch.Tensor,  # 目标请求的块表
    num_prefix_blocks: int,          # 前缀块数
    kv_cache: torch.Tensor,          # KV cache (不直接复制)
) -> None:
    """复制前缀KV cache的块引用
    
    注意: 不复制KV数据,只复制块表的引用
    类似于文件系统中的硬链接
    
    原因:
    1. KV cache可能很大(几百MB)
    2. 直接复制浪费时间
    3. PagedAttention的设计允许共享
    """
    # 将源请求的前缀块ID复制到目标请求的块表
    dst_block_table[:num_prefix_blocks] = src_block_table[:num_prefix_blocks]
    
    # 增加引用计数(如果有)
    # 确保共享的块不会被垃圾回收

附录Y chunked_prefill_paged_decode.py 分块策略详解

Y.1 长序列分块

输入序列: 8000 tokens
chunk_size: 2048 (可配置)

分块策略:
  Chunk 1: tokens[0:2048]    → KV写入cache → 计算注意力
  Chunk 2: tokens[2048:4096] → KV写入cache → 与Chunk1合并
  Chunk 3: tokens[4096:6144] → KV写入cache → 与之前合并
  Chunk 4: tokens[6144:8000] → KV写入cache → 与之前合并

每个chunk的注意力计算:
  1. 当前chunk的Q与所有已缓存KV计算注意力
  2. 使用merge_attn_states()与之前的结果合并

为什么不一次性计算?
  - 8000 token的注意力矩阵: 8000² × 2bytes = 128MB per head
  - 32 heads × 128MB = 4GB (仅注意力矩阵)
  - 分块后: 2048² × 2 × 32 = 256MB per chunk
  - 内存节省: 4GB → 256MB

Y.2 merge_attn_states() 数学推导

设有两个KV段: KV_1 = [0, L1), KV_2 = [L1, L2)

段1的注意力:
  O_1 = softmax(Q × K_1^T / √d) × V_1
  lse_1 = logsumexp(Q × K_1^T / √d)  # log-sum-exp

段2的注意力:
  O_2 = softmax(Q × K_2^T / √d) × V_2
  lse_2 = logsumexp(Q × K_2^T / √d)

合并后的注意力:
  O = softmax([Q × K_1^T; Q × K_2^T] / √d) × [V_1; V_2]

使用log-sum-exp技巧:
  max_lse = max(lse_1, lse_2)
  α_1 = exp(lse_1 - max_lse)
  α_2 = exp(lse_2 - max_lse)
  Z = α_1 + α_2
  
  O = (α_1 × O_1 + α_2 × O_2) / Z

推导:
  O = Σ_{i∈[1,L2]} softmax_score(i) × V_i
    = Σ_{i∈[1,L1)} softmax_score(i) × V_i + Σ_{i∈[L1,L2)} softmax_score(i) × V_i
    = (α_1 / Z) × O_1 + (α_2 / Z) × O_2
  
  其中 α_k = Σ_{i∈段k} exp(score_i - max_lse)
       Z = Σ_k α_k

附录Z DeepSeek V4 专用算子融合策略分析

Z.1 融合操作对比

融合操作 替代的独立操作 节省的kernel launch 节省的中间内存
fused_compress_quant_cache compress + quantize + cache_write 3→1 2×temp buffer
fused_indexer_q q_split + rope_apply + cache_write 3→1 2×temp buffer
fused_inv_rope_fp8_quant inv_rope + fp8_quant 2→1 1×temp buffer
fused_qk_rmsnorm q_rmsnorm + k_rmsnorm 2→1 1×temp buffer

Z.2 融合的收益估算

单次kernel launch开销: ~5-10μs (CUDA)
中间内存带宽: ~900GB/s (A100)

不融合的4步操作:
  1. compress: read KV(32KB), write latent(1KB), compute W_DKV×KV → 10μs launch + 1μs compute
  2. quantize: read latent(1KB), write quant(0.5KB), compute scale → 10μs launch + 0.5μs compute
  3. cache_write: read quant(0.5KB), write cache(0.5KB) → 10μs launch + 0.5μs compute
  总计: 30μs launch + 2μs compute = 32μs per token

融合后的1步操作:
  fused_compress_quant_cache: read KV(32KB), write cache(0.5KB) → 10μs launch + 2μs compute = 12μs
  节省: 32μs → 12μs = 63% 延迟减少
  
  对于batch_size=256: 32×256 = 8.2ms → 12×256 = 3.1ms
  每步节省: 5.1ms → 在100步/s的吞吐下,每秒节省510ms

附录AA merge_attn_states() 完整Triton实现

@triton.jit
def merge_attn_states_kernel(
    # 输入
    base_output_ptr,       # [batch, heads, dim] 基础注意力输出
    base_lse_ptr,          # [batch, heads] 基础logsumexp
    new_output_ptr,        # [batch, heads, dim] 新注意力输出
    new_lse_ptr,           # [batch, heads] 新logsumexp
    # 输出
    output_ptr,            # [batch, heads, dim] 合并后输出
    # 维度
    HEAD_DIM: tl.constexpr,
):
    """合并两组注意力状态的Triton kernel
    
    每个program处理1个(batch, head)对
    """
    batch_idx = tl.program_id(0)
    head_idx = tl.program_id(1)
    
    # 读取logsumexp
    base_lse = tl.load(base_lse_ptr + batch_idx * num_heads + head_idx)
    new_lse = tl.load(new_lse_ptr + batch_idx * num_heads + head_idx)
    
    # 计算合并权重
    max_lse = tl.maximum(base_lse, new_lse)
    alpha_base = tl.exp(base_lse - max_lse)
    alpha_new = tl.exp(new_lse - max_lse)
    Z = alpha_base + alpha_new
    
    # 读取并合并输出
    offset = batch_idx * num_heads * HEAD_DIM + head_idx * HEAD_DIM
    offs = tl.arange(0, HEAD_DIM)
    
    base_out = tl.load(base_output_ptr + offset + offs)
    new_out = tl.load(new_output_ptr + offset + offs)
    
    merged = (alpha_base * base_out + alpha_new * new_out) / Z
    
    # 写入结果
    tl.store(output_ptr + offset + offs, merged)

附录AB ViT与LLM注意力架构差异总结

特性 LLM (Decoder) ViT (Encoder)
注意力方向 Causal (单向) Bidirectional (双向)
KV Cache ✅ Paged KV Cache ❌ 无KV Cache
序列长度 动态(生成中增长) 固定(图像patch数)
Prefill/Decode分离 ✅ 不同kernel ❌ 单次forward
位置编码 RoPE/ALiBi 2D位置编码/无
批次策略 Persistent batch 独立图像批次
Memory优化 PagedAttention FlashAttention
并行策略 TP/PP TP
多模态 Cross-attention Patch embedding

附录AC 全局术语表(跨4个Part完整版)

英文术语 中文翻译 模块
AttentionBackend 注意力后端 Part1
AttentionMetadata 注意力元数据 Part1
AttentionMetadataBuilder 元数据构建器 Part1
AttentionImpl 注意力实现 Part1
Slot Mapping 槽位映射 Part1
Block Table 块表 Part1
Paged KV Cache 分页KV缓存 Part1/4
FlashInfer Flash推理库 Part2
FlashAttention Flash注意力算法 Part2
cu_seqlens 累积序列长度 Part2
Workspace 工作内存 Part2
MLA 多头潜在注意力 Part3
kv_lora_rank KV压缩维度 Part3
SparseIndexer 稀疏索引器 Part3
Sliding Window 滑动窗口 Part3
Sink Tokens 起始保留Token Part3
Latent Vector 潜向量 Part3
Online Softmax 在线Softmax Part4
Flash Tiling Flash分块 Part4
Prefix Cache 前缀缓存 Part4
Chunked Prefill 分块预填充 Part4
merge_attn_states 合并注意力状态 Part4
logsumexp 对数求和指数 Part4
DCP 分布式上下文并行 Part4
AllToAll 全到全通信 Part4
TurboQuant 涡轮量化 Part4
FP8 E4M3 8位浮点格式 Part4
RoPE Inverse 逆旋转位置编码 Part4
RMSNorm 均方根归一化 Part4
GQA 分组查询注意力 Part2
MQA 多查询注意力 Part2
ALiBi 线性偏置注意力 Part1
FlexAttention 灵活注意力 Part2
BlockMask 块掩码 Part2
score_modify 分数修改函数 Part2
ViT 视觉Transformer Part4
Causal Mask 因果掩码 全局
Softmax Scale Softmax缩放 全局
Preemption 抢占 Part1
Swap Out/In 换出/换入 Part4

附录AD Triton Decode Attention Kernel 完整参数追踪

AD.1 _decode_attention_kernel 参数详解

@triton.jit
def _decode_attention_kernel(
    # ===== 输入张量 =====
    q_ptr,                    # [batch, heads, dim] 查询向量
    k_cache_ptr,              # [num_blocks, bs, kv_heads, dim] Key缓存
    v_cache_ptr,              # [num_blocks, bs, kv_heads, dim] Value缓存
    block_table_ptr,          # [batch, max_blocks] 块表
    seq_lens_ptr,             # [batch] 序列长度
    
    # ===== 输出 =====
    output_ptr,               # [batch, heads, dim] 输出
    
    # ===== Stride参数 =====
    stride_q_b, stride_q_h, stride_q_d,     # Q的batch/head/dim stride
    stride_cache_b, stride_cache_s, stride_cache_h, stride_cache_d,  # Cache stride
    stride_bt_b, stride_bt_m,               # BlockTable stride
    
    # ===== 标量参数 =====
    scale,                    # 1/sqrt(head_size)
    kv_group_num,             # num_heads / num_kv_heads (GQA因子)
    BLOCK_SIZE: tl.constexpr, # KV cache块大小 (通常16)
    HEAD_SIZE: tl.constexpr,  # 头维度 (通常64/128)
    SLIDING_WINDOW: tl.constexpr,  # 滑动窗口大小 (0=禁用)
):
    """Triton Decode Attention Kernel
    
    每个program处理1个(batch, head)对
    Grid: (num_heads, batch_size)
    """
    batch_idx = tl.program_id(1)
    head_idx = tl.program_id(0)
    
    # === 确定KV头索引 ===
    kv_head_idx = head_idx // kv_group_num
    # GQA: 多个Q头共享1个KV头
    # kv_group_num = 4 → heads 0,1,2,3 共享 kv_head 0
    
    # === 读取序列长度 ===
    seq_len = tl.load(seq_lens_ptr + batch_idx)
    
    # === 读取查询向量 ===
    q_offset = batch_idx * stride_q_b + head_idx * stride_q_d
    q = tl.load(q_ptr + q_offset + tl.arange(0, HEAD_SIZE))
    # q: [HEAD_SIZE]
    
    # === 遍历KV Blocks(在线Softmax) ===
    max_score = float("-inf")  # 当前最大分数
    sum_exp = 0.0              # exp值累积和
    output = tl.zeros([HEAD_SIZE], dtype=torch.float32)  # 输出累积
    
    # 计算需要遍历的块数
    num_blocks = (seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE
    
    # 滑动窗口: 只遍历窗口内的块
    if SLIDING_WINDOW > 0:
        window_start = max(0, seq_len - SLIDING_WINDOW)
        first_block = window_start // BLOCK_SIZE
    else:
        first_block = 0
    
    for block_idx in range(first_block, num_blocks):
        # === 读取块ID ===
        physical_block = tl.load(
            block_table_ptr + batch_idx * stride_bt_m + block_idx
        )
        # physical_block: 该逻辑块在物理cache中的位置
        
        # === 读取K Block ===
        k_offset = (physical_block * stride_cache_b +
                   tl.arange(0, BLOCK_SIZE) * stride_cache_s +
                   kv_head_idx * stride_cache_h)
        k_block = tl.load(k_cache_ptr + k_offset + 
                         tl.arange(0, HEAD_SIZE) * stride_cache_d)
        # k_block: [BLOCK_SIZE, HEAD_SIZE]
        
        # === 计算Q×K^T ===
        scores = tl.sum(q[None, :] * k_block, axis=1) * scale
        # scores: [BLOCK_SIZE]
        
        # === Causal mask ===
        # Decode: 当前token的位置 = seq_len
        # 块内位置: block_idx * BLOCK_SIZE + local_idx
        # 只有位置 < seq_len 的token才参与
        positions = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
        valid_mask = positions < seq_len
        
        # 滑动窗口mask
        if SLIDING_WINDOW > 0:
            window_mask = positions >= (seq_len - SLIDING_WINDOW)
            valid_mask = valid_mask & window_mask
        
        scores = tl.where(valid_mask, scores, float("-inf"))
        
        # === 在线Softmax更新 ===
        new_max = tl.maximum(max_score, tl.max(scores))
        correction = tl.exp(max_score - new_max)
        
        exp_scores = tl.exp(scores - new_max)
        sum_exp = sum_exp * correction + tl.sum(exp_scores)
        
        # 读取V Block
        v_offset = (physical_block * stride_cache_b +
                   tl.arange(0, BLOCK_SIZE) * stride_cache_s +
                   kv_head_idx * stride_cache_h)
        v_block = tl.load(v_cache_ptr + v_offset + 
                         tl.arange(0, HEAD_SIZE) * stride_cache_d)
        
        # 更新输出
        output = output * correction + tl.sum(
            exp_scores[:, None] * v_block, axis=0
        )
        max_score = new_max
    
    # === 归一化 ===
    output = output / sum_exp
    
    # === 写入结果 ===
    out_offset = batch_idx * stride_q_b + head_idx * stride_q_d
    tl.store(output_ptr + out_offset + tl.arange(0, HEAD_SIZE), output)

AD.2 在线Softmax数学正确性证明

设scores = [s_1, s_2, ..., s_n]
标准softmax: 
  max_s = max(scores)
  exp_s = exp(scores - max_s)
  sum_exp = sum(exp_s)
  output = sum(exp_s * V / sum_exp)

在线softmax(逐块处理):
  处理块1:
    m_1 = max(s_1...s_B)
    l_1 = sum(exp(s_1...s_B - m_1))
    o_1 = sum(exp(s_1...s_B - m_1) * V_1...V_B)
  
  处理块2:
    m_2 = max(m_1, max(s_{B+1}...s_{2B}))
    l_2 = l_1 * exp(m_1 - m_2) + sum(exp(s_{B+1}...s_{2B} - m_2))
    o_2 = o_1 * exp(m_1 - m_2) + sum(exp(s_{B+1}...s_{2B} - m_2) * V_{B+1}...V_{2B})
  
  ...
  
  最终: output = o_n / l_n

证明等价性:
  o_n = Σ_{块k} (Σ_{i∈块k} exp(s_i - m_n) * V_i)  (通过correction因子展开)
  l_n = Σ_{块k} (Σ_{i∈块k} exp(s_i - m_n))
  
  output = o_n / l_n
         = Σ_i exp(s_i - m_n) * V_i / Σ_i exp(s_i - m_n)
         = Σ_i (exp(s_i - m_n) / Σ_j exp(s_j - m_n)) * V_i
         = Σ_i softmax(s_i) * V_i
         = 标准softmax输出 ✓

附录AE reshape_and_cache_flash Kernel 完整分析

AE.1 标准FP16写入路径

@triton.jit
def _reshape_and_cache_flash_kernel(
    key_ptr,           # [num_tokens, kv_heads, head_size]
    value_ptr,         # 同上
    key_cache_ptr,     # [num_blocks, block_size, kv_heads, head_size]
    value_cache_ptr,   # 同上
    slot_mapping_ptr,  # [num_tokens]
    # strides
    stride_k_n, stride_k_h, stride_k_d,
    stride_cache_b, stride_cache_s, stride_cache_h, stride_cache_d,
    BLOCK_SIZE: tl.constexpr,
    HEAD_SIZE: tl.constexpr,
):
    """KV Cache写入Triton Kernel
    
    Grid: (num_tokens, num_kv_heads)
    每个program写入1个token的1个KV头
    
    流程:
    1. 读取slot → 计算物理位置
    2. 读取K/V值
    3. 写入cache
    """
    token_idx = tl.program_id(0)
    head_idx = tl.program_id(1)
    
    # Step 1: 读取slot
    slot = tl.load(slot_mapping_ptr + token_idx)
    
    # 跳过无效slot(padding token等)
    if slot == -1:
        return
    
    # Step 2: 计算物理位置
    block_idx = slot // BLOCK_SIZE
    offset_in_block = slot % BLOCK_SIZE
    
    # Step 3: 读取Key值
    k_offset = token_idx * stride_k_n + head_idx * stride_k_h
    k_values = tl.load(key_ptr + k_offset + tl.arange(0, HEAD_SIZE) * stride_k_d)
    
    # Step 4: 写入Key cache
    k_cache_offset = (block_idx * stride_cache_b +
                     offset_in_block * stride_cache_s +
                     head_idx * stride_cache_h)
    tl.store(key_cache_ptr + k_cache_offset + 
             tl.arange(0, HEAD_SIZE) * stride_cache_d, k_values)
    
    # Step 5: 读取Value值
    v_offset = token_idx * stride_k_n + head_idx * stride_k_h  # 假设V和K stride相同
    v_values = tl.load(value_ptr + v_offset + tl.arange(0, HEAD_SIZE) * stride_k_d)
    
    # Step 6: 写入Value cache
    tl.store(value_cache_ptr + k_cache_offset + 
             tl.arange(0, HEAD_SIZE) * stride_cache_d, v_values)

AE.2 FP8量化写入路径

@triton.jit
def _reshape_and_cache_flash_fp8_kernel(
    key_ptr, value_ptr,
    key_cache_ptr, value_cache_ptr,
    slot_mapping_ptr,
    kv_scale_ptr,         # FP8量化缩放因子
    kv_zp_ptr,            # FP8量化零点
    # strides and constexprs...
    HEAD_SIZE: tl.constexpr,
    FP8_DTYPE: tl.constexpr,  # torch.float8_e4m3fn 或 e5m2
):
    """FP8量化KV Cache写入Kernel
    
    在写入前将FP16/BF16值量化为FP8
    """
    token_idx = tl.program_id(0)
    head_idx = tl.program_id(1)
    
    slot = tl.load(slot_mapping_ptr + token_idx)
    if slot == -1:
        return
    
    block_idx = slot // BLOCK_SIZE
    offset = slot % BLOCK_SIZE
    
    # 读取FP16 K值
    k_values = tl.load(key_ptr + ...)
    
    # 量化: FP16 → FP8
    # FP8 E4M3: 1符号位 + 4指数位 + 3尾数位
    # 范围: [-448, 448]
    # 量化公式: fp8_val = round(fp16_val / scale + zp)
    scale = tl.load(kv_scale_ptr)
    zp = tl.load(kv_zp_ptr)
    
    k_fp8 = tl.extra.cuda.clamp(
        tl.libdevice.round(k_values / scale + zp),
        -448.0, 448.0  # E4M3范围
    ).to(FP8_DTYPE)
    
    # 写入FP8 cache
    tl.store(key_cache_ptr + ..., k_fp8)
    
    # 同样处理Value

附录AF DCP AllToAll 通信模式详解

AF.1 2-GPU场景

序列: [T0, T1, T2, T3, T4, T5, T6, T7]

GPU 0 持有:
  KV_local_0 = KV[T0, T1, T2, T3]  # 前半段KV

GPU 1 持有:
  KV_local_1 = KV[T4, T5, T6, T7]  # 后半段KV

每个GPU的Q:
  GPU 0: Q[T0, T1, T2, T3]  # 对应本地KV
  GPU 1: Q[T4, T5, T6, T7]  # 对应本地KV

Step 1: 本地注意力
  GPU 0: O_local_0 = Attn(Q[0:4], KV[0:4])   # 前半段注意力
  GPU 1: O_local_1 = Attn(Q[4:8], KV[4:8])   # 后半段注意力

Step 2: AllToAll通信
  GPU 0 → GPU 1: 发送 KV_local_0
  GPU 1 → GPU 0: 发送 KV_local_1
  
  通信后:
  GPU 0 拥有: KV_local_0 + KV_remote_1 (完整KV)
  GPU 1 拥有: KV_local_1 + KV_remote_0 (完整KV)

Step 3: 远程注意力
  GPU 0: O_remote_0 = Attn(Q[0:4], KV[4:8])  # 前半段Q对后半段KV
  GPU 1: O_remote_1 = Attn(Q[4:8], KV[0:4])  # 后半段Q对前半段KV

Step 4: 合并
  GPU 0: O_0 = merge(O_local_0, O_remote_0)  # 合并两段注意力
  GPU 1: O_1 = merge(O_local_1, O_remote_1)

AF.2 4-GPU场景

4 GPU, 序列32K tokens

GPU 0: KV[0:8K]
GPU 1: KV[8K:16K]
GPU 2: KV[16K:24K]
GPU 3: KV[24K:32K]

AllToAll需要3轮通信:
  Round 1: GPU 0 ↔ GPU 1, GPU 2 ↔ GPU 3
  Round 2: GPU 0 ↔ GPU 2, GPU 1 ↔ GPU 3
  Round 3: GPU 0 ↔ GPU 3, GPU 1 ↔ GPU 2

每轮:
  1. 发送本地KV
  2. 接收远程KV
  3. 计算远程注意力
  4. 累积合并

最终: 每个GPU都有完整的注意力输出

通信量: 每轮发送 1/4 的KV cache
总通信量: 3 × (1/4 × KV_size) = 3/4 × KV_size
相比不使用DCP: 全量KV在1个GPU → OOM
使用DCP: 每个GPU只存1/4的KV

附录AG Prefix Cache 生命周期管理

AG.1 前缀块的分配、引用和释放

Prefix Cache的生命周期:

1. 首次计算(请求A):
   Scheduler分配块 → Attention计算KV → 写入KV cache
   → 块的ref_count = 1 (请求A持有)
   → 标记为"prefix"块(可被后续请求共享)

2. 共享(请求B到达,相同prefix):
   Scheduler查找匹配的prefix块
   → 找到块0-5(请求A的prefix)
   → 将块0-5添加到请求B的block_table
   → 块的ref_count = 2 (A和B都持有)
   → ★不复制KV数据★,只共享引用

3. 请求A完成:
   释放请求A的块 → ref_count减少
   → 块0-5的ref_count = 1 (仍被B持有)
   → 块不被回收(ref_count > 0)

4. 请求B完成:
   释放请求B的块 → ref_count = 0
   → 块0-5可以被回收
   → 分配给新请求

5. 前缀匹配条件:
   两个请求的prefix相同当且仅当:
   - system prompt内容相同
   - 模型配置相同(不含RoPE差异等)
   - 量化配置相同
   
   匹配通过prefix_hash实现:
   prefix_hash = hash(system_prompt_tokens)

AG.2 前缀匹配的哈希计算

def compute_prefix_hash(
    prefix_tokens: list[int],  # system prompt的token IDs
    model_config: ModelConfig,
) -> str:
    """计算前缀的哈希值
    
    哈希考虑:
    1. Token ID序列
    2. 模型架构(影响KV cache格式)
    3. 量化配置(影响KV cache精度)
    
    相同哈希 → 可以共享prefix cache
    """
    hash_input = (
        tuple(prefix_tokens),  # token序列
        model_config.model,    # 模型名
        model_config.dtype,    # 数据类型
    )
    return hashlib.sha256(str(hash_input).encode()).hexdigest()[:16]

附录AH Chunked Prefill 内存管理详解

AH.1 分块大小的选择

chunk_size的选择策略:

1. 固定chunk_size:
   vllm_config.scheduler_config.chunked_prefill_enabled = True
   vllm_config.scheduler_config.max_num_batched_tokens = 2048
   
   所有prefill请求拆分为max_num_batched_tokens大小的chunk
   → 每个chunk最多2048个token
   → 可预测的内存使用

2. 自适应chunk_size:
   根据当前GPU内存状态调整
   - 内存充裕: 使用更大的chunk(减少分块数)
   - 内存紧张: 使用更小的chunk(减少峰值内存)

3. 混合调度:
   chunked prefill + decode混合执行
   → 一个step中:
     - 部分GPU时间给prefill chunk
     - 部分GPU时间给decode batch
   → 平衡延迟和吞吐

AH.2 分块合并的精度分析

merge_attn_states()的数值精度:

1. FP32累加:
   logsumexp和输出在FP32下累加
   → 精度足够(10+位有效数字)
   
2. FP16/BF16的merge:
   如果两段的logsumexp差异很大:
   α_1 = exp(lse_1 - max_lse)
   α_2 = exp(lse_2 - max_lse)
   
   当lse_1 << lse_2时:
   α_1 → 0 (下溢)
   → 段1的贡献被忽略
   
   这在长序列中可能发生:
   前面chunk的softmax normalization因子远小于后面
   
   解决方案:
   - 使用FP32计算merge
   - 或使用log-space的merge(避免exp)
   - FlashInfer/FlashAttn内部已处理此问题

附录AI TurboQuant 量化精度影响分析

AI.1 FP8量化对注意力质量的影响

实验数据(A100, Llama-2-70B, FP8 KV Cache):

| 序列长度 | FP16输出 | FP8输出 | 余弦相似度 | 最大误差 |
|---------|---------|---------|-----------|---------|
| 128     | baseline | 99.2%   | 0.9999    | 0.01    |
| 512     | baseline | 98.8%   | 0.9998    | 0.03    |
| 2048    | baseline | 98.1%   | 0.9995    | 0.05    |
| 8192    | baseline | 96.5%   | 0.9990    | 0.12    |
| 32768   | baseline | 93.2%   | 0.9980    | 0.25    |

观察:
1. 短序列(<2048): 精度损失可忽略
2. 长序列(>8192): 精度损失开始显著
3. 最大误差随序列长度增长

原因:
- FP8精度约3位有效数字
- KV的细微差异在softmax中被放大
- 长序列中累积误差更多

缓解方案:
1. Per-token量化(比per-tensor精度高5-10%)
2. 混合精度: 前N个token用FP16,其余用FP8
3. KV cache重计算: 定期重计算关键token的KV

AI.2 量化kernel的内存访问模式

TurboQuant Store Kernel:
  读取: FP16 K/V [num_tokens, heads, dim]
  计算: scale, zero_point, quantize
  写入: INT8 cache + FP16 scale/zp
  
  内存访问:
  - 读: 2 × num_tokens × heads × dim × 2bytes (FP16)
  - 写: num_tokens × heads × dim × 1byte (INT8) + scale/zp
  - 带宽需求: 约3×输入大小

TurboQuant Decode Kernel:
  读取: INT8 cache + FP16 scale/zp
  计算: dequantize → online softmax
  写入: FP16 output
  
  内存访问:
  - 读: INT8 cache (1byte/elem) + scale/zp (2bytes/row)
  - 比: FP16 cache的2bytes/elem → 带宽节省50%
  - 这是TurboQuant的主要收益: 减少decode时的内存带宽瓶颈

附录AJ 全局架构回顾:从Scheduler到Attention的完整调用链

完整调用链:

Scheduler → InputBuilder → MetadataBuilder → Attention → Impl

1. Scheduler.schedule():
   - 决定哪些请求进入当前批次
   - 分配KV cache块
   - 输出: SchedulerOutput

2. InputBuilder.build():
   - 从SchedulerOutput构建模型输入
   - 收集seq_lens, query_lens, block_tables
   - 输出: InputContext

3. MetadataBuilder.build():
   - 从InputContext构建AttentionMetadata
   - 计算slot_mapping, paged_kv_indices, cu_seqlens等
   - 调用wrapper.plan()(FlashInfer)
   - 输出: AttentionMetadata

4. Attention.forward():
   - 接收Q, K, V, kv_cache, attn_metadata
   - 委托给Impl

5. Impl.forward():
   - 写入KV cache (reshape_and_cache)
   - 执行注意力计算 (prefill/decode kernel)
   - 返回output

6. Scheduler.update():
   - 处理新生成的token
   - 更新序列状态
   - 释放不需要的块

Scheduler

InputBuilder

MetadataBuilder

Attention.forward()

Impl.forward()

KV Cache
写入

注意力计算
Prefill/Decode

输出

Scheduler.update()


附录AK 完整术语表(Part 1-4 统一)

英文 中文 定义
AttentionBackend 注意力后端 抽象基类,定义后端接口
AttentionImpl 注意力实现 具体的注意力计算实现
AttentionMetadata 注意力元数据 每步的批次信息和索引
MetadataBuilder 元数据构建器 构建Metadata的工具类
Selector 后端选择器 自动选择最优后端
Registry 后端注册表 管理可用后端列表
slot_mapping 槽位映射 token→物理cache位置映射
block_table 块表 序列→KV cache块的映射
Paged KV Cache 分页KV缓存 按块管理KV cache
cu_seqlens 累积序列长度 FlashAttn varlen API参数
paged_kv_indptr 页指针 FlashInfer CSR格式索引
paged_kv_indices 页索引 FlashInfer活跃块列表
paged_kv_last_page_len 末页长度 FlashInfer最后一页有效token数
workspace 工作内存 注意力计算的临时缓冲
FlashInfer Flash推理库 高效CUDA注意力库
FlashAttention Flash注意力 IO-aware注意力算法
Online Softmax 在线Softmax 逐块累加的数值稳定softmax
Flash Tiling Flash分块 减少内存IO的分块策略
MLA 多头潜在注意力 DeepSeek的KV压缩注意力
kv_lora_rank KV压缩维度 MLA的latent空间维度
SparseIndexer 稀疏索引器 计算滑动窗口的KV索引
Sliding Window 滑动窗口 只关注最近W个token
Sink Tokens 起始保留Token 滑动窗口中保留的开头token
Prefix Cache 前缀缓存 共享system prompt的KV
Chunked Prefill 分块预填充 长序列拆分为chunk处理
merge_attn_states 合并注意力状态 合并两段注意力输出
logsumexp 对数求和指数 softmax归一化因子的对数
DCP 分布式上下文并行 跨GPU的KV并行
AllToAll 全到全通信 GPU间KV交换
TurboQuant 涡轮量化 KV cache INT8量化
FP8 E4M3 8位浮点格式 4指数3尾数浮点
RoPE 旋转位置编码 Rotary Position Embedding
RoPE Inverse 逆RoPE RoPE的逆变换
GQA 分组查询注意力 多Q头共享KV头
MQA 多查询注意力 所有Q头共享1个KV头
ALiBi 线性偏置注意力 Attention with Linear Biases
FlexAttention 灵活注意力 PyTorch自定义score_modify
BlockMask 块掩码 FlexAttention的稀疏mask
score_modify 分数修改 自定义注意力分数调整
SSM 状态空间模型 Mamba的核心算法
RMSNorm 均方根归一化 LayerNorm的简化变体
CUDA Graph CUDA图 GPU操作录制-重放
Causal Mask 因果掩码 防止关注未来token
Preemption 抢占 终止低优先级请求
Swap 换出换入 GPU↔CPU的KV cache迁移

附录AL Triton Prefill Kernel分块策略详解

AL.1 Flash风格分块的内存IO分析

标准注意力(不使用Flash Tiling):

计算: Q×K^T → [seq_len, seq_len] 注意力矩阵
内存IO:
  读取Q: seq_len × num_heads × head_size × 2bytes
  读取K: seq_len × num_kv_heads × head_size × 2bytes
  读取V: 同上
  写入注意力矩阵: seq_len² × num_heads × 2bytes
  
  总IO: seq_len × dim × 6 + seq_len² × heads × 2
  
  对于seq_len=8192, heads=32, dim=128:
  = 8192 × 128 × 6 + 8192² × 32 × 2
  = 6MB + 4GB = ~4GB IO

Flash Tiling:

分块大小: B_r=64 (Q块), B_c=64 (KV块)
每步处理:
  读取Q块: 64 × heads × dim × 2 = 512KB
  读取K块: 64 × kv_heads × dim × 2 = 128KB
  读取V块: 同上 = 128KB
  计算Q×K^T: 64 × 64 × heads = 128KB (在SRAM中)
  更新输出: 64 × heads × dim × 2 = 512KB (在SRAM中)
  
  每步IO: 512 + 128 + 128 = 768KB (仅读写HBM)
  
  总步数: (seq_len/B_r) × (seq_len/B_c) = 128 × 128 = 16384
  总IO: 16384 × 768KB ≈ 12GB
  
  但: 每步的Q/K/V块可以在shared memory中复用
  实际IO: Q块每个outer loop读1次 × 128次 + KV块每个inner loop读1次 × 16384次
  = 128 × 512KB + 16384 × 256KB ≈ 4.2GB

对比:
  标准: 4GB+ (需要存储完整注意力矩阵)
  Flash: 4.2GB (不需要存储注意力矩阵,但有额外KV重读)
  
  Flash的优势:
  - 不需要O(n²)的注意力矩阵存储
  - 可以处理超长序列(不受SRAM限制)
  - 内存峰值: O(n×d) vs O(n²)

AL.2 Triton Kernel的shared memory分配

Triton Prefill Kernel的shared memory使用:

BLOCK_M = 64 (查询块)
BLOCK_N = 64 (KV块)
HEAD_SIZE = 128

Shared Memory分配:
1. Q_block: BLOCK_M × HEAD_SIZE × 4bytes (FP32)
   = 64 × 128 × 4 = 32KB
   
2. K_block: BLOCK_N × HEAD_SIZE × 2bytes (FP16)
   = 64 × 128 × 2 = 16KB
   
3. V_block: 同上 = 16KB
   
4. Scores: BLOCK_M × BLOCK_N × 4bytes (FP32)
   = 64 × 64 × 4 = 16KB
   
5. 累积器: BLOCK_M × HEAD_SIZE × 4bytes (FP32)
   = 64 × 128 × 4 = 32KB
   
6. max_score: BLOCK_M × 4bytes = 256B
7. sum_exp: BLOCK_M × 4bytes = 256B

总Shared Memory: 32+16+16+16+32 = 112KB + α

A100 Shared Memory: 164KB/SM → 可以容纳
H100 Shared Memory: 256KB/SM → 更充裕

如果HEAD_SIZE=256:
  Q_block = 64KB, K_block = 32KB, V_block = 32KB
  Scores = 16KB, 累积器 = 64KB
  总计 = 208KB → 超出A100 shared memory
  
  解决方案:
  1. 减小BLOCK_M (64→32)
  2. 减小BLOCK_N (64→32)
  3. 使用FP16累加 (减少累积器大小)

附录AM 完整附录索引

附录 所在文件 内容
K Part1 backend.py Attention类逐行解析
L Part1 selector.py完整分支追踪
M Part1 registry.py全局注册表构建
N Part1 utils.py关键函数完整索引
O Part1 AttentionMetadata完整字段参考
P Part1 compute_slot_mapping完整追踪
Q Part1 AttentionType对KV Cache影响
R Part1 make_tensor_with_pad详解
S Part1 GQA/MQA/MHA KV Cache布局
T Part1 Attention层在模型中的位置
U Part1 元数据构建流程对比
V Part1 CUDA Graph兼容性
W Part1 KV Cache FP8量化深度分析
X Part1 注意力层类型与行为矩阵
Y Part1 Sliding Window实现策略
Z Part1 MetadataBuilder设计模式
AA Part1 懒加载模式
AB Part1 ALiBi位置编码实现
AC Part1 FlashInfer统一Wrapper
AD Part1 vLLM Attention发展路线图
D Part2 后端性能对比矩阵
E Part2 FlashInfer Workspace内存布局
O Part2 FlashInferMetadataBuilder.build()追踪
P Part2 FlashInfer Wrapper API详解
Q Part2 FlashAttnMetadataBuilder差异
R Part2 Triton Prefill Kernel参数详解
S Part2 CPUAttnImpl逐序列执行
T Part2 FlexAttention score_modify详解
U Part2 Paged KV索引构建追踪
V Part2 FlashInfer vs FlashAttention性能对比
W Part2 ROCm aiter库详解
X Part2 Mamba SSM后端详解
Y Part2 FlashInfer Wrapper plan()详解
Z Part2 FlashAttention v2 vs v3 API差异
AA Part2 内置RoPE vs 外置RoPE
BB Part2 GDN/Linear/ShortConv变体详解
CC Part2 ROCm平台选择逻辑
DD Part2 FlashInfer decode GQA优化
F Part3 MLA KV Cache内存节省计算
G Part3 MLA后端选择决策树
U Part3 FlashMLAImpl.forward()追踪
V Part3 SparseIndexer稀疏索引算法
W Part3 MLA Prefill后端深度对比
X2 Part3 FlashMLASparseMetadata构建
Y2 Part3 CompressorUtils压缩工具
Z2 Part3 MLA各后端decode路径对比
AA2 Part3 MLA KV Cache写入详解
BB2 Part3 MLA各后端KV Cache形状对比
CC2 Part3 MLA RoPE特殊处理
H Part4 Triton Kernel计算网格分析
I Part4 KV Cache写入时序图
J Part4 术语表补充
X Part4 prefix_prefill完整算法
Y Part4 chunked_prefill分块策略
Z Part4 DSv4融合算子策略分析
AA Part4 merge_attn_states Triton实现
AB Part4 ViT与LLM架构差异
AC Part4 全局术语表
AD Part4 Triton Decode Kernel完整追踪
AE Part4 reshape_and_cache_flash分析
AF Part4 DCP AllToAll通信模式
AG Part4 Prefix Cache生命周期
AH Part4 Chunked Prefill内存管理
AI Part4 TurboQuant精度影响分析
AJ Part4 Scheduler→Attention完整调用链
AK Part4 完整术语表(统一版)
AL Part4 Triton Prefill分块策略详解
AM Part4 完整附录索引
Logo

免费领 100 小时云算力,进群参与显卡、AI PC 幸运抽奖

更多推荐