三、核心业务逻辑深度解析

3.1 __init__.py - 模块标记

文件内容:空文件。

该文件仅作为 Python 包标记文件存在,不包含任何代码逻辑。在 vLLM v1 的 spec_decode 模块中,所有功能类均通过显式 import 导入,而非通过 __init__.py 聚合导出。这种设计使得每个 Proposer 类可以由上层根据 speculative_config.method 按需选择加载,避免不必要的模块初始化开销。


3.2 utils.py - Triton Kernels 与工具函数

utils.py 是整个 spec_decode 模块的底层计算基础设施,提供了 5 个 Triton GPU 内核和若干工具函数,所有内核均用于加速输入数据的复制、扩展和 slot mapping 计算,是推测解码中 CPU→GPU 数据路径上的关键瓶颈优化。

3.2.1 next_power_of_2(n: int) -> int

作用:返回大于等于 n 的最小 2 的幂。

实现细节

  • 经典的位运算算法:先 n -= 1,然后依次右移 1、2、4、8、16、32 位做 OR 操作,最终 +1
  • 用于 Triton kernel 中 BLOCK_SIZE 的参数设置,确保 block size 为 2 的幂以匹配 GPU warp 调度。
  • 时间复杂度 O(1),无循环。
3.2.2 PADDING_SLOT_ID = -1

全局常量,表示无效的 KV Cache slot 位置。当 token 被拒绝(rejected)或超出模型最大长度时,slot mapping 被设置为此值,防止 KV Cache 被无效 token 覆盖。

3.2.3 eagle_step_slot_mapping_metadata_kernel - EAGLE 单步 slot mapping 更新 Kernel

这是一个融合(fused)Triton kernel,将 EAGLE 自回归推测中的三个独立操作合并为一个 kernel 启动,显著减少 kernel launch 开销。

函数签名

@triton.jit
def eagle_step_slot_mapping_metadata_kernel(
    positions_ptr,              # [batch_size] - 当前位置
    block_table_ptr,            # [batch_size, n_blocks_per_req]
    block_table_stride,         # block_table 维度 1 的步长
    seq_lens_ptr,               # [batch_size] - 读/写
    out_clamped_positions_ptr,  # [batch_size] (输出)
    out_slot_mapping_ptr,       # [input_batch_size] (输出)
    block_size: tl.constexpr,
    max_model_len: tl.constexpr,
    n_blocks_per_req: tl.constexpr,
    PAD_ID: tl.constexpr,
    batch_size,
):

逐行逻辑分析

  1. req_idx = tl.program_id(0) — 每个线程块处理一个请求
  2. Padding 槽处理if req_idx >= batch_size — 对于 cudagraph 填充槽,仅写入 PAD_IDout_slot_mapping 后直接 return,不做任何其他计算
  3. position = tl.load(positions_ptr + req_idx) — 加载当前位置
  4. new_position = position + 1 — 推测下一个位置(自回归递增)
  5. exceeds_max = new_position >= max_model_len — 检查是否超出模型最大长度
  6. clamped_position = tl.where(exceeds_max, 0, new_position) — 如果超出,clamped 为 0(避免后续 block table 越界)
  7. block_number = clamped_position // block_size — 计算 block 编号
  8. block_number = tl.minimum(block_number, n_blocks_per_req - 1) — 防止越界
  9. block_id = tl.load(block_table_ptr + req_idx * block_table_stride + block_number) — 从 block table 查找 block ID
  10. slot_id = block_id * block_size + (clamped_position % block_size) — slot = block_id × block_size + 块内偏移
  11. slot_id = tl.where(exceeds_max, PAD_ID, slot_id) — 如果超出 max length,slot 设为 -1
  12. seq_len = tl.load(seq_lens_ptr + req_idx) — 加载当前序列长度
  13. new_seq_len = tl.where(exceeds_max, 1, seq_len + 1) — 超出时重置为 1,否则递增
  14. new_seq_len = tl.minimum(new_seq_len, max_model_len) — 上限截断
  15. 存储三个输出:out_clamped_positionsout_slot_mappingseq_lens(原位更新)

关键设计:此 kernel 以 input_batch_size 个线程启动(而非 batch_size),多出来的线程用于处理 cudagraph padding,确保在 CUDA Graph replay 时 grid 尺寸不变。

eagle_step_update_slot_mapping_and_metadata 封装函数

  • 提取 batch_size = positions_1d.shape[0]
  • 如果 input_batch_size 未指定,默认为 batch_size
  • 获取 n_blocks_per_req = block_table_tensor.shape[1]
  • (input_batch_size,) 的 grid 启动 kernel
3.2.4 eagle_prepare_inputs_padded_kernel - 准备填充输入 Kernel

作用:为 padded drafter batch 模式计算每个请求的采样 token 索引和被拒绝的 token 数量。

关键参数

  • cu_num_draft_tokens_ptr:累积 draft token 数量(包含求和)
  • valid_sampled_tokens_count_ptr:每个请求有效的采样 token 数(= 1 + 接受数)
  • query_start_loc_gpu_ptr:query 起始位置 [num_reqs + 1]
  • 输出:token_indices_to_samplenum_rejected_tokens_gpu

逐行逻辑

  1. cu_draft_curr = tl.load(cu_num_draft_tokens_ptr + req_idx) — 加载当前请求的累积 draft 数
  2. num_draft_tokens 计算:第一个请求直接等于累积值,后续为 cu_draft_curr - cu_draft_prev
  3. valid_count = tl.load(valid_sampled_tokens_count_ptr + req_idx) — 有效采样 token 数
  4. num_rejected_tokens = num_draft_tokens + 1 - valid_count — 被拒绝 token 数 = draft 数 + 1(bonus token)- 有效数
  5. num_rejected_tokens = tl.where(num_draft_tokens > 0, num_rejected_tokens, 0) — 如果没有 draft,拒绝数也为 0
  6. q_last_tok_idx = tl.load(query_start_loc_gpu_ptr + req_idx + 1) - 1 — 当前请求的最后一个 token 索引
  7. index_to_sample = q_last_tok_idx - num_rejected_tokens — 需要采样的索引 = 最后索引 - 拒绝数
3.2.5 eagle_prepare_next_token_padded_kernel - 准备下一个 Token Kernel

作用:确定每个请求的 “next token id”(用于传入推测器),同时统计有效采样 token 数量。

逐行逻辑

  1. is_discarded = tl.load(discard_request_mask_ptr + req_idx) — 检查是否被丢弃
  2. 丢弃分支:直接使用 backup token,valid_count = 0
  3. 正常分支
    • 加载该行所有 sampled_token_ids(通过 BLOCK_SIZE_TOKENS 的 block 加载)
    • is_valid_mask = (token_ids != -1) & (token_ids < vocab_size) & token_mask — 有效 token 的 mask
    • valid_count = tl.sum(is_valid_mask) — 计数
    • 如果 valid_count > 0:找到最后一个有效 token 的索引(tl.max(tl.where(...))),使用 sum trick 提取对应 token
    • 如果 valid_count == 0:使用 backup token

设计亮点:使用 tl.sum(tl.where(token_offs == last_valid_index, token_ids, 0)) 技巧避免第二次内存加载,因为 Triton 没有直接的按索引选取操作。

3.2.6 compute_new_slot_mapping — 纯 PyTorch slot mapping 计算

作用:在扩展了 query 长度后,重新计算所有 token 的 slot mapping。

逐行分析

  1. req_indices = torch.arange(batch_size, device=...) — 请求索引
  2. torch.repeat_interleave(req_indices, cad.naive_query_lens() + num_new_tokens, ...) — 根据每个请求的 token 数扩展,得到每个 token 所属的请求索引
  3. clamped_positions = torch.clamp(new_positions, max=max_model_len - 1) — 防止 block table 越界
  4. block_table_indices = req_indices * n_blocks_per_req + clamped_positions // block_size — 展平后的 block table 索引
  5. block_nums = cad.block_table_tensor.view(-1)[block_table_indices] — 批量获取 block 号
  6. new_slot_mapping = block_nums * block_size + block_offsets — 计算 slot
  7. new_slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID) — 超出长度位置的 slot 设为 -1
  8. new_slot_mapping.masked_fill_(is_rejected_token_mask, PADDING_SLOT_ID) — 被拒绝 token 的 slot 设为 -1
3.2.7 extend_all_queries_by_N — 扩展所有 Query 长度

作用:将所有请求的 query 长度增加 N,同时更新相关元数据。用于并行推测(parallel drafting)模式。

逐行分析

  1. new_query_start_loc = cad.query_start_loc + N * arange[:len(cad.query_start_loc)] — 每个请求的起始位置递增 [0, N, 2N, ..., batch_size*N]
  2. new_query_start_loc_cpu = ... — 同样更新 CPU 端
  3. cad.replace(...) 创建新的 CommonAttentionMetadata
    • seq_lens = cad.seq_lens + N — 序列长度递增
    • num_actual_tokens = cad.num_actual_tokens + cad.batch_size() * N — 总 token 数增加
    • max_query_len = cad.max_query_len + N — 最大 query 长度增加
    • max_seq_len = cad.max_seq_len + N — 最大序列长度增加
    • slot_mapping = new_slot_mapping — 使用新的 slot mapping
3.2.8 copy_and_expand_eagle_inputs_kernel — EAGLE 输入复制与扩展 Kernel

这是 utils.py最复杂、最重要的 Triton kernel,负责将目标模型的输入复制到草稿模型的缓冲区中,同时处理 padding、并行推测 slot 和被拒绝 token。

参数详解

参数 形状 说明
target_token_ids_ptr [total_tokens_in_batch] 目标模型的 token IDs
target_positions_ptr [total_tokens_in_batch] 目标模型的位置编码
next_token_ids_ptr [num_reqs] 目标模型采样的下一个 token
out_input_ids_ptr [total_draft_tokens] 输出:草稿模型的输入 IDs
out_positions_ptr [total_draft_tokens] 输出:位置编码
out_is_rejected_token_mask_ptr [total_draft_tokens] 输出:被拒绝 token mask
out_is_masked_token_mask_ptr [total_draft_tokens] 输出:并行推测 mask
out_new_token_indices_ptr [num_padding_slots_per_req * num_reqs] 输出:新 token 的索引
out_hidden_state_mapping_ptr [total_tokens_in_batch] 输出:隐藏状态映射
query_start_loc_ptr [num_reqs + 1] query 起始位置
query_end_loc_ptr [num_reqs] query 结束位置
padding_token_id scalar padding token(通常为 0)
parallel_drafting_token_id scalar 并行推测占位 token
total_input_tokens scalar 输入总 token 数
num_padding_slots_per_request scalar 每请求新增 slot 数
shift_input_ids bool 是否移位输入(EAGLE 模式为 true)

执行逻辑(2D grid: request_idx × token_batch_idx)

  1. 加载 query 边界:从 query_start_locquery_end_loc 获取当前请求的有效 token 范围
  2. 计算有效 token 数
    • shift_input_ids=True 时(EAGLE):num_valid_tokens = query_end_loc - query_start_loc(跳过第一个 token),input_offset = 1
    • shift_input_ids=False 时(Draft Model):num_valid_tokens = query_end_loc - query_start_loc + 1input_offset = 0
  3. 计算输出起始位置:考虑 cudagraph padding 导致的位移
  4. 被拒绝 token 数num_rejected = next_query_start_loc - query_end_loc - 1
  5. 总输出 token 数num_valid_tokens + num_padding_slots_per_request + num_rejected
  6. 区域分类 mask
    • is_valid_region: j < num_valid_tokens — 从输入复制
    • is_bonus_region: j == num_valid_tokens — bonus token(next_token_id)
    • is_parallel_draft_region: 并行推测占位区域
    • is_rejected_region: 被拒绝 token 区域
  7. Token IDs 组装:通过 tl.where 链式选择,不同区域写入不同的 token
  8. Position 计算start_pos + j,位置不跟随输入移位
  9. 隐藏状态映射:当 shift_input_ids=True 时,记录每个输入位置在输出缓冲区中的对应索引
  10. 存储所有输出

输出缓冲区布局

[valid_tokens | bonus_token | parallel_draft_slots | rejected_slots]
3.2.9 copy_and_expand_dflash_inputs_kernel — DFlash 输入复制 Kernel

与 EAGLE kernel 类似但针对 DFlash 的交叉注意力机制做了定制:

关键差异

  • 将**上下文(context)查询(query)**分存到不同的缓冲区中
  • Context 的 slot mapping 和 query 的 slot mapping 分别存储
  • Query 部分由 [next_token, mask, mask, ...] 组成
  • 处理 num_rejected_tokens(padded 模式下)
  • 使用 HAS_NUM_REJECTED compile-time 常量条件

逐行逻辑

  1. 加载 context 范围 ctx_startctx_end
  2. 计算 num_ctxtotal_tokens = num_ctx + num_query_per_req
  3. 位置处理:
    • Context:从 target_positions 加载
    • Query:last_pos + 1 + query_off(在最后一个有效位置之后递增)
  4. Slot mapping:通过 block table 查找,context 和 query 分别写入不同缓冲区
  5. Input IDs:query 部分的第一个 token 为 bonus token,其余为 parallel_drafting_token_id
  6. Token indices:记录 mask token 的位置(需要采样的位置)
3.2.10 update_num_computed_tokens_for_batch_change — Batch 变更修正

使用 @torch.compile 编译的函数,用于异步推测解码中的 num_computed_tokens 修正:

  1. gather_indices = prev_positions.clamp(min=0) — 防止新请求的 -1 位置导致越界
  2. 从旧位置 gather 对应的 valid_count、prev_computed、prev_drafts
  3. participating = (prev_positions >= 0) & (prev_drafts > 0) — 参与推测的请求
  4. corrected = prev_computed + valid_counts.int() — 修正计算
  5. torch.where(participating, corrected, cpu_num_computed_tokens) — 参与修正,未参与使用 CPU 值

调用者

PyTorch Functions

Triton Kernels

eagle_step_slot_mapping
metadata_kernel

eagle_prepare_inputs
padded_kernel

eagle_prepare_next_token
padded_kernel

copy_and_expand
eagle_inputs_kernel

copy_and_expand
dflash_inputs_kernel

compute_new_slot_mapping

extend_all_queries_by_N

update_num_computed_tokens
for_batch_change

next_power_of_2

EagleProposer.propose
AR Loop

EagleProposer.set_inputs
first_pass

DFlashProposer.set_inputs
first_pass

EagleProposer.prepare
inputs_padded

EagleProposer.prepare_next
token_ids_padded


3.3 metadata.py - SpecDecodeMetadata 数据类

SpecDecodeMetadata 是推测解码中的元数据容器,封装了 draft token 的布局信息和索引信息。

@dataclass
class SpecDecodeMetadata:
    draft_token_ids: torch.Tensor        # [num_tokens] - 扁平化的 draft token IDs
    num_draft_tokens: list[int]          # [batch_size] - 每个请求的 draft token 数
    cu_num_draft_tokens: torch.Tensor    # [batch_size] - 累积 draft token 数
    cu_num_sampled_tokens: torch.Tensor  # [batch_size] - 累积采样 token 数(= draft + 1)
    target_logits_indices: torch.Tensor  # [num_tokens] - 需要 target 验证的 logit 索引
    bonus_logits_indices: torch.Tensor   # [batch_size] - bonus token 的 logit 索引
    logits_indices: torch.Tensor         # [num_tokens + batch_size] - 所有 logit 索引

逐字段分析

  • draft_token_ids:所有请求的 draft token IDs 扁平化拼接。例如请求 1 有 2 个 draft [a, b],请求 2 有 3 个 [c, d, e],则合并为 [a, b, c, d, e]

  • num_draft_tokens:每个请求的 draft token 数量列表,用于后续按请求切分。

  • cu_num_draft_tokensnum_draft_tokens 的累积和(np.cumsum)。例如 [2, 3][2, 5]。通过索引此数组可以快速定位每个请求在 draft_token_ids 中的起始和结束位置。

  • cu_num_sampled_tokensnum_draft_tokens + 1 的累积和。“+1” 是因为每个请求还有一个 bonus token(目标模型实际采样的 token)。

  • target_logits_indices:用于在 target model 的 logits 中索引需要验证的 draft token 位置。在初始化时为零填充,由 verifier 阶段填入。

  • bonus_logits_indices:bonus token 对应的 logit 索引。

  • logits_indices:所有 logits 索引的合并数组(target + bonus),shape 为 [num_tokens + batch_size]

__post_init__:计算 max_spec_len = max(self.num_draft_tokens),即 batch 中最大的 draft 长度,用于后续的 tensor 维度分配。

make_dummy 方法

@classmethod
def make_dummy(cls, draft_token_ids: list[list[int]], device: torch.device) -> "SpecDecodeMetadata":

用于创建 dummy metadata,主要场景是 CUDA Graph 捕获阶段,此时还没有真实的 draft token。该方法:

  1. list[list[int]] 扁平化并转为 tensor
  2. 计算 num_draft_tokens 和累积和
  3. target_logits_indicesbonus_logits_indiceslogits_indices 均初始化为零填充
  4. 返回完整的 SpecDecodeMetadata 实例

3.4 metrics.py - 指标体系

指标体系由三个组件构成,形成"采集 → 聚合 → 上报"的三层架构。

3.4.1 SpecDecodingStats — 单步统计
@dataclass
class SpecDecodingStats:
    num_spec_tokens: int                    # 推测 token 数(配置值)
    num_drafts: int = 0                     # 推测次数
    num_draft_tokens: int = 0               # 总 draft token 数
    num_accepted_tokens: int = 0            # 总接受 token 数
    num_accepted_tokens_per_pos: list[int]  # 每个位置的接受计数

new(cls, num_spec_tokens) 工厂方法:

  • 创建新实例,num_accepted_tokens_per_pos 初始化为 [0] * num_spec_tokens
  • 长度为 num_spec_tokens,索引 0 对应第 1 个推测位置,索引 1 对应第 2 个,以此类推

observe_draft(num_draft_tokens, num_accepted_tokens)

  • num_drafts += 1 — 推测次数 +1
  • num_draft_tokens += num_draft_tokens — 累计 draft
  • num_accepted_tokens += num_accepted_tokens — 累计接受
  • num_accepted_tokens 个位置的 num_accepted_tokens_per_pos 各 +1
  • assert 确保 num_accepted_tokens <= num_spec_tokens(接受数不能超过推测数)
3.4.2 SpecDecodingLogging — 日志聚合器
class SpecDecodingLogging:

状态变量

  • num_drafts: list[int] — 各次观察的推测次数
  • num_draft_tokens: list[int] — 各次观察的 draft token 数
  • num_accepted_tokens: list[int] — 各次观察的接受 token 数
  • accepted_tokens_per_pos_lists: list[list[int]] — 各次观察的位置接受计数
  • last_log_time — 上次日志时间戳

observe(spec_decoding_stats):将单步统计追加到各列表中。

log(log_fn=logger.info) 核心聚合逻辑:

  1. num_drafts = np.sum(self.num_drafts) — 总推测次数
  2. num_draft_tokens = np.sum(...) — 总 draft
  3. num_accepted_tokens = np.sum(...) — 总接受
  4. draft_throughput = num_draft_tokens / elapsed_time — draft 吞吐(tokens/s)
  5. accepted_throughput = num_accepted_tokens / elapsed_time — 接受吞吐
  6. draft_acceptance_rate = num_accepted_tokens / num_draft_tokens * 100 — 接受率
  7. mean_acceptance_length = 1 + (num_accepted_tokens / num_drafts) — 平均接受长度(含 bonus token)
  8. pos_matrix = np.array(self.accepted_tokens_per_pos_lists) — 位置矩阵
  9. acceptance_rates = np.sum(pos_matrix, axis=0) / num_drafts — 每个位置的接受率
  10. 格式化为日志字符串,包含所有指标
3.4.3 SpecDecodingProm — Prometheus 集成
class SpecDecodingProm:

初始化时创建的 Counter

Counter 名称 含义 标签
vllm:spec_decode_num_drafts 推测次数 engine 标签
vllm:spec_decode_num_draft_tokens draft token 总数 engine 标签
vllm:spec_decode_num_accepted_tokens 接受 token 总数 engine 标签
vllm:spec_decode_num_accepted_tokens_per_pos 各位置接受数 engine + position 标签

PromQL 计算公式

  • 接受率:rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) / rate(vllm:spec_decode_num_draft_tokens_total[$interval])
  • 平均接受长度:1 + (rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) / rate(vllm:spec_decode_num_drafts[$interval]))
  • 各位置接受率向量:vllm:spec_decode_num_accepted_tokens_per_pos[$interval] / vllm:spec_decode_num_drafts[$interval]

observe(spec_decoding_stats, engine_idx=0)

  • 根据 engine 索引递增对应的 counter
  • per-position counter 通过位置索引分别递增

消费层

聚合层

采集层

SpecDecodingStats
单步观测

SpecDecodingLogging
日志聚合

SpecDecodingProm
Prometheus Counter

Logger.info
结构化日志

Grafana / Alert
PromQL 查询


3.5 eagle.py - SpecDecodeBaseProposer 基类与 EagleProposer

这是整个 spec_decode 模块中最核心、最复杂的文件(1793 行),包含了所有基于自回归推测的 Proposer 的公共基类以及 EAGLE 专用实现。

3.5.1 SpecDecodeBaseProposer.__init__() — 初始化

核心变量逐行分析

def __init__(self, vllm_config, device, pass_hidden_states_to_model, runner=None):

配置提取

  1. self.vllm_config = vllm_config — 保存完整的 vLLM 配置
  2. self.speculative_config = vllm_config.speculative_config — 推测解码配置
  3. self.draft_model_config = self.speculative_config.draft_model_config — 草稿模型配置
  4. self.method = self.speculative_config.method — 推测方法名称(“eagle”, “eagle3”, “dflash” 等)
  5. self.pass_hidden_states_to_model = pass_hidden_states_to_model — 是否将 target 的 hidden states 传给 draft model

类型和维度参数
6. self.device = device — GPU 设备
7. self.dtype = vllm_config.model_config.dtype — 模型数据类型(float16/bfloat16)
8. self.max_model_len = vllm_config.model_config.max_model_len — 模型最大长度
9. self.dp_rank = vllm_config.parallel_config.data_parallel_rank — 数据并行 rank
10. self.num_speculative_tokens = self.speculative_config.num_speculative_tokens — 每次推测的 token 数

隐藏层维度
11. self.hidden_size = self.draft_model_config.get_hidden_size() — 从 draft model 获取 hidden size(可能与 target model 不同,如 Llama 3.3 70B 场景)
12. self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size() — 嵌入层维度

并行推测参数
13. self.parallel_drafting: bool = self.speculative_config.parallel_drafting — 是否启用并行推测
14. self.extra_slots_per_request = 1 if not self.parallel_drafting else self.num_speculative_tokens — 每个请求的额外 slot 数
- 非并行模式(EAGLE 递归):每个请求只需 1 个额外 slot 放 bonus token
- 并行模式(parallel drafting):需要 num_speculative_tokens 个 slot
15. self.net_num_new_slots_per_request = self.extra_slots_per_request - (1 if ... else 0) — 净新增 slot
- 当 pass_hidden_states_to_model=True 且非 DFlash 时,EAGLE 会"吃掉"第一个 token(移位),因此净新增减 1
16. self.needs_extra_input_slots = self.net_num_new_slots_per_request > 0 — 是否需要额外输入 slot

并行推测专用参数
17. self.parallel_drafting_token_id: int = 0 — 并行推测的占位 token ID
18. self.parallel_drafting_hidden_state_tensor: Tensor | None = None — 并行推测的占位 hidden state
19. 如果 parallel_drafting,调用 _init_parallel_drafting_params() 初始化

本地 argmax 优化
20. self.use_local_argmax_reduction: bool = self.speculative_config.use_local_argmax_reduction — 是否使用本地 argmax 归约(减少 TP 通信,O(2*tp_size) vs O(vocab_size))

批处理大小限制
21. self.max_batch_size = vllm_config.scheduler_config.max_num_seqs — 最大 batch size
22. self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens — 最大 token 数
23. self.token_arange_np = np.arange(self.max_num_tokens) — CPU 端 token 索引数组

Query/Position 限制
24. self.max_query_tokens = self.max_num_tokens — 最大 query token 数(可被子类如 DFlash 缩小)
25. self.max_positions = self.max_num_tokens — 最大位置数

多模态支持
26. self.mm_registry = MULTIMODAL_REGISTRY — 多模态注册表
27. self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(vllm_config.model_config) — 是否支持多模态输入

注意力后端
28. self.draft_attn_groups: list[AttentionGroup] = [] — 草稿模型的注意力组(延迟初始化)
29. self.kv_cache_gid: int = -1 — KV Cache 组 ID
30. self.eagle3_use_aux_hidden_state: bool = ... — EAGLE3 是否使用辅助隐藏状态

编译配置
31. self.compilation_config = self.vllm_config.compilation_config

CUDA Graph 调度器
32. self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) — 仅用于 PIECEWISE 模式
33. 在 initialize_cudagraph_keys() 中延迟初始化 key

GPU 持久化缓冲区(核心内存分配):
34. self.input_ids = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=device) — 输入 token IDs 缓冲区
35. self.uses_mrope = self.draft_model_config.uses_mrope — 是否使用 M-RoPE(使用 draft model 的设置,而非 target)
36. self.uses_xdrope_dim / self.draft_uses_xdrope_dim — XDrop 维度
37. M-RoPE 位置缓冲self.mrope_positions = torch.zeros((3, self.max_positions + 1), dtype=torch.int64, device=device)
- 故意多分配 1 个位置使其非连续(non-contiguous),以兼容 torch compile
- 3 个维度分别对应 M-RoPE 的三个空间维度
38. XDrop 位置缓冲self.xdrope_positions(当使用 XDrop 时)
39. 普通位置缓冲self.positions = torch.zeros(self.max_positions, dtype=torch.int64, device=device)
40. self.hidden_states = torch.zeros((self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device) — 隐藏状态缓冲
41. self.inputs_embeds = torch.zeros((self.max_num_tokens, self.inputs_embeds_size), dtype=self.dtype, device=device) — 嵌入缓冲
42. self.block_size: int = -1 — KV cache block size(延迟初始化)

ARange 缓冲
43. self.arange = torch.arange(max(self.max_batch_size + 1, self.max_num_tokens), device=device, dtype=torch.int32) — 需要 +1 因为 query_start_loc 比 batch_size 多一个元素

拒绝/mask 标记缓冲(当 needs_extra_input_slots 时):
44. self.is_rejected_token_mask = torch.zeros((self.max_num_tokens,), dtype=torch.bool, device=device) — 被拒绝 token 标记
45. self.is_masked_token_mask = torch.zeros((self.max_num_tokens,), dtype=torch.bool, device=device) — 并行推测占位标记

Backup token 缓冲
46. self.backup_next_token_ids = CpuGpuBuffer(self.max_batch_size, dtype=torch.int32, ...) — 备用的下一个 token ID(CPU-GPU 双端缓冲)

Slot mapping 缓冲
47. self._slot_mapping_buffer = torch.zeros(self.max_positions, dtype=torch.int64, device=device)

ROCm 注意力类型
48. 在 ROCm 平台上,构建 allowed_attn_types 元组,包含各种 ROCm 注意力后端类型

推测 token 树
49. spec_token_tree = self.speculative_config.speculative_token_tree — 从配置解析 token 树结构
50. self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree) — token 树表示
- 例如 [(0,), (0,0), (0,1), (1,), (1,0), (1,1)] 表示一个 2 叉深度 2 的树
51. tree_depth = len(self.tree_choices[-1]) — 树的深度
52. num_drafts_per_level — 每层的 draft 数
53. self.cu_drafts_per_level — 累积 draft 数(前缀和)
54. self.child_drafts_per_level — 每层的孩子数(每个节点的孩子数)
55. self.tree_draft_pos_offsets = torch.arange(1, len(self.tree_choices) + 1, ...).repeat(batch_size, 1) — 树中 draft 的位置偏移

GPU 内存预算(以 max_num_tokens=4096, hidden_size=4096, batch_size=128 为例):

  • input_ids: 4096 × 4 = 16 KB
  • positions: 4096 × 8 = 32 KB
  • mrope_positions: 3 × 4097 × 8 = 98 KB
  • hidden_states: 4096 × 4096 × 2 (bfloat16) = 32 MB
  • inputs_embeds: 4096 × 4096 × 2 = 32 MB
  • is_rejected_token_mask: 4096 bytes
  • is_masked_token_mask: 4096 bytes
  • backup_next_token_ids: 128 × 4 × 2 = 1 KB
  • _slot_mapping_buffer: 4096 × 8 = 32 KB
  • 总计约 64.5 MB(不含模型权重)
3.5.2 _init_parallel_drafting_params()

初始化并行推测所需的特殊参数:

def _init_parallel_drafting_params(self):
    model_hf_config = self.draft_model_config.hf_config
    dflash_config = getattr(model_hf_config, "dflash_config", None)
    if dflash_config and "mask_token_id" in dflash_config:
        self.parallel_drafting_token_id = dflash_config["mask_token_id"]
    elif hasattr(model_hf_config, "pard_token"):
        self.parallel_drafting_token_id = model_hf_config.pard_token
    elif hasattr(model_hf_config, "ptd_token_id"):
        self.parallel_drafting_token_id = model_hf_config.ptd_token_id
    else:
        raise ValueError(...)  # 必须有 mask token id
    
    if self.pass_hidden_states_to_model:
        self.parallel_drafting_hidden_state_tensor = torch.empty(
            self.hidden_size, dtype=self.dtype, device=self.device
        )

逻辑:从模型的 config.json 中依次查找 dflash_config.mask_token_idpard_tokenptd_token_id,获取并行推测使用的 mask token ID。如果 pass_hidden_states_to_model=True(EAGLE 模式),还需分配一个空的 hidden state tensor 用于 mask token 的 hidden state。

3.5.3 load_model() — 模型加载
def load_model(self, target_model: nn.Module) -> None:

逐行分析

  1. 获取目标模型的注意力层名称

    target_attn_layer_names = set(
        get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
    )
    

    记录目标模型中所有 AttentionLayerBase 子类的层名。

  2. 加载草稿模型

    self.model = self._get_model()
    

    调用 _get_model() 加载 EAGLE head 模型。

  3. 识别草稿模型的注意力层

    all_attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
    self._draft_attn_layer_names = set(all_attn_layers.keys()) - target_attn_layer_names
    

    草稿模型的注意力层 = 所有注意力层 - 目标模型的注意力层。这些是 EAGLE head 新增的层。

  4. 多模态兼容性检查:如果目标模型是多模态但草稿模型不支持,降级为纯文本模式。

  5. 多模态 image_token_index 对齐:对一系列特定模型(Qwen3VL、Gemma4 等),将草稿模型的 image_token_index 设置为与目标模型一致。

  6. 共享嵌入层self._maybe_share_embeddings(target_language_model)

  7. 共享 LM Headself._maybe_share_lm_head(target_language_model)

  8. 并行推测 hidden state 初始化

    if self.parallel_drafting and self.pass_hidden_states_to_model:
        flat_mask = self.model.mask_hidden.view(-1)
        if self.eagle3_use_aux_hidden_state:
            self.parallel_drafting_hidden_state_tensor.copy_(
                self.model.combine_hidden_states(flat_mask)
            )
        else:
            self.parallel_drafting_hidden_state_tensor.copy_(flat_mask)
    

    将 EAGLE 模型内部的 mask hidden state 复制到缓冲区中,用于并行推测时的占位隐藏状态。

3.5.4 _maybe_share_embeddings() — 可选共享嵌入层

目的:当草稿模型没有自己的嵌入层(或嵌入层与目标模型完全相同)时,直接共享目标模型的嵌入层权重,节省 GPU 内存。

逐行逻辑

  1. 仅 PP=1 时生效if get_pp_group().world_size == 1

  2. 获取目标模型的嵌入层

    • 先尝试 inner_model.embed_tokens
    • 再尝试 inner_model.embedding
    • 都找不到则 raise AttributeError
  3. EAGLE 模型判断

    • hasattr(self.model, "has_own_embed_tokens") — EAGLE 模型特有属性
    • has_own_embed_tokens=False:模型没有自己的嵌入 → 直接共享
    • has_own_embed_tokens=True:比较权重是否相同 → 相同则共享
    • 权重不同:保留独立嵌入
  4. MTP 模型判断:没有 has_own_embed_tokens 属性的模型(MTP),默认共享

  5. 执行共享

    if share_embeddings:
        if hasattr(self.model.model, "embed_tokens"):
            del self.model.model.embed_tokens
        self.model.model.embed_tokens = target_embed_tokens
    

    先删除草稿模型的嵌入层,然后直接引用目标模型的嵌入层(共享同一份 GPU 内存)。

3.5.5 _maybe_share_lm_head() — 可选共享 LM Head

_maybe_share_embeddings 类似,但针对 LM Head:

额外处理 MTP 模型

inner = getattr(self.model, "model", None)
layers = getattr(inner, "layers", None) if inner else None
if layers is not None:
    items = layers.values() if isinstance(layers, nn.ModuleDict) else layers
    for layer in items:
        sh = getattr(layer, "shared_head", None)
        if sh is not None and hasattr(sh, "head"):
            del sh.head
            sh.head = target_language_model.lm_head

MTP 模型的 compute_logits 通过 shared_head.head(一个 ParallelLMHead)执行,如果 checkpoint 中没有独立的 lm_head 权重,这个 head 会保持未初始化状态产生 NaN。因此必须显式共享。

Local Argmax 检查:当 use_local_argmax_reduction=True 时:

  • 检查模型是否实现了 get_top_tokens() 方法
  • 如果模型有 draft_id_to_target_id vocab 映射,会 fallback 到完整 logits 路径(因为 argmax 后还需要映射)
3.5.6 propose() — 核心推测流程

这是整个模块的入口方法,实现了完整的推测解码流程。

def propose(
    self,
    target_token_ids: torch.Tensor,           # [num_tokens]
    target_positions: torch.Tensor,            # [num_tokens] 或 [3, num_tokens]
    target_hidden_states: torch.Tensor,        # [num_tokens, hidden_size]
    next_token_ids: torch.Tensor,              # [batch_size]
    token_indices_to_sample: torch.Tensor | None,
    common_attn_metadata: CommonAttentionMetadata,
    sampling_metadata: SamplingMetadata,
    mm_embed_inputs: tuple[...] | None = None,
    num_rejected_tokens_gpu: torch.Tensor | None = None,
    slot_mappings: dict[...] | None = None,
) -> torch.Tensor:

逐行流程分析

Phase 0: 隐藏状态组合(EAGLE3/DFlash)

if self.method in ("eagle3", "dflash"):
    target_hidden_states = self.model.combine_hidden_states(target_hidden_states)

EAGLE3 和 DFlash 需要将目标模型的多个辅助隐藏状态合并为单一隐藏状态,以匹配 EAGLE 模型的 hidden_size。

Phase 1: 设置第一轮输入

num_tokens, token_indices_to_sample, common_attn_metadata = self.set_inputs_first_pass(...)

调用 set_inputs_first_pass() 处理输入数据,将目标模型的 token/position/hidden_states 复制到草稿模型的缓冲区中。

Phase 2: 构建注意力元数据

per_group_attn_metadata, per_layer_attn_metadata = self.build_per_group_and_layer_attn_metadata(common_attn_metadata)

遍历 self.draft_attn_groups,为每个注意力组调用 build_for_drafting() 构建注意力元数据。

Phase 3: 确定执行模式

cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = self._determine_batch_execution_and_padding(num_tokens)

通过 CudagraphDispatcher 确定是否使用 CUDA Graph,以及 batch 填充大小。

Phase 4: 构建第一轮模型输入

model_kwargs, slot_mapping_size = self.build_model_inputs_first_pass(num_tokens, num_input_tokens, mm_embed_inputs)

准备 input_idspositionsinputs_embedshidden_states 等 model forward 参数。

Phase 5: 第一轮前向传播

with set_forward_context(per_layer_attn_metadata, self.vllm_config, ...):
    ret_hidden_states = self.model(**model_kwargs)
    if not self.model_returns_tuple():
        last_hidden_states = ret_hidden_states
        hidden_states = last_hidden_states
    else:
        last_hidden_states, hidden_states = ret_hidden_states

执行草稿模型的前向传播。EAGLE 模型返回 (last_hidden_states, hidden_states) 元组,分别用于 logits 计算和后续自回归步骤。DFlash/Draft Model 只返回单一 hidden states。

Phase 6: 第一轮采样

sample_hidden_states = last_hidden_states[token_indices_to_sample]
if self.num_speculative_tokens == 1 or self.parallel_drafting:
    draft_token_ids = self._greedy_sample(sample_hidden_states)
    return draft_token_ids.view(-1, self.num_speculative_tokens)

如果只需推测 1 个 token 或启用并行推测,直接 argmax 采样后返回。

Phase 7: 提取采样位置的 hidden states

positions = self.mrope_positions[:, token_indices_to_sample] if self.uses_mrope else self.positions[token_indices_to_sample]
hidden_states = hidden_states[token_indices_to_sample]

根据 token 索引提取对应的位置编码和隐藏状态。

Phase 8: 树状推测或自回归推测

if any(isinstance(md, TreeAttentionMetadata) for md in per_group_attn_metadata):
    logits = self.model.compute_logits(sample_hidden_states)
    draft_token_ids_list = self.propose_tree(...)
    return torch.cat(draft_token_ids_list, dim=1)

如果使用了 Tree Attention,走树状推测路径(propose_tree())。

否则走自回归递归路径:

draft_token_ids = self._greedy_sample(sample_hidden_states)
draft_token_ids_list = [draft_token_ids]

Phase 9: 自回归循环(AR Loop)

for token_index in range(self.num_speculative_tokens - 1):
    input_ids = draft_token_ids_list[-1].int()
    eagle_step_update_slot_mapping_and_metadata(...)
    common_attn_metadata.slot_mapping = self._slot_mapping_buffer[:batch_size]
    # ... update positions ...
    _, per_layer_attn_metadata = self.build_per_group_and_layer_attn_metadata(...)
    self.input_ids[:batch_size] = input_ids
    self.hidden_states[:batch_size] = hidden_states
    with set_forward_context(...):
        ret_hidden_states = self.model(**model_kwargs)
        last_hidden_states, hidden_states = ret_hidden_states
    hidden_states = hidden_states[:batch_size]
    draft_token_ids = self._greedy_sample(last_hidden_states[:batch_size])
    draft_token_ids_list.append(draft_token_ids)

每个 AR 步骤:

  1. 将上一步的 draft token 作为 input_ids
  2. 更新 slot mapping 和元数据(位置+1、序列长度+1)
  3. 重建注意力元数据
  4. 执行模型 forward(单 token 前向)
  5. 贪婪采样下一个 draft token

Phase 10: 拼接结果

draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids  # [batch_size, num_speculative_tokens]

EAGLE3/DFlash

Other

Tree Mode

AR Mode

Input: target_hidden_states
next_token_ids
positions

copy_and_expand
eagle_inputs_kernel

AR Loop
N iterations

Build TreeAttentionMetadata

Forward through
EAGLE head

logits → argmax/topk
draft_token_ids

eagle_step_update
slot_mapping

Update hidden_states
for next level

Output: draft_token_ids
batch × tree_depth

Method Check

combine_hidden_states

3.5.7 set_inputs_first_pass() — 第一轮输入准备

此方法有两个分支:EAGLE 默认路径扩展路径

默认 EAGLE 路径needs_extra_input_slots=False):

if token_indices_to_sample is None:
    token_indices_to_sample = cad.query_start_loc[1:] - 1

# Shift the input ids by one token
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token
self.input_ids[token_indices_to_sample] = next_token_ids

移位操作示例

  • 输入: [a1, b1, b2, c1, c2, c3](a 有 1 个 token,b 有 2 个,c 有 3 个)
  • 移位后: [b1, b2, c1, c2, c3, c3](每个请求丢掉第一个 token,最后一个位置保留)
  • 替换: [a2, b2, b3, c2, c3, c4](在 token_indices_to_sample 位置放入 next_token_ids

这利用了 EAGLE 模型的架构特性:EAGLE head 接收 target model 的 hidden states(已包含所有上下文信息),而 token IDs 只需要提供"当前位置之后的 token"作为输入,因此可以安全地跳过每个请求的第一个 token(其 hidden state 已经包含了前一个 token 的信息)。

扩展路径needs_extra_input_slots=True,用于 parallel drafting 和 draft model):

  1. 调用 copy_and_expand_eagle_inputs_kernel Triton kernel
  2. 如果 pass_hidden_states_to_model,通过 out_hidden_state_mapping 将 target hidden states 映射到草稿缓冲区
  3. 使用 mask token 的 hidden state 填充并行推测 slot
  4. 调用 compute_new_slot_mapping 计算新的 slot mapping
  5. 调用 extend_all_queries_by_N 扩展 query 长度
3.5.8 prepare_inputs_padded() — 填充模式输入准备

使用 eagle_prepare_inputs_padded_kernel 计算每个请求的 token_indices_to_samplenum_rejected_tokens_gpu。然后创建新的 CommonAttentionMetadata,其中 token 布局保持不变(不重新排列),被拒绝的 token 作为 padding 保留在 buffer 中,后续由 token_indices_to_sample 过滤。

3.5.9 prepare_next_token_ids_padded() — 填充模式下一个 Token 准备

使用 eagle_prepare_next_token_padded_kernel 确定每个请求的 next token ID(来自采样结果或 backup token)。

3.5.10 dummy_run() — CUDA Graph 捕获
@torch.inference_mode()
def dummy_run(self, num_tokens, use_cudagraphs=True, is_graph_capturing=False, slot_mappings=None):

用于 CUDA Graph 捕获阶段的 dummy 前向传播:

  • only_one_forward_pass = is_graph_capturing or self.parallel_drafting — 捕获时或并行推测时只执行一次 forward
  • 非并行推测时需要 num_speculative_tokens 次 forward(模拟 AR loop)
  • 使用 EAGLE 自己的 slot mapping buffer
3.5.11 propose_tree() — 树状推测

当使用 Tree Attention 时,推测不再是一维链而是树结构:

  1. 第一层:对每个请求的 bonus token 位置,从 logits 中 argmax 或 top-k 采样 draft token
  2. 循环:每层将新的 draft token 与之前的上下文拼接,构建 tree attention metadata
  3. concat:将 draft tokens、positions、hidden states 沿 token 维度拼接
  4. Forward:执行树状注意力的 forward
  5. 采样:从输出 logits 中继续采样下一层
  6. 位置偏移flattened_draft_positions = positions + self.tree_draft_pos_offsets
3.5.12 build_per_group_and_layer_attn_metadata()

遍历 draft_attn_groups,为每个组调用 build_for_drafting(common_attn_metadata, draft_index),构建 per_layer_attn_metadata 字典,key 为层名,value 为注意力元数据。

3.5.13 辅助方法
方法 作用
_get_positions(num_tokens) 根据 RoPE 类型返回位置缓冲区的切片
_set_positions(num_tokens, positions) 将位置写入缓冲区(处理 M-RoPE/XDrop 转换)
_get_slot_mapping(num_tokens, slot_mapping) 将 slot mapping 拷贝到 buffer 并返回 dict
_greedy_sample(hidden_states) argmax 或 get_top_tokens() 采样
_determine_batch_execution_and_padding(num_tokens) 通过 cudagraph_dispatcher 确定执行模式和填充
model_returns_tuple() 判断模型返回类型(EAGLE 返回元组,DFlash/Draft 返回单一值)
validate_same_kv_cache_group() 确保所有草稿层在同一 KV Cache 组中
initialize_attn_backend() 初始化草稿层的注意力后端

3.5.14 EagleProposer — EAGLE 专用子类
class EagleProposer(SpecDecodeBaseProposer):
    def __init__(self, vllm_config, device, runner=None):
        super().__init__(
            vllm_config,
            device,
            pass_hidden_states_to_model=True,
            runner=runner,
        )

EagleProposerSpecDecodeBaseProposer 的极简子类,唯一自定义的行为是在构造时将 pass_hidden_states_to_model=True,表示 EAGLE 模型需要接收目标模型的隐藏状态作为输入(这是 EAGLE 架构的核心设计:EAGLE head 以 target model 的 hidden states 为条件生成 draft tokens)。

文件末尾的 compute_probs_and_sample_next_token() 函数
目前未使用(注释标注为 FIXME),保留了将来使用带温度采样的 draft token 生成逻辑。当前实现仅使用 argmax 贪婪采样。


3.6 dflash.py - DFlashProposer

DFlashProposer 继承自 SpecDecodeBaseProposer,实现了 DFlash(Dynamic Flash)推测方法。DFlash 的核心创新是使用交叉注意力(Cross-Attention)机制,而非 EAGLE 的自回归递归。

3.6.1 初始化
def __init__(self, vllm_config, device, runner=None):
    super().__init__(..., pass_hidden_states_to_model=True, ...)
    
    # 关键差异:DFlash 只将推测 token 作为 query
    self.max_query_tokens = self.max_batch_size * (1 + self.num_speculative_tokens)
    # 位置缓冲区涵盖 context + query
    self.max_positions = self.max_num_tokens + self.max_query_tokens

DFlash 的内存布局

  • max_query_tokens:每个请求 (1 + num_speculative_tokens) 个 query token(1 个 bonus + k 个 mask),batch_size × 这个数
  • max_positions:context positions + query positions 的总和
  • _context_slot_mapping_buffer:context 的 slot mapping(大小 = max_num_tokens)
  • _slot_mapping_buffer:query 的 slot mapping(大小 = max_query_tokens)
  • _context_positions_buffer:context 位置(大小 = max_num_tokens)
  • positions:query 位置(大小 = max_query_tokens)
  • parallel_drafting_hidden_state_tensor = None:DFlash 不使用 hidden state 传递,而是使用嵌入
3.6.2 _raise_if_multimodal() 重写

DFlash 覆盖了父类方法为空操作,允许多模态输入(支持 Qwen3.5 系列模型)。

3.6.3 set_inputs_first_pass() 重写

DFlash 的输入处理与 EAGLE 完全不同:

num_context = target_token_ids.shape[0]
num_query_per_req = 1 + self.num_speculative_tokens
num_query_total = batch_size * num_query_per_req

关键差异

  1. Context K/V 直接来自 target_hidden_states,不需要复制到 buffer
  2. Query tokens 只有 bonus + mask tokens,数量极少
  3. 使用 copy_and_expand_dflash_inputs_kernel 而非 EAGLE 的 kernel
  4. Query 的 slot mapping 和 context 的 slot mapping 分别管理

kernel 参数配置

copy_and_expand_dflash_inputs_kernel[grid](
    next_token_ids_ptr=next_token_ids,
    target_positions_ptr=target_positions,
    out_input_ids_ptr=self.input_ids,
    out_context_positions_ptr=self._context_positions_buffer,
    out_query_positions_ptr=self.positions,
    out_context_slot_mapping_ptr=self._context_slot_mapping_buffer,
    out_query_slot_mapping_ptr=self._slot_mapping_buffer,
    out_token_indices_ptr=token_indices_to_sample,
    ...
)

新的 CommonAttentionMetadata

  • causal=False — DFlash 需要非因果注意力(交叉注意力可以访问所有 context)
  • num_actual_tokens = num_query_total — 只有 query tokens 参与 forward
  • slot_mapping = query_slot_mapping — 仅 query 的 slot mapping
3.6.4 dummy_run() 重写

DFlash 的 dummy run 有特殊处理:

  1. 先执行 precompute_and_store_context_kv() — 将 context 的 KV 预计算并存入 cache
  2. 再执行 query 的 forward pass
context_positions = self._context_positions_buffer[:num_tokens]
context_states = self.hidden_states[:num_tokens]
self.model.precompute_and_store_context_kv(context_states, context_positions)
3.6.5 build_model_inputs_first_pass() 重写
def build_model_inputs_first_pass(self, num_tokens, num_input_tokens, mm_embed_inputs):
    num_context = self._dflash_num_context
    # 预插入 context KV
    self.model.precompute_and_store_context_kv(
        self._dflash_hidden_states,
        self._context_positions_buffer[:num_context],
        self._context_slot_mapping_buffer[:num_context],
    )
    return dict(input_ids=..., positions=..., inputs_embeds=None), num_input_tokens

关键:context 的 KV 通过 precompute_and_store_context_kv 直接预计算并插入 KV Cache,不经过常规的 model forward。

3.6.6 build_per_group_and_layer_attn_metadata() 重写

在父类基础上增加 causal=False 检查,确保所有注意力后端都支持非因果注意力。

DFlash 交叉注意力机制详解

DFlash Cross-Attention

Query

Context K/V

K,V from

Q from embed

target_hidden_states
num_context x hidden_size

context_positions

context_slot_mapping

next_token_id + mask_tokens
batch x 1+k

query_positions

query_slot_mapping

Cross-Attention Layer

precompute_and_store_context_kv
GEMM + norms + RoPE

KV Cache Insert

Model Forward
Query only

logits
batch x 1+k x vocab

argmax per head

draft_token_ids
batch x k


3.7 draft_model.py - DraftModelProposer

DraftModelProposer 使用一个独立的小型语言模型作为草稿模型(如 Qwen2.5-0.5B 作为 Qwen2.5-72B 的草稿)。

3.7.1 初始化
def __init__(self, vllm_config, device, runner=None):
    super().__init__(
        vllm_config,
        device,
        pass_hidden_states_to_model=False,  # 关键差异
        runner=runner,
    )
    self._raise_if_vocab_size_mismatch()
    self._raise_if_draft_tp_mismatch()

pass_hidden_states_to_model=False 是 EAGLE 和 Draft Model 的核心差异:Draft Model 是一个完整的语言模型,不需要接收目标模型的 hidden states,而是通过 token IDs 自主推理。

3.7.2 _raise_if_vocab_size_mismatch()
def _raise_if_vocab_size_mismatch(self):
    self.speculative_config.verify_equal_vocab_size_if_draft_model()

确保目标模型和草稿模型的词汇表大小一致。这是 Draft Model 的必要条件,因为草稿模型生成的 token IDs 需要直接用于目标模型。

3.7.3 _raise_if_draft_tp_mismatch()
def _raise_if_draft_tp_mismatch(self):
    tgt_tp = spec_cfg.target_parallel_config.tensor_parallel_size
    draft_tp = spec_cfg.draft_parallel_config.tensor_parallel_size
    if draft_tp != tgt_tp:
        raise ValueError(...)

确保目标模型和草稿模型的 TP 并行度一致。原因:如果 TP > 1 的目标模型与 TP = 1 的草稿模型共存,所有 rank 都会在 rank 0 上编译草稿模型,导致 torch.compile 缓存被覆盖和损坏。

3.7.4 _create_draft_vllm_config()

为草稿模型创建专用的 VllmConfig

  • quant_config=None — 草稿模型不使用量化
  • parallel_config — 使用草稿模型的并行配置
  • model_config — 使用草稿模型的配置
3.7.5 _get_model()
def _get_model(self) -> nn.Module:
    with set_model_tag("draft_model"):
        model = get_model(
            vllm_config=draft_vllm_config,
            prefix="draft_model",
        )
    return model

加载独立的草稿模型,model tag 为 “draft_model”,prefix 为 “draft_model”(与 EAGLE 的 “eagle_head” 区分)。

3.7.6 嵌入和 LM Head 不共享
def _maybe_share_embeddings(self, target_language_model):
    pass  # Draft models don't share embeddings

def _maybe_share_lm_head(self, target_language_model):
    pass  # Draft models don't share lm_head

Draft Model 拥有自己独立的嵌入层和 LM Head,不与目标模型共享。


3.8 medusa.py - MedusaProposer

MedusaProposer 是一个独立类(不继承 SpecDecodeBaseProposer),实现了 Medusa 推测方法。Medusa 的核心思想是在目标模型的最后几个 transformer 层之后附加多个"Medusa Head",每个 Head 预测未来第 k 个 token。

3.8.1 初始化
def __init__(self, vllm_config, device):
    self.vllm_config = vllm_config
    self.spec_config = vllm_config.speculative_config
    self.device = device
    self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
    self.hidden_size = self.spec_config.draft_model_config.get_hidden_size()
    self.dtype = vllm_config.model_config.dtype

极简初始化,不维护任何 GPU 缓冲区。

3.8.2 propose() — 推测
def propose(self, target_hidden_states, sampling_metadata, slot_mappings=None):
    blocks = self.model(target_hidden_states)
    logits = self.model.compute_logits(blocks)
    draft_tokens = torch.stack([logit.argmax(dim=-1) for logit in logits], dim=1)
    return draft_tokens  # [batch_size, num_heads]

流程

  1. self.model(target_hidden_states) — 将目标模型的隐藏状态传入 Medusa 模型(实际上是多个 Medusa Head)
  2. blocks 是一个包含每个 Head 输出的列表
  3. self.model.compute_logits(blocks) — 计算每个 Head 的 logits
  4. 对每个 Head 的 logits 做 argmax,得到一个 draft token
  5. torch.stack(..., dim=1) 拼接为 [batch_size, num_heads]
3.8.3 load_model()
def load_model(self, target_model):
    with set_model_tag("medusa_head"):
        self.model = get_model(
            vllm_config=self.vllm_config,
            model_config=self.spec_config.draft_model_config,
        )
    assert not (is_mixture_of_experts(self.model) and enable_eplb)

加载 Medusa Head 模型,tag 为 “medusa_head”。检查 Mixture of Experts 模型与 EPLB 的兼容性。

3.8.4 dummy_run()

使用零填充的 hidden states 执行 dummy forward,用于 CUDA Graph 捕获。

Medusa Architecture

target_hidden_states
num_tokens x hidden_size

Medusa Head 1
预测 t+1

Medusa Head 2
预测 t+2

Medusa Head N
预测 t+N

logits_1

argmax → token_1

logits_2

argmax → token_2

logits_N

argmax → token_N

stack → batch x N


3.9 extract_hidden_states.py - ExtractHiddenStatesProposer

ExtractHiddenStatesProposer 是一个特殊用途的 Proposer,其核心功能不是推测 token,而是缓存隐藏状态到 KV Cache。它主要用于 vLLM 的 KV 传输(KV Transfer) 场景。

3.9.1 初始化
def __init__(self, vllm_config, device):
    assert vllm_config.speculative_config.num_speculative_tokens == 1

强制要求 num_speculative_tokens == 1,因为它本质上只做单步操作。

关键参数

  • layer_ids = getattr(self.hf_config, "eagle_aux_hidden_state_layer_ids", None) — 辅助隐藏状态的层 ID 列表
  • self.num_hidden_states = len(layer_ids) — 辅助隐藏状态的层数
  • self.hidden_states = torch.zeros((max_num_tokens, num_hidden_states, hidden_size), ...) — 3D 隐藏状态缓冲区
3.9.2 propose() — “推测”(实际上是缓存)
def propose(self, sampled_token_ids, target_hidden_states, common_attn_metadata, slot_mappings=None):
    stacked_hidden_states = torch.stack(target_hidden_states, dim=1)
    num_tokens = stacked_hidden_states.shape[0]
    self.hidden_states[:num_tokens] = stacked_hidden_states

流程

  1. target_hidden_states(list of tensors)沿 dim=1 拼接为 [num_tokens, num_hidden_states, hidden_size]
  2. 写入 self.hidden_states 缓冲区
  3. 构建注意力元数据
  4. 执行 self.model(hidden_states=...) — 将隐藏状态存入 KV Cache(不做注意力计算)
  5. 返回 sampled_token_ids[:, :1] — 直接返回目标模型采样的 token 作为 “draft”(保证 100% 接受率)

核心洞察:这个 Proposer 并不真正做推测。它的作用是将目标模型的辅助隐藏状态缓存到 KV Cache 中,供后续使用(如跨节点 KV 传输)。返回的 “draft tokens” 就是目标模型自己采样的 token,因此在验证阶段必然全部接受。

3.9.3 prepare_next_token_ids_padded()

简化版实现:num_speculative_tokens=1,直接判断 sampled token 是否有效,无效则使用 backup token。

3.9.4 load_model()

加载 ExtractHiddenStatesModel 模型(只有 cache-only attention 层),识别出恰好 1 个注意力层。


3.10 ngram_proposer.py - NgramProposer (CPU Numba)

NgramProposer 使用纯 CPU 端的 Numba JIT 编译实现 n-gram 文本匹配,无需加载任何模型。它利用提示文本中的重复模式来推测后续 token。

3.10.1 初始化
def __init__(self, vllm_config):
    self.min_n = vllm_config.speculative_config.prompt_lookup_min    # 最小 n-gram 长度
    self.max_n = vllm_config.speculative_config.prompt_lookup_max    # 最大 n-gram 长度
    self.k = vllm_config.speculative_config.num_speculative_tokens   # 推测 token 数
    self.max_model_len = vllm_config.model_config.max_model_len
    
    # 预分配 Numba 缓冲区
    max_num_seqs = vllm_config.scheduler_config.max_num_seqs
    self.valid_ngram_draft = np.zeros((max_num_seqs, self.k), dtype=np.int32)
    self.valid_ngram_num_drafts = np.zeros((max_num_seqs), dtype=np.int32)

线程配置

  • num_tokens_threshold = 8192 — 启用多线程的 token 数阈值
  • num_numba_thread_available = min(1, cpu_count // 2) — 当前默认为 1 线程
  • 除以 tp_size 确保每个 TP rank 都有线程

JIT 预热:在初始化时立即执行一次 propose([[]] * 1024, ...) 触发 Numba JIT 编译,避免首次实际调用时的延迟。

3.10.2 propose() — 推测入口
def propose(self, sampled_token_ids, num_tokens_no_spec, token_ids_cpu, slot_mappings=None):
    valid_ngram_requests = []
    for i, sampled_ids in enumerate(sampled_token_ids):
        if not len(sampled_ids):
            continue
        if num_tokens_no_spec[i] >= self.max_model_len:
            continue
        valid_ngram_requests.append(i)
    
    return self.batch_propose(len(sampled_token_ids), valid_ngram_requests, ...)

过滤条件:

  1. 必须有采样结果(非 partial prefill)
  2. 未超过最大模型长度
3.10.3 batch_propose() — 批量 Numba 推测
def batch_propose(self, num_requests, valid_ngram_requests, num_tokens_no_spec, token_ids_cpu):
    if len(valid_ngram_requests):
        # 动态调整线程数
        if total_tokens >= self.num_tokens_threshold:
            final_num_threads = max(1, min(self.num_numba_thread_available, num_ngram_requests))
        else:
            set_num_threads(1)
        
        batch_propose_numba(valid_ngram_requests, ...)
    
    for i in range(num_requests):
        if i in valid_ngram_requests and self.valid_ngram_num_drafts[i] > 0:
            draft_token_ids.append(self.valid_ngram_draft[i, :self.valid_ngram_num_drafts[i]].tolist())
        else:
            draft_token_ids.append([])
    return draft_token_ids

多线程策略:当总 token 数 >= 8192 时启用多线程,否则使用单线程(小 batch 多线程开销大于收益)。

3.10.4 batch_propose_numba — Numba 并行入口
@njit(parallel=True)
def batch_propose_numba(valid_ngram_requests, num_tokens_no_spec, token_ids_cpu, ...):
    for i in prange(len(valid_ngram_requests)):
        idx = valid_ngram_requests[i]
        context_token_ids = token_ids_cpu[idx, :num_tokens_no_spec[idx]]
        drafter_output = _find_longest_matched_ngram_and_propose_tokens(...)
        valid_ngram_num_drafts[idx] = drafter_output.shape[0]
        if len(drafter_output):
            valid_ngram_draft[idx, :drafter_output.shape[0]] = drafter_output

使用 prange 实现请求级并行,每个请求独立执行 n-gram 匹配。

3.10.5 _find_longest_matched_ngram_and_propose_tokens — KMP 算法核心

这是 N-gram 匹配的核心算法,基于 KMP(Knuth-Morris-Pratt)字符串匹配算法的变体:

@jit(nopython=True)
def _find_longest_matched_ngram_and_propose_tokens(origin_tokens, min_ngram, max_ngram, max_model_len, k):

算法步骤

  1. 边界检查:如果 context 长度 < min_ngram 或已超出 max model length,直接返回空。

  2. 翻转 token 序列tokens = origin_tokens[::-1]

    • 翻转后,原序列的后缀变为前缀
    • 目标变为:在翻转后的序列中找到最长的前缀后缀匹配(LPS, Longest Prefix which is also Suffix)
  3. 构建 LPS 数组

    lps = np.zeros(max_ngram, dtype=np.int32)
    

    lps[i] 表示 tokens[0:i+1] 的最长公共前后缀长度。由于 n-gram 长度上限为 max_ngram,LPS 数组大小仅需 max_ngram

  4. KMP 核心循环

    prev_lps = 0
    i = 1
    while i < total_token:
        if tokens[prev_lps] == tokens[i]:
            prev_lps += 1
            if prev_lps >= longest_ngram:
                longest_ngram = prev_lps
                position = i
            if i < max_ngram:
                lps[i] = prev_lps
            if prev_lps == max_ngram:
                prev_lps = lps[max_ngram - 1]
            i += 1
        elif prev_lps != 0:
            prev_lps = lps[prev_lps - 1]
        else:
            i += 1
    
    • 匹配成功tokens[prev_lps] == tokens[i]prev_lps += 1,更新最长匹配
    • 匹配失败且有 fallbackprev_lps = lps[prev_lps - 1] → 回退到次长前缀
    • 匹配失败且无 fallbacki += 1 → 继续扫描
    • 长度截断:当 prev_lps == max_ngram 时,截断到 lps[max_ngram - 1],防止超过最大 n-gram 长度
  5. 结果提取

    if longest_ngram < min_ngram:
        return np.empty((0,), dtype=origin_tokens.dtype)
    
    start_position = total_token - 1 - position + longest_ngram
    k = min(k, total_token - start_position)
    return origin_tokens[start_position : start_position + k]
    
    • position 是翻转后的位置,total_token - 1 - position 是翻转回原序列的位置
    • 从匹配 n-gram 的末尾开始,提取后续的 k 个 token 作为推测结果

算法复杂度:O(N) 时间,O(max_ngram) 空间,其中 N 为序列长度。

origin_tokens
例: 1 2 3 4 1 2 3 4 5

翻转
5 4 3 2 1 4 3 2 1

KMP LPS 计算

最长匹配 ≥ min_ngram?

返回空

定位匹配位置

提取后续 k 个 token

返回 draft tokens
例: 4 5


3.11 ngram_proposer_gpu.py - NgramProposerGPU

NgramProposerGPU 是 N-gram 推测的GPU 加速版本,使用纯 PyTorch tensor 操作和 torch.compile 编译优化,完全避免了 CPU-GPU 数据传输。

3.11.1 架构

由两个类组成:

  • NgramGPUKernel(nn.Module) — 可编译的 GPU kernel,使用 @support_torch_compile() 装饰器
  • NgramProposerGPU — 外层封装,负责数据准备和调度
3.11.2 NgramGPUKernel.__init__()
def __init__(self, vllm_config, prefix="", device):
    self.min_n = ...    # prompt_lookup_min
    self.max_n = ...    # prompt_lookup_max
    self.k = ...        # num_speculative_tokens
    self.max_model_len = ...
    self.max_num_seqs = ...
    self.device = device

作为 nn.Module 的子类,可以被 torch.compile 编译。

3.11.3 _find_first_and_extract_all_n_parallel() — GPU 向量化匹配

这是 GPU 版本的核心算法,与 CPU 版本的 KMP 完全不同:

def _find_first_and_extract_all_n_parallel(
    self, token_ids, seq_lengths, min_ngram_len, max_ngram_len, num_draft_tokens
):

算法思路

  1. 遍历所有 n-gram 长度min_ngrammax_ngram):

    for i, ngram_len in enumerate(range(min_ngram_len, max_ngram_len + 1)):
    
  2. Sliding Window:使用 torch.unfold 创建所有大小为 ngram_len 的滑动窗口

    search_windows = token_ids.unfold(1, ngram_len, 1)
    

    unfold 是 O(1) view 操作,不复制数据。

  3. Suffix 提取:获取每个序列末尾 ngram_len 个 token

    suffix_starts = seq_lengths - ngram_len
    suffix_indices = suffix_starts.unsqueeze(1) + torch.arange(ngram_len, device=device)
    suffix = torch.gather(token_ids, 1, suffix_indices.clamp(min=0))
    
  4. 批量匹配

    matches = (search_windows == suffix.unsqueeze(1)).all(dim=-1)
    

    逐元素比较后沿 n-gram 维度做 all 操作,得到 (batch_size, num_windows) 的布尔 mask。

  5. 有效性过滤:匹配位置必须允许至少 1 个 draft token

    max_valid_suffix_start = seq_lengths - ngram_len - 1
    valid_mask = window_positions <= max_valid_suffix_start.unsqueeze(1)
    final_matches = matches & valid_mask
    
  6. 最早匹配

    first_match_idx = torch.argmax(final_matches.int(), dim=1)
    has_match = final_matches[batch_indices, first_match_idx]
    

    argmax 在全部为 false 时返回 0,所以需要额外的 has_match 检查。

  7. 选择最长有效匹配

    best_ngram_idx = (first_match_positions >= 0).int().flip(dims=[1]).argmax(dim=1)
    best_ngram_idx = num_ngram_sizes - 1 - best_ngram_idx
    

    翻转后 argmax 找到最后一个(即最长的)有匹配的 n-gram 长度。

  8. 提取 draft tokens

    draft_start = best_match_pos + best_ngram_lengths
    draft_indices = draft_start.unsqueeze(1) + torch.arange(num_draft_tokens, device=device)
    draft_tokens = torch.gather(token_ids, 1, draft_indices.clamp(min=0, max=max_seq_len - 1))
    
  9. 有效性 mask

    • 超出可用 token 数的位置填充 -1
    • 无匹配的位置全部填充 -1

算法复杂度:O(batch_size × (max_ngram - min_ngram) × max_seq_len),完全向量化,无 Python 循环(n-gram 长度的 for 循环在编译时展开)。

3.11.4 forward() — 前向传播
def forward(self, num_tokens_no_spec, token_ids_gpu, combined_mask):
    actual_batch_size = token_ids_gpu.shape[0]
    draft_tokens = torch.full((actual_batch_size, self.k), -1, ...)
    
    results = self._find_first_and_extract_all_n_parallel(...)
    draft_tokens = torch.where(combined_mask.unsqueeze(1), results, -1)
    
    # 计算前导有效 token 数
    is_valid = draft_tokens != -1
    cum_valid = is_valid.int().cumsum(dim=1)
    positions = torch.arange(1, self.k + 1, device=device).unsqueeze(0)
    num_valid_draft_tokens = (cum_valid == positions).int().sum(dim=1)
    
    return draft_tokens, num_valid_draft_tokens

有效 token 计数技巧cumsum(dim=1) 然后与位置索引比较。如果 token 0 有效则 cumsum[0]=1, 与位置 1 相等;如果 token 1 也是有效则 cumsum[1]=2, 与位置 2 相等;一旦遇到 -1,cumsum 停止增长,不再匹配。

3.11.5 NgramProposerGPU.__init__() — 编译配置
compilation_config = CompilationConfig(
    mode=CompilationMode.VLLM_COMPILE,
    custom_ops=["none"],
    splitting_ops=[],
    compile_sizes=[],
    inductor_compile_config={
        "enable_auto_functionalized_v2": False,
        "max_autotune": True,
        "aggressive_fusion": True,
        "triton.autotune_pointwise": True,
        "coordinate_descent_tuning": True,
        "use_mixed_mm": False,
    },
    cudagraph_mode=CUDAGraphMode.NONE,
)

编译选项详解

  • enable_auto_functionalized_v2=False:禁用自动功能化,避免不必要的 in-place 操作转换
  • max_autotune=True:启用 Triton 的自动调优
  • aggressive_fusion=True:激进算子融合
  • triton.autotune_pointwise=True:pointwise 操作的自动调优
  • coordinate_descent_tuning=True:坐标下降调优(Triton kernel 的调参策略)
  • use_mixed_mm=False:禁用混合精度矩阵乘法(N-gram 不需要矩阵乘法)
  • cudagraph_mode=CUDAGraphMode.NONE:不使用 CUDA Graph(因为 shape 动态变化)
3.11.6 _dummy_run() — 预热

生成随机测试数据,执行 3 次 forward 以触发 torch.compile 编译。

3.11.7 propose() — GPU 推测入口
def propose(self, num_tokens_no_spec, token_ids_gpu, valid_sampled_token_ids_gpu, valid_sampled_tokens_count):

Scatter 操作:将新采样的 token 写入 token_ids_gpu(原位修改):

offsets = torch.arange(max_new_tokens, device=self.device)
write_positions = num_tokens_no_spec.unsqueeze(1) + offsets.unsqueeze(0)
scatter_mask = (valid_write_mask & (valid_sampled_token_ids_gpu != -1) & in_bounds)
token_ids_gpu.scatter_(1, write_positions_long, tokens_to_scatter)

Validity mask

combined_mask = sampled_flags & valid_mask & (num_tokens_tmp >= self.min_n)
3.11.8 update_token_ids_ngram() — Token IDs 更新

用于准备推测解码输入:

  1. 处理 list[list[int]] 到 tensor 的转换(padded mode)
  2. 计算 backup next token IDs
  3. Mask 被丢弃请求的 token
  4. 计算每个请求的有效 token 数
  5. 找到最后一个有效 token 作为 next_token_id
3.11.9 辅助函数

update_scheduler_for_invalid_drafts()

  • 等待异步 D2H 传输完成
  • 根据 num_valid_draft_tokens 裁剪调度器中的推测 token

update_ngram_gpu_tensors_incremental()

  • 增量更新 token_ids_gpu_tensornum_tokens_no_spec_gpu
  • 处理新请求、重排序请求
  • 使用 pinned buffer 减少内存分配

调用链

GPU Kernel

NgramGPUKernel.forward

_find_first_and_extract_all_n_parallel

unfold sliding windows

batch match with suffix

argmax find first match

select longest ngram

gather draft tokens

mask invalid positions

NgramProposerGPU.propose

scatter sampled tokens

compute combined_mask

return draft_tokens + count


3.12 suffix_decoding.py - SuffixDecodingProposer

SuffixDecodingProposer 实现了**后缀解码(Suffix Decoding)方法,基于 Arctic Inference 项目的官方实现。后缀解码利用提示文本的后缀树(Suffix Tree)**来推测 token。

3.12.1 初始化
def __init__(self, vllm_config):
    self.num_speculative_tokens = config.num_speculative_tokens
    self.max_tree_depth = config.suffix_decoding_max_tree_depth
    self.max_spec_factor = config.suffix_decoding_max_spec_factor
    self.min_token_prob = config.suffix_decoding_min_token_prob
    self.max_model_len = vllm_config.model_config.max_model_len
    
    from arctic_inference.suffix_decoding import SuffixDecodingCache
    self.suffix_cache = SuffixDecodingCache(
        max_tree_depth=config.suffix_decoding_max_tree_depth,
        max_cached_requests=config.suffix_decoding_max_cached_requests,
    )

关键配置

  • max_tree_depth:后缀树的最大深度(控制 pattern 匹配长度)
  • max_spec_factor:最大推测因子(限制动态推测的 token 数上限)
  • min_token_prob:最小 token 概率阈值(低于此阈值的推测被丢弃)
  • max_cached_requests:缓存的最大请求数

SuffixDecodingCache 是外部库对象,负责:

  • 缓存请求的输出
  • 管理过期请求的驱逐
  • 维护每个 prompt 的后缀树
3.12.2 propose() — 推测
def propose(self, input_batch, sampled_token_ids, slot_mappings=None):
    draft_token_ids = []
    for i, sampled_ids in enumerate(sampled_token_ids):
        if not sampled_ids:
            draft_token_ids.append([])
            continue
        
        req_id = input_batch.req_ids[i]
        num_tokens = input_batch.num_tokens_no_spec[i]
        if num_tokens >= self.max_model_len:
            draft_token_ids.append([])
            continue
        
        index = input_batch.req_id_to_index[req_id]

请求生命周期管理

  1. 新请求

    if req_id not in self.suffix_cache.active_requests:
        if req_id in self.suffix_cache.cached_requests:
            self.suffix_cache.evict_cached_response(req_id)
        num_prompt_tokens = input_batch.num_prompt_tokens[index]
        prompt_token_ids = input_batch.token_ids_cpu[index, :num_prompt_tokens]
        self.suffix_cache.start_request(req_id, prompt_token_ids)
    
    • 如果请求之前被缓存但现在是新的,先驱逐旧缓存
    • 使用 prompt token 构建后缀树
  2. 活跃请求

    self.suffix_cache.add_active_response(req_id, sampled_ids)
    

    将新采样的 token 添加到后缀缓存

  3. 推测

    start = max(0, num_tokens - self.max_tree_depth)
    pattern = input_batch.token_ids_cpu[i, start:num_tokens]
    draft = self.suffix_cache.speculate(
        req_id,
        pattern,
        max_spec_tokens=min(self.num_speculative_tokens, self.max_model_len - num_tokens - 1),
        max_spec_factor=self.max_spec_factor,
        min_token_prob=self.min_token_prob,
    )
    draft_token_ids.append(draft.token_ids)
    
    • 提取最近的 max_tree_depth 个 token 作为 pattern
    • 调用 suffix_cache.speculate() 进行推测
    • 推测结果是动态数量的(每请求不同)
  4. 清理

    for req_id in (self.suffix_cache.active_requests - input_batch.req_id_to_index.keys()):
        self.suffix_cache.stop_request(req_id)
    

    移除不在当前 batch 中的请求

SuffixTree SuffixDecodingCache SuffixDecodingProposer InputBatch SuffixTree SuffixDecodingCache SuffixDecodingProposer InputBatch alt [新请求] [已有请求] loop [每个请求] propose(input_batch, sampled_ids) start_request(req_id, prompt_tokens) 构建后缀树 add_active_response(req_id, sampled_ids) 更新后缀树 提取 pattern (最近 max_tree_depth tokens) speculate(req_id, pattern, ...) 在后缀树中查找匹配 匹配的后续 token draft token_ids (动态数量) stop_request(已完成的请求) draft_token_ids


四、综合架构分析

4.1 Proposer 方法对比矩阵

维度 EAGLE / EAGLE3 DFlash Draft Model Medusa N-gram CPU N-gram GPU Suffix Decoding Extract HS
继承基类 SpecDecodeBaseProposer SpecDecodeBaseProposer SpecDecodeBaseProposer
需要模型加载
需要额外权重 EAGLE head DFlash head 完整小型模型 Medusa heads Cache layer
Hidden State 传递 ✅ (pass_hidden_states_to_model) ❌ (直接传 hidden states) N/A N/A N/A
推测机制 自回归递归 交叉注意力 自回归递归 多头并行 文本匹配 文本匹配 后缀树 缓存
并行推测 可选 (parallel_drafting) 原生支持 可选 原生 N/A N/A 动态 N/A
树状推测 ✅ (tree_choices) ✅ (通过基类) N/A N/A ✅ (后缀树)
设备 GPU GPU GPU GPU CPU GPU CPU GPU
CUDA Graph 支持 ✅ (PIECEWISE) N/A N/A
torch.compile 支持 N/A N/A
多模态支持 ❌ (需要 extra slots) ✅ (已覆盖) N/A N/A N/A N/A
TP 一致性要求 必须相同 N/A N/A N/A
Vocab 一致性要求 可选 可选 必须相同 N/A N/A N/A
num_spec_tokens 限制 任意 任意 任意 任意 任意 任意 动态 固定为 1
接受率影响因素 EAGLE head 质量 交叉注意力质量 草稿模型质量 Head 数量/质量 文本重复度 文本重复度 后缀树匹配度 固定 100%
适用场景 通用 Qwen3 系列 有独立草稿模型 Medusa 架构 高重复文本 高吞吐重复文本 长尾请求 KV 传输

4.2 性能优化手段汇总

4.2.1 Triton GPU Kernels

utils.py 中的 5 个 Triton kernel 是核心性能优化手段:

Kernel 优化目标 优化手段
eagle_step_slot_mapping_metadata_kernel 减少 kernel launch 融合 position 更新、slot mapping 计算、seq_lens 更新为一个 kernel
eagle_prepare_inputs_padded_kernel 避免 Python 循环 GPU 端并行计算每个请求的采样索引
eagle_prepare_next_token_padded_kernel 避免 CPU-GPU 同步 GPU 端确定 next token ID
copy_and_expand_eagle_inputs_kernel 减少内存复制 2D grid 并行复制+扩展,一个 kernel 完成所有操作
copy_and_expand_dflash_inputs_kernel DFlash 专用优化 context 和 query 分离处理的融合 kernel
4.2.2 CUDA Graph
  • CudagraphDispatcher 支持 PIECEWISE 模式的 CUDA Graph
  • dummy_run() 用于捕获阶段的 dummy forward
  • _determine_batch_execution_and_padding() 确定何时使用 CUDA Graph
  • 对固定 batch size 的推理步骤有显著加速(消除 kernel launch 开销)
4.2.3 torch.compile
  • NgramGPUKernel 使用 @support_torch_compile() 装饰器
  • 编译配置启用了 max_autotuneaggressive_fusioncoordinate_descent_tuning
  • update_num_computed_tokens_for_batch_change 也使用了 @torch.compile
  • 对于 M-RoPE 位置缓冲,故意制造 non-contiguous 布局以兼容 torch compile
4.2.4 内存共享
  • _maybe_share_embeddings():当草稿模型嵌入层与目标模型相同时,共享同一份 GPU 内存
  • _maybe_share_lm_head():同理共享 LM Head
  • 节省约 2×hidden_size × vocab_size 的 GPU 内存
4.2.5 Local Argmax Reduction
  • use_local_argmax_reduction=True 时,使用 get_top_tokens() 替代完整的 compute_logits().argmax()
  • 通信复杂度从 O(vocab_size) 降到 O(2 × tp_size)
  • 当存在 draft_id_to_target_id 映射时 fallback 到完整路径
4.2.6 Numba JIT (CPU N-gram)
  • @njit(parallel=True)@jit(nopython=True) 实现 CPU 端的 JIT 编译
  • prange 实现请求级并行
  • 预编译(初始化时执行 dummy call)避免首次调用延迟
4.2.7 GPU 向量化 (N-gram GPU)
  • torch.unfold 创建 O(1) view 的滑动窗口
  • 全 batch 并行匹配,无 Python 循环
  • torch.compile 进一步优化算子融合

4.3 内存布局分析

4.3.1 核心缓冲区布局(SpecDecodeBaseProposer)
GPU Memory Layout (按 max_num_tokens = N, hidden_size = H 估算)

┌────────────────────────────────────────────────────┐
│ input_ids           [N]       int32     4N bytes   │
├────────────────────────────────────────────────────┤
│ positions           [N]       int64     8N bytes   │
│ ─ or ─                                                  │
│ mrope_positions    [3,N+1]   int64    24(N+1)     │
│ ─ or ─                                                  │
│ xdrope_positions  [D,N+1]   int64     8D(N+1)    │
├────────────────────────────────────────────────────┤
│ hidden_states       [N,H]    dtype    2HN bytes   │
├────────────────────────────────────────────────────┤
│ inputs_embeds       [N,H']   dtype    2HN' bytes  │
├────────────────────────────────────────────────────┤
│ is_rejected_mask    [N]       bool      N bytes    │
│ is_masked_mask      [N]       bool      N bytes    │
├────────────────────────────────────────────────────┤
│ _slot_mapping_buf   [N]       int64     8N bytes   │
│ arange              [max]     int32     4max bytes │
├────────────────────────────────────────────────────┤
│ backup_next_tokens  [B]       int32     4B bytes   │
│ tree_pos_offsets    [B,T]     int32    4BT bytes   │
└────────────────────────────────────────────────────┘
4.3.2 DFlash 专用内存布局
DFlash Extra Buffers:

┌────────────────────────────────────────────────────┐
│ _context_slot_mapping [N]     int64    8N bytes   │
│ _slot_mapping_buf     [M]     int64    8M bytes   │
│ _context_positions    [N]     int64    8N bytes   │
│ positions             [M]     int64    8M bytes   │
│ arange                [N+M+1] int32   4(N+M+1)   │
└────────────────────────────────────────────────────┘
M = max_batch_size × (1 + num_speculative_tokens)
4.3.3 ExtractHiddenStates 专用内存布局
┌────────────────────────────────────────────────────┐
│ hidden_states     [N,L,H]  dtype   2N*L*H bytes  │
│ _slot_mapping_buf [N]       int64   8N bytes      │
└────────────────────────────────────────────────────┘
L = num_hidden_states (辅助隐藏状态的层数)

4.4 GPU 缓冲区总览

Proposer 主要缓冲区 总内存估算 (N=4096, H=4096, B=128)
EagleProposer input_ids, positions/mrope, hidden_states, inputs_embeds, masks, slot_mapping, backup, tree_offsets ~65 MB
DFlashProposer 上述 + context_slot_mapping, context_positions, query_positions ~65 MB + ~0.1 MB
DraftModelProposer 与 Eagle 相同 (不共享嵌入/lm_head) ~65 MB + 草稿模型权重
MedusaProposer 无缓冲区 0 MB (仅模型权重)
NgramProposer CPU 端缓冲区 ~0.5 MB (CPU)
NgramProposerGPU 无预分配缓冲区 0 MB
SuffixDecodingProposer 外部库管理 取决于后缀树大小
ExtractHiddenStates hidden_states [N,L,H], slot_mapping ~128 MB (L=4) + 0.03 MB

五、总结与最佳实践

5.1 架构设计亮点

  1. 统一的基类设计SpecDecodeBaseProposer 为所有基于模型的推测方法提供了统一的接口和缓冲区管理,大幅减少了重复代码。

  2. 灵活的输入准备策略:通过 needs_extra_input_slots 标志在"移位模式"(EAGLE 默认)和"扩展模式"(parallel drafting / draft model)之间切换,使用不同的 Triton kernel。

  3. 多层次的性能优化:从 Triton kernel 融合、CUDA Graph 捕获、torch.compile 编译,到内存共享和 local argmax reduction,覆盖了 GPU 推理的多个优化维度。

  4. 设备无关的 N-gram 实现:提供 CPU Numba 和 GPU 向量化两种实现,适应不同场景需求。

  5. 模块化设计:8 种推测方法完全独立,按需加载,不相互依赖。

5.2 各方法的适用场景建议

场景 推荐方法 理由
通用 LLM 推理加速 EAGLE / EAGLE3 成熟稳定,支持树状推测和并行推测
Qwen3 系列模型 DFlash 原生支持,交叉注意力效率高
有预训练的小型草稿模型 Draft Model 利用已有的小模型
Medusa 架构模型 Medusa 原生支持多头并行预测
提示文本重复度高 N-gram GPU 零模型开销,GPU 加速
多模态推理 DFlash 唯一明确支持多模态的 Proposer
KV 传输场景 Extract Hidden States 专为缓存隐藏状态设计
动态推测长度 Suffix Decoding 基于后缀树的动态推测

5.3 常见问题排查

  1. M-RoPE 不兼容 parallel drafting_raise_if_mrope() 会在需要额外 slot 且使用 M-RoPE 时抛出 NotImplementedError

  2. TP 不匹配:Draft Model 要求目标模型和草稿模型的 TP 并行度一致。

  3. Vocab 不匹配:Draft Model 要求词汇表大小一致。

  4. CUDA Graph 捕获失败:确保 dummy_run() 中使用的 buffer 与运行时一致。

  5. NaN logits:MTP 模型需要显式共享 shared_head.head,否则会产生 NaN。

5.4 未来演进方向

  1. M-RoPE 支持 parallel drafting:当前 _raise_if_mrope() 阻止了 M-RoPE 模型的并行推测,扩展支持后可以覆盖更多模型。

  2. 多模态 draft model 支持_raise_if_multimodal() 目前阻止了除 DFlash 外的多模态推测,扩展后可利用 EAGLE 等方法加速多模态推理。

  3. DBO ubatching:代码中多次提到 TODO(Flechman): support DBO ubatching,支持数据并行下的 microbatching 可以进一步提升吞吐。

  4. 带概率采样的 draft:文件末尾注释表明将来可能使用带温度参数的采样替代 argmax,以获得更好的接受率。


文档结束
本文档基于 vLLM v1 spec_decode 模块的源码进行逐行分析,共涵盖 12 个源文件,总计约 3600 行代码。所有分析均基于实际源码,未添加任何推测性内容。

附录 A: EAGLE 推测解码完整时序图

KV Cache SpecDecodeVerifier EagleProposer Target Model GpuModelRunner Scheduler KV Cache SpecDecodeVerifier EagleProposer Target Model GpuModelRunner Scheduler loop [AR Loop (num_spec_tokens - 1 steps)] scheduled_requests forward (decode step) target_hidden_states + sampled_token_ids propose(target_hidden_states, next_token_ids, ...) set_inputs_first_pass (shift/copy inputs) build_per_group_and_layer_attn_metadata model forward (first pass) greedy_sample (draft token 1) eagle_step_update_slot_mapping rebuild attn metadata model forward (single token) greedy_sample (next draft token) draft_token_ids [batch_size x num_spec_tokens] verify(draft_token_ids, target_logits) accept/reject → update KV Cache accepted_tokens + final_sampled_token output tokens

附录 B: CUDA Graph 捕获与 Replay 流程

运行阶段 Replay

捕获阶段 Capture

initialize_cudagraph_keys
PIECEWISE mode

dummy_run
dummy forward pass

cudagraph.capture
record kernel sequence

cudagraph.replay_ready
captured graph stored

_determine_batch_execution_and_padding

batch size in
captured range?

cudagraph.replay
replay captured kernels

eager mode fallback
normal forward

结果返回

附录 C: 推测 Token 树结构示意图

渲染错误: Mermaid 渲染失败: Parse error on line 2: ...raph Tree Structure (example: tree_choic -----------------------^ Expecting 'SEMI', 'NEWLINE', 'SPACE', 'EOF', 'GRAPH', 'DIR', 'subgraph', 'SQS', 'end', 'AMP', 'COLON', 'START_LINK', 'STYLE', 'LINKSTYLE', 'CLASSDEF', 'CLASS', 'CLICK', 'DOWN', 'UP', 'NUM', 'NODE_STRING', 'BRKT', 'MINUS', 'MULT', 'UNICODE_TEXT', got 'PS'

附录 D: N-gram CPU vs GPU 对比

渲染错误: Mermaid 渲染失败: Parse error on line 7: ... A5 --> A6[list[list[int]] output] -----------------------^ Expecting 'SQE', 'DOUBLECIRCLEEND', 'PE', '-)', 'STADIUMEND', 'SUBROUTINEEND', 'PIPE', 'CYLINDEREND', 'DIAMOND_STOP', 'TAGEND', 'TRAPEND', 'INVTRAPEND', 'UNICODE_TEXT', 'TEXT', 'TAGSTART', got 'SQS'

附录 E: SpecDecodingStats 指标计算流程

渲染错误: Mermaid 渲染失败: Parse error on line 8: ...] S6[per_pos[i] += 1
for i i ----------------------^ Expecting 'SQE', 'DOUBLECIRCLEEND', 'PE', '-)', 'STADIUMEND', 'SUBROUTINEEND', 'PIPE', 'CYLINDEREND', 'DIAMOND_STOP', 'TAGEND', 'TRAPEND', 'INVTRAPEND', 'UNICODE_TEXT', 'TEXT', 'TAGSTART', got 'SQS'
Logo

免费领 100 小时云算力,进群参与显卡、AI PC 幸运抽奖

更多推荐