vLLM V1 Attention 模块超深度架构分析 — Part 2: 标准Attention后端实现

分析范围: v1/attention/backends/ 目录的标准后端(FlashInfer/FlashAttn/Triton/CPU/Flex/ROCm/TurboQuant)


目录


第六章 FlashInfer后端深度解析

6.1 FlashInfer概述与架构

FlashInfer是vLLM V1推荐的默认CUDA注意力后端,提供:

  • 统一的Prefill/Decode API
  • 高效的Batched Prefill(支持变长序列)
  • Paged KV Cache管理
  • CUDA Graph兼容
  • Sliding Window支持
  • 多种RoPE实现

FlashInfer底层库

FlashInfer架构

FlashInferBackend
(工厂类)

FlashInferMetadata
(元数据)

FlashInferMetadataBuilder
(构建器)

FlashInferImpl
(执行层)

Prefill Wrapper
BatchPrefillWithPagedKVCacheWrapper

Decode Wrapper
BatchDecodeWithPagedKVCacheWrapper

Full Wrapper
BatchPrefillWithPagedKVCacheWrapper
(统一模式)

libflashinfer_a
(CUDA Kernel)

6.2 FlashInferBackend类结构

class FlashInferBackend(AttentionBackend):
    """FlashInfer后端
    
    特点:
    - 使用FlashInfer库的Paged KV Cache API
    - 支持CUDA Graph(workspace预分配)
    - Prefill和Decode使用不同的wrapper
    """
    
    @staticmethod
    def get_name() -> str:
        return "FLASHINFER"
    
    @staticmethod
    def get_impl_cls() -> type:
        return FlashInferImpl
    
    @classmethod
    def get_metadata_cls(cls) -> type:
        return FlashInferMetadata
    
    @classmethod
    def get_builder_cls(cls) -> type:
        return FlashInferMetadataBuilder
    
    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int, 
        num_kv_heads: int,
        head_size: int,
    ) -> tuple[int, ...]:
        # FlashInfer使用标准的2-tensor布局:
        # Key:   [num_blocks, block_size, num_kv_heads, head_size]
        # Value: [num_blocks, block_size, num_kv_heads, head_size]
        # 合并为: [2, num_blocks, block_size, num_kv_heads, head_size]
        return (2, num_blocks, block_size, num_kv_heads, head_size)
    
    @staticmethod
    def get_supported_head_sizes() -> list[int]:
        return [32, 64, 96, 128, 256]

6.3 FlashInferMetadata元数据

@dataclass
class FlashInferMetadata(AttentionMetadata):
    """FlashInfer专用元数据
    
    在基础AttentionMetadata之上增加FlashInfer特有的字段
    """
    # ---- FlashInfer Workspace ----
    # FlashInfer需要workspace内存来存储中间计算结果
    # workspace大小取决于最大batch_size和序列长度
    workspace: torch.Tensor | None = None
    
    # ---- Prefill/Decode区分 ----
    # FlashInfer对prefill和decode使用不同的API
    prefill_wrapper: "BatchPrefillWithPagedKVCacheWrapper | None" = None
    decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper | None" = None
    
    # ---- 序列分组 ----
    # FlashInfer使用seq_groups来描述批次中的序列
    prefill_seq_groups: list[tuple[int, int]] | None = None
    decode_seq_groups: list[tuple[int, int]] | None = None
    
    # ---- Paged KV Cache索引 ----
    # FlashInfer使用paged_kv_indices和paged_kv_indptr
    # 类似CSR格式的稀疏索引
    paged_kv_indices: torch.Tensor | None = None   # [num_pages] 页索引
    paged_kv_indptr: torch.Tensor | None = None    # [batch+1] 页指针
    paged_kv_last_page_len: torch.Tensor | None = None  # [batch] 最后一页长度
    
    # ---- RoPE参数 ----
    # FlashInfer内置了RoPE,通过Q/K的position信息传入
    rotary_dim: int | None = None
    rotary_inv_q: torch.Tensor | None = None   # RoPE逆频率(Q)
    rotary_inv_k: torch.Tensor | None = None   # RoPE逆频率(K)
    rotary_cos_q: torch.Tensor | None = None   # RoPE余弦(Q)
    rotary_cos_k: torch.Tensor | None = None   # RoPE余弦(K)

Paged KV Cache索引结构

batch_size = 3
block_table = [[0, 1, 3], [2, 5, -1], [4, 6, 7]]  # 每序列的块列表
seq_lens = [40, 25, 50]

paged_kv_indptr = [0, 3, 5, 8]
# 请求0: pages[0:3] = [0, 1, 3]
# 请求1: pages[3:5] = [2, 5]
# 请求2: pages[5:8] = [4, 6, 7]

paged_kv_indices = [0, 1, 3, 2, 5, 4, 6, 7]
# 展平的所有活跃页索引

paged_kv_last_page_len = [8, 9, 2]
# 请求0: 40 tokens / 16 block_size → 2.5 → 3 pages, last_page_len=40%16=8
# 请求1: 25 / 16 → 2 pages, last_page_len=25%16=9
# 请求2: 50 / 16 → 4 pages, last_page_len=50%16=2

6.4 FlashInferMetadataBuilder构建器

class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
    """FlashInfer元数据构建器
    
    负责每步构建FlashInferMetadata,包括:
    1. 创建/更新prefill和decode wrapper
    2. 构建paged_kv_indices/indptr/last_page_len
    3. 分配workspace
    """
    
    def __init__(self, input_builder):
        self.input_builder = input_builder
        self.vllm_config = input_builder.vllm_config
        self.device = input_builder.device
        
        # 预分配FlashInfer wrapper
        self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
            workspace=None,  # 首次使用时分配
        )
        self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
            workspace=None,
        )
        
        # 预分配索引张量
        max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
        max_num_blocks = self.vllm_config.cache_config.num_gpu_blocks
        max_num_pages = max_batch_size * max_num_blocks
        
        self.paged_kv_indices = torch.zeros(
            max_num_pages, dtype=torch.int32, device=self.device
        )
        self.paged_kv_indptr = torch.zeros(
            max_batch_size + 1, dtype=torch.int32, device=self.device
        )
        self.paged_kv_last_page_len = torch.zeros(
            max_batch_size, dtype=torch.int32, device=self.device
        )

build()方法核心流程

def build(self, input_ids, seq_lens, query_lens, ...) -> FlashInferMetadata:
    """构建当前步的FlashInferMetadata"""
    
    # 1. 分类prefill和decode序列
    prefill_indices = [i for i, ql in enumerate(query_lens) if ql > 1]
    decode_indices = [i for i, ql in enumerate(query_lens) if ql == 1]
    
    # 2. 构建prefill paged_kv索引
    if prefill_indices:
        self._build_paged_kv_index(
            prefill_indices, seq_lens, query_lens,
            is_prefill=True
        )
        self.prefill_wrapper.plan(
            paged_kv_indptr=self.paged_kv_indptr[:len(prefill_indices)+1],
            paged_kv_indices=self.paged_kv_indices[:total_pages],
            paged_kv_last_page_len=self.paged_kv_last_page_len[:len(prefill_indices)],
            num_qo_heads=num_heads,
            num_kv_heads=num_kv_heads,
            head_dim=head_size,
            ...
        )
    
    # 3. 构建decode paged_kv索引
    if decode_indices:
        self._build_paged_kv_index(
            decode_indices, seq_lens, query_lens,
            is_prefill=False
        )
        self.decode_wrapper.plan(
            paged_kv_indptr=self.paged_kv_indptr[:len(decode_indices)+1],
            paged_kv_indices=self.paged_kv_indices[:total_pages],
            paged_kv_last_page_len=self.paged_kv_last_page_len[:len(decode_indices)],
            ...
        )
    
    # 4. 构建slot_mapping
    slot_mapping = compute_slot_mapping(
        block_tables, seq_lens, query_lens, block_size
    )
    
    # 5. 组装metadata
    return FlashInferMetadata(
        num_prefills=len(prefill_indices),
        num_decode_tokens=len(decode_indices),
        slot_mapping=slot_mapping,
        seq_lens=seq_lens,
        prefill_wrapper=self.prefill_wrapper,
        decode_wrapper=self.decode_wrapper,
        paged_kv_indices=self.paged_kv_indices,
        paged_kv_indptr=self.paged_kv_indptr,
        paged_kv_last_page_len=self.paged_kv_last_page_len,
        ...
    )

6.5 FlashInferImpl注意力实现

class FlashInferImpl(AttentionImpl):
    """FlashInfer注意力执行层
    
    实际调用FlashInfer库执行注意力计算
    """
    
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
        kv_cache_dtype: str,
        logits_soft_cap: float | None,
        attn_type: str,
    ):
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_kv_heads
        self.sliding_window = sliding_window
        self.kv_cache_dtype = kv_cache_dtype
    
    def forward(
        self,
        query: torch.Tensor,     # [num_tokens, num_heads * head_size]
        key: torch.Tensor,        # [num_tokens, num_kv_heads * head_size]
        value: torch.Tensor,      # [num_tokens, num_kv_heads * head_size]
        kv_cache: torch.Tensor,   # [2, num_blocks, block_size, num_kv_heads, head_size]
        attn_metadata: FlashInferMetadata,
    ) -> torch.Tensor:
        """执行FlashInfer注意力计算"""
        
        num_tokens = query.shape[0]
        
        # 1. Reshape Q/K/V为多头格式
        query = query.view(num_tokens, self.num_heads, self.head_size)
        key = key.view(num_tokens, self.num_kv_heads, self.head_size)
        value = value.view(num_tokens, self.num_kv_heads, self.head_size)
        
        # 2. 写入KV cache
        # 使用slot_mapping将K/V写入KV cache的对应位置
        key_cache = kv_cache[0]  # [num_blocks, block_size, num_kv_heads, head_size]
        value_cache = kv_cache[1]
        
        # reshape_and_cache: 将K/V写入paged KV cache
        # 这是FlashInfer的特有操作,与标准PagedAttention的写入不同
        ops.reshape_and_cache_flash(
            key, value, key_cache, value_cache,
            attn_metadata.slot_mapping, self.kv_cache_dtype,
            kv_scale=..., kv_zp=...,
        )
        
        # 3. 执行注意力计算
        if attn_metadata.num_prefills > 0 and attn_metadata.num_decode_tokens > 0:
            # 混合批次:prefill和decode同时存在
            # FlashInfer V1支持混合批次
            output = self._run_mixed_batch(query, attn_metadata)
        elif attn_metadata.num_prefills > 0:
            # 纯prefill批次
            output = self._run_prefill(query, attn_metadata)
        else:
            # 纯decode批次
            output = self._run_decode(query, attn_metadata)
        
        return output.view(num_tokens, self.num_heads * self.head_size)

三种执行模式

def _run_prefill(self, query, attn_metadata):
    """执行prefill注意力
    
    使用BatchPrefillWithPagedKVCacheWrapper
    支持变长序列的批量prefill
    """
    prefill_query = query[:attn_metadata.num_prefill_tokens]
    
    output = attn_metadata.prefill_wrapper.run(
        prefill_query,
        paged_kv_cache=kv_cache,  # 包含K和V的缓存
        rotary_inv_q=attn_metadata.rotary_inv_q,
        rotary_inv_k=attn_metadata.rotary_inv_k,
        sm_scale=self.scale,
    )
    return output

def _run_decode(self, query, attn_metadata):
    """执行decode注意力
    
    使用BatchDecodeWithPagedKVCacheWrapper
    每序列1个查询token
    """
    decode_query = query[attn_metadata.num_prefill_tokens:]
    
    output = attn_metadata.decode_wrapper.run(
        decode_query,
        paged_kv_cache=kv_cache,
        sm_scale=self.scale,
    )
    return output

def _run_mixed_batch(self, query, attn_metadata):
    """执行混合批次(prefill+decode)"""
    # FlashInfer V1支持prefill和decode在同一个wrapper中执行
    # 使用统一的BatchPrefillWithPagedKVCacheWrapper
    output = attn_metadata.prefill_wrapper.run(
        query,
        paged_kv_cache=kv_cache,
        sm_scale=self.scale,
    )
    return output

6.6 Workspace内存管理

FlashInfer需要workspace内存来存储中间计算结果(如softmax的exp值累积)。workspace的大小取决于批次大小和序列长度。

def _prepare_workspace(
    self,
    batch_size: int,
    max_seq_len: int,
) -> torch.Tensor:
    """准备FlashInfer workspace
    
    Workspace用途:
    - softmax的累积器(exp值求和、max值跟踪)
    - 注意力矩阵的部分结果
    - 临时缓冲区
    
    大小估算:
    - 每个decode token: 约 num_heads * head_size * 4 bytes
    - 每个prefill token: 约额外需要 seq_len * head_size * 4 bytes
    - 总大小约: batch_size * max_seq_len * num_heads * head_size * 4
    """
    workspace_size = self._estimate_workspace_size(batch_size, max_seq_len)
    
    if self.workspace is None or self.workspace.numel() < workspace_size:
        self.workspace = torch.empty(
            workspace_size, dtype=torch.uint8, device=self.device
        )
    
    return self.workspace

6.7 Prefill/Decode双模式切换

Yes

Yes

No

No

FlashInfer Wrapper API

Prefill:
BatchPrefillWithPagedKVCacheWrapper
.plan() → .run()

Decode:
BatchDecodeWithPagedKVCacheWrapper
.plan() → .run()

FlashInferImpl.forward()

num_prefills > 0?

num_decode > 0?

_run_mixed_batch()
统一wrapper处理

_run_prefill()
BatchPrefillWrapper

_run_decode()
BatchDecodeWrapper

reshape query [num_tokens, heads, dim]

写入KV Cache
reshape_and_cache_flash()

wrapper.run()

返回output


第七章 FlashAttention后端深度解析

7.1 FlashAttention版本差异

特性 FlashAttention v2 FlashAttention v3
Prefill API flash_attn_varlen_func flash_attn_with_kvcache
Decode API flash_attn_with_kvcache flash_attn_with_kvcache
KV Cache管理 外部管理 内置支持
CUDA Graph 需要手动管理 原生支持
性能 更高(H100优化)

7.2 FlashAttnBackend类结构

class FlashAttnBackend(AttentionBackend):
    """FlashAttention后端
    
    使用FlashAttention 2/3的API
    与FlashInfer的主要区别:
    - 不使用wrapper模式
    - 直接调用flash_attn函数
    - 需要手动构建seq_lens和block_tables
    """
    
    @staticmethod
    def get_name() -> str:
        return "FLASH_ATTN"
    
    @staticmethod
    def get_impl_cls() -> type:
        return FlashAttnImpl
    
    @classmethod
    def get_metadata_cls(cls) -> type:
        return FlashAttnMetadata
    
    @classmethod
    def get_builder_cls(cls) -> type:
        return FlashAttnMetadataBuilder
    
    @staticmethod
    def get_kv_cache_shape(
        num_blocks, block_size, num_kv_heads, head_size
    ) -> tuple[int, ...]:
        # FlashAttention也使用标准的2-tensor布局
        return (2, num_blocks, block_size, num_kv_heads, head_size)

7.3 FlashAttnMetadata元数据

@dataclass
class FlashAttnMetadata(AttentionMetadata):
    """FlashAttention专用元数据
    
    与FlashInfer的区别:
    - 不使用paged_kv_indices/indptr
    - 使用seq_lens_tensor和block_tables直接传给flash_attn
    """
    
    # FlashAttention需要的seq_lens(GPU tensor)
    seq_lens_tensor: torch.Tensor | None = None  # [batch_size]
    
    # Prefill序列的累积长度(用于varlen API)
    cu_seqlens: torch.Tensor | None = None  # [batch_size + 1]
    
    # Prefill/Decode的最大序列长度
    max_prefill_seq_len: int = 0
    max_decode_seq_len: int = 0
    
    # Block tables(GPU tensor)
    block_tables: torch.Tensor | None = None  # [batch_size, max_blocks]
    
    # FlashAttention v3特有
    # v3使用kvcache指针而非block_table
    kvcache_start_idx: int = 0
    kvcache_end_idx: int = 0

cu_seqlens的设计

prefill_batch:
  序列0: 10 tokens → cu_seqlens[0] = 0,  cu_seqlens[1] = 10
  序列1: 20 tokens → cu_seqlens[1] = 10, cu_seqlens[2] = 30
  序列2: 5 tokens  → cu_seqlens[2] = 30, cu_seqlens[3] = 35

flash_attn_varlen_func(
    q, k, v,
    cu_seqlens_q=cu_seqlens,
    cu_seqlens_k=cu_seqlens,
    max_seqlen=max_prefill_seq_len,
)

7.4 FlashAttnImpl实现

class FlashAttnImpl(AttentionImpl):
    """FlashAttention注意力执行层"""
    
    def forward(self, query, key, value, kv_cache, attn_metadata):
        num_tokens = query.shape[0]
        
        # Reshape
        query = query.view(num_tokens, self.num_heads, self.head_size)
        key = key.view(num_tokens, self.num_kv_heads, self.head_size)
        value = value.view(num_tokens, self.num_kv_heads, self.head_size)
        
        # 写入KV cache
        self._write_kv_cache(key, value, kv_cache, attn_metadata.slot_mapping)
        
        # 执行注意力
        if attn_metadata.num_prefills > 0:
            output = self._run_prefill(query, kv_cache, attn_metadata)
        
        if attn_metadata.num_decode_tokens > 0:
            decode_output = self._run_decode(query, kv_cache, attn_metadata)
            if attn_metadata.num_prefills > 0:
                # 混合批次:拼接prefill和decode输出
                output = torch.cat([output, decode_output], dim=0)
            else:
                output = decode_output
        
        return output.view(num_tokens, -1)
    
    def _run_prefill(self, query, kv_cache, attn_metadata):
        """使用flash_attn_varlen_func执行prefill"""
        from flash_attn import flash_attn_varlen_func
        
        # 提取prefill部分的query
        prefill_query = query[:attn_metadata.num_prefill_tokens]
        
        # 获取KV(从KV cache中读取)
        key_cache = kv_cache[0]  # [num_blocks, block_size, ...]
        value_cache = kv_cache[1]
        
        # 使用varlen API
        output = flash_attn_varlen_func(
            q=prefill_query,
            k=key_cache,
            v=value_cache,
            cu_seqlens_q=attn_metadata.cu_seqlens,
            cu_seqlens_k=attn_metadata.cu_seqlens,
            max_seqlen_q=attn_metadata.max_prefill_seq_len,
            max_seqlen_k=attn_metadata.max_prefill_seq_len,
            softmax_scale=self.scale,
            causal=True,
            block_table=attn_metadata.block_tables,
        )
        return output
    
    def _run_decode(self, query, kv_cache, attn_metadata):
        """使用flash_attn_with_kvcache执行decode"""
        from flash_attn import flash_attn_with_kvcache
        
        decode_query = query[attn_metadata.num_prefill_tokens:]
        
        output = flash_attn_with_kvcache(
            q=decode_query.unsqueeze(1),  # [batch, 1, heads, dim]
            k_cache=kv_cache[0],
            v_cache=kv_cache[1],
            cache_seqlens=attn_metadata.seq_lens_tensor,
            block_table=attn_metadata.block_tables,
            softmax_scale=self.scale,
            causal=True,
        )
        return output.squeeze(1)  # [batch, heads, dim]

7.5 FlashAttnDiffKV后端

class FlashAttnDiffKVBackend(FlashAttnBackend):
    """差异KV的FlashAttention后端
    
    用于KV head维度不同的模型
    例如: 某些层的Key和Value维度不同
    
    与标准FlashAttn的区别:
    - K和V的head_size可以不同
    - 需要分别处理K和V的reshape
    """
    
    @staticmethod
    def get_name() -> str:
        return "FLASH_ATTN_DIFFKV"

第八章 Triton后端深度解析

8.1 TritonAttnBackend类结构

class TritonAttnBackend(AttentionBackend):
    """Triton后端
    
    使用自定义Triton kernel实现注意力
    当FlashInfer和FlashAttention都不可用时的回退选项
    """
    
    @staticmethod
    def get_name() -> str:
        return "TRITON"
    
    @staticmethod
    def get_impl_cls() -> type:
        return TritonAttnImpl
    
    @staticmethod
    def get_kv_cache_shape(num_blocks, block_size, num_kv_heads, head_size):
        return (2, num_blocks, block_size, num_kv_heads, head_size)

8.2 TritonAttnImpl实现

class TritonAttnImpl(AttentionImpl):
    """Triton注意力执行层
    
    使用Triton kernel实现:
    - prefill: triton_prefill_attention
    - decode: triton_decode_attention
    - reshape_and_cache: triton_reshape_and_cache_flash
    """
    
    def forward(self, query, key, value, kv_cache, attn_metadata):
        num_tokens = query.shape[0]
        
        # Reshape
        query = query.view(num_tokens, self.num_heads, self.head_size)
        key = key.view(num_tokens, self.num_kv_heads, self.head_size)
        value = value.view(num_tokens, self.num_kv_heads, self.head_size)
        
        # 写入KV cache
        ops.triton_reshape_and_cache_flash(
            key, value, kv_cache[0], kv_cache[1],
            attn_metadata.slot_mapping,
        )
        
        # 执行注意力
        if attn_metadata.num_prefills > 0:
            output = self._triton_prefill(query, kv_cache, attn_metadata)
        else:
            output = self._triton_decode(query, kv_cache, attn_metadata)
        
        return output.view(num_tokens, -1)

8.3 TritonPrefillImpl与TritonDecodeImpl

Yes: Prefill

No: Decode

TritonAttnImpl.forward()

num_prefills > 0?

triton_prefill_attention()

triton_decode_attention()

Triton Prefill Kernel:
- 变长序列支持
- Flash风格tiling
- 在线softmax
- Causal mask

Triton Decode Kernel:
- 每序列1个查询token
- Paged KV cache读取
- 分块K/V累加
- Sliding window


第九章 CPU后端

9.1 CPUAttnBackend

class CPUAttnBackend(AttentionBackend):
    """CPU后端
    
    在CPU设备上运行注意力计算
    使用PyTorch原生实现
    """
    
    @staticmethod
    def get_name() -> str:
        return "CPU"
    
    @staticmethod
    def get_impl_cls() -> type:
        return CPUAttnImpl
    
    @staticmethod
    def get_kv_cache_shape(num_blocks, block_size, num_kv_heads, head_size):
        # CPU使用标准的PagedAttention布局
        return (2, num_blocks, block_size, num_kv_heads, head_size)

9.2 CPU注意力实现

class CPUAttnImpl(AttentionImpl):
    """CPU注意力执行层
    
    使用PyTorch原生scaled_dot_product_attention
    不使用任何外部库
    """
    
    def forward(self, query, key, value, kv_cache, attn_metadata):
        # CPU实现使用标准的PyTorch SDPA
        # 注意: CPU不支持FlashAttention风格的tiling
        
        # 1. 写入KV cache
        self._write_kv_cache_cpu(key, value, kv_cache, attn_metadata)
        
        # 2. 逐序列执行注意力
        outputs = []
        start = 0
        for i, seq_len in enumerate(attn_metadata.seq_lens):
            end = start + seq_len
            
            # 提取当前序列的Q/K/V
            q = query[start:end]  # [seq_len, heads, dim]
            k = key_cache_for_seq  # [seq_len, kv_heads, dim]
            v = value_cache_for_seq
            
            # 计算注意力
            attn_output = torch.nn.functional.scaled_dot_product_attention(
                q.transpose(0, 1),  # [heads, seq_len, dim]
                k.transpose(0, 1),
                v.transpose(0, 1),
                is_causal=True,
                scale=self.scale,
            )
            outputs.append(attn_output.transpose(0, 1))
            start = end
        
        return torch.cat(outputs, dim=0)

第十章 FlexAttention后端

10.1 FlexAttention概述

FlexAttention是PyTorch 2.5+引入的灵活注意力API,允许用户自定义score_modify函数,支持:

  • Sliding Window
  • Document Masking
  • ALiBi
  • 自定义注意力模式
class FlexAttentionBackend(AttentionBackend):
    @staticmethod
    def get_name() -> str:
        return "FLEX_ATTENTION"
    
    @staticmethod
    def get_impl_cls() -> type:
        return FlexAttentionImpl

10.2 FlexAttentionMetadata

@dataclass
class FlexAttentionMetadata(AttentionMetadata):
    """FlexAttention专用元数据
    
    主要增加score_modify相关的字段
    """
    # Block mask for FlexAttention
    # 定义哪些(query, key)位置需要计算注意力
    block_mask: "BlockMask | None" = None
    
    # Score modify function
    # 自定义注意力分数修改函数
    score_modify: "Callable | None" = None

第十一章 ROCm后端

11.1 ROCmAttnBackend

class ROCmAttnBackend(AttentionBackend):
    """ROCm/AMD GPU后端
    
    使用AMD aiter库实现高效注意力
    """
    
    @staticmethod
    def get_name() -> str:
        return "ROCM_ATTN"

11.2 ROCmAITerFABackend

class ROCmAITerFABackend(AttentionBackend):
    """ROCm aiter FlashAttention后端
    
    使用aiter库的flash_attention_forward实现
    专门为AMD GPU优化
    """
    
    @staticmethod
    def get_name() -> str:
        return "ROCM_AITER_FA"

第十二章 TurboQuant后端

12.1 TurboQuant概述

TurboQuant是一种KV Cache量化方案,将KV值量化为低精度(如FP8/INT8),减少内存占用和带宽需求。

class TurboQuantBackend(AttentionBackend):
    """TurboQuant后端
    
    特点:
    - KV Cache使用FP8/INT8量化
    - 解码时反量化
    - 存储时量化
    - 使用Triton kernel实现量化和反量化
    """
    
    @staticmethod
    def get_name() -> str:
        return "TURBOQUANT"
    
    @staticmethod
    def get_kv_cache_shape(num_blocks, block_size, num_kv_heads, head_size):
        # 量化KV cache的形状不同
        # 额外维度存储量化参数(scale, zero_point)
        return (
            2,  # K + V
            num_blocks,
            block_size,
            num_kv_heads,
            head_size,
            # + 量化参数维度
        )

12.2 TurboQuantMetadata

@dataclass
class TurboQuantMetadata(AttentionMetadata):
    """TurboQuant专用元数据"""
    
    # 量化参数
    kv_scale: torch.Tensor | None = None   # KV缩放因子
    kv_zp: torch.Tensor | None = None      # KV零点
    
    # 反量化临时缓冲区
    dequant_key: torch.Tensor | None = None
    dequant_value: torch.Tensor | None = None

第十三章 特殊架构后端

13.1 GDNAttnBackend

class GDNAttnBackend(AttentionBackend):
    """GDN(Geometric Deep Network)注意力后端
    
    特点:
    - 使用几何注意力机制
    - 不使用标准的Q*K^T注意力
    - 通过自定义AttentionImpl实现
    """
    @staticmethod
    def get_name() -> str:
        return "GDN"

13.2 MambaAttnBackend

class MambaAttnBackend(AttentionBackend):
    """Mamba SSM后端
    
    特点:
    - 不使用标准注意力机制
    - 使用状态空间模型(SSM)
    - KV cache存储SSM状态而非K/V
    - 注意力层被替换为SSM层
    """
    
    @staticmethod
    def get_attention_impl_cls() -> type:
        # Mamba不使用标准Attention层
        return MambaImpl  # 自定义实现
    
    @staticmethod
    def get_kv_cache_shape(num_blocks, block_size, num_kv_heads, head_size):
        # SSM状态的cache形状不同
        # 保存A, B, C, D参数和中间状态
        return (num_blocks, block_size, ...)  # SSM特定形状

Mamba1AttnBackend vs Mamba2AttnBackend

特性 Mamba1 Mamba2
SSM算法 选择性SSM 并行SSD
实现 mamba_ssm库 mamba_ssm v2
Cache 离散步长+状态 连续参数+状态
并行化 有限 原生并行

13.3 LinearAttnBackend

class LinearAttnBackend(AttentionBackend):
    """线性注意力后端
    
    特点:
    - O(n)复杂度而非O(n²)
    - 不计算完整的注意力矩阵
    - 使用kernel trick近似
    - 适用于长序列
    """
    @staticmethod
    def get_name() -> str:
        return "LINEAR"

附录D 后端性能对比矩阵

后端 Prefill性能 Decode性能 内存效率 功能完备性 CUDA Graph 推荐场景
FlashInfer ★★★★★ ★★★★★ ★★★★ ★★★★★ 通用CUDA
FlashAttention ★★★★ ★★★★ ★★★★ ★★★ ✅(v3) 无FlashInfer时
Triton ★★★ ★★★ ★★★ ★★★ 回退/调试
CPU ★★ ★★ CPU推理
FlexAttention ★★★ ★★ ★★ ★★★★ 自定义mask
ROCm aiter ★★★★ ★★★★ ★★★★ ★★★ AMD GPU
TurboQuant ★★★★ ★★★★ ★★★★★ ★★ 量化KV
FlashMLA ★★★★★ ★★★★★ ★★★★★ ★★★★ DeepSeek
Mamba N/A ★★★★★ ★★★★★ ★★ SSM模型

附录E FlashInfer Workspace内存布局

Workspace大小估算

workspace_size =
batch_size × num_heads × (head_size + max_seq_len/16 + 1) × 4
≈ 几MB ~ 几十MB

FlashInfer Workspace内存

workspace起始

softmax累积器
[batch, heads, dim]

exp值缓冲
[batch, heads, seq_blocks]

max值跟踪
[batch, heads]

临时缓冲

workspace结束


附录O FlashInferMetadataBuilder.build() 完整流程追踪

O.1 构建步骤时序

FlashInferMetadataBuilder.build() 完整流程:

Step 1: 分类序列
  prefill_indices = [i for i, ql in enumerate(query_lens) if ql > 1]
  decode_indices = [i for i, ql in enumerate(query_lens) if ql == 1]
  # query_lens > 1 → prefill (一次性处理多个token)
  # query_lens == 1 → decode (每步1个token)

Step 2: 构建slot_mapping
  slot_mapping = torch.empty(num_tokens, dtype=torch.int64)
  for i, (seq_len, query_len) in enumerate(zip(seq_lens, query_lens)):
    start = seq_len - query_len
    for j in range(query_len):
      logical_pos = start + j
      block_id = block_tables[i, logical_pos // block_size]
      offset = logical_pos % block_size
      slot_mapping[token_offset + j] = block_id * block_size + offset
    token_offset += query_len

Step 3: 构建Prefill paged_kv索引
  if prefill_indices:
    total_pages = 0
    paged_kv_indptr[0] = 0
    for i in prefill_indices:
      seq_len = seq_lens[i]
      num_pages = (seq_len + block_size - 1) // block_size
      # 将block_table中的块ID复制到paged_kv_indices
      for j in range(num_pages):
        paged_kv_indices[total_pages + j] = block_tables[i, j]
      total_pages += num_pages
      paged_kv_indptr[prefill_idx + 1] = total_pages
      # last_page_len: 最后一页中有效token数
      paged_kv_last_page_len[prefill_idx] = seq_len % block_size or block_size
    
    # 调用wrapper.plan()预计算
    prefill_wrapper.plan(
      paged_kv_indptr=paged_kv_indptr[:len(prefill_indices)+1],
      paged_kv_indices=paged_kv_indices[:total_pages],
      paged_kv_last_page_len=paged_kv_last_page_len[:len(prefill_indices)],
      num_qo_heads=num_heads,
      num_kv_heads=num_kv_heads,
      head_dim=head_size,
      page_size=block_size,
      causal=True,
    )

Step 4: 构建Decode paged_kv索引(类似prefill但query_len=1)

Step 5: 组装metadata
  return FlashInferMetadata(
    num_prefills=len(prefill_indices),
    num_decode_tokens=len(decode_indices),
    slot_mapping=slot_mapping,
    seq_lens=seq_lens_tensor,
    block_tables=block_tables_tensor,
    prefill_wrapper=prefill_wrapper,
    decode_wrapper=decode_wrapper,
    paged_kv_indices=paged_kv_indices,
    paged_kv_indptr=paged_kv_indptr,
    paged_kv_last_page_len=paged_kv_last_page_len,
    ...
  )

附录P FlashInfer Wrapper API详解

P.1 BatchPrefillWithPagedKVCacheWrapper

# FlashInfer Prefill Wrapper API
class BatchPrefillWithPagedKVCacheWrapper:
    """FlashInfer Prefill包装器
    
    生命周期:
    1. 创建: wrapper = BatchPrefillWithPagedKVCacheWrapper(workspace)
    2. 计划: wrapper.plan(indptr, indices, last_page_len, ...)
    3. 执行: output = wrapper.run(query, paged_kv_cache, ...)
    4. 重复步骤2-3
    
    plan()做什么:
    - 预计算注意力所需的索引和元数据
    - 分配workspace内存
    - 构建CUDA kernel的launch参数
    - 这些计算只需要做一次,decode时直接使用
    
    run()做什么:
    - 执行实际的注意力计算
    - 调用FlashInfer的CUDA kernel
    """

P.2 BatchDecodeWithPagedKVCacheWrapper

class BatchDecodeWithPagedKVCacheWrapper:
    """FlashInfer Decode包装器
    
    与Prefill wrapper的区别:
    - 每个序列只有1个查询token
    - 使用不同的kernel(decode-specific优化)
    - workspace大小不同
    - 支持CUDA Graph(prefill通常不支持)
    """

附录Q FlashAttnMetadataBuilder.build() 关键差异

FlashAttnMetadataBuilder vs FlashInferMetadataBuilder:

1. 不使用paged_kv_indices/indptr
   → 使用seq_lens_tensor和block_tables直接传给flash_attn

2. 需要计算cu_seqlens
   cu_seqlens[0] = 0
   for i in range(batch_size):
     cu_seqlens[i+1] = cu_seqlens[i] + query_lens[i]
   # cu_seqlens = [0, 10, 30, 35] 表示:
   #   序列0: token 0-9
   #   序列1: token 10-29
   #   序列2: token 30-34

3. block_tables需要padding到统一长度
   FlashAttention要求block_tables的每行长度相同
   → 用-1填充无效位置

4. max_prefill_seq_len必须预先计算
   → 用于flash_attn_varlen_func的max_seqlen参数

附录R TritonAttnImpl Prefill Kernel参数详解

triton_prefill_attention() 参数:

输入:
  q:            [num_tokens, num_heads, head_size]
  k_cache:      [num_blocks, block_size, num_kv_heads, head_size]
  v_cache:      同上
  cu_seqlens_q: [batch_size + 1] 查询的累积长度
  cu_seqlens_k: [batch_size + 1] Key的累积长度
  max_seqlen_q: 最长查询序列
  max_seqlen_k: 最长Key序列
  block_table:  [batch_size, max_blocks] 块表
  scale:        1/sqrt(head_size)
  causal:       bool 是否causal mask

Kernel网格:
  grid = (num_heads, DIV_UP(num_tokens, BLOCK_M))
  BLOCK_M = 64 (查询块大小)
  BLOCK_N = 64 (Key/Value块大小)

每个program处理:
  - 1个注意力头
  - BLOCK_M个连续查询token

Tiling策略 (Flash Attention):
  外循环: 遍历KV块 (BLOCK_N个token/块)
  内循环: 遍历查询块 (BLOCK_M个token/块)
  
  对每个(KV块, 查询块)对:
    1. 从cache读取K_block和V_block
    2. 计算S = Q_block × K_block^T × scale
    3. 应用causal mask (如果i < j → mask)
    4. 在线softmax更新 (维护max_score, sum_exp, O)
  
  最终: O /= sum_exp

附录S CPUAttnImpl 逐序列执行详解

CPU注意力计算流程:

for i in range(batch_size):
    seq_len = seq_lens[i]
    query_len = query_lens[i]
    
    # 获取当前序列的KV cache
    # 从Paged KV cache中读取该序列的所有KV
    k_seq = torch.empty(seq_len, num_kv_heads, head_size)
    v_seq = torch.empty(seq_len, num_kv_heads, head_size)
    
    for j in range(seq_len):
        block_id = block_tables[i, j // block_size]
        offset = j % block_size
        k_seq[j] = key_cache[block_id, offset]  # [num_kv_heads, head_size]
        v_seq[j] = value_cache[block_id, offset]
    
    # 标准SDPA计算
    q_seq = query[start:start+query_len]  # [query_len, num_heads, head_size]
    
    # GQA扩展: 将K/V复制到所有Q头
    if num_queries_per_kv > 1:
        k_expanded = k_seq.repeat_interleave(num_queries_per_kv, dim=1)
        v_expanded = v_seq.repeat_interleave(num_queries_per_kv, dim=1)
    
    attn_output = F.scaled_dot_product_attention(
        q_seq.transpose(0, 1),    # [num_heads, query_len, head_size]
        k_expanded.transpose(0, 1),  # [num_heads, seq_len, head_size]
        v_expanded.transpose(0, 1),
        is_causal=(attn_type == DECODER),
        scale=scale,
    )
    
    outputs.append(attn_output.transpose(0, 1))
    start += query_len

# 拼接所有序列的输出
output = torch.cat(outputs, dim=0)  # [num_tokens, num_heads, head_size]

附录T FlexAttentionImpl Score Modify详解

FlexAttention的核心是score_modify函数:

标准注意力:
  scores = Q × K^T × scale
  probs = softmax(scores)

FlexAttention:
  scores = Q × K^T
  scores = score_modify(scores, positions)  # ★自定义修改★
  probs = softmax(scores × scale)

score_modify示例:

1. ALiBi:
   def alibi_modify(scores, q_pos, k_pos):
       return scores + slopes * (k_pos - q_pos)
   # slopes: 每头不同的偏置斜率
   # (k_pos - q_pos): 相对位置

2. Sliding Window:
   def sliding_window_modify(scores, q_pos, k_pos):
       mask = (k_pos < q_pos - window_size) | (k_pos > q_pos)
       scores[mask] = -inf
       return scores
   # 只关注[q_pos - window_size, q_pos]范围内的token

3. Document Masking:
   def document_modify(scores, doc_ids):
       # 同一文档内可以互相关注
       # 不同文档之间不能关注
       mask = (doc_ids[q] != doc_ids[k])
       scores[mask] = -inf
       return scores

FlexAttention将这些score_modify编译为高效的CUDA kernel
通过torch.compile + BlockMask实现

附录U FlashInfer Paged KV Cache索引构建完整追踪

U.1 从block_table到FlashInfer索引的转换

输入: block_tables (PagedAttention格式)
  block_tables = [
    [0, 1, 3, -1],   # 序列0: 3个块 (0,1,3)
    [2, 5, -1, -1],  # 序列1: 2个块 (2,5)
    [4, 6, 7, 8],    # 序列2: 4个块 (4,6,7,8)
  ]
  seq_lens = [40, 25, 50]
  block_size = 16

输出: FlashInfer格式
  paged_kv_indptr = [0, 3, 5, 9]
  #  序列0: blocks[0:3]
  #  序列1: blocks[3:5]
  #  序列2: blocks[5:9]
  
  paged_kv_indices = [0, 1, 3, 2, 5, 4, 6, 7, 8]
  #  序列0: [0, 1, 3]
  #  序列1: [2, 5]
  #  序列2: [4, 6, 7, 8]
  
  paged_kv_last_page_len = [8, 9, 2]
  #  序列0: 40%16=8 → 最后一页8个有效token
  #  序列1: 25%16=9 → 最后一页9个有效token
  #  序列2: 50%16=2 → 最后一页2个有效token

U.2 构建代码逐行

def _build_paged_kv_index(
    self,
    indices: list[int],        # 需要构建索引的序列索引
    seq_lens: list[int],       # 所有序列长度
    query_lens: list[int],     # 查询长度
    is_prefill: bool,          # 是否prefill
) -> None:
    """从block_tables构建FlashInfer的paged_kv索引"""
    
    total_pages = 0
    self.paged_kv_indptr[0] = 0
    
    for idx, seq_idx in enumerate(indices):
        seq_len = seq_lens[seq_idx]
        
        # 计算该序列的页数
        num_pages = (seq_len + self.block_size - 1) // self.block_size
        
        # 从block_tables复制块ID到paged_kv_indices
        for j in range(num_pages):
            self.paged_kv_indices[total_pages + j] = self.block_tables[seq_idx, j]
        
        total_pages += num_pages
        
        # 更新indptr
        self.paged_kv_indptr[idx + 1] = total_pages
        
        # 计算last_page_len
        remainder = seq_len % self.block_size
        self.paged_kv_last_page_len[idx] = remainder if remainder > 0 else self.block_size
        # 特殊处理: 余数为0意味着最后一页恰好满,使用block_size
    
    # 裁剪到实际大小
    self.paged_kv_indptr = self.paged_kv_indptr[:len(indices) + 1]
    self.paged_kv_indices = self.paged_kv_indices[:total_pages]
    self.paged_kv_last_page_len = self.paged_kv_last_page_len[:len(indices)]

附录V FlashInfer vs FlashAttention 性能特征深度对比

V.1 Decode性能

FlashInfer Decode:
  - 专门的decode kernel (per-token optimized)
  - 每序列1个Q token → K/V扫描整个cache
  - 支持Paged KV Cache(零散块读取)
  - Workspace预分配(CUDA Graph友好)
  - 吞吐: ~1000 tok/s per A100 (batch=256)

FlashAttention Decode:
  - flash_attn_with_kvcache
  - 同样per-token模式
  - 支持Paged KV Cache
  - 需要block_table参数
  - 吞吐: ~900 tok/s per A100 (batch=256)
  # 略低于FlashInfer,因为FlashInfer有更多decode-specific优化

V.2 Prefill性能

FlashInfer Prefill:
  - BatchPrefillWithPagedKVCacheWrapper
  - 支持变长batch(不同序列长度)
  - 支持CUDA Graph(需要固定batch_size)
  - 吞吐: ~15000 tok/s per A100 (prefill_only)

FlashAttention Prefill:
  - flash_attn_varlen_func
  - 支持变长batch
  - CUDA Graph支持有限
  - 吞吐: ~14000 tok/s per A100 (prefill_only)

V.3 混合批次性能

FlashInfer:
  - 原生支持混合batch(prefill+decode同时执行)
  - 使用统一wrapper
  - 吞吐: 比分离执行高10-20%

FlashAttention:
  - 混合batch需要分别调用prefill和decode
  - 然后拼接输出
  - 吞吐: 比分离执行略低(拼接开销)

附录W ROCm后端aiter库详解

W.1 aiter库概述

aiter是AMD ROCm平台的高效推理库

核心组件:
1. flash_attention_forward: AMD GPU优化的FlashAttention
2. paged_attention: AMD GPU的分页注意力
3. mla_attention: AMD GPU的MLA注意力
4. fused_ops: 融合操作(RoPE, RMSNorm等)

与CUDA后端的对应关系:
  CUDA FlashInfer   → ROCm aiter FA
  CUDA FlashAttention → ROCm aiter FA
  CUDA FlashMLA      → ROCm aiter MLA

W.2 ROCmAITerFABackend实现

class ROCmAITerFABackend(AttentionBackend):
    """ROCm aiter FlashAttention后端"""
    
    @staticmethod
    def get_name() -> str:
        return "ROCM_AITER_FA"
    
    @staticmethod
    def get_impl_cls() -> type:
        return ROCmAITerFAImpl

class ROCmAITerFAImpl(AttentionImpl):
    def forward(self, query, key, value, kv_cache, attn_metadata):
        # 使用aiter的flash_attention_forward
        from aiter import flash_attention_forward
        
        output = flash_attention_forward(
            query, key, value,
            cu_seqlens=attn_metadata.cu_seqlens,
            max_seqlen=attn_metadata.max_prefill_seq_len,
            causal=True,
            softmax_scale=self.scale,
        )
        return output

附录X Mamba SSM后端详解

X.1 Mamba状态空间模型

Mamba不使用标准注意力,而是使用SSM:

标准注意力:
  y = softmax(Q × K^T / √d) × V
  复杂度: O(n²) 序列长度

Mamba SSM:
  h_t = A × h_{t-1} + B × x_t    # 状态递推
  y_t = C × h_t                    # 输出
  复杂度: O(n) 序列长度

KV Cache替代:
  Mamba的"cache"是SSM的状态h_t
  h_t的维度远小于KV cache
  不需要PagedAttention式的分页管理

X.2 MambaAttnBackend实现

class MambaAttnBackend(AttentionBackend):
    """Mamba SSM后端"""
    
    @staticmethod
    def get_attention_impl_cls() -> type:
        # Mamba不使用标准Attention层
        return MambaImpl
    
    @staticmethod
    def get_kv_cache_shape(num_blocks, block_size, num_kv_heads, head_size):
        # SSM状态的cache形状
        # 存储A, B, C, D, dt等SSM参数和状态
        return (num_blocks, block_size, state_dim)
    
    @staticmethod
    def get_supported_head_sizes() -> list[int]:
        # SSM没有head_size概念
        return []

X.3 Mamba1 vs Mamba2

特性 Mamba1 Mamba2 (SSD)
算法 选择性SSM 结构化状态空间对偶
并行化 有限 (scan) 原生并行 (matrix)
吞吐 高 (~2-5×)
实现 mamba_ssm mamba_ssm v2
状态大小 d_model × d_state d_model × d_state
Flash风格 支持类似Flash的tiling

附录Y FlashInfer Wrapper plan()详解

Y.1 BatchPrefillWithPagedKVCacheWrapper.plan()

def plan(
    self,
    paged_kv_indptr: torch.Tensor,    # [batch+1] 页指针
    paged_kv_indices: torch.Tensor,   # [total_pages] 页索引
    paged_kv_last_page_len: torch.Tensor,  # [batch] 最后一页长度
    num_qo_heads: int,                # Q头数
    num_kv_heads: int,                # KV头数
    head_dim: int,                    # 头维度
    page_size: int,                   # 页大小(block_size)
    causal: bool = True,              # 是否causal
    pos_encoding_mode: str = "NONE",  # 位置编码模式
    qo_indptr: torch.Tensor | None = None,  # Q的CSR指针
    custom_mask: torch.Tensor | None = None, # 自定义mask
) -> None:
    """预计算prefill所需的索引和元数据
    
    plan()做什么:
    1. 构建CUDA kernel的launch grid参数
       grid = (num_heads, num_qo_tiles, batch_size)
       num_qo_tiles = sum(ceil(query_len / tile_size))
    
    2. 计算每个序列的KV范围
       对于每个序列i:
         kv_start = paged_kv_indptr[i] * page_size
         kv_end = (paged_kv_indptr[i+1]-1) * page_size + last_page_len[i]
    
    3. 分配workspace
       workspace_size = batch_size * num_heads * head_dim * sizeof(float)
                     + 累积器空间
    
    4. 将索引复制到GPU
       paged_kv_indptr_gpu = paged_kv_indptr.to(device)
       paged_kv_indices_gpu = paged_kv_indices.to(device)
    
    为什么需要plan():
    - FlashInfer的kernel是预编译的
    - plan()提供了kernel需要的所有运行时参数
    - 避免在run()时重复计算
    - 支持CUDA Graph(plan一次,run多次)
    """

Y.2 BatchDecodeWithPagedKVCacheWrapper.plan()

def plan(
    self,
    paged_kv_indptr: torch.Tensor,
    paged_kv_indices: torch.Tensor,
    paged_kv_last_page_len: torch.Tensor,
    num_qo_heads: int,
    num_kv_heads: int,
    head_dim: int,
    page_size: int,
    pos_encoding_mode: str = "NONE",
    sliding_window: int | None = None,  # Decode特有: 滑动窗口
) -> None:
    """预计算decode所需的索引和元数据
    
    与prefill.plan()的区别:
    1. 每个序列只有1个查询token
       → 不需要qo_indptr
       → grid更简单
    
    2. 支持sliding_window
       → 只读取最近window_size个KV
       → 减少内存带宽
    
    3. workspace大小不同
       decode workspace更小(每token只需累加1个输出)
    """

附录Z FlashAttention v2 vs v3 API差异详解

Z.1 Prefill API

FlashAttention v2:
  flash_attn_varlen_func(
    q, k, v,                    # Q/K/V
    cu_seqlens_q, cu_seqlens_k, # 累积长度
    max_seqlen_q, max_seqlen_k, # 最大长度
    softmax_scale,
    causal=True,
  )
  → 返回: output [num_tokens, heads, dim]
  
  特点:
  - 变长batch API
  - K/V从输入参数传入(不是从cache读取)
  - 需要手动将KV cache中的K/V读取出来

FlashAttention v3:
  flash_attn_varlen_func(
    q, k, v,                    # 同v2
    cu_seqlens_q, cu_seqlens_k,
    max_seqlen_q, max_seqlen_k,
    softmax_scale,
    causal=True,
    block_table=block_tables,   # ★新增★: 直接从Paged KV Cache读取
  )
  → 返回: output
  
  特点:
  - 支持Paged KV Cache
  - 不需要手动读取KV
  - 更高效的内存访问

Z.2 Decode API

FlashAttention v2:
  flash_attn_with_kvcache(
    q,                          # [batch, 1, heads, dim]
    k_cache, v_cache,           # KV cache
    cache_seqlens,              # [batch] 缓存序列长度
    block_table,                # [batch, max_blocks] 块表
    softmax_scale,
    causal=True,
  )

FlashAttention v3:
  flash_attn_with_kvcache(
    q,
    k_cache, v_cache,
    cache_seqlens,
    block_table,
    softmax_scale,
    causal=True,
    # ★新增参数★
    cache_batch_idx=None,       # 批次索引(用于非连续cache)
    q_batch_idx=None,           # Q的批次索引
    q_start_idx=0,              # Q的起始索引
    cache_leftpad=None,         # 左填充(特定格式)
  )
  
  v3改进:
  - 支持非连续的Q和KV cache布局
  - 支持q_start_idx(混合batch中Q的偏移)
  - 更好的CUDA Graph兼容性

附录AA FlashAttention 内置RoPE vs 外置RoPE

FlashInfer的RoPE处理方式:
  FlashInfer内置了RoPE,通过以下参数传入:
    rotary_inv_q: RoPE的逆频率(Q)
    rotary_inv_k: RoPE的逆频率(K)
  
  在kernel内部完成RoPE计算:
    q_rope = apply_rope(q, positions, inv_freq_q)
    k_rope = apply_rope(k, positions, inv_freq_k)
  
  优点: 减少一次kernel launch(RoPE+Attn融合)
  缺点: 只支持标准RoPE,不支持自定义变体

FlashAttention的RoPE处理方式:
  在调用flash_attn之前,在外部计算RoPE:
    q = apply_rope(q, positions, freqs)  # 独立kernel
    k = apply_rope(k, positions, freqs)
    output = flash_attn(q, k, v, ...)
  
  优点: 支持任意RoPE变体(YaRN, LongRoPE等)
  缺点: 额外2次kernel launch

Triton的RoPE处理方式:
  同FlashAttention,外置RoPE
  可以将RoPE融合到Q/K投影kernel中

附录BB GDN/Linear/ShortConv 注意力变体详解

BB.1 GDN (Geometric Deep Network) 注意力

# GDN使用几何注意力机制
# 不计算 Q*K^T,而是使用距离度量
# 核心公式:
#   attention = softmax(-distance(Q, K) / temperature) × V
#   distance可以是: L2距离, 余弦距离, 或学习到的度量

class GDNImpl(AttentionImpl):
    def forward(self, query, key, value, kv_cache, attn_metadata):
        # 不使用标准Q*K^T
        # 使用几何距离
        distance = compute_geometric_distance(query, key)
        scores = -distance / self.temperature
        attn = softmax(scores)
        output = attn @ value
        return output

BB.2 Linear Attention

# 线性注意力: O(n)复杂度
# 核心思想: 使用kernel trick避免显式计算注意力矩阵
# 
# 标准注意力: O(n²)
#   O = softmax(Q×K^T/√d) × V
# 
# 线性注意力: O(n×d)
#   O = φ(Q) × (φ(K)^T × V)
#   其中φ是特征映射函数
#   由于矩阵乘法结合律: 先计算 φ(K)^T × V [d×d],再乘以 φ(Q) [n×d]
#   复杂度: n×d² (vs n²×d)

# 常见的φ函数:
# 1. ELU+1: φ(x) = elu(x) + 1
# 2. Softmax近似: φ(x) = exp(x) / √d (Performer)
# 3. 随机特征: φ(x) = random_features(x) (FAVOR+)

class LinearAttnImpl(AttentionImpl):
    def forward(self, query, key, value, kv_cache, attn_metadata):
        phi_q = self.feature_map(query)  # [n, d]
        phi_k = self.feature_map(key)    # [n, d]
        
        # 累积KV状态: S = Σ φ(k_i)^T × v_i
        # 这是O(d²)的状态,可以增量更新
        S = phi_k.T @ value  # [d, d_v]
        
        # 输出: O = φ(Q) × S / (φ(Q) × z)
        # z = Σ φ(k_i) 的归一化因子
        z = phi_q @ phi_k.sum(dim=0, keepdim=True).T
        output = phi_q @ S / z
        return output

BB.3 ShortConv (短卷积替代注意力)

# 短卷积: 用1D卷积替代注意力
# 感受野: 固定窗口大小(kernel_size)
# 复杂度: O(n × kernel_size)
# 
# 适用场景: 局部依赖(不需要全局上下文)
# 常见于: 长序列模型的早期层

class ShortConvImpl(AttentionImpl):
    def forward(self, query, key, value, kv_cache, attn_metadata):
        # 使用1D因果卷积
        # kernel_size = 3 或 5
        output = self.causal_conv1d(value, self.weight, self.bias)
        return output
    
    @staticmethod
    def get_kv_cache_shape(*args):
        # 短卷积不需要KV Cache
        return (0,)  # 空cache

附录CC ROCm平台注意力后端完整选择逻辑

ROCm平台的注意力后端选择:

1. 检测ROCm版本和GPU型号
   rocm_version = get_rocm_version()
   gpu_arch = get_gpu_arch()  # gfx90a, gfx942, etc.

2. 检测aiter库可用性
   aiter_available = importlib.util.find_spec("aiter") is not None

3. 后端选择逻辑:
   if aiter_available:
     if is_mla_model():
       → ROCmAITerMLABackend 或 ROCmAITerMLASparseBackend
     else:
       → ROCmAITerFABackend  # aiter的FlashAttention实现
   else:
     → ROCmAttnBackend  # 基于composable_kernel的实现

4. aiter库的性能特点:
   - 专门为AMD GPU优化
   - 支持MI200/MI300系列
   - 使用WMMA (Wave Matrix Multiply-Accumulate)
   - 与CUDA FlashInfer的API对齐

5. 不支持的后端:
   ❌ FlashInfer (仅CUDA)
   ❌ FlashAttention v3 (仅CUDA)
   ❌ Triton (ROCm支持有限)
   ❌ CUTLASS (仅CUDA)

附录DD FlashInfer decode kernel GQA优化详解

DD.1 GQA在decode中的内存访问模式

GQA decode的关键: 多个Q头共享1组KV头

假设: num_heads=32, num_kv_heads=8, kv_group_num=4

标准实现:
  for head in range(32):
    kv_head = head // 4  # 0,0,0,0,1,1,1,1,...,7,7,7,7
    k = k_cache[kv_head]  # 读取KV头
    scores = q[head] @ k.T
    output[head] = softmax(scores) @ v_cache[kv_head]
  
  问题: KV被重复读取4次(每4个Q头读取同一组KV)

优化实现 (FlashInfer):
  # 将Q按KV头分组
  for kv_head in range(8):
    k = k_cache[kv_head]  # 只读1次!
    v = v_cache[kv_head]  # 只读1次!
    
    for q_head in range(kv_head*4, (kv_head+1)*4):
      scores = q[q_head] @ k.T
      output[q_head] = softmax(scores) @ v
  
  优化: KV读取从32次减少到8次
  内存带宽节省: 75%

DD.2 FlashInfer的KV cache共享机制

FlashInfer decode kernel的KV读取策略:

1. 所有Q头先共享读取KV:
   # 第一步: 将KV从cache加载到shared memory
   for kv_head in range(num_kv_heads):
     k_shared[kv_head] = load_from_cache(kv_head)
     v_shared[kv_head] = load_from_cache(kv_head)
   
   # 第二步: 每个Q头使用shared memory中的KV
   for q_head in range(num_heads):
     kv_head = q_head // kv_group_num
     scores = q[q_head] @ k_shared[kv_head].T
     output[q_head] = softmax(scores) @ v_shared[kv_head]

2. shared memory利用:
   KV在shared memory中 → 所有Q头共享
   → 全局内存读取量: num_kv_heads × head_size × 2
   → 而非: num_heads × head_size × 2
   
   GQA节省: num_kv_heads/num_heads = 8/32 = 4×
Logo

免费领 100 小时云算力,进群参与显卡、AI PC 幸运抽奖

更多推荐