【vllm】(九)vLLM v1 Speculative Decoding (spec_decode) — 模块超深度分析之二
文件内容:空文件。该文件仅作为 Python 包标记文件存在,不包含任何代码逻辑。在 vLLM v1 的模块中,所有功能类均通过显式 import 导入,而非通过聚合导出。这种设计使得每个 Proposer 类可以由上层根据按需选择加载,避免不必要的模块初始化开销。 是整个模块的底层计算基础设施,提供了 5 个 Triton GPU 内核和若干工具函数,所有内核均用于加速输入数据的复制、扩展和 s
三、核心业务逻辑深度解析
3.1 __init__.py - 模块标记
文件内容:空文件。
该文件仅作为 Python 包标记文件存在,不包含任何代码逻辑。在 vLLM v1 的 spec_decode 模块中,所有功能类均通过显式 import 导入,而非通过 __init__.py 聚合导出。这种设计使得每个 Proposer 类可以由上层根据 speculative_config.method 按需选择加载,避免不必要的模块初始化开销。
3.2 utils.py - Triton Kernels 与工具函数
utils.py 是整个 spec_decode 模块的底层计算基础设施,提供了 5 个 Triton GPU 内核和若干工具函数,所有内核均用于加速输入数据的复制、扩展和 slot mapping 计算,是推测解码中 CPU→GPU 数据路径上的关键瓶颈优化。
3.2.1 next_power_of_2(n: int) -> int
作用:返回大于等于 n 的最小 2 的幂。
实现细节:
- 经典的位运算算法:先
n -= 1,然后依次右移 1、2、4、8、16、32 位做 OR 操作,最终+1。 - 用于 Triton kernel 中 BLOCK_SIZE 的参数设置,确保 block size 为 2 的幂以匹配 GPU warp 调度。
- 时间复杂度 O(1),无循环。
3.2.2 PADDING_SLOT_ID = -1
全局常量,表示无效的 KV Cache slot 位置。当 token 被拒绝(rejected)或超出模型最大长度时,slot mapping 被设置为此值,防止 KV Cache 被无效 token 覆盖。
3.2.3 eagle_step_slot_mapping_metadata_kernel - EAGLE 单步 slot mapping 更新 Kernel
这是一个融合(fused)Triton kernel,将 EAGLE 自回归推测中的三个独立操作合并为一个 kernel 启动,显著减少 kernel launch 开销。
函数签名:
@triton.jit
def eagle_step_slot_mapping_metadata_kernel(
positions_ptr, # [batch_size] - 当前位置
block_table_ptr, # [batch_size, n_blocks_per_req]
block_table_stride, # block_table 维度 1 的步长
seq_lens_ptr, # [batch_size] - 读/写
out_clamped_positions_ptr, # [batch_size] (输出)
out_slot_mapping_ptr, # [input_batch_size] (输出)
block_size: tl.constexpr,
max_model_len: tl.constexpr,
n_blocks_per_req: tl.constexpr,
PAD_ID: tl.constexpr,
batch_size,
):
逐行逻辑分析:
req_idx = tl.program_id(0)— 每个线程块处理一个请求- Padding 槽处理:
if req_idx >= batch_size— 对于 cudagraph 填充槽,仅写入PAD_ID到out_slot_mapping后直接 return,不做任何其他计算 position = tl.load(positions_ptr + req_idx)— 加载当前位置new_position = position + 1— 推测下一个位置(自回归递增)exceeds_max = new_position >= max_model_len— 检查是否超出模型最大长度clamped_position = tl.where(exceeds_max, 0, new_position)— 如果超出,clamped 为 0(避免后续 block table 越界)block_number = clamped_position // block_size— 计算 block 编号block_number = tl.minimum(block_number, n_blocks_per_req - 1)— 防止越界block_id = tl.load(block_table_ptr + req_idx * block_table_stride + block_number)— 从 block table 查找 block IDslot_id = block_id * block_size + (clamped_position % block_size)— slot = block_id × block_size + 块内偏移slot_id = tl.where(exceeds_max, PAD_ID, slot_id)— 如果超出 max length,slot 设为 -1seq_len = tl.load(seq_lens_ptr + req_idx)— 加载当前序列长度new_seq_len = tl.where(exceeds_max, 1, seq_len + 1)— 超出时重置为 1,否则递增new_seq_len = tl.minimum(new_seq_len, max_model_len)— 上限截断- 存储三个输出:
out_clamped_positions、out_slot_mapping、seq_lens(原位更新)
关键设计:此 kernel 以 input_batch_size 个线程启动(而非 batch_size),多出来的线程用于处理 cudagraph padding,确保在 CUDA Graph replay 时 grid 尺寸不变。
eagle_step_update_slot_mapping_and_metadata 封装函数:
- 提取
batch_size = positions_1d.shape[0] - 如果
input_batch_size未指定,默认为batch_size - 获取
n_blocks_per_req = block_table_tensor.shape[1] - 以
(input_batch_size,)的 grid 启动 kernel
3.2.4 eagle_prepare_inputs_padded_kernel - 准备填充输入 Kernel
作用:为 padded drafter batch 模式计算每个请求的采样 token 索引和被拒绝的 token 数量。
关键参数:
cu_num_draft_tokens_ptr:累积 draft token 数量(包含求和)valid_sampled_tokens_count_ptr:每个请求有效的采样 token 数(= 1 + 接受数)query_start_loc_gpu_ptr:query 起始位置[num_reqs + 1]- 输出:
token_indices_to_sample和num_rejected_tokens_gpu
逐行逻辑:
cu_draft_curr = tl.load(cu_num_draft_tokens_ptr + req_idx)— 加载当前请求的累积 draft 数num_draft_tokens计算:第一个请求直接等于累积值,后续为cu_draft_curr - cu_draft_prevvalid_count = tl.load(valid_sampled_tokens_count_ptr + req_idx)— 有效采样 token 数num_rejected_tokens = num_draft_tokens + 1 - valid_count— 被拒绝 token 数 = draft 数 + 1(bonus token)- 有效数num_rejected_tokens = tl.where(num_draft_tokens > 0, num_rejected_tokens, 0)— 如果没有 draft,拒绝数也为 0q_last_tok_idx = tl.load(query_start_loc_gpu_ptr + req_idx + 1) - 1— 当前请求的最后一个 token 索引index_to_sample = q_last_tok_idx - num_rejected_tokens— 需要采样的索引 = 最后索引 - 拒绝数
3.2.5 eagle_prepare_next_token_padded_kernel - 准备下一个 Token Kernel
作用:确定每个请求的 “next token id”(用于传入推测器),同时统计有效采样 token 数量。
逐行逻辑:
is_discarded = tl.load(discard_request_mask_ptr + req_idx)— 检查是否被丢弃- 丢弃分支:直接使用 backup token,
valid_count = 0 - 正常分支:
- 加载该行所有
sampled_token_ids(通过BLOCK_SIZE_TOKENS的 block 加载) is_valid_mask = (token_ids != -1) & (token_ids < vocab_size) & token_mask— 有效 token 的 maskvalid_count = tl.sum(is_valid_mask)— 计数- 如果
valid_count > 0:找到最后一个有效 token 的索引(tl.max(tl.where(...))),使用 sum trick 提取对应 token - 如果
valid_count == 0:使用 backup token
- 加载该行所有
设计亮点:使用 tl.sum(tl.where(token_offs == last_valid_index, token_ids, 0)) 技巧避免第二次内存加载,因为 Triton 没有直接的按索引选取操作。
3.2.6 compute_new_slot_mapping — 纯 PyTorch slot mapping 计算
作用:在扩展了 query 长度后,重新计算所有 token 的 slot mapping。
逐行分析:
req_indices = torch.arange(batch_size, device=...)— 请求索引torch.repeat_interleave(req_indices, cad.naive_query_lens() + num_new_tokens, ...)— 根据每个请求的 token 数扩展,得到每个 token 所属的请求索引clamped_positions = torch.clamp(new_positions, max=max_model_len - 1)— 防止 block table 越界block_table_indices = req_indices * n_blocks_per_req + clamped_positions // block_size— 展平后的 block table 索引block_nums = cad.block_table_tensor.view(-1)[block_table_indices]— 批量获取 block 号new_slot_mapping = block_nums * block_size + block_offsets— 计算 slotnew_slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)— 超出长度位置的 slot 设为 -1new_slot_mapping.masked_fill_(is_rejected_token_mask, PADDING_SLOT_ID)— 被拒绝 token 的 slot 设为 -1
3.2.7 extend_all_queries_by_N — 扩展所有 Query 长度
作用:将所有请求的 query 长度增加 N,同时更新相关元数据。用于并行推测(parallel drafting)模式。
逐行分析:
new_query_start_loc = cad.query_start_loc + N * arange[:len(cad.query_start_loc)]— 每个请求的起始位置递增[0, N, 2N, ..., batch_size*N]new_query_start_loc_cpu = ...— 同样更新 CPU 端cad.replace(...)创建新的CommonAttentionMetadata:seq_lens = cad.seq_lens + N— 序列长度递增num_actual_tokens = cad.num_actual_tokens + cad.batch_size() * N— 总 token 数增加max_query_len = cad.max_query_len + N— 最大 query 长度增加max_seq_len = cad.max_seq_len + N— 最大序列长度增加slot_mapping = new_slot_mapping— 使用新的 slot mapping
3.2.8 copy_and_expand_eagle_inputs_kernel — EAGLE 输入复制与扩展 Kernel
这是 utils.py 中最复杂、最重要的 Triton kernel,负责将目标模型的输入复制到草稿模型的缓冲区中,同时处理 padding、并行推测 slot 和被拒绝 token。
参数详解:
| 参数 | 形状 | 说明 |
|---|---|---|
target_token_ids_ptr |
[total_tokens_in_batch] |
目标模型的 token IDs |
target_positions_ptr |
[total_tokens_in_batch] |
目标模型的位置编码 |
next_token_ids_ptr |
[num_reqs] |
目标模型采样的下一个 token |
out_input_ids_ptr |
[total_draft_tokens] |
输出:草稿模型的输入 IDs |
out_positions_ptr |
[total_draft_tokens] |
输出:位置编码 |
out_is_rejected_token_mask_ptr |
[total_draft_tokens] |
输出:被拒绝 token mask |
out_is_masked_token_mask_ptr |
[total_draft_tokens] |
输出:并行推测 mask |
out_new_token_indices_ptr |
[num_padding_slots_per_req * num_reqs] |
输出:新 token 的索引 |
out_hidden_state_mapping_ptr |
[total_tokens_in_batch] |
输出:隐藏状态映射 |
query_start_loc_ptr |
[num_reqs + 1] |
query 起始位置 |
query_end_loc_ptr |
[num_reqs] |
query 结束位置 |
padding_token_id |
scalar | padding token(通常为 0) |
parallel_drafting_token_id |
scalar | 并行推测占位 token |
total_input_tokens |
scalar | 输入总 token 数 |
num_padding_slots_per_request |
scalar | 每请求新增 slot 数 |
shift_input_ids |
bool | 是否移位输入(EAGLE 模式为 true) |
执行逻辑(2D grid: request_idx × token_batch_idx):
- 加载 query 边界:从
query_start_loc和query_end_loc获取当前请求的有效 token 范围 - 计算有效 token 数:
shift_input_ids=True时(EAGLE):num_valid_tokens = query_end_loc - query_start_loc(跳过第一个 token),input_offset = 1shift_input_ids=False时(Draft Model):num_valid_tokens = query_end_loc - query_start_loc + 1,input_offset = 0
- 计算输出起始位置:考虑 cudagraph padding 导致的位移
- 被拒绝 token 数:
num_rejected = next_query_start_loc - query_end_loc - 1 - 总输出 token 数:
num_valid_tokens + num_padding_slots_per_request + num_rejected - 区域分类 mask:
is_valid_region:j < num_valid_tokens— 从输入复制is_bonus_region:j == num_valid_tokens— bonus token(next_token_id)is_parallel_draft_region: 并行推测占位区域is_rejected_region: 被拒绝 token 区域
- Token IDs 组装:通过
tl.where链式选择,不同区域写入不同的 token - Position 计算:
start_pos + j,位置不跟随输入移位 - 隐藏状态映射:当
shift_input_ids=True时,记录每个输入位置在输出缓冲区中的对应索引 - 存储所有输出
输出缓冲区布局:
[valid_tokens | bonus_token | parallel_draft_slots | rejected_slots]
3.2.9 copy_and_expand_dflash_inputs_kernel — DFlash 输入复制 Kernel
与 EAGLE kernel 类似但针对 DFlash 的交叉注意力机制做了定制:
关键差异:
- 将**上下文(context)和查询(query)**分存到不同的缓冲区中
- Context 的 slot mapping 和 query 的 slot mapping 分别存储
- Query 部分由
[next_token, mask, mask, ...]组成 - 处理
num_rejected_tokens(padded 模式下) - 使用
HAS_NUM_REJECTEDcompile-time 常量条件
逐行逻辑:
- 加载 context 范围
ctx_start到ctx_end - 计算
num_ctx和total_tokens = num_ctx + num_query_per_req - 位置处理:
- Context:从
target_positions加载 - Query:
last_pos + 1 + query_off(在最后一个有效位置之后递增)
- Context:从
- Slot mapping:通过 block table 查找,context 和 query 分别写入不同缓冲区
- Input IDs:query 部分的第一个 token 为 bonus token,其余为
parallel_drafting_token_id - Token indices:记录 mask token 的位置(需要采样的位置)
3.2.10 update_num_computed_tokens_for_batch_change — Batch 变更修正
使用 @torch.compile 编译的函数,用于异步推测解码中的 num_computed_tokens 修正:
gather_indices = prev_positions.clamp(min=0)— 防止新请求的 -1 位置导致越界- 从旧位置 gather 对应的 valid_count、prev_computed、prev_drafts
participating = (prev_positions >= 0) & (prev_drafts > 0)— 参与推测的请求corrected = prev_computed + valid_counts.int()— 修正计算torch.where(participating, corrected, cpu_num_computed_tokens)— 参与修正,未参与使用 CPU 值
3.3 metadata.py - SpecDecodeMetadata 数据类
SpecDecodeMetadata 是推测解码中的元数据容器,封装了 draft token 的布局信息和索引信息。
@dataclass
class SpecDecodeMetadata:
draft_token_ids: torch.Tensor # [num_tokens] - 扁平化的 draft token IDs
num_draft_tokens: list[int] # [batch_size] - 每个请求的 draft token 数
cu_num_draft_tokens: torch.Tensor # [batch_size] - 累积 draft token 数
cu_num_sampled_tokens: torch.Tensor # [batch_size] - 累积采样 token 数(= draft + 1)
target_logits_indices: torch.Tensor # [num_tokens] - 需要 target 验证的 logit 索引
bonus_logits_indices: torch.Tensor # [batch_size] - bonus token 的 logit 索引
logits_indices: torch.Tensor # [num_tokens + batch_size] - 所有 logit 索引
逐字段分析:
-
draft_token_ids:所有请求的 draft token IDs 扁平化拼接。例如请求 1 有 2 个 draft[a, b],请求 2 有 3 个[c, d, e],则合并为[a, b, c, d, e]。 -
num_draft_tokens:每个请求的 draft token 数量列表,用于后续按请求切分。 -
cu_num_draft_tokens:num_draft_tokens的累积和(np.cumsum)。例如[2, 3]→[2, 5]。通过索引此数组可以快速定位每个请求在draft_token_ids中的起始和结束位置。 -
cu_num_sampled_tokens:num_draft_tokens + 1的累积和。“+1” 是因为每个请求还有一个 bonus token(目标模型实际采样的 token)。 -
target_logits_indices:用于在 target model 的 logits 中索引需要验证的 draft token 位置。在初始化时为零填充,由 verifier 阶段填入。 -
bonus_logits_indices:bonus token 对应的 logit 索引。 -
logits_indices:所有 logits 索引的合并数组(target + bonus),shape 为[num_tokens + batch_size]。
__post_init__:计算 max_spec_len = max(self.num_draft_tokens),即 batch 中最大的 draft 长度,用于后续的 tensor 维度分配。
make_dummy 方法:
@classmethod
def make_dummy(cls, draft_token_ids: list[list[int]], device: torch.device) -> "SpecDecodeMetadata":
用于创建 dummy metadata,主要场景是 CUDA Graph 捕获阶段,此时还没有真实的 draft token。该方法:
- 将
list[list[int]]扁平化并转为 tensor - 计算
num_draft_tokens和累积和 target_logits_indices、bonus_logits_indices、logits_indices均初始化为零填充- 返回完整的
SpecDecodeMetadata实例
3.4 metrics.py - 指标体系
指标体系由三个组件构成,形成"采集 → 聚合 → 上报"的三层架构。
3.4.1 SpecDecodingStats — 单步统计
@dataclass
class SpecDecodingStats:
num_spec_tokens: int # 推测 token 数(配置值)
num_drafts: int = 0 # 推测次数
num_draft_tokens: int = 0 # 总 draft token 数
num_accepted_tokens: int = 0 # 总接受 token 数
num_accepted_tokens_per_pos: list[int] # 每个位置的接受计数
new(cls, num_spec_tokens) 工厂方法:
- 创建新实例,
num_accepted_tokens_per_pos初始化为[0] * num_spec_tokens - 长度为
num_spec_tokens,索引0对应第 1 个推测位置,索引1对应第 2 个,以此类推
observe_draft(num_draft_tokens, num_accepted_tokens):
num_drafts += 1— 推测次数 +1num_draft_tokens += num_draft_tokens— 累计 draftnum_accepted_tokens += num_accepted_tokens— 累计接受- 前
num_accepted_tokens个位置的num_accepted_tokens_per_pos各 +1 - assert 确保
num_accepted_tokens <= num_spec_tokens(接受数不能超过推测数)
3.4.2 SpecDecodingLogging — 日志聚合器
class SpecDecodingLogging:
状态变量:
num_drafts: list[int]— 各次观察的推测次数num_draft_tokens: list[int]— 各次观察的 draft token 数num_accepted_tokens: list[int]— 各次观察的接受 token 数accepted_tokens_per_pos_lists: list[list[int]]— 各次观察的位置接受计数last_log_time— 上次日志时间戳
observe(spec_decoding_stats):将单步统计追加到各列表中。
log(log_fn=logger.info) 核心聚合逻辑:
num_drafts = np.sum(self.num_drafts)— 总推测次数num_draft_tokens = np.sum(...)— 总 draftnum_accepted_tokens = np.sum(...)— 总接受draft_throughput = num_draft_tokens / elapsed_time— draft 吞吐(tokens/s)accepted_throughput = num_accepted_tokens / elapsed_time— 接受吞吐draft_acceptance_rate = num_accepted_tokens / num_draft_tokens * 100— 接受率mean_acceptance_length = 1 + (num_accepted_tokens / num_drafts)— 平均接受长度(含 bonus token)pos_matrix = np.array(self.accepted_tokens_per_pos_lists)— 位置矩阵acceptance_rates = np.sum(pos_matrix, axis=0) / num_drafts— 每个位置的接受率- 格式化为日志字符串,包含所有指标
3.4.3 SpecDecodingProm — Prometheus 集成
class SpecDecodingProm:
初始化时创建的 Counter:
| Counter 名称 | 含义 | 标签 |
|---|---|---|
vllm:spec_decode_num_drafts |
推测次数 | engine 标签 |
vllm:spec_decode_num_draft_tokens |
draft token 总数 | engine 标签 |
vllm:spec_decode_num_accepted_tokens |
接受 token 总数 | engine 标签 |
vllm:spec_decode_num_accepted_tokens_per_pos |
各位置接受数 | engine + position 标签 |
PromQL 计算公式:
- 接受率:
rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) / rate(vllm:spec_decode_num_draft_tokens_total[$interval]) - 平均接受长度:
1 + (rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) / rate(vllm:spec_decode_num_drafts[$interval])) - 各位置接受率向量:
vllm:spec_decode_num_accepted_tokens_per_pos[$interval] / vllm:spec_decode_num_drafts[$interval]
observe(spec_decoding_stats, engine_idx=0):
- 根据 engine 索引递增对应的 counter
- per-position counter 通过位置索引分别递增
3.5 eagle.py - SpecDecodeBaseProposer 基类与 EagleProposer
这是整个 spec_decode 模块中最核心、最复杂的文件(1793 行),包含了所有基于自回归推测的 Proposer 的公共基类以及 EAGLE 专用实现。
3.5.1 SpecDecodeBaseProposer.__init__() — 初始化
核心变量逐行分析:
def __init__(self, vllm_config, device, pass_hidden_states_to_model, runner=None):
配置提取:
self.vllm_config = vllm_config— 保存完整的 vLLM 配置self.speculative_config = vllm_config.speculative_config— 推测解码配置self.draft_model_config = self.speculative_config.draft_model_config— 草稿模型配置self.method = self.speculative_config.method— 推测方法名称(“eagle”, “eagle3”, “dflash” 等)self.pass_hidden_states_to_model = pass_hidden_states_to_model— 是否将 target 的 hidden states 传给 draft model
类型和维度参数:
6. self.device = device — GPU 设备
7. self.dtype = vllm_config.model_config.dtype — 模型数据类型(float16/bfloat16)
8. self.max_model_len = vllm_config.model_config.max_model_len — 模型最大长度
9. self.dp_rank = vllm_config.parallel_config.data_parallel_rank — 数据并行 rank
10. self.num_speculative_tokens = self.speculative_config.num_speculative_tokens — 每次推测的 token 数
隐藏层维度:
11. self.hidden_size = self.draft_model_config.get_hidden_size() — 从 draft model 获取 hidden size(可能与 target model 不同,如 Llama 3.3 70B 场景)
12. self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size() — 嵌入层维度
并行推测参数:
13. self.parallel_drafting: bool = self.speculative_config.parallel_drafting — 是否启用并行推测
14. self.extra_slots_per_request = 1 if not self.parallel_drafting else self.num_speculative_tokens — 每个请求的额外 slot 数
- 非并行模式(EAGLE 递归):每个请求只需 1 个额外 slot 放 bonus token
- 并行模式(parallel drafting):需要 num_speculative_tokens 个 slot
15. self.net_num_new_slots_per_request = self.extra_slots_per_request - (1 if ... else 0) — 净新增 slot
- 当 pass_hidden_states_to_model=True 且非 DFlash 时,EAGLE 会"吃掉"第一个 token(移位),因此净新增减 1
16. self.needs_extra_input_slots = self.net_num_new_slots_per_request > 0 — 是否需要额外输入 slot
并行推测专用参数:
17. self.parallel_drafting_token_id: int = 0 — 并行推测的占位 token ID
18. self.parallel_drafting_hidden_state_tensor: Tensor | None = None — 并行推测的占位 hidden state
19. 如果 parallel_drafting,调用 _init_parallel_drafting_params() 初始化
本地 argmax 优化:
20. self.use_local_argmax_reduction: bool = self.speculative_config.use_local_argmax_reduction — 是否使用本地 argmax 归约(减少 TP 通信,O(2*tp_size) vs O(vocab_size))
批处理大小限制:
21. self.max_batch_size = vllm_config.scheduler_config.max_num_seqs — 最大 batch size
22. self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens — 最大 token 数
23. self.token_arange_np = np.arange(self.max_num_tokens) — CPU 端 token 索引数组
Query/Position 限制:
24. self.max_query_tokens = self.max_num_tokens — 最大 query token 数(可被子类如 DFlash 缩小)
25. self.max_positions = self.max_num_tokens — 最大位置数
多模态支持:
26. self.mm_registry = MULTIMODAL_REGISTRY — 多模态注册表
27. self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(vllm_config.model_config) — 是否支持多模态输入
注意力后端:
28. self.draft_attn_groups: list[AttentionGroup] = [] — 草稿模型的注意力组(延迟初始化)
29. self.kv_cache_gid: int = -1 — KV Cache 组 ID
30. self.eagle3_use_aux_hidden_state: bool = ... — EAGLE3 是否使用辅助隐藏状态
编译配置:
31. self.compilation_config = self.vllm_config.compilation_config
CUDA Graph 调度器:
32. self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) — 仅用于 PIECEWISE 模式
33. 在 initialize_cudagraph_keys() 中延迟初始化 key
GPU 持久化缓冲区(核心内存分配):
34. self.input_ids = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=device) — 输入 token IDs 缓冲区
35. self.uses_mrope = self.draft_model_config.uses_mrope — 是否使用 M-RoPE(使用 draft model 的设置,而非 target)
36. self.uses_xdrope_dim / self.draft_uses_xdrope_dim — XDrop 维度
37. M-RoPE 位置缓冲:self.mrope_positions = torch.zeros((3, self.max_positions + 1), dtype=torch.int64, device=device)
- 故意多分配 1 个位置使其非连续(non-contiguous),以兼容 torch compile
- 3 个维度分别对应 M-RoPE 的三个空间维度
38. XDrop 位置缓冲:self.xdrope_positions(当使用 XDrop 时)
39. 普通位置缓冲:self.positions = torch.zeros(self.max_positions, dtype=torch.int64, device=device)
40. self.hidden_states = torch.zeros((self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device) — 隐藏状态缓冲
41. self.inputs_embeds = torch.zeros((self.max_num_tokens, self.inputs_embeds_size), dtype=self.dtype, device=device) — 嵌入缓冲
42. self.block_size: int = -1 — KV cache block size(延迟初始化)
ARange 缓冲:
43. self.arange = torch.arange(max(self.max_batch_size + 1, self.max_num_tokens), device=device, dtype=torch.int32) — 需要 +1 因为 query_start_loc 比 batch_size 多一个元素
拒绝/mask 标记缓冲(当 needs_extra_input_slots 时):
44. self.is_rejected_token_mask = torch.zeros((self.max_num_tokens,), dtype=torch.bool, device=device) — 被拒绝 token 标记
45. self.is_masked_token_mask = torch.zeros((self.max_num_tokens,), dtype=torch.bool, device=device) — 并行推测占位标记
Backup token 缓冲:
46. self.backup_next_token_ids = CpuGpuBuffer(self.max_batch_size, dtype=torch.int32, ...) — 备用的下一个 token ID(CPU-GPU 双端缓冲)
Slot mapping 缓冲:
47. self._slot_mapping_buffer = torch.zeros(self.max_positions, dtype=torch.int64, device=device)
ROCm 注意力类型:
48. 在 ROCm 平台上,构建 allowed_attn_types 元组,包含各种 ROCm 注意力后端类型
推测 token 树:
49. spec_token_tree = self.speculative_config.speculative_token_tree — 从配置解析 token 树结构
50. self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree) — token 树表示
- 例如 [(0,), (0,0), (0,1), (1,), (1,0), (1,1)] 表示一个 2 叉深度 2 的树
51. tree_depth = len(self.tree_choices[-1]) — 树的深度
52. num_drafts_per_level — 每层的 draft 数
53. self.cu_drafts_per_level — 累积 draft 数(前缀和)
54. self.child_drafts_per_level — 每层的孩子数(每个节点的孩子数)
55. self.tree_draft_pos_offsets = torch.arange(1, len(self.tree_choices) + 1, ...).repeat(batch_size, 1) — 树中 draft 的位置偏移
GPU 内存预算(以 max_num_tokens=4096, hidden_size=4096, batch_size=128 为例):
input_ids: 4096 × 4 = 16 KBpositions: 4096 × 8 = 32 KBmrope_positions: 3 × 4097 × 8 = 98 KBhidden_states: 4096 × 4096 × 2 (bfloat16) = 32 MBinputs_embeds: 4096 × 4096 × 2 = 32 MBis_rejected_token_mask: 4096 bytesis_masked_token_mask: 4096 bytesbackup_next_token_ids: 128 × 4 × 2 = 1 KB_slot_mapping_buffer: 4096 × 8 = 32 KB- 总计约 64.5 MB(不含模型权重)
3.5.2 _init_parallel_drafting_params()
初始化并行推测所需的特殊参数:
def _init_parallel_drafting_params(self):
model_hf_config = self.draft_model_config.hf_config
dflash_config = getattr(model_hf_config, "dflash_config", None)
if dflash_config and "mask_token_id" in dflash_config:
self.parallel_drafting_token_id = dflash_config["mask_token_id"]
elif hasattr(model_hf_config, "pard_token"):
self.parallel_drafting_token_id = model_hf_config.pard_token
elif hasattr(model_hf_config, "ptd_token_id"):
self.parallel_drafting_token_id = model_hf_config.ptd_token_id
else:
raise ValueError(...) # 必须有 mask token id
if self.pass_hidden_states_to_model:
self.parallel_drafting_hidden_state_tensor = torch.empty(
self.hidden_size, dtype=self.dtype, device=self.device
)
逻辑:从模型的 config.json 中依次查找 dflash_config.mask_token_id → pard_token → ptd_token_id,获取并行推测使用的 mask token ID。如果 pass_hidden_states_to_model=True(EAGLE 模式),还需分配一个空的 hidden state tensor 用于 mask token 的 hidden state。
3.5.3 load_model() — 模型加载
def load_model(self, target_model: nn.Module) -> None:
逐行分析:
-
获取目标模型的注意力层名称:
target_attn_layer_names = set( get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys() )记录目标模型中所有
AttentionLayerBase子类的层名。 -
加载草稿模型:
self.model = self._get_model()调用
_get_model()加载 EAGLE head 模型。 -
识别草稿模型的注意力层:
all_attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) self._draft_attn_layer_names = set(all_attn_layers.keys()) - target_attn_layer_names草稿模型的注意力层 = 所有注意力层 - 目标模型的注意力层。这些是 EAGLE head 新增的层。
-
多模态兼容性检查:如果目标模型是多模态但草稿模型不支持,降级为纯文本模式。
-
多模态 image_token_index 对齐:对一系列特定模型(Qwen3VL、Gemma4 等),将草稿模型的
image_token_index设置为与目标模型一致。 -
共享嵌入层:
self._maybe_share_embeddings(target_language_model) -
共享 LM Head:
self._maybe_share_lm_head(target_language_model) -
并行推测 hidden state 初始化:
if self.parallel_drafting and self.pass_hidden_states_to_model: flat_mask = self.model.mask_hidden.view(-1) if self.eagle3_use_aux_hidden_state: self.parallel_drafting_hidden_state_tensor.copy_( self.model.combine_hidden_states(flat_mask) ) else: self.parallel_drafting_hidden_state_tensor.copy_(flat_mask)将 EAGLE 模型内部的 mask hidden state 复制到缓冲区中,用于并行推测时的占位隐藏状态。
3.5.4 _maybe_share_embeddings() — 可选共享嵌入层
目的:当草稿模型没有自己的嵌入层(或嵌入层与目标模型完全相同)时,直接共享目标模型的嵌入层权重,节省 GPU 内存。
逐行逻辑:
-
仅 PP=1 时生效:
if get_pp_group().world_size == 1 -
获取目标模型的嵌入层:
- 先尝试
inner_model.embed_tokens - 再尝试
inner_model.embedding - 都找不到则 raise
AttributeError
- 先尝试
-
EAGLE 模型判断:
hasattr(self.model, "has_own_embed_tokens")— EAGLE 模型特有属性has_own_embed_tokens=False:模型没有自己的嵌入 → 直接共享has_own_embed_tokens=True:比较权重是否相同 → 相同则共享- 权重不同:保留独立嵌入
-
MTP 模型判断:没有
has_own_embed_tokens属性的模型(MTP),默认共享 -
执行共享:
if share_embeddings: if hasattr(self.model.model, "embed_tokens"): del self.model.model.embed_tokens self.model.model.embed_tokens = target_embed_tokens先删除草稿模型的嵌入层,然后直接引用目标模型的嵌入层(共享同一份 GPU 内存)。
3.5.5 _maybe_share_lm_head() — 可选共享 LM Head
与 _maybe_share_embeddings 类似,但针对 LM Head:
额外处理 MTP 模型:
inner = getattr(self.model, "model", None)
layers = getattr(inner, "layers", None) if inner else None
if layers is not None:
items = layers.values() if isinstance(layers, nn.ModuleDict) else layers
for layer in items:
sh = getattr(layer, "shared_head", None)
if sh is not None and hasattr(sh, "head"):
del sh.head
sh.head = target_language_model.lm_head
MTP 模型的 compute_logits 通过 shared_head.head(一个 ParallelLMHead)执行,如果 checkpoint 中没有独立的 lm_head 权重,这个 head 会保持未初始化状态产生 NaN。因此必须显式共享。
Local Argmax 检查:当 use_local_argmax_reduction=True 时:
- 检查模型是否实现了
get_top_tokens()方法 - 如果模型有
draft_id_to_target_idvocab 映射,会 fallback 到完整 logits 路径(因为 argmax 后还需要映射)
3.5.6 propose() — 核心推测流程
这是整个模块的入口方法,实现了完整的推测解码流程。
def propose(
self,
target_token_ids: torch.Tensor, # [num_tokens]
target_positions: torch.Tensor, # [num_tokens] 或 [3, num_tokens]
target_hidden_states: torch.Tensor, # [num_tokens, hidden_size]
next_token_ids: torch.Tensor, # [batch_size]
token_indices_to_sample: torch.Tensor | None,
common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
mm_embed_inputs: tuple[...] | None = None,
num_rejected_tokens_gpu: torch.Tensor | None = None,
slot_mappings: dict[...] | None = None,
) -> torch.Tensor:
逐行流程分析:
Phase 0: 隐藏状态组合(EAGLE3/DFlash)
if self.method in ("eagle3", "dflash"):
target_hidden_states = self.model.combine_hidden_states(target_hidden_states)
EAGLE3 和 DFlash 需要将目标模型的多个辅助隐藏状态合并为单一隐藏状态,以匹配 EAGLE 模型的 hidden_size。
Phase 1: 设置第一轮输入
num_tokens, token_indices_to_sample, common_attn_metadata = self.set_inputs_first_pass(...)
调用 set_inputs_first_pass() 处理输入数据,将目标模型的 token/position/hidden_states 复制到草稿模型的缓冲区中。
Phase 2: 构建注意力元数据
per_group_attn_metadata, per_layer_attn_metadata = self.build_per_group_and_layer_attn_metadata(common_attn_metadata)
遍历 self.draft_attn_groups,为每个注意力组调用 build_for_drafting() 构建注意力元数据。
Phase 3: 确定执行模式
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = self._determine_batch_execution_and_padding(num_tokens)
通过 CudagraphDispatcher 确定是否使用 CUDA Graph,以及 batch 填充大小。
Phase 4: 构建第一轮模型输入
model_kwargs, slot_mapping_size = self.build_model_inputs_first_pass(num_tokens, num_input_tokens, mm_embed_inputs)
准备 input_ids、positions、inputs_embeds、hidden_states 等 model forward 参数。
Phase 5: 第一轮前向传播
with set_forward_context(per_layer_attn_metadata, self.vllm_config, ...):
ret_hidden_states = self.model(**model_kwargs)
if not self.model_returns_tuple():
last_hidden_states = ret_hidden_states
hidden_states = last_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
执行草稿模型的前向传播。EAGLE 模型返回 (last_hidden_states, hidden_states) 元组,分别用于 logits 计算和后续自回归步骤。DFlash/Draft Model 只返回单一 hidden states。
Phase 6: 第一轮采样
sample_hidden_states = last_hidden_states[token_indices_to_sample]
if self.num_speculative_tokens == 1 or self.parallel_drafting:
draft_token_ids = self._greedy_sample(sample_hidden_states)
return draft_token_ids.view(-1, self.num_speculative_tokens)
如果只需推测 1 个 token 或启用并行推测,直接 argmax 采样后返回。
Phase 7: 提取采样位置的 hidden states
positions = self.mrope_positions[:, token_indices_to_sample] if self.uses_mrope else self.positions[token_indices_to_sample]
hidden_states = hidden_states[token_indices_to_sample]
根据 token 索引提取对应的位置编码和隐藏状态。
Phase 8: 树状推测或自回归推测
if any(isinstance(md, TreeAttentionMetadata) for md in per_group_attn_metadata):
logits = self.model.compute_logits(sample_hidden_states)
draft_token_ids_list = self.propose_tree(...)
return torch.cat(draft_token_ids_list, dim=1)
如果使用了 Tree Attention,走树状推测路径(propose_tree())。
否则走自回归递归路径:
draft_token_ids = self._greedy_sample(sample_hidden_states)
draft_token_ids_list = [draft_token_ids]
Phase 9: 自回归循环(AR Loop)
for token_index in range(self.num_speculative_tokens - 1):
input_ids = draft_token_ids_list[-1].int()
eagle_step_update_slot_mapping_and_metadata(...)
common_attn_metadata.slot_mapping = self._slot_mapping_buffer[:batch_size]
# ... update positions ...
_, per_layer_attn_metadata = self.build_per_group_and_layer_attn_metadata(...)
self.input_ids[:batch_size] = input_ids
self.hidden_states[:batch_size] = hidden_states
with set_forward_context(...):
ret_hidden_states = self.model(**model_kwargs)
last_hidden_states, hidden_states = ret_hidden_states
hidden_states = hidden_states[:batch_size]
draft_token_ids = self._greedy_sample(last_hidden_states[:batch_size])
draft_token_ids_list.append(draft_token_ids)
每个 AR 步骤:
- 将上一步的 draft token 作为
input_ids - 更新 slot mapping 和元数据(位置+1、序列长度+1)
- 重建注意力元数据
- 执行模型 forward(单 token 前向)
- 贪婪采样下一个 draft token
Phase 10: 拼接结果
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids # [batch_size, num_speculative_tokens]
3.5.7 set_inputs_first_pass() — 第一轮输入准备
此方法有两个分支:EAGLE 默认路径和扩展路径。
默认 EAGLE 路径(needs_extra_input_slots=False):
if token_indices_to_sample is None:
token_indices_to_sample = cad.query_start_loc[1:] - 1
# Shift the input ids by one token
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token
self.input_ids[token_indices_to_sample] = next_token_ids
移位操作示例:
- 输入:
[a1, b1, b2, c1, c2, c3](a 有 1 个 token,b 有 2 个,c 有 3 个) - 移位后:
[b1, b2, c1, c2, c3, c3](每个请求丢掉第一个 token,最后一个位置保留) - 替换:
[a2, b2, b3, c2, c3, c4](在token_indices_to_sample位置放入next_token_ids)
这利用了 EAGLE 模型的架构特性:EAGLE head 接收 target model 的 hidden states(已包含所有上下文信息),而 token IDs 只需要提供"当前位置之后的 token"作为输入,因此可以安全地跳过每个请求的第一个 token(其 hidden state 已经包含了前一个 token 的信息)。
扩展路径(needs_extra_input_slots=True,用于 parallel drafting 和 draft model):
- 调用
copy_and_expand_eagle_inputs_kernelTriton kernel - 如果
pass_hidden_states_to_model,通过out_hidden_state_mapping将 target hidden states 映射到草稿缓冲区 - 使用 mask token 的 hidden state 填充并行推测 slot
- 调用
compute_new_slot_mapping计算新的 slot mapping - 调用
extend_all_queries_by_N扩展 query 长度
3.5.8 prepare_inputs_padded() — 填充模式输入准备
使用 eagle_prepare_inputs_padded_kernel 计算每个请求的 token_indices_to_sample 和 num_rejected_tokens_gpu。然后创建新的 CommonAttentionMetadata,其中 token 布局保持不变(不重新排列),被拒绝的 token 作为 padding 保留在 buffer 中,后续由 token_indices_to_sample 过滤。
3.5.9 prepare_next_token_ids_padded() — 填充模式下一个 Token 准备
使用 eagle_prepare_next_token_padded_kernel 确定每个请求的 next token ID(来自采样结果或 backup token)。
3.5.10 dummy_run() — CUDA Graph 捕获
@torch.inference_mode()
def dummy_run(self, num_tokens, use_cudagraphs=True, is_graph_capturing=False, slot_mappings=None):
用于 CUDA Graph 捕获阶段的 dummy 前向传播:
only_one_forward_pass = is_graph_capturing or self.parallel_drafting— 捕获时或并行推测时只执行一次 forward- 非并行推测时需要
num_speculative_tokens次 forward(模拟 AR loop) - 使用 EAGLE 自己的 slot mapping buffer
3.5.11 propose_tree() — 树状推测
当使用 Tree Attention 时,推测不再是一维链而是树结构:
- 第一层:对每个请求的 bonus token 位置,从 logits 中 argmax 或 top-k 采样 draft token
- 循环:每层将新的 draft token 与之前的上下文拼接,构建 tree attention metadata
- concat:将 draft tokens、positions、hidden states 沿 token 维度拼接
- Forward:执行树状注意力的 forward
- 采样:从输出 logits 中继续采样下一层
- 位置偏移:
flattened_draft_positions = positions + self.tree_draft_pos_offsets
3.5.12 build_per_group_and_layer_attn_metadata()
遍历 draft_attn_groups,为每个组调用 build_for_drafting(common_attn_metadata, draft_index),构建 per_layer_attn_metadata 字典,key 为层名,value 为注意力元数据。
3.5.13 辅助方法
| 方法 | 作用 |
|---|---|
_get_positions(num_tokens) |
根据 RoPE 类型返回位置缓冲区的切片 |
_set_positions(num_tokens, positions) |
将位置写入缓冲区(处理 M-RoPE/XDrop 转换) |
_get_slot_mapping(num_tokens, slot_mapping) |
将 slot mapping 拷贝到 buffer 并返回 dict |
_greedy_sample(hidden_states) |
argmax 或 get_top_tokens() 采样 |
_determine_batch_execution_and_padding(num_tokens) |
通过 cudagraph_dispatcher 确定执行模式和填充 |
model_returns_tuple() |
判断模型返回类型(EAGLE 返回元组,DFlash/Draft 返回单一值) |
validate_same_kv_cache_group() |
确保所有草稿层在同一 KV Cache 组中 |
initialize_attn_backend() |
初始化草稿层的注意力后端 |
3.5.14 EagleProposer — EAGLE 专用子类
class EagleProposer(SpecDecodeBaseProposer):
def __init__(self, vllm_config, device, runner=None):
super().__init__(
vllm_config,
device,
pass_hidden_states_to_model=True,
runner=runner,
)
EagleProposer 是 SpecDecodeBaseProposer 的极简子类,唯一自定义的行为是在构造时将 pass_hidden_states_to_model=True,表示 EAGLE 模型需要接收目标模型的隐藏状态作为输入(这是 EAGLE 架构的核心设计:EAGLE head 以 target model 的 hidden states 为条件生成 draft tokens)。
文件末尾的 compute_probs_and_sample_next_token() 函数:
目前未使用(注释标注为 FIXME),保留了将来使用带温度采样的 draft token 生成逻辑。当前实现仅使用 argmax 贪婪采样。
3.6 dflash.py - DFlashProposer
DFlashProposer 继承自 SpecDecodeBaseProposer,实现了 DFlash(Dynamic Flash)推测方法。DFlash 的核心创新是使用交叉注意力(Cross-Attention)机制,而非 EAGLE 的自回归递归。
3.6.1 初始化
def __init__(self, vllm_config, device, runner=None):
super().__init__(..., pass_hidden_states_to_model=True, ...)
# 关键差异:DFlash 只将推测 token 作为 query
self.max_query_tokens = self.max_batch_size * (1 + self.num_speculative_tokens)
# 位置缓冲区涵盖 context + query
self.max_positions = self.max_num_tokens + self.max_query_tokens
DFlash 的内存布局:
max_query_tokens:每个请求(1 + num_speculative_tokens)个 query token(1 个 bonus + k 个 mask),batch_size × 这个数max_positions:context positions + query positions 的总和_context_slot_mapping_buffer:context 的 slot mapping(大小 = max_num_tokens)_slot_mapping_buffer:query 的 slot mapping(大小 = max_query_tokens)_context_positions_buffer:context 位置(大小 = max_num_tokens)positions:query 位置(大小 = max_query_tokens)parallel_drafting_hidden_state_tensor = None:DFlash 不使用 hidden state 传递,而是使用嵌入
3.6.2 _raise_if_multimodal() 重写
DFlash 覆盖了父类方法为空操作,允许多模态输入(支持 Qwen3.5 系列模型)。
3.6.3 set_inputs_first_pass() 重写
DFlash 的输入处理与 EAGLE 完全不同:
num_context = target_token_ids.shape[0]
num_query_per_req = 1 + self.num_speculative_tokens
num_query_total = batch_size * num_query_per_req
关键差异:
- Context K/V 直接来自
target_hidden_states,不需要复制到 buffer - Query tokens 只有 bonus + mask tokens,数量极少
- 使用
copy_and_expand_dflash_inputs_kernel而非 EAGLE 的 kernel - Query 的 slot mapping 和 context 的 slot mapping 分别管理
kernel 参数配置:
copy_and_expand_dflash_inputs_kernel[grid](
next_token_ids_ptr=next_token_ids,
target_positions_ptr=target_positions,
out_input_ids_ptr=self.input_ids,
out_context_positions_ptr=self._context_positions_buffer,
out_query_positions_ptr=self.positions,
out_context_slot_mapping_ptr=self._context_slot_mapping_buffer,
out_query_slot_mapping_ptr=self._slot_mapping_buffer,
out_token_indices_ptr=token_indices_to_sample,
...
)
新的 CommonAttentionMetadata:
causal=False— DFlash 需要非因果注意力(交叉注意力可以访问所有 context)num_actual_tokens = num_query_total— 只有 query tokens 参与 forwardslot_mapping = query_slot_mapping— 仅 query 的 slot mapping
3.6.4 dummy_run() 重写
DFlash 的 dummy run 有特殊处理:
- 先执行
precompute_and_store_context_kv()— 将 context 的 KV 预计算并存入 cache - 再执行 query 的 forward pass
context_positions = self._context_positions_buffer[:num_tokens]
context_states = self.hidden_states[:num_tokens]
self.model.precompute_and_store_context_kv(context_states, context_positions)
3.6.5 build_model_inputs_first_pass() 重写
def build_model_inputs_first_pass(self, num_tokens, num_input_tokens, mm_embed_inputs):
num_context = self._dflash_num_context
# 预插入 context KV
self.model.precompute_and_store_context_kv(
self._dflash_hidden_states,
self._context_positions_buffer[:num_context],
self._context_slot_mapping_buffer[:num_context],
)
return dict(input_ids=..., positions=..., inputs_embeds=None), num_input_tokens
关键:context 的 KV 通过 precompute_and_store_context_kv 直接预计算并插入 KV Cache,不经过常规的 model forward。
3.6.6 build_per_group_and_layer_attn_metadata() 重写
在父类基础上增加 causal=False 检查,确保所有注意力后端都支持非因果注意力。
DFlash 交叉注意力机制详解:
3.7 draft_model.py - DraftModelProposer
DraftModelProposer 使用一个独立的小型语言模型作为草稿模型(如 Qwen2.5-0.5B 作为 Qwen2.5-72B 的草稿)。
3.7.1 初始化
def __init__(self, vllm_config, device, runner=None):
super().__init__(
vllm_config,
device,
pass_hidden_states_to_model=False, # 关键差异
runner=runner,
)
self._raise_if_vocab_size_mismatch()
self._raise_if_draft_tp_mismatch()
pass_hidden_states_to_model=False 是 EAGLE 和 Draft Model 的核心差异:Draft Model 是一个完整的语言模型,不需要接收目标模型的 hidden states,而是通过 token IDs 自主推理。
3.7.2 _raise_if_vocab_size_mismatch()
def _raise_if_vocab_size_mismatch(self):
self.speculative_config.verify_equal_vocab_size_if_draft_model()
确保目标模型和草稿模型的词汇表大小一致。这是 Draft Model 的必要条件,因为草稿模型生成的 token IDs 需要直接用于目标模型。
3.7.3 _raise_if_draft_tp_mismatch()
def _raise_if_draft_tp_mismatch(self):
tgt_tp = spec_cfg.target_parallel_config.tensor_parallel_size
draft_tp = spec_cfg.draft_parallel_config.tensor_parallel_size
if draft_tp != tgt_tp:
raise ValueError(...)
确保目标模型和草稿模型的 TP 并行度一致。原因:如果 TP > 1 的目标模型与 TP = 1 的草稿模型共存,所有 rank 都会在 rank 0 上编译草稿模型,导致 torch.compile 缓存被覆盖和损坏。
3.7.4 _create_draft_vllm_config()
为草稿模型创建专用的 VllmConfig:
quant_config=None— 草稿模型不使用量化parallel_config— 使用草稿模型的并行配置model_config— 使用草稿模型的配置
3.7.5 _get_model()
def _get_model(self) -> nn.Module:
with set_model_tag("draft_model"):
model = get_model(
vllm_config=draft_vllm_config,
prefix="draft_model",
)
return model
加载独立的草稿模型,model tag 为 “draft_model”,prefix 为 “draft_model”(与 EAGLE 的 “eagle_head” 区分)。
3.7.6 嵌入和 LM Head 不共享
def _maybe_share_embeddings(self, target_language_model):
pass # Draft models don't share embeddings
def _maybe_share_lm_head(self, target_language_model):
pass # Draft models don't share lm_head
Draft Model 拥有自己独立的嵌入层和 LM Head,不与目标模型共享。
3.8 medusa.py - MedusaProposer
MedusaProposer 是一个独立类(不继承 SpecDecodeBaseProposer),实现了 Medusa 推测方法。Medusa 的核心思想是在目标模型的最后几个 transformer 层之后附加多个"Medusa Head",每个 Head 预测未来第 k 个 token。
3.8.1 初始化
def __init__(self, vllm_config, device):
self.vllm_config = vllm_config
self.spec_config = vllm_config.speculative_config
self.device = device
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.hidden_size = self.spec_config.draft_model_config.get_hidden_size()
self.dtype = vllm_config.model_config.dtype
极简初始化,不维护任何 GPU 缓冲区。
3.8.2 propose() — 推测
def propose(self, target_hidden_states, sampling_metadata, slot_mappings=None):
blocks = self.model(target_hidden_states)
logits = self.model.compute_logits(blocks)
draft_tokens = torch.stack([logit.argmax(dim=-1) for logit in logits], dim=1)
return draft_tokens # [batch_size, num_heads]
流程:
self.model(target_hidden_states)— 将目标模型的隐藏状态传入 Medusa 模型(实际上是多个 Medusa Head)blocks是一个包含每个 Head 输出的列表self.model.compute_logits(blocks)— 计算每个 Head 的 logits- 对每个 Head 的 logits 做 argmax,得到一个 draft token
torch.stack(..., dim=1)拼接为[batch_size, num_heads]
3.8.3 load_model()
def load_model(self, target_model):
with set_model_tag("medusa_head"):
self.model = get_model(
vllm_config=self.vllm_config,
model_config=self.spec_config.draft_model_config,
)
assert not (is_mixture_of_experts(self.model) and enable_eplb)
加载 Medusa Head 模型,tag 为 “medusa_head”。检查 Mixture of Experts 模型与 EPLB 的兼容性。
3.8.4 dummy_run()
使用零填充的 hidden states 执行 dummy forward,用于 CUDA Graph 捕获。
3.9 extract_hidden_states.py - ExtractHiddenStatesProposer
ExtractHiddenStatesProposer 是一个特殊用途的 Proposer,其核心功能不是推测 token,而是缓存隐藏状态到 KV Cache。它主要用于 vLLM 的 KV 传输(KV Transfer) 场景。
3.9.1 初始化
def __init__(self, vllm_config, device):
assert vllm_config.speculative_config.num_speculative_tokens == 1
强制要求 num_speculative_tokens == 1,因为它本质上只做单步操作。
关键参数:
layer_ids = getattr(self.hf_config, "eagle_aux_hidden_state_layer_ids", None)— 辅助隐藏状态的层 ID 列表self.num_hidden_states = len(layer_ids)— 辅助隐藏状态的层数self.hidden_states = torch.zeros((max_num_tokens, num_hidden_states, hidden_size), ...)— 3D 隐藏状态缓冲区
3.9.2 propose() — “推测”(实际上是缓存)
def propose(self, sampled_token_ids, target_hidden_states, common_attn_metadata, slot_mappings=None):
stacked_hidden_states = torch.stack(target_hidden_states, dim=1)
num_tokens = stacked_hidden_states.shape[0]
self.hidden_states[:num_tokens] = stacked_hidden_states
流程:
- 将
target_hidden_states(list of tensors)沿 dim=1 拼接为[num_tokens, num_hidden_states, hidden_size] - 写入
self.hidden_states缓冲区 - 构建注意力元数据
- 执行
self.model(hidden_states=...)— 将隐藏状态存入 KV Cache(不做注意力计算) - 返回
sampled_token_ids[:, :1]— 直接返回目标模型采样的 token 作为 “draft”(保证 100% 接受率)
核心洞察:这个 Proposer 并不真正做推测。它的作用是将目标模型的辅助隐藏状态缓存到 KV Cache 中,供后续使用(如跨节点 KV 传输)。返回的 “draft tokens” 就是目标模型自己采样的 token,因此在验证阶段必然全部接受。
3.9.3 prepare_next_token_ids_padded()
简化版实现:num_speculative_tokens=1,直接判断 sampled token 是否有效,无效则使用 backup token。
3.9.4 load_model()
加载 ExtractHiddenStatesModel 模型(只有 cache-only attention 层),识别出恰好 1 个注意力层。
3.10 ngram_proposer.py - NgramProposer (CPU Numba)
NgramProposer 使用纯 CPU 端的 Numba JIT 编译实现 n-gram 文本匹配,无需加载任何模型。它利用提示文本中的重复模式来推测后续 token。
3.10.1 初始化
def __init__(self, vllm_config):
self.min_n = vllm_config.speculative_config.prompt_lookup_min # 最小 n-gram 长度
self.max_n = vllm_config.speculative_config.prompt_lookup_max # 最大 n-gram 长度
self.k = vllm_config.speculative_config.num_speculative_tokens # 推测 token 数
self.max_model_len = vllm_config.model_config.max_model_len
# 预分配 Numba 缓冲区
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
self.valid_ngram_draft = np.zeros((max_num_seqs, self.k), dtype=np.int32)
self.valid_ngram_num_drafts = np.zeros((max_num_seqs), dtype=np.int32)
线程配置:
num_tokens_threshold = 8192— 启用多线程的 token 数阈值num_numba_thread_available = min(1, cpu_count // 2)— 当前默认为 1 线程- 除以
tp_size确保每个 TP rank 都有线程
JIT 预热:在初始化时立即执行一次 propose([[]] * 1024, ...) 触发 Numba JIT 编译,避免首次实际调用时的延迟。
3.10.2 propose() — 推测入口
def propose(self, sampled_token_ids, num_tokens_no_spec, token_ids_cpu, slot_mappings=None):
valid_ngram_requests = []
for i, sampled_ids in enumerate(sampled_token_ids):
if not len(sampled_ids):
continue
if num_tokens_no_spec[i] >= self.max_model_len:
continue
valid_ngram_requests.append(i)
return self.batch_propose(len(sampled_token_ids), valid_ngram_requests, ...)
过滤条件:
- 必须有采样结果(非 partial prefill)
- 未超过最大模型长度
3.10.3 batch_propose() — 批量 Numba 推测
def batch_propose(self, num_requests, valid_ngram_requests, num_tokens_no_spec, token_ids_cpu):
if len(valid_ngram_requests):
# 动态调整线程数
if total_tokens >= self.num_tokens_threshold:
final_num_threads = max(1, min(self.num_numba_thread_available, num_ngram_requests))
else:
set_num_threads(1)
batch_propose_numba(valid_ngram_requests, ...)
for i in range(num_requests):
if i in valid_ngram_requests and self.valid_ngram_num_drafts[i] > 0:
draft_token_ids.append(self.valid_ngram_draft[i, :self.valid_ngram_num_drafts[i]].tolist())
else:
draft_token_ids.append([])
return draft_token_ids
多线程策略:当总 token 数 >= 8192 时启用多线程,否则使用单线程(小 batch 多线程开销大于收益)。
3.10.4 batch_propose_numba — Numba 并行入口
@njit(parallel=True)
def batch_propose_numba(valid_ngram_requests, num_tokens_no_spec, token_ids_cpu, ...):
for i in prange(len(valid_ngram_requests)):
idx = valid_ngram_requests[i]
context_token_ids = token_ids_cpu[idx, :num_tokens_no_spec[idx]]
drafter_output = _find_longest_matched_ngram_and_propose_tokens(...)
valid_ngram_num_drafts[idx] = drafter_output.shape[0]
if len(drafter_output):
valid_ngram_draft[idx, :drafter_output.shape[0]] = drafter_output
使用 prange 实现请求级并行,每个请求独立执行 n-gram 匹配。
3.10.5 _find_longest_matched_ngram_and_propose_tokens — KMP 算法核心
这是 N-gram 匹配的核心算法,基于 KMP(Knuth-Morris-Pratt)字符串匹配算法的变体:
@jit(nopython=True)
def _find_longest_matched_ngram_and_propose_tokens(origin_tokens, min_ngram, max_ngram, max_model_len, k):
算法步骤:
-
边界检查:如果 context 长度 <
min_ngram或已超出 max model length,直接返回空。 -
翻转 token 序列:
tokens = origin_tokens[::-1]- 翻转后,原序列的后缀变为前缀
- 目标变为:在翻转后的序列中找到最长的前缀后缀匹配(LPS, Longest Prefix which is also Suffix)
-
构建 LPS 数组:
lps = np.zeros(max_ngram, dtype=np.int32)lps[i]表示tokens[0:i+1]的最长公共前后缀长度。由于 n-gram 长度上限为max_ngram,LPS 数组大小仅需max_ngram。 -
KMP 核心循环:
prev_lps = 0 i = 1 while i < total_token: if tokens[prev_lps] == tokens[i]: prev_lps += 1 if prev_lps >= longest_ngram: longest_ngram = prev_lps position = i if i < max_ngram: lps[i] = prev_lps if prev_lps == max_ngram: prev_lps = lps[max_ngram - 1] i += 1 elif prev_lps != 0: prev_lps = lps[prev_lps - 1] else: i += 1- 匹配成功:
tokens[prev_lps] == tokens[i]→prev_lps += 1,更新最长匹配 - 匹配失败且有 fallback:
prev_lps = lps[prev_lps - 1]→ 回退到次长前缀 - 匹配失败且无 fallback:
i += 1→ 继续扫描 - 长度截断:当
prev_lps == max_ngram时,截断到lps[max_ngram - 1],防止超过最大 n-gram 长度
- 匹配成功:
-
结果提取:
if longest_ngram < min_ngram: return np.empty((0,), dtype=origin_tokens.dtype) start_position = total_token - 1 - position + longest_ngram k = min(k, total_token - start_position) return origin_tokens[start_position : start_position + k]position是翻转后的位置,total_token - 1 - position是翻转回原序列的位置- 从匹配 n-gram 的末尾开始,提取后续的
k个 token 作为推测结果
算法复杂度:O(N) 时间,O(max_ngram) 空间,其中 N 为序列长度。
3.11 ngram_proposer_gpu.py - NgramProposerGPU
NgramProposerGPU 是 N-gram 推测的GPU 加速版本,使用纯 PyTorch tensor 操作和 torch.compile 编译优化,完全避免了 CPU-GPU 数据传输。
3.11.1 架构
由两个类组成:
NgramGPUKernel(nn.Module)— 可编译的 GPU kernel,使用@support_torch_compile()装饰器NgramProposerGPU— 外层封装,负责数据准备和调度
3.11.2 NgramGPUKernel.__init__()
def __init__(self, vllm_config, prefix="", device):
self.min_n = ... # prompt_lookup_min
self.max_n = ... # prompt_lookup_max
self.k = ... # num_speculative_tokens
self.max_model_len = ...
self.max_num_seqs = ...
self.device = device
作为 nn.Module 的子类,可以被 torch.compile 编译。
3.11.3 _find_first_and_extract_all_n_parallel() — GPU 向量化匹配
这是 GPU 版本的核心算法,与 CPU 版本的 KMP 完全不同:
def _find_first_and_extract_all_n_parallel(
self, token_ids, seq_lengths, min_ngram_len, max_ngram_len, num_draft_tokens
):
算法思路:
-
遍历所有 n-gram 长度(
min_ngram到max_ngram):for i, ngram_len in enumerate(range(min_ngram_len, max_ngram_len + 1)): -
Sliding Window:使用
torch.unfold创建所有大小为ngram_len的滑动窗口search_windows = token_ids.unfold(1, ngram_len, 1)unfold是 O(1) view 操作,不复制数据。 -
Suffix 提取:获取每个序列末尾
ngram_len个 tokensuffix_starts = seq_lengths - ngram_len suffix_indices = suffix_starts.unsqueeze(1) + torch.arange(ngram_len, device=device) suffix = torch.gather(token_ids, 1, suffix_indices.clamp(min=0)) -
批量匹配:
matches = (search_windows == suffix.unsqueeze(1)).all(dim=-1)逐元素比较后沿 n-gram 维度做 all 操作,得到
(batch_size, num_windows)的布尔 mask。 -
有效性过滤:匹配位置必须允许至少 1 个 draft token
max_valid_suffix_start = seq_lengths - ngram_len - 1 valid_mask = window_positions <= max_valid_suffix_start.unsqueeze(1) final_matches = matches & valid_mask -
最早匹配:
first_match_idx = torch.argmax(final_matches.int(), dim=1) has_match = final_matches[batch_indices, first_match_idx]argmax在全部为 false 时返回 0,所以需要额外的has_match检查。 -
选择最长有效匹配:
best_ngram_idx = (first_match_positions >= 0).int().flip(dims=[1]).argmax(dim=1) best_ngram_idx = num_ngram_sizes - 1 - best_ngram_idx翻转后 argmax 找到最后一个(即最长的)有匹配的 n-gram 长度。
-
提取 draft tokens:
draft_start = best_match_pos + best_ngram_lengths draft_indices = draft_start.unsqueeze(1) + torch.arange(num_draft_tokens, device=device) draft_tokens = torch.gather(token_ids, 1, draft_indices.clamp(min=0, max=max_seq_len - 1)) -
有效性 mask:
- 超出可用 token 数的位置填充 -1
- 无匹配的位置全部填充 -1
算法复杂度:O(batch_size × (max_ngram - min_ngram) × max_seq_len),完全向量化,无 Python 循环(n-gram 长度的 for 循环在编译时展开)。
3.11.4 forward() — 前向传播
def forward(self, num_tokens_no_spec, token_ids_gpu, combined_mask):
actual_batch_size = token_ids_gpu.shape[0]
draft_tokens = torch.full((actual_batch_size, self.k), -1, ...)
results = self._find_first_and_extract_all_n_parallel(...)
draft_tokens = torch.where(combined_mask.unsqueeze(1), results, -1)
# 计算前导有效 token 数
is_valid = draft_tokens != -1
cum_valid = is_valid.int().cumsum(dim=1)
positions = torch.arange(1, self.k + 1, device=device).unsqueeze(0)
num_valid_draft_tokens = (cum_valid == positions).int().sum(dim=1)
return draft_tokens, num_valid_draft_tokens
有效 token 计数技巧:cumsum(dim=1) 然后与位置索引比较。如果 token 0 有效则 cumsum[0]=1, 与位置 1 相等;如果 token 1 也是有效则 cumsum[1]=2, 与位置 2 相等;一旦遇到 -1,cumsum 停止增长,不再匹配。
3.11.5 NgramProposerGPU.__init__() — 编译配置
compilation_config = CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=["none"],
splitting_ops=[],
compile_sizes=[],
inductor_compile_config={
"enable_auto_functionalized_v2": False,
"max_autotune": True,
"aggressive_fusion": True,
"triton.autotune_pointwise": True,
"coordinate_descent_tuning": True,
"use_mixed_mm": False,
},
cudagraph_mode=CUDAGraphMode.NONE,
)
编译选项详解:
enable_auto_functionalized_v2=False:禁用自动功能化,避免不必要的 in-place 操作转换max_autotune=True:启用 Triton 的自动调优aggressive_fusion=True:激进算子融合triton.autotune_pointwise=True:pointwise 操作的自动调优coordinate_descent_tuning=True:坐标下降调优(Triton kernel 的调参策略)use_mixed_mm=False:禁用混合精度矩阵乘法(N-gram 不需要矩阵乘法)cudagraph_mode=CUDAGraphMode.NONE:不使用 CUDA Graph(因为 shape 动态变化)
3.11.6 _dummy_run() — 预热
生成随机测试数据,执行 3 次 forward 以触发 torch.compile 编译。
3.11.7 propose() — GPU 推测入口
def propose(self, num_tokens_no_spec, token_ids_gpu, valid_sampled_token_ids_gpu, valid_sampled_tokens_count):
Scatter 操作:将新采样的 token 写入 token_ids_gpu(原位修改):
offsets = torch.arange(max_new_tokens, device=self.device)
write_positions = num_tokens_no_spec.unsqueeze(1) + offsets.unsqueeze(0)
scatter_mask = (valid_write_mask & (valid_sampled_token_ids_gpu != -1) & in_bounds)
token_ids_gpu.scatter_(1, write_positions_long, tokens_to_scatter)
Validity mask:
combined_mask = sampled_flags & valid_mask & (num_tokens_tmp >= self.min_n)
3.11.8 update_token_ids_ngram() — Token IDs 更新
用于准备推测解码输入:
- 处理
list[list[int]]到 tensor 的转换(padded mode) - 计算 backup next token IDs
- Mask 被丢弃请求的 token
- 计算每个请求的有效 token 数
- 找到最后一个有效 token 作为 next_token_id
3.11.9 辅助函数
update_scheduler_for_invalid_drafts():
- 等待异步 D2H 传输完成
- 根据
num_valid_draft_tokens裁剪调度器中的推测 token
update_ngram_gpu_tensors_incremental():
- 增量更新
token_ids_gpu_tensor和num_tokens_no_spec_gpu - 处理新请求、重排序请求
- 使用 pinned buffer 减少内存分配
3.12 suffix_decoding.py - SuffixDecodingProposer
SuffixDecodingProposer 实现了**后缀解码(Suffix Decoding)方法,基于 Arctic Inference 项目的官方实现。后缀解码利用提示文本的后缀树(Suffix Tree)**来推测 token。
3.12.1 初始化
def __init__(self, vllm_config):
self.num_speculative_tokens = config.num_speculative_tokens
self.max_tree_depth = config.suffix_decoding_max_tree_depth
self.max_spec_factor = config.suffix_decoding_max_spec_factor
self.min_token_prob = config.suffix_decoding_min_token_prob
self.max_model_len = vllm_config.model_config.max_model_len
from arctic_inference.suffix_decoding import SuffixDecodingCache
self.suffix_cache = SuffixDecodingCache(
max_tree_depth=config.suffix_decoding_max_tree_depth,
max_cached_requests=config.suffix_decoding_max_cached_requests,
)
关键配置:
max_tree_depth:后缀树的最大深度(控制 pattern 匹配长度)max_spec_factor:最大推测因子(限制动态推测的 token 数上限)min_token_prob:最小 token 概率阈值(低于此阈值的推测被丢弃)max_cached_requests:缓存的最大请求数
SuffixDecodingCache 是外部库对象,负责:
- 缓存请求的输出
- 管理过期请求的驱逐
- 维护每个 prompt 的后缀树
3.12.2 propose() — 推测
def propose(self, input_batch, sampled_token_ids, slot_mappings=None):
draft_token_ids = []
for i, sampled_ids in enumerate(sampled_token_ids):
if not sampled_ids:
draft_token_ids.append([])
continue
req_id = input_batch.req_ids[i]
num_tokens = input_batch.num_tokens_no_spec[i]
if num_tokens >= self.max_model_len:
draft_token_ids.append([])
continue
index = input_batch.req_id_to_index[req_id]
请求生命周期管理:
-
新请求:
if req_id not in self.suffix_cache.active_requests: if req_id in self.suffix_cache.cached_requests: self.suffix_cache.evict_cached_response(req_id) num_prompt_tokens = input_batch.num_prompt_tokens[index] prompt_token_ids = input_batch.token_ids_cpu[index, :num_prompt_tokens] self.suffix_cache.start_request(req_id, prompt_token_ids)- 如果请求之前被缓存但现在是新的,先驱逐旧缓存
- 使用 prompt token 构建后缀树
-
活跃请求:
self.suffix_cache.add_active_response(req_id, sampled_ids)将新采样的 token 添加到后缀缓存
-
推测:
start = max(0, num_tokens - self.max_tree_depth) pattern = input_batch.token_ids_cpu[i, start:num_tokens] draft = self.suffix_cache.speculate( req_id, pattern, max_spec_tokens=min(self.num_speculative_tokens, self.max_model_len - num_tokens - 1), max_spec_factor=self.max_spec_factor, min_token_prob=self.min_token_prob, ) draft_token_ids.append(draft.token_ids)- 提取最近的
max_tree_depth个 token 作为 pattern - 调用
suffix_cache.speculate()进行推测 - 推测结果是动态数量的(每请求不同)
- 提取最近的
-
清理:
for req_id in (self.suffix_cache.active_requests - input_batch.req_id_to_index.keys()): self.suffix_cache.stop_request(req_id)移除不在当前 batch 中的请求
四、综合架构分析
4.1 Proposer 方法对比矩阵
| 维度 | EAGLE / EAGLE3 | DFlash | Draft Model | Medusa | N-gram CPU | N-gram GPU | Suffix Decoding | Extract HS |
|---|---|---|---|---|---|---|---|---|
| 继承基类 | SpecDecodeBaseProposer |
SpecDecodeBaseProposer |
SpecDecodeBaseProposer |
无 | 无 | 无 | 无 | 无 |
| 需要模型加载 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ |
| 需要额外权重 | EAGLE head | DFlash head | 完整小型模型 | Medusa heads | 无 | 无 | 无 | Cache layer |
| Hidden State 传递 | ✅ (pass_hidden_states_to_model) | ✅ | ❌ | ❌ (直接传 hidden states) | N/A | N/A | N/A | ✅ |
| 推测机制 | 自回归递归 | 交叉注意力 | 自回归递归 | 多头并行 | 文本匹配 | 文本匹配 | 后缀树 | 缓存 |
| 并行推测 | 可选 (parallel_drafting) | 原生支持 | 可选 | 原生 | N/A | N/A | 动态 | N/A |
| 树状推测 | ✅ (tree_choices) | ❌ | ✅ (通过基类) | ❌ | N/A | N/A | ✅ (后缀树) | ❌ |
| 设备 | GPU | GPU | GPU | GPU | CPU | GPU | CPU | GPU |
| CUDA Graph 支持 | ✅ (PIECEWISE) | ✅ | ✅ | ✅ | N/A | ❌ | N/A | ✅ |
| torch.compile 支持 | ✅ | ✅ | ✅ | ✅ | N/A | ✅ | N/A | ✅ |
| 多模态支持 | ❌ (需要 extra slots) | ✅ (已覆盖) | ❌ | ❌ | N/A | N/A | N/A | N/A |
| TP 一致性要求 | 无 | 无 | 必须相同 | 无 | N/A | N/A | N/A | 无 |
| Vocab 一致性要求 | 可选 | 可选 | 必须相同 | 无 | N/A | N/A | N/A | 无 |
| num_spec_tokens 限制 | 任意 | 任意 | 任意 | 任意 | 任意 | 任意 | 动态 | 固定为 1 |
| 接受率影响因素 | EAGLE head 质量 | 交叉注意力质量 | 草稿模型质量 | Head 数量/质量 | 文本重复度 | 文本重复度 | 后缀树匹配度 | 固定 100% |
| 适用场景 | 通用 | Qwen3 系列 | 有独立草稿模型 | Medusa 架构 | 高重复文本 | 高吞吐重复文本 | 长尾请求 | KV 传输 |
4.2 性能优化手段汇总
4.2.1 Triton GPU Kernels
utils.py 中的 5 个 Triton kernel 是核心性能优化手段:
| Kernel | 优化目标 | 优化手段 |
|---|---|---|
eagle_step_slot_mapping_metadata_kernel |
减少 kernel launch | 融合 position 更新、slot mapping 计算、seq_lens 更新为一个 kernel |
eagle_prepare_inputs_padded_kernel |
避免 Python 循环 | GPU 端并行计算每个请求的采样索引 |
eagle_prepare_next_token_padded_kernel |
避免 CPU-GPU 同步 | GPU 端确定 next token ID |
copy_and_expand_eagle_inputs_kernel |
减少内存复制 | 2D grid 并行复制+扩展,一个 kernel 完成所有操作 |
copy_and_expand_dflash_inputs_kernel |
DFlash 专用优化 | context 和 query 分离处理的融合 kernel |
4.2.2 CUDA Graph
CudagraphDispatcher支持 PIECEWISE 模式的 CUDA Graphdummy_run()用于捕获阶段的 dummy forward_determine_batch_execution_and_padding()确定何时使用 CUDA Graph- 对固定 batch size 的推理步骤有显著加速(消除 kernel launch 开销)
4.2.3 torch.compile
NgramGPUKernel使用@support_torch_compile()装饰器- 编译配置启用了
max_autotune、aggressive_fusion、coordinate_descent_tuning update_num_computed_tokens_for_batch_change也使用了@torch.compile- 对于 M-RoPE 位置缓冲,故意制造 non-contiguous 布局以兼容 torch compile
4.2.4 内存共享
_maybe_share_embeddings():当草稿模型嵌入层与目标模型相同时,共享同一份 GPU 内存_maybe_share_lm_head():同理共享 LM Head- 节省约 2×hidden_size × vocab_size 的 GPU 内存
4.2.5 Local Argmax Reduction
use_local_argmax_reduction=True时,使用get_top_tokens()替代完整的compute_logits().argmax()- 通信复杂度从 O(vocab_size) 降到 O(2 × tp_size)
- 当存在
draft_id_to_target_id映射时 fallback 到完整路径
4.2.6 Numba JIT (CPU N-gram)
@njit(parallel=True)和@jit(nopython=True)实现 CPU 端的 JIT 编译prange实现请求级并行- 预编译(初始化时执行 dummy call)避免首次调用延迟
4.2.7 GPU 向量化 (N-gram GPU)
torch.unfold创建 O(1) view 的滑动窗口- 全 batch 并行匹配,无 Python 循环
torch.compile进一步优化算子融合
4.3 内存布局分析
4.3.1 核心缓冲区布局(SpecDecodeBaseProposer)
GPU Memory Layout (按 max_num_tokens = N, hidden_size = H 估算)
┌────────────────────────────────────────────────────┐
│ input_ids [N] int32 4N bytes │
├────────────────────────────────────────────────────┤
│ positions [N] int64 8N bytes │
│ ─ or ─ │
│ mrope_positions [3,N+1] int64 24(N+1) │
│ ─ or ─ │
│ xdrope_positions [D,N+1] int64 8D(N+1) │
├────────────────────────────────────────────────────┤
│ hidden_states [N,H] dtype 2HN bytes │
├────────────────────────────────────────────────────┤
│ inputs_embeds [N,H'] dtype 2HN' bytes │
├────────────────────────────────────────────────────┤
│ is_rejected_mask [N] bool N bytes │
│ is_masked_mask [N] bool N bytes │
├────────────────────────────────────────────────────┤
│ _slot_mapping_buf [N] int64 8N bytes │
│ arange [max] int32 4max bytes │
├────────────────────────────────────────────────────┤
│ backup_next_tokens [B] int32 4B bytes │
│ tree_pos_offsets [B,T] int32 4BT bytes │
└────────────────────────────────────────────────────┘
4.3.2 DFlash 专用内存布局
DFlash Extra Buffers:
┌────────────────────────────────────────────────────┐
│ _context_slot_mapping [N] int64 8N bytes │
│ _slot_mapping_buf [M] int64 8M bytes │
│ _context_positions [N] int64 8N bytes │
│ positions [M] int64 8M bytes │
│ arange [N+M+1] int32 4(N+M+1) │
└────────────────────────────────────────────────────┘
M = max_batch_size × (1 + num_speculative_tokens)
4.3.3 ExtractHiddenStates 专用内存布局
┌────────────────────────────────────────────────────┐
│ hidden_states [N,L,H] dtype 2N*L*H bytes │
│ _slot_mapping_buf [N] int64 8N bytes │
└────────────────────────────────────────────────────┘
L = num_hidden_states (辅助隐藏状态的层数)
4.4 GPU 缓冲区总览
| Proposer | 主要缓冲区 | 总内存估算 (N=4096, H=4096, B=128) |
|---|---|---|
| EagleProposer | input_ids, positions/mrope, hidden_states, inputs_embeds, masks, slot_mapping, backup, tree_offsets | ~65 MB |
| DFlashProposer | 上述 + context_slot_mapping, context_positions, query_positions | ~65 MB + ~0.1 MB |
| DraftModelProposer | 与 Eagle 相同 (不共享嵌入/lm_head) | ~65 MB + 草稿模型权重 |
| MedusaProposer | 无缓冲区 | 0 MB (仅模型权重) |
| NgramProposer | CPU 端缓冲区 | ~0.5 MB (CPU) |
| NgramProposerGPU | 无预分配缓冲区 | 0 MB |
| SuffixDecodingProposer | 外部库管理 | 取决于后缀树大小 |
| ExtractHiddenStates | hidden_states [N,L,H], slot_mapping | ~128 MB (L=4) + 0.03 MB |
五、总结与最佳实践
5.1 架构设计亮点
-
统一的基类设计:
SpecDecodeBaseProposer为所有基于模型的推测方法提供了统一的接口和缓冲区管理,大幅减少了重复代码。 -
灵活的输入准备策略:通过
needs_extra_input_slots标志在"移位模式"(EAGLE 默认)和"扩展模式"(parallel drafting / draft model)之间切换,使用不同的 Triton kernel。 -
多层次的性能优化:从 Triton kernel 融合、CUDA Graph 捕获、torch.compile 编译,到内存共享和 local argmax reduction,覆盖了 GPU 推理的多个优化维度。
-
设备无关的 N-gram 实现:提供 CPU Numba 和 GPU 向量化两种实现,适应不同场景需求。
-
模块化设计:8 种推测方法完全独立,按需加载,不相互依赖。
5.2 各方法的适用场景建议
| 场景 | 推荐方法 | 理由 |
|---|---|---|
| 通用 LLM 推理加速 | EAGLE / EAGLE3 | 成熟稳定,支持树状推测和并行推测 |
| Qwen3 系列模型 | DFlash | 原生支持,交叉注意力效率高 |
| 有预训练的小型草稿模型 | Draft Model | 利用已有的小模型 |
| Medusa 架构模型 | Medusa | 原生支持多头并行预测 |
| 提示文本重复度高 | N-gram GPU | 零模型开销,GPU 加速 |
| 多模态推理 | DFlash | 唯一明确支持多模态的 Proposer |
| KV 传输场景 | Extract Hidden States | 专为缓存隐藏状态设计 |
| 动态推测长度 | Suffix Decoding | 基于后缀树的动态推测 |
5.3 常见问题排查
-
M-RoPE 不兼容 parallel drafting:
_raise_if_mrope()会在需要额外 slot 且使用 M-RoPE 时抛出NotImplementedError。 -
TP 不匹配:Draft Model 要求目标模型和草稿模型的 TP 并行度一致。
-
Vocab 不匹配:Draft Model 要求词汇表大小一致。
-
CUDA Graph 捕获失败:确保
dummy_run()中使用的 buffer 与运行时一致。 -
NaN logits:MTP 模型需要显式共享
shared_head.head,否则会产生 NaN。
5.4 未来演进方向
-
M-RoPE 支持 parallel drafting:当前
_raise_if_mrope()阻止了 M-RoPE 模型的并行推测,扩展支持后可以覆盖更多模型。 -
多模态 draft model 支持:
_raise_if_multimodal()目前阻止了除 DFlash 外的多模态推测,扩展后可利用 EAGLE 等方法加速多模态推理。 -
DBO ubatching:代码中多次提到
TODO(Flechman): support DBO ubatching,支持数据并行下的 microbatching 可以进一步提升吞吐。 -
带概率采样的 draft:文件末尾注释表明将来可能使用带温度参数的采样替代 argmax,以获得更好的接受率。
文档结束
本文档基于 vLLM v1spec_decode模块的源码进行逐行分析,共涵盖 12 个源文件,总计约 3600 行代码。所有分析均基于实际源码,未添加任何推测性内容。
附录 A: EAGLE 推测解码完整时序图
附录 B: CUDA Graph 捕获与 Replay 流程
附录 C: 推测 Token 树结构示意图
附录 D: N-gram CPU vs GPU 对比
附录 E: SpecDecodingStats 指标计算流程
for i i ----------------------^ Expecting 'SQE', 'DOUBLECIRCLEEND', 'PE', '-)', 'STADIUMEND', 'SUBROUTINEEND', 'PIPE', 'CYLINDEREND', 'DIAMOND_STOP', 'TAGEND', 'TRAPEND', 'INVTRAPEND', 'UNICODE_TEXT', 'TEXT', 'TAGSTART', got 'SQS'
更多推荐



所有评论(0)