【vllm】(v1 Sample)vLLM V1 Sample—Part 2 Logits处理器体系与思考预算
vLLM V1 Sample 模块超深度架构分析 — Part 2: Logits处理器体系与思考预算
vLLM V1 Sample 模块超深度架构分析 — Part 2: Logits处理器体系与思考预算
分析范围:
vllm/v1/sample/logits_processor/+thinking_budget_state.py
分析日期: 2026-05-25
分析深度: 架构师级,逐行解析,Mermaid图表30+
目录
- 第七章 Logits处理器抽象接口体系
- 第八章 BatchUpdateBuilder 批次更新构建器
- 第九章 LogitsProcessors 容器类
- 第十章 内置Logits处理器深度解析
- 第十一章 Logits处理器构建与插件系统
- 第十二章 AdapterLogitsProcessor 适配器模式
- 第十三章 ThinkingBudgetStateHolder 思考预算管理
- 附录C Logits处理器生命周期时序图
- 附录D 思考预算状态机全图
第七章 Logits处理器抽象接口体系
7.1 模块定位与架构总览
logits_processor/ 子模块定义了 vLLM V1 中 logits 后处理的可扩展框架。其核心设计理念是:
- 抽象接口:
LogitsProcessor定义统一的apply()/update_state()/is_argmax_invariant()接口 - 批次感知:通过
BatchUpdate机制,处理器可以跟踪持久批次(persistent batch)中请求的增删移动 - 双分类体系:处理器分为 argmax-invariant(不影响贪心采样)和 non-argmax-invariant(影响贪心采样)两类,影响采样管线的应用顺序
- 插件化扩展:支持通过 entry_points 和 FQCN 加载自定义处理器
7.2 MoveDirectionality 枚举
class MoveDirectionality(Enum):
"""批次内请求移动的方向性"""
UNIDIRECTIONAL = auto() # 单向移动: i1 → i2
SWAP = auto() # 双向交换: i1 ↔ i2
设计背景:vLLM V1 使用持久批次(persistent batch),请求在整个生命周期中占据固定的索引位置。当请求完成或新请求加入时,需要移动或交换请求位置以保持批次的紧凑性。
UNIDIRECTIONAL:请求从索引i1移动到i2(i1位置被清空)SWAP:两个请求交换位置(两个位置都有新内容)
对logits处理器的影响:处理器维护 dict[int, state] 映射,当请求移动时需要同步更新映射的键。
7.3 BatchUpdate 数据类
@dataclass(frozen=True)
class BatchUpdate:
"""持久批次状态变更信息
记录一次调度步骤中批次的增删移动操作
"""
batch_size: int # 当前批次中的请求数量
# 三种操作的元数据(按顺序处理: removed → added → moved)
removed: Sequence[RemovedRequest] # 被移除的请求索引列表
added: Sequence[AddedRequest] # 新增请求的(index, params, prompt_ids, output_ids)元组
moved: Sequence[MovedRequest] # 移动请求的(i1, i2, direction)元组
冻结性(frozen=True):BatchUpdate 一旦创建就不可修改,确保处理器读到的是一致的状态快照。
类型别名:
RemovedRequest = int # 被移除请求的索引
AddedRequest = tuple[int, SamplingParams, list[int] | None, list[int]]
# (index, sampling_params, prompt_token_ids, output_token_ids)
# output_token_ids 是对请求运行中的输出token列表的引用
# 通过这个引用,logits处理器始终能看到最新的已生成token
MovedRequest = tuple[int, int, MoveDirectionality]
# (index_1, index_2, direction)
关键设计假设:
output_tok_ids列表是请求运行中的输出token列表的引用(而非副本)。因此,logits处理器通过这个引用始终能看到最新的已生成token列表。
这个设计使得处理器不需要在每步都复制完整的输出token列表,节省内存和计算。
操作处理顺序:removed → added → moved
- 先移除:释放索引位置,避免与新增请求冲突
- 再新增:填充空出的位置
- 最后移动:在稳定的新索引上执行位置调整
7.4 LogitsProcessor 抽象基类
class LogitsProcessor(ABC):
"""Logits处理器抽象基类
所有logits处理器必须实现4个抽象方法:
1. __init__(vllm_config, device, is_pin_memory) — 构造器
2. apply(logits) — 应用处理逻辑
3. is_argmax_invariant() — 声明是否影响贪心采样
4. update_state(batch_update) — 更新内部状态
"""
@classmethod
def validate_params(cls, sampling_params: SamplingParams):
"""验证采样参数是否适用于此处理器
默认实现:不做验证(返回None)
子类可覆盖以抛出ValueError
"""
return None
@abstractmethod
def __init__(
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
) -> None:
"""构造器——必须接受三个参数(即使不全部使用)"""
raise NotImplementedError
@abstractmethod
def apply(self, logits: torch.Tensor) -> torch.Tensor:
"""应用处理器到批次logits张量
Args:
logits: [batch_size, vocab_size] float32
Returns:
处理后的logits(可以原地修改,但必须返回)
"""
raise NotImplementedError
@abstractmethod
def is_argmax_invariant(self) -> bool:
"""是否不影响argmax(贪心采样结果)
True → 可以在贪心采样后应用(如MinP)
False → 必须在贪心采样前应用(如MinTokens, LogitBias)
"""
raise NotImplementedError
@abstractmethod
def update_state(self, batch_update: "BatchUpdate | None") -> None:
"""批次变更时更新处理器内部状态
在每次forward前调用
"""
raise NotImplementedError
argmax-invariant 的分类影响:
7.5 类型别名体系
第八章 BatchUpdateBuilder 批次更新构建器
8.1 设计动机与职责
BatchUpdateBuilder 是一个增量构建器,负责在调度步骤中收集批次变更信息,最终生成不可变的 BatchUpdate 对象。
核心问题:vLLM V1 的持久批次在每步可能经历多种变更(移除完成的请求、添加新请求、重排位置),需要一个中间层来收集这些变更并确保一致性。
假设与保证:
| 假设 | 保证 |
|---|---|
所有移除通过 removed_append() 注册 |
removed 列表始终降序排列 |
removed_append() 在 removed/pop/peek 之前调用 |
pop_removed() 返回最小索引 |
不直接修改 _removed 列表 |
has_removed() 正确反映移除状态 |
8.2 init 初始化
def __init__(
self,
removed: list[RemovedRequest] | None = None, # 预设的移除列表
added: list[AddedRequest] | None = None, # 预设的新增列表
moved: list[MovedRequest] | None = None, # 预设的移动列表
) -> None:
self._removed = removed or [] # 移除的请求索引
self.added = added or [] # 新增的请求信息
self.moved = moved or [] # 移动的请求信息
self._is_removed_sorted = False # 排序标志:首次访问时排序
# 用于追踪pooling场景的批次变更
# pooling模型不填充added列表,但batch_changed仍需正确反映变更
self.batch_changed = False
8.3 removed_append() 移除注册
def removed_append(self, index: int) -> None:
"""注册一个请求从持久批次中移除
⚠️ 必须在首次调用 self.removed / pop_removed / peek_removed 之前调用
否则抛出 RuntimeError
Args:
index: 被移除请求在持久批次中的索引
"""
if self._is_removed_sorted:
# 如果已经排序过(即removed已被读取),不允许再追加
# 因为追加会破坏排序不变量
raise RuntimeError(
"Cannot register new removed request after self.removed has been read."
)
self._removed.append(index)
self.batch_changed = True # 标记批次有变更
排序不变量:_removed 列表在首次访问时按降序排序。降序的原因是在持久批次中移除请求时,从高索引向低索引移除可以避免索引偏移问题(类似删除数组元素时从后往前删)。
8.4 _ensure_removed_sorted() 排序保证
def _ensure_removed_sorted(self) -> None:
"""确保 _removed 列表按降序排列
幂等操作:首次调用后不再重复排序
"""
if not self._is_removed_sorted:
self._removed.sort(reverse=True) # 降序排列
self._is_removed_sorted = True
8.5 peek_removed() 与 pop_removed()
def has_removed(self) -> bool:
"""是否有待处理的移除请求"""
return bool(self._removed)
def peek_removed(self) -> int | None:
"""查看最小的移除索引(不弹出)
因为_removed是降序排列,最后一个元素就是最小值
"""
if self.has_removed():
self._ensure_removed_sorted()
return self._removed[-1] # 降序列表的最后一个=最小值
return None
def pop_removed(self) -> int | None:
"""弹出最小的移除索引
返回后该索引从列表中移除
"""
if self.has_removed():
self._ensure_removed_sorted()
return self._removed.pop() # pop最后一个=最小值
return None
降序设计理由:
pop()在Python列表中操作末尾元素是O(1)- 如果升序排列,弹出最小值需要
pop(0)是O(n) - 降序排列使
pop()直接弹出最小值,高效
8.6 reset() 与 get_and_reset()
def reset(self) -> bool:
"""重置内部状态,返回是否有变更
Returns:
True 如果本步有任何批次变更
"""
self._is_removed_sorted = False
self._removed.clear()
self.added.clear()
self.moved.clear()
batch_changed = self.batch_changed # 保存当前值
self.batch_changed = False
return batch_changed
def get_and_reset(self, batch_size: int) -> BatchUpdate | None:
"""生成BatchUpdate并重置内部状态
Args:
batch_size: 当前持久批次的请求数量
Returns:
BatchUpdate实例(如果没有变更则返回None)
"""
self._is_removed_sorted = False
self.batch_changed = False
# 快速路径:三种操作都为空
if not any((self._removed, self.moved, self.added)):
return None
# 构建不可变的BatchUpdate
batch_update = BatchUpdate(
batch_size=batch_size,
removed=self._removed,
moved=self.moved,
added=self.added,
)
# 重置列表(创建新列表,因为BatchUpdate持有旧列表的引用)
self._removed = []
self.moved = []
self.added = []
return batch_update
注意:get_and_reset() 将旧列表传给 BatchUpdate,然后创建新列表。这意味着 BatchUpdate 持有的是调用时刻的快照,后续修改不影响已生成的 BatchUpdate。
第九章 LogitsProcessors 容器类
9.1 设计目的
LogitsProcessors 是所有已初始化logits处理器的容器,提供按 argmax-invariant 属性分类的访问接口。
9.2 argmax_invariant / non_argmax_invariant 分类
class LogitsProcessors:
"""封装已初始化的logits处理器对象"""
def __init__(self, logitsprocs: Iterable["LogitsProcessor"] | None = None) -> None:
self.argmax_invariant: list[LogitsProcessor] = [] # 不影响贪心采样
self.non_argmax_invariant: list[LogitsProcessor] = [] # 影响贪心采样
if logitsprocs:
for logitproc in logitsprocs:
# 根据is_argmax_invariant()分类
(
self.argmax_invariant
if logitproc.is_argmax_invariant()
else self.non_argmax_invariant
).append(logitproc)
分类的重要性:Sampler 在不同阶段调用不同类别的处理器:
non_argmax_invariant→ 在贪心采样前调用(可能改变argmax结果)argmax_invariant→ 在贪心采样后/温度缩放后调用(不影响贪心但影响随机采样)
9.3 all 属性迭代器
@property
def all(self) -> Iterator["LogitsProcessor"]:
"""迭代所有处理器(argmax_invariant在前,non_argmax_invariant在后)"""
return chain(self.argmax_invariant, self.non_argmax_invariant)
使用 itertools.chain 避免创建新列表,惰性迭代节省内存。
第十章 内置Logits处理器深度解析
10.1 MinPLogitsProcessor — Min-P过滤
设计目的:Min-P 是一种动态阈值过滤策略,过滤掉概率低于 max_probability × min_p 的token。与 top-p 不同,min-p 的阈值随最大概率自适应调整。
argmax-invariant:True — 最大概率的token永远不会被过滤(因为 max_prob × min_p < max_prob),因此不影响贪心采样结果。
init 逐行解析
class MinPLogitsProcessor(LogitsProcessor):
def __init__(
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
):
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
self.min_p_count: int = 0 # 当前批次中有min_p > 0的请求数
# CPU端numpy数组:零拷贝与GPU同步
self.min_p_cpu_tensor = torch.zeros(
(max_num_reqs,), dtype=torch.float32, device="cpu",
pin_memory=is_pin_memory # pin_memory加速CPU→GPU传输
)
# numpy视图:共享底层内存,可零拷贝读写
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
# 是否需要双缓冲(CPU + GPU)
self.use_double_tensor = torch.device(device).type != "cpu"
# CPU设备不需要双缓冲(直接在CPU tensor上操作)
if self.use_double_tensor:
# GPU设备:预分配GPU tensor
self.min_p_device: torch.Tensor = torch.empty(
(max_num_reqs,), dtype=torch.float32, device=device
)
else:
# CPU设备:直接使用CPU tensor
self.min_p_device = self.min_p_cpu_tensor
# 当前切片:只使用[:batch_size]的部分
self.min_p: torch.Tensor = self.min_p_device[:0] # 初始为空
双缓冲设计:
- CPU端(
min_p_cpu_tensor/min_p_cpu):用numpy高效更新 - GPU端(
min_p_device):用于apply()中的GPU计算 min_p.copy_(min_p_cpu_tensor[:size]):异步CPU→GPU传输
update_state() 逐行解析
def update_state(self, batch_update: BatchUpdate | None):
if not batch_update:
return # 无变更,跳过
needs_update = False
# ===== 处理新增请求 =====
for index, params, _, _ in batch_update.added:
min_p = params.min_p # 从采样参数中获取min_p值
min_p_before = self.min_p_cpu[index] # 该索引之前的min_p值
if min_p_before != min_p:
needs_update = True
self.min_p_cpu[index] = min_p # 更新CPU数组
if min_p and not min_p_before:
# 从0变为非0:增加活跃计数
self.min_p_count += 1
elif not min_p and min_p_before:
# 从非0变为0:减少活跃计数
self.min_p_count -= 1
if self.min_p_count: # 只有存在活跃min_p请求时才处理移除/移动
# ===== 处理移除请求 =====
if batch_update.removed:
needs_update = True
for index in batch_update.removed:
if self.min_p_cpu[index]:
self.min_p_cpu[index] = 0
self.min_p_count -= 1
# ===== 处理移动请求 =====
for adx, bdx, direct in batch_update.moved:
min_p_a, min_p_b = self.min_p_cpu[adx], self.min_p_cpu[bdx]
if min_p_a != min_p_b:
needs_update = True
# 目标位置总是接收源位置的值
self.min_p_cpu[bdx] = min_p_a
if direct == MoveDirectionality.SWAP:
# 双向交换:源位置也接收目标位置的值
self.min_p_cpu[adx] = min_p_b
if direct == MoveDirectionality.UNIDIRECTIONAL:
# 单向移动:源位置清零
if min_p_a:
self.min_p_cpu[adx] = 0
if min_p_b:
# min_p_b原来就有值但被覆盖了
self.min_p_count -= 1
# ===== 同步到GPU =====
size = batch_update.batch_size
if self.min_p_count and (needs_update or self.min_p.shape[0] != size):
# 只在有活跃min_p或批次大小变化时更新
self.min_p = self.min_p_device[:size] # 切片到当前批次大小
if self.use_double_tensor:
# 异步CPU→GPU传输
self.min_p.copy_(self.min_p_cpu_tensor[:size], non_blocking=True)
self.min_p.unsqueeze_(1) # [batch_size] → [batch_size, 1] 用于广播
unsqueeze_(1) 的意义:min_p 需要与 logits [B, V] 做广播比较,因此需要从 [B] 扩展为 [B, 1]。
apply() 逐行解析
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if not self.min_p_count:
return logits # 快速路径:无活跃min_p请求
# 1. 计算softmax概率分布
probability_values = torch.nn.functional.softmax(logits, dim=-1)
# [batch_size, vocab_size]
# 2. 计算每行的最大概率
max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True)
# [batch_size, 1] — keepdim=True保留维度用于广播
# 3. 计算自适应阈值 = max_prob × min_p
adjusted_min_p = max_probabilities.mul_(self.min_p)
# [batch_size, 1] — self.min_p已经是[B, 1]形状
# mul_() 是原地操作,修改max_probabilities
# 4. 标记低于阈值的token
invalid_token_mask = probability_values < adjusted_min_p
# [batch_size, vocab_size] bool
# 5. 屏蔽无效token
logits.masked_fill_(invalid_token_mask, -float("inf"))
return logits
Min-P vs Top-P 的对比:
| 特性 | Min-P | Top-P |
|---|---|---|
| 阈值类型 | 绝对概率阈值 | 累积概率阈值 |
| 阈值计算 | max_prob × min_p | 直接指定 |
| 自适应性 | 随max_prob动态调整 | 固定值 |
| 效果 | 当模型很确定时更严格 | 始终保留概率和≥p的token |
| argmax影响 | 不影响(最大概率总保留) | 不影响(累积概率包含最大值) |
10.2 LogitBiasLogitsProcessor — 对数偏置
设计目的:为特定token添加偏置值,影响采样概率。典型用法:增加特定token被采样的概率(正偏置)或降低概率(负偏置)。
argmax-invariant:False — 偏置可以改变最大logit的位置,从而改变贪心采样结果。
init 与 update_state
class LogitBiasLogitsProcessor(LogitsProcessor):
def __init__(self, _, device: torch.device, is_pin_memory: bool):
# _ 是vllm_config,此处不使用
self.device = device
self.pin_memory = is_pin_memory
# 核心数据结构:{req_index: {token_id: bias_value}}
# 稀疏存储:只有设置了logit_bias的请求才在dict中
self.biases: dict[int, dict[int, float]] = {}
# GPU上的偏置张量
self.bias_tensor: torch.Tensor = torch.tensor(())
# 索引张量:(req_indices, token_indices) — 用于logits的scatter操作
self.logits_slice = (
self._device_tensor([], torch.int32), # 请求索引
self._device_tensor([], torch.int32), # token索引
)
def update_state(self, batch_update: BatchUpdate | None):
# 使用通用状态更新工具
needs_update = process_dict_updates(
self.biases,
batch_update,
lambda params, _, __: params.logit_bias or None
# 从SamplingParams提取logit_bias字典
# 如果没有logit_bias则返回None(表示该请求不需要此处理器)
)
# 如果状态有变化,重建GPU张量
if needs_update:
reqs: list[int] = [] # 请求索引
tok_ids: list[int] = [] # token索引
biases: list[float] = [] # 偏置值
for req, lb in self.biases.items():
reqs.extend([req] * len(lb)) # 每个偏置对应一个请求索引
tok_ids.extend(lb.keys()) # token id
biases.extend(lb.values()) # 偏置值
# 构建GPU张量
self.bias_tensor = self._device_tensor(biases, torch.float32)
self.logits_slice = (
self._device_tensor(reqs, torch.int32),
self._device_tensor(tok_ids, torch.int32),
)
def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
"""创建CPU tensor + 异步传输到GPU"""
return torch.tensor(
data, device="cpu", dtype=dtype,
pin_memory=self.pin_memory
).to(device=self.device, non_blocking=True)
稀疏索引设计:logits[logits_slice] 等价于 logits[req_indices, token_indices],这是PyTorch的高级索引语法,可以一次性为多个位置赋值/加值。
apply()
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if self.biases:
# 使用高级索引一次性添加所有偏置
# logits[req_indices, token_indices] += bias_values
logits[self.logits_slice] += self.bias_tensor
return logits
性能分析:
- 常规方法:遍历每个请求和每个偏置,逐个
logits[req, tok] += bias→ O(n) 次小kernel - 高级索引:单次scatter-add kernel → O(1) 次kernel launch
- 当偏置数量较多时(如100+个),高级索引显著更快
10.3 MinTokensLogitsProcessor — 最小长度约束
设计目的:强制模型生成至少N个token后才允许输出停止token(EOS)。通过屏蔽EOS token的logit实现。
argmax-invariant:False — 屏蔽EOS可能改变argmax结果(如果原本最大logit对应EOS)。
init
class MinTokensLogitsProcessor(LogitsProcessor):
def __init__(
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
):
self.device = device
self.pin_memory = is_pin_memory
# 核心状态: {req_index: (min_tokens, output_token_ids_ref, stop_token_ids_set)}
# output_token_ids是引用,始终反映最新的输出长度
self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {}
# 索引张量: (req_indices, stop_token_indices)
self.logits_slice: tuple[torch.Tensor, torch.Tensor] = (
self._device_tensor([], torch.int32),
self._device_tensor([], torch.int32),
)
# -inf张量: 用于屏蔽stop token
self.neg_inf_tensor = torch.tensor(
-float("inf"), dtype=torch.float32, device=self.device
)
update_state()
def update_state(self, batch_update: BatchUpdate | None):
needs_update = process_dict_updates(
self.min_toks, batch_update, self.add_request
)
# 检查已满足最小长度要求的请求
if self.min_toks:
to_remove = tuple(
index
for index, (min_toks, out_tok_ids, _) in self.min_toks.items()
if len(out_tok_ids) >= min_toks
# output_token_ids是引用,长度实时更新
# 当已生成token数 >= min_tokens时,移除该请求的约束
)
if to_remove:
needs_update = True
for index in to_remove:
del self.min_toks[index]
# 重建GPU索引张量
if needs_update:
reqs: list[int] = []
tok_ids: list[int] = []
for req, (_, _, stop_tok_ids) in self.min_toks.items():
reqs.extend([req] * len(stop_tok_ids))
tok_ids.extend(stop_tok_ids)
self.logits_slice = (
self._device_tensor(reqs, torch.int32),
self._device_tensor(tok_ids, torch.int32),
)
关键设计:output_token_ids 是引用,因此 len(out_tok_ids) 实时反映已生成token数。无需在每步手动更新。
apply()
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if self.min_toks:
# 将stop token的logit设为-inf
# index_put_ 类似于 logits[req_indices, tok_indices] = -inf
logits.index_put_(self.logits_slice, self.neg_inf_tensor)
return logits
index_put_ vs 高级索引赋值:
logits[slice] = value→ 创建新张量(非原地)logits.index_put_(slice, value)→ 原地修改,更高效
apply_with_spec_decode() — 投机解码版本
def apply_with_spec_decode(
self,
logits: torch.Tensor, # [num_tokens, vocab_size] 所有draft位置的logits
num_draft_tokens: list[int], # 每请求的draft token数
) -> torch.Tensor:
"""投机解码版本: 需要对每个draft位置独立判断是否屏蔽stop token
Example: num_draft_tokens = [2, 3, 1]
→ logits shape [6, V]
→ cumsum = [0, 2, 5, 6]
→ request 0 owns rows 0-1, request 1 rows 2-4, request 2 row 5
"""
if not self.min_toks:
return logits
# 计算每个请求在展平logits中的起始行偏移
num_draft_arr = np.array(num_draft_tokens, dtype=np.int64)
cumsum = np.concatenate([[0], np.cumsum(num_draft_arr)])
# 收集需要屏蔽的行和token
entries = [
(req_idx, min_tok, len(out_tok_ids), list(stop_tok_ids))
for req_idx, (min_tok, out_tok_ids, stop_tok_ids) in self.min_toks.items()
if stop_tok_ids
]
if not entries:
return logits
all_rows: list[np.ndarray] = [] # 行索引
all_toks: list[np.ndarray] = [] # stop token ids
for req_idx, min_tok, current_len, stop_toks in entries:
remaining = min_tok - current_len # 还需要多少token才满足最小长度
# 计算前多少个draft位置仍需要屏蔽stop token
n_mask = int(min(max(remaining, 0), num_draft_arr[req_idx]))
if n_mask > 0:
offset = cumsum[req_idx]
row_indices = np.arange(offset, offset + n_mask, dtype=np.int64)
n_stop = len(stop_toks)
# 每个stop token在每个需要屏蔽的行都要出现
all_rows.append(np.repeat(row_indices, n_stop))
all_toks.append(np.tile(stop_toks, n_mask))
if all_rows:
rows_arr = np.concatenate(all_rows)
toks_arr = np.concatenate(all_toks)
logits_slice = (
torch.from_numpy(rows_arr).to(self.device, non_blocking=True),
torch.from_numpy(toks_arr).to(self.device, non_blocking=True),
)
logits.index_put_(logits_slice, self.neg_inf_tensor)
return logits
10.4 process_dict_updates() 通用状态更新工具
def process_dict_updates(
req_entries: dict[int, T], # 处理器的状态字典(如self.biases, self.min_toks)
batch_update: BatchUpdate | None,
new_state: Callable[[SamplingParams, list[int] | None, list[int]], T | None],
) -> bool:
"""通用字典状态更新工具
处理BatchUpdate中的三种操作(removed/added/moved),
更新处理器的状态字典
Args:
req_entries: 处理器维护的 {req_index: state} 字典
batch_update: 批次变更信息
new_state: 工厂函数,从SamplingParams创建新的状态条目
Returns:
True 如果字典有变化(需要重建GPU张量)
"""
if not batch_update:
return False
updated = False
# ===== 处理新增请求 =====
for index, params, prompt_tok_ids, output_tok_ids in batch_update.added:
if (state := new_state(params, prompt_tok_ids, output_tok_ids)) is not None:
req_entries[index] = state # 新请求有状态 → 插入
updated = True
elif req_entries.pop(index, None) is not None:
# 新请求不需要此处理器,但该索引有旧状态 → 清除
updated = True
if req_entries: # 只有字典非空时才处理移除和移动
# ===== 处理移除请求 =====
for index in batch_update.removed:
if req_entries.pop(index, None):
updated = True
# ===== 处理移动请求 =====
for a_index, b_index, direct in batch_update.moved:
a_entry = req_entries.pop(a_index, None) # 弹出源
b_entry = req_entries.pop(b_index, None) # 弹出目标
if a_entry is not None:
req_entries[b_index] = a_entry # 源→目标
updated = True
if b_entry is not None:
updated = True
if direct == MoveDirectionality.SWAP:
req_entries[a_index] = b_entry # 目标→源(双向交换)
# UNIDIRECTIONAL: 源位置自然为空(已pop)
return updated
设计精妙之处:
- 先pop后放回:避免交换时键冲突(如果先
a→b,再b→a,可能丢失中间状态) - SWAP vs UNIDIRECTIONAL:SWAP双向交换;UNIDIRECTIONAL单向移动,源位置自动清空
- 返回needs_update:调用者据此决定是否重建GPU张量,避免无变化时的冗余重建
第十一章 Logits处理器构建与插件系统
11.1 _load_logitsprocs_plugins() 插件加载
LOGITSPROCS_GROUP = "vllm.logits_processors" # entry_points组名
def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]:
"""通过importlib.metadata加载已安装的logitproc插件
插件通过setup.py/pyproject.toml注册:
[project.entry-points."vllm.logits_processors"]
my_proc = "my_package:MyLogitsProcessor"
"""
from importlib.metadata import entry_points
installed_logitsprocs_plugins = entry_points(group=LOGITSPROCS_GROUP)
if len(installed_logitsprocs_plugins) == 0:
logger.debug("No logitsprocs plugins installed (group %s).", LOGITSPROCS_GROUP)
return []
classes: list[type[LogitsProcessor]] = []
for entrypoint in installed_logitsprocs_plugins:
try:
with guard_cuda_initialization():
# guard_cuda_initialization: 防止import时触发CUDA初始化
classes.append(entrypoint.load())
except Exception as e:
logger.error("Failed to load LogitsProcessor plugin %s: %s", entrypoint, e)
raise RuntimeError(
f"Failed to load LogitsProcessor plugin {entrypoint}"
) from e
return classes
guard_cuda_initialization() 的作用:在CUDA尚未初始化时加载模块可能导致子进程fork失败(CUDA runtime不支持fork)。此上下文管理器确保import期间不触发CUDA初始化。
11.2 _load_logitsprocs_by_fqcns() FQCN加载
def _load_logitsprocs_by_fqcns(
logits_processors: Sequence[str | type[LogitsProcessor]] | None,
) -> list[type[LogitsProcessor]]:
"""通过完全限定类名(FQCN)加载logit处理器
FQCN语法: <module_path>:<qualname>
例如: "my_package.processors:CustomLogitsProcessor"
"""
if not logits_processors:
return []
classes: list[type[LogitsProcessor]] = []
for ldx, logitproc in enumerate(logits_processors):
if isinstance(logitproc, type):
# 已经是类对象 → 直接验证
if not issubclass(logitproc, LogitsProcessor):
raise ValueError(
f"{logitproc.__name__} is not a subclass of LogitsProcessor"
)
classes.append(logitproc)
continue
# 字符串FQCN → 动态加载
module_path, qualname = logitproc.split(":")
try:
with guard_cuda_initialization():
module = importlib.import_module(module_path)
except Exception as e:
raise RuntimeError(
f"Failed to load {ldx}th LogitsProcessor plugin {logitproc}"
) from e
# 沿点号路径获取类对象
# 例如: "a.b.c" → getattr(getattr(module, "a"), "b").c
obj = module
for attr in qualname.split("."):
obj = getattr(obj, attr)
if not isinstance(obj, type):
raise ValueError("Loaded logit processor must be a type.")
if not issubclass(obj, LogitsProcessor):
raise ValueError(f"{obj.__name__} must be a subclass of LogitsProcessor")
classes.append(obj)
return classes
11.3 _load_custom_logitsprocs() 统一加载
def _load_custom_logitsprocs(
logits_processors: Sequence[str | type[LogitsProcessor]] | None,
) -> list[type[LogitsProcessor]]:
"""加载所有自定义logits处理器
加载顺序:
1. 先加载entry_points插件
2. 再加载用户指定的FQCN
"""
from vllm.platforms import current_platform
if current_platform.is_tpu():
# TPU不支持自定义logits处理器
return []
return _load_logitsprocs_plugins() + _load_logitsprocs_by_fqcns(logits_processors)
11.4 build_logitsprocs() 构建入口
BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [
MinTokensLogitsProcessor,
LogitBiasLogitsProcessor,
MinPLogitsProcessor,
]
def build_logitsprocs(
vllm_config: "VllmConfig",
device: torch.device,
is_pin_memory: bool,
is_pooling_model: bool,
custom_logitsprocs: Sequence[str | type[LogitsProcessor]] = (),
) -> LogitsProcessors:
"""构建LogitsProcessors容器
Args:
vllm_config: vLLM配置
device: 计算设备
is_pin_memory: 是否使用pin_memory
is_pooling_model: 是否为pooling模型
custom_logitsprocs: 用户自定义的logit处理器列表
Returns:
LogitsProcessors容器(已分类所有处理器实例)
"""
# Pooling模型不支持logits处理器
if is_pooling_model:
if custom_logitsprocs:
raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS)
return LogitsProcessors()
# 投机解码模式:只允许MinTokens(其他处理器与spec decode不兼容)
if vllm_config.speculative_config:
if custom_logitsprocs:
raise ValueError(STR_SPEC_DEC_REJECTS_LOGITSPROCS)
logger.warning(
"min_p and logit_bias parameters won't work with speculative decoding."
)
return LogitsProcessors(
[MinTokensLogitsProcessor(vllm_config, device, is_pin_memory)]
)
# 正常模式:加载内置 + 自定义处理器
custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs)
return LogitsProcessors(
ctor(vllm_config, device, is_pin_memory)
for ctor in itertools.chain(
BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes
)
)
三种模式的处理器配置:
| 模式 | 可用处理器 | 原因 |
|---|---|---|
| Pooling | 无 | Pooling模型不产生logits |
| Spec Decode | 仅MinTokens | 其他处理器与draft验证逻辑冲突 |
| 正常 | MinTokens + LogitBias + MinP + 自定义 | 完整支持 |
11.5 validate_logits_processors_parameters() 参数验证
cached_load_custom_logitsprocs = lru_cache(_load_custom_logitsprocs)
def validate_logits_processors_parameters(
logits_processors: Sequence[str | type[LogitsProcessor]] | None,
sampling_params: SamplingParams,
):
"""验证采样参数是否被所有logits处理器接受
使用lru_cache缓存加载结果,避免重复加载
"""
logits_processors = (
tuple(logits_processors) if logits_processors is not None else None
)
# tuple化以支持lru_cache(list不可hash)
for logits_procs in cached_load_custom_logitsprocs(logits_processors):
logits_procs.validate_params(sampling_params)
第十二章 AdapterLogitsProcessor 适配器模式
12.1 设计动机
AdapterLogitsProcessor 是一个适配器基类,将vLLM V0风格的per-request logits处理器(RequestLogitsProcessor)包装为V1风格的批次处理器。
V0的处理器的签名:def __call__(input_ids, scores) → scores
V1的处理器的签名:def apply(logits: [B, V]) → [B, V]
适配器为每个请求创建 partial 函数,预填充 input_ids 和 output_ids 参数。
12.2 init 初始化
class AdapterLogitsProcessor(LogitsProcessor):
def __init__(
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
):
# Map req_index → partial[Tensor]
# partial函数携带了预填充的output_ids引用
# 由于是引用,始终操作最新的output_ids列表
self.req_info: dict[int, partial[torch.Tensor]] = {}
12.3 new_req_logits_processor() 工厂方法
@abstractmethod
def new_req_logits_processor(
self, params: SamplingParams,
) -> RequestLogitsProcessor | None:
"""为请求创建per-request logits处理器
返回None表示此请求不需要此处理器
"""
raise NotImplementedError
12.4 _new_state() 状态构建
def _new_state(
self,
params: SamplingParams,
prompt_ids: list[int] | None,
output_ids: list[int],
) -> partial[torch.Tensor] | None:
"""为新请求构建状态(partial函数)"""
if req_lp := self.new_req_logits_processor(params):
# 检查处理器签名:是否需要prompt_ids参数
if len(inspect.signature(req_lp).parameters) == 3:
# 3参数: (input_ids, scores, ...) — 需要prompt_ids
if prompt_ids is None:
raise ValueError(
"Prompt token ids are required for this logits processor "
"but were not provided."
)
args = [prompt_ids, output_ids]
else:
# 2参数: (scores, ...) — 只需要output_ids
args = [output_ids]
# 创建partial:预填充prompt_ids和output_ids参数
# output_ids是引用,后续调用时看到最新值
return partial(req_lp, *args)
return None
inspect.signature 的使用:通过反射检查处理器的参数数量,自动判断是否需要prompt_ids。这使得同一个适配器可以兼容两种签名的处理器。
12.5 update_state() 状态更新
def update_state(self, batch_update: BatchUpdate | None):
# 使用通用工具,传入_new_state作为工厂函数
process_dict_updates(
self.req_info,
batch_update,
self._new_state,
)
12.6 apply() 逐行应用
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if self.req_info:
# 对每个有per-request处理器的请求,逐行应用
for req_idx, req_lp in self.req_info.items():
req_logits = logits[req_idx] # [vocab_size] 单行
new_logits = req_lp(req_logits) # 调用partial函数
if new_logits is not req_logits:
# 处理器返回了新张量(而非原地修改)→ 写回
logits[req_idx] = new_logits
return logits
逐行应用的性能考量:每个per-request处理器独立调用,无法批量处理。这是设计上的取舍:
- Per-request处理器的逻辑各不相同,无法batch化
- 大多数请求没有per-request处理器(self.req_info通常很小或为空)
- 性能影响可接受
第十三章 ThinkingBudgetStateHolder 思考预算管理
13.1 设计背景与动机
问题:推理模型(如DeepSeek-R1, QwQ)使用特殊token标记思考过程(<think>...</think>)。当设置 thinking_token_budget 时,需要强制模型在思考token数达到预算时输出结束token。
挑战:
- 思考token计数需要跨多步累积
- 投机解码中,每个draft位置都可能需要独立判断是否强制结束
- 拒绝采样可能拒绝强制结束的token,需要恢复思考状态
- 多token的结束标记(如
<|end_thinking|>可能是多个token序列)
13.2 maybe_create_thinking_budget_state_holder() 工厂函数
def maybe_create_thinking_budget_state_holder(
reasoning_config: "ReasoningConfig | None",
max_num_seqs: int,
num_spec_tokens: int,
device: torch.device,
is_pin_memory: bool,
) -> "ThinkingBudgetStateHolder | None":
"""条件创建:只有配置了reasoning_config时才创建holder"""
if reasoning_config is None:
return None # 非推理模型不需要
return ThinkingBudgetStateHolder(
reasoning_config, max_num_seqs, num_spec_tokens, device, is_pin_memory
)
13.3 init 初始化逐行解析
class ThinkingBudgetStateHolder:
think_start_token_ids: list[int] # 思考开始token序列(如 <think>的token ids)
think_end_token_ids: list[int] # 思考结束token序列(如 </think>的token ids)
def __init__(
self,
reasoning_config: "ReasoningConfig | None",
max_num_seqs: int, # 最大并发序列数
num_spec_tokens: int, # 投机解码的draft token数(0=非spec模式)
device: torch.device,
is_pin_memory: bool,
):
_ = is_pin_memory # API对齐,未使用
max_num_reqs = max_num_seqs
self.in_spec_mode = num_spec_tokens > 0 # 是否在投机解码模式
self.num_spec_tokens = num_spec_tokens
# 启用标志:reasoning_config非None即启用
self.is_enabled = reasoning_config is not None
# 初始化思考开始/结束token ids
if reasoning_config is None:
self.think_start_token_ids = []
self.think_end_token_ids = []
else:
rs = reasoning_config.reasoning_start_token_ids
re = reasoning_config.reasoning_end_token_ids
self.think_start_token_ids = rs if rs else []
self.think_end_token_ids = re if re else []
self.device = device
self._state: dict[int, dict[str, Any]] = {} # req_index → 状态字典
self.cu_num_tokens: dict[int, int] = {} # req_index → 累积token偏移
# 预分配GPU张量用于强制结束
if self.num_spec_tokens > 0:
# 投机模式:每请求 (spec_tokens + 1) 个位置
total = max_num_reqs * (self.num_spec_tokens + 1)
self.mask = torch.zeros(total, dtype=torch.bool, device=device)
self.force_token_ids = torch.full(
(total,), -1, dtype=torch.long, device=device
)
else:
# 常规模式:每请求1个位置
self.mask = torch.zeros(max_num_reqs, dtype=torch.bool, device=device)
self.force_token_ids = torch.full(
(max_num_reqs,), -1, dtype=torch.long, device=device
)
mask / force_token_ids 的设计:
mask[i] = True→ 第i行需要强制输出特定tokenforce_token_ids[i]→ 第i行需要强制输出的token id- 预分配避免每步动态创建张量
13.4 has_tracked_requests() 活跃检测
def has_tracked_requests(self) -> bool:
"""是否有正在追踪思考预算的请求
用于决定采样是否需要output-token行和spec combining
不同于仅仅有holder实例(reasoning可能开启但当前批次没有预算请求)
"""
return bool(self._state)
13.5 sync_batch() 批次同步
def sync_batch(self, batch_update: BatchUpdate | None) -> None:
"""根据批次变更增删移动状态(不调用_update_think_state)"""
if not self.is_enabled or not batch_update:
return
# 移除已完成的请求
for index in batch_update.removed:
self._state.pop(index, None)
# 新增请求:如果设置了thinking_token_budget则初始化状态
for index, params, prompt_tok_ids, output_tok_ids in batch_update.added:
thinking_token_budget = params.thinking_token_budget
if thinking_token_budget is not None:
self._state[index] = self._init_state_entry(
prompt_tok_ids, thinking_token_budget
)
self._state[index]["output_tok_ids"] = output_tok_ids # 引用
self._state[index]["spec_token_ids"] = []
else:
self._state.pop(index, None) # 不需要预算追踪
# 移动请求
for i1, i2, direction in batch_update.moved:
if direction == MoveDirectionality.SWAP:
state1 = self._state.get(i1)
state2 = self._state.get(i2)
if state1 is not None:
self._state[i2] = state1
if state2 is not None:
self._state[i1] = state2
else:
state = self._state.pop(i1, None)
if state is not None:
self._state[i2] = state
13.6 update_state() 状态更新
def update_state(
self,
output_token_ids: list[list[int]], # 每请求最新输出token
spec_token_ids: list[list[int]] | None, # 投机draft token
repeat_indices: torch.Tensor | None = None, # spec模式下请求→行映射
) -> None:
"""刷新output/spec数据并重新计算思考状态"""
if not self.is_enabled or not self._state:
return
spec_lists = spec_token_ids or []
# 构建请求到最新输出行的映射(spec模式下需要)
last_row_for_req: dict[int, int] | None = None
if repeat_indices is not None:
last_row_for_req = {}
rpt = repeat_indices.cpu().tolist()
for batch_row, req_i in enumerate(rpt):
last_row_for_req[req_i] = batch_row # 最后出现的行号
# 更新每个追踪请求的output_tok_ids引用
for seq_idx, state in list(self._state.items()):
if last_row_for_req is not None:
output_row = last_row_for_req.get(seq_idx)
if output_row is None or output_row >= len(output_token_ids):
continue
state["output_tok_ids"] = output_token_ids[output_row]
elif seq_idx >= len(output_token_ids):
continue
else:
state["output_tok_ids"] = output_token_ids[seq_idx]
# 更新spec token ids
if seq_idx < len(spec_lists):
state["spec_token_ids"] = list(spec_lists[seq_idx])
else:
state["spec_token_ids"] = []
state["in_spec_mode"] = self.in_spec_mode
state["force_index"] = [] # 清空强制索引
# 剥离draft token后缀
if len(state["output_tok_ids"]) > 0:
spec_len = len(state["spec_token_ids"])
if spec_len > 0 and len(state["output_tok_ids"]) >= spec_len:
state["output_tok_ids"] = state["output_tok_ids"][:-spec_len]
# 去掉draft tokens,只保留已确认的输出
# 重新计算思考状态
self._update_think_state(state)
13.7 _init_state_entry() 状态初始化
def _init_state_entry(
self, prompt_tok_ids: list[int] | None, thinking_token_budget: int
) -> dict[str, Any]:
"""为新请求初始化思考状态条目"""
if prompt_tok_ids is None:
# 无prompt:初始状态为"未进入思考"
in_think = False
think_count = 0
start_thinking = -1
countdown = thinking_token_budget
continue_thinking = False
in_end = False
else:
# 有prompt:检查prompt中是否已在思考中
start_thinking = -1
countdown = thinking_token_budget
continue_thinking = False
in_end = False
# 在prompt中查找最后出现的start/end token序列
last_start = self._find_last_sequence_index(
prompt_tok_ids, self.think_start_token_ids
)
last_end = self._find_last_sequence_index(
prompt_tok_ids, self.think_end_token_ids
)
# 判断是否在思考中:start出现在end之后
in_think = last_start > last_end
if in_think:
# 计算已使用的思考token数
think_count = len(prompt_tok_ids) - (
last_start + len(self.think_start_token_ids)
)
start_thinking = len(prompt_tok_ids) - think_count - 1
countdown -= think_count # 剩余预算
continue_thinking = True
# 检查是否在prompt内已超预算
token_exhausted = thinking_token_budget - think_count
in_end = token_exhausted <= 0
return {
"in_think": in_think, # 是否在思考中
"in_end": in_end, # 是否需要强制结束
"check_count_down": countdown, # 倒计数
"think_count": think_count, # 已用思考token数
"end_count": 0, # 结束标记进度(多token序列)
"prompt_tok_ids": prompt_tok_ids,
"output_tok_ids": [],
"thinking_token_budget": thinking_token_budget,
"prev_output_length": 0, # 上次处理时的输出长度
"spec_token_ids": [], # 当前draft tokens
"force_index": [], # 需要强制的位置索引
"start_thinking": start_thinking, # 思考开始的绝对位置
"end_thinking": -1, # 思考结束的绝对位置
"in_spec_mode": False,
"bonus_token_forced": False, # bonus token是否已被强制
"continue_thinking": continue_thinking, # 从prompt继续思考
}
13.8 _update_think_state() 思考状态机
这是整个ThinkingBudget系统最复杂的部分,实现了一个有限状态机来管理思考→结束→恢复的状态转换。
状态机核心逻辑(简化版):
def _update_think_state(self, state: dict[str, Any]) -> None:
# 1. 检查是否仍需要追踪
if state.get("thinking_token_budget", -1) == -1:
return
if len(self.think_end_token_ids) == 0:
# 无结束token → 禁用预算追踪
state["thinking_token_budget"] = -1
return
# 2. 查找output中的start/end位置
if state["start_thinking"] == -1:
state["start_thinking"] = self._find_last_sequence_index(
state.get("output_tok_ids", []), self.think_start_token_ids
)
if state["end_thinking"] == -1:
state["end_thinking"] = self._find_last_sequence_index(
state.get("output_tok_ids", []), self.think_end_token_ids
)
# 3. 如果从未开始思考,直接返回
if state["start_thinking"] == -1:
return
# 4. 计算当前步新增的token数
sampled_tokens_from_previous_step = len(output) - prev_length
# 5. 更新倒计数
current_step_countdown = state["check_count_down"] - sampled_tokens
predicted_countdown = current_step_countdown - spec_len - 1
# 6. 如果倒计数仍为正,只更新countdown,不需要强制
if not in_end and predicted_countdown >= 0 and start_thinking > -1:
state["check_count_down"] = current_step_countdown
state["prev_output_length"] = len(output)
return
# 7. 需要强制结束的情况
if state["in_end"] and state["end_count"] == 0:
# 检查拒绝采样是否回退了end token
new_tokens = output[prev_length:]
stopping_thinking = self.think_end_token_ids[0] in new_tokens
if not stopping_thinking:
# 回退 → 重新进入思考模式
state["in_think"] = True
state["in_end"] = False
state["end_count"] = 0
# 8. 根据start/end位置判断当前状态
if not state["in_end"]:
if absolute_start_pos > absolute_end_pos:
# start在end之后 → 在思考中
state["in_think"] = True
elif absolute_end_pos > absolute_start_pos:
# end在start之后 → 不在思考
state["in_think"] = False
# 9. 检查是否超预算
if state["in_think"] and total_thinking_tokens > budget:
state["in_think"] = False
state["in_end"] = True # 切换到强制结束模式
# 计算force_index(在哪个draft位置开始强制)
remaining = budget - think_count
if remaining <= 0:
state["force_index"] = [0] # 从第一个位置开始强制
elif remaining < spec_len:
state["force_index"] = [remaining] # 从中间位置开始
else:
state["force_index"] = [spec_len] # 在bonus位置强制
else: # in_end状态
# 10. 追踪end token序列的进度
# 多token结束标记需要逐token强制
if len(spec_token_ids) > 0:
for i, token_id in enumerate(spec_token_ids):
if end_count + 1 < len(think_end_token_ids):
if token_id == think_end_token_ids[end_count + 1]:
end_count += 1 # 匹配,继续
else:
end_count += 1
force_index = [i] # 不匹配,强制
break
# 11. 完成end序列
if end_count >= len(think_end_token_ids):
state.update({
"in_end": False,
"end_count": 0,
"check_count_down": budget,
})
13.9 apply_to_logits() logits强制
def apply_to_logits(
self,
logits: torch.Tensor,
predict_bonus_token: bool,
spec_token_ids: list[list[int]] | None,
) -> torch.Tensor:
"""将强制结束逻辑应用到logits"""
if not self.is_enabled or not self._state:
return logits
spec_lists = spec_token_ids or []
return self._apply_forcing_to_logits(logits, predict_bonus_token, spec_lists)
13.10 _apply_forcing_to_logits() 强制逻辑
def _apply_forcing_to_logits(
self,
logits: torch.Tensor,
predict_bonus_token: bool,
spec_token_ids_for_layout: list[list[int]],
) -> torch.Tensor:
"""核心强制逻辑:为需要强制end token的位置设置极大logit"""
# 清空mask和force_token_ids
self.mask[:] = False
cumulative_total = 0
self.cu_num_tokens.clear()
# 计算每个请求在展平logits中的位置偏移
n_layout = len(spec_token_ids_for_layout)
if self._state:
n_layout = max(n_layout, max(self._state.keys()) + 1)
for index in range(n_layout):
self.cu_num_tokens[index] = cumulative_total
spec_tokens = (
spec_token_ids_for_layout[index]
if index < len(spec_token_ids_for_layout) else []
)
if self.in_spec_mode:
cumulative_total += len(spec_tokens) if not predict_bonus_token else 1
else:
cumulative_total += 1
# 为每个追踪状态的请求设置强制位置
for seq_idx in sorted(self._state.keys()):
if seq_idx not in self.cu_num_tokens:
continue
state = self._state[seq_idx]
if state.get("in_end", False):
# 投机解码中logits处理器被调用两次:
# 1. bonus token logits
# 2. target logits
if predict_bonus_token:
if state.get("force_index") and state["force_index"][0] < len(
state["spec_token_ids"]
):
# force_index指向draft位置 → 跳过bonus调用
continue
else:
# force_index指向bonus位置 → 在bonus调用中强制
state["force_index"] = [0]
if state.get("end_count", 0) > 0:
state["bonus_token_forced"] = False
if state and not state["bonus_token_forced"]:
force_index = state.get("force_index", [])
if len(force_index) == 0:
continue
end_count = state.get("end_count", 0)
for force_idx in force_index:
if end_count < len(self.think_end_token_ids):
# 计算在展平logits中的绝对行索引
mask_idx = self.cu_num_tokens[seq_idx] + force_idx
if mask_idx < len(self.mask) and mask_idx < logits.shape[0]:
self.mask[mask_idx] = True
self.force_token_ids[mask_idx] = (
self.think_end_token_ids[end_count]
)
if predict_bonus_token:
if state["end_count"] > 0:
state["bonus_token_forced"] = False
state["force_index"] = []
else:
state["bonus_token_forced"] = True
# 应用强制:将对应token的logit设为极大值
has_active_thinking = any(
state.get("in_end", False) for state in self._state.values()
)
if has_active_thinking:
active_indices = self.mask.nonzero(as_tuple=False).view(-1)
if len(active_indices) > 0:
force_tokens = self.force_token_ids[active_indices]
# 设置极大logit(1e9)确保该token被采样
logits[active_indices, force_tokens] = 1e9
return logits
1e9 的设计选择:使用 1e9 而非 float("inf") 的原因:
float("inf")可能导致softmax产生NaN1e9足够大,确保在float32精度下,该token的概率接近1- 同时保持数值稳定性
附录C Logits处理器生命周期时序图
附录D 思考预算状态机全图
force_index 的三种情况:
| 场景 | remaining = budget - think_count | force_index | 含义 |
|---|---|---|---|
| 已超预算 | remaining ≤ 0 | [0] | 从第一个draft位置强制 |
| 部分超 | 0 < remaining < spec_len | [remaining] | 从第remaining个draft位置强制 |
| 恰好在边界 | remaining ≥ spec_len | [spec_len] | 在bonus token位置强制 |
_find_last_sequence_index() 工具方法:
@staticmethod
def _find_last_sequence_index(target_list: list[int], token_ids: list[int]) -> int:
"""在target_list中查找token_ids序列最后一次出现的位置
Args:
target_list: 搜索目标(如prompt_token_ids或output_token_ids)
token_ids: 要查找的token序列(如think_start_token_ids)
Returns:
最后匹配的起始索引,-1表示未找到
"""
if not token_ids:
return -1
# 从后往前搜索
for i in range(len(target_list) - len(token_ids), -1, -1):
if target_list[i: i + len(token_ids)] == token_ids:
return i
return -1
附录N process_dict_updates() 全场景推演
N.1 场景1: 新增请求(有logit_bias)
初始状态: biases = {}
BatchUpdate.added = [(0, params_with_bias, prompt0, output0)]
Step 1: 调用 new_state(params_with_bias, prompt0, output0)
→ params.logit_bias = {100: 0.5, 200: -0.3}
→ 返回 {100: 0.5, 200: -0.3}
Step 2: req_entries[0] = {100: 0.5, 200: -0.3}
→ updated = True
Step 3: 重建GPU张量
reqs = [0, 0] # 两个偏置都在请求0
tok_ids = [100, 200] # token id
biases = [0.5, -0.3] # 偏置值
logits_slice = (tensor([0, 0], int32), tensor([100, 200], int32))
bias_tensor = tensor([0.5, -0.3], float32)
N.2 场景2: 请求移除
初始状态: biases = {0: {100: 0.5}, 1: {200: -0.3}}
BatchUpdate.removed = [1]
Step 1: req_entries.pop(1, None) → {200: -0.3}
→ updated = True
Step 2: 重建GPU张量
biases = {0: {100: 0.5}}
reqs = [0]
tok_ids = [100]
biases_vals = [0.5]
N.3 场景3: SWAP交换
初始状态: biases = {0: {100: 0.5}, 1: {200: -0.3}}
BatchUpdate.moved = [(0, 1, SWAP)]
Step 1: a_entry = pop(0) = {100: 0.5}
Step 2: b_entry = pop(1) = {200: -0.3}
Step 3: req_entries[1] = {100: 0.5} # a→b
Step 4: req_entries[0] = {200: -0.3} # b→a (SWAP)
→ updated = True
结果: biases = {0: {200: -0.3}, 1: {100: 0.5}}
# 两个请求的状态互换了
N.4 场景4: UNIDIRECTIONAL单向移动
初始状态: biases = {0: {100: 0.5}, 2: {300: 0.8}}
BatchUpdate.moved = [(0, 1, UNIDIRECTIONAL)]
Step 1: a_entry = pop(0) = {100: 0.5}
Step 2: b_entry = pop(1) = None # 位置1原来没有状态
Step 3: req_entries[1] = {100: 0.5} # a→b
Step 4: b_entry is None → 不做额外操作
→ updated = True
结果: biases = {1: {100: 0.5}, 2: {300: 0.8}}
# 位置0被清空,位置1获得了位置0的状态
N.5 场景5: 新请求不需要此处理器
初始状态: biases = {0: {100: 0.5}} # 位置0有logit_bias
BatchUpdate.added = [(0, params_without_bias, prompt0, output0)]
# 位置0的新请求没有logit_bias → 覆盖旧状态
Step 1: new_state(params_without_bias) → params.logit_bias is None → None
Step 2: state is None → 不插入
Step 3: req_entries.pop(0) = {100: 0.5} → 清除旧状态
→ updated = True
结果: biases = {} # 位置0的旧状态被清除
附录O ThinkingBudget 状态字段完整参考
O.1 状态字典字段说明
| 字段 | 类型 | 初始值 | 含义 |
|---|---|---|---|
in_think |
bool | False | 当前是否在思考模式中 |
in_end |
bool | False | 当前是否需要强制结束思考 |
check_count_down |
int | budget | 剩余可用思考token计数 |
think_count |
int | 0 | 已使用的思考token数 |
end_count |
int | 0 | 结束标记序列中已完成的token数 |
prompt_tok_ids |
list|None | prompt | 原始prompt token ids |
output_tok_ids |
list[int] | [] (ref) | 已生成输出token ids(引用) |
thinking_token_budget |
int | budget | 用户设定的思考token预算 |
prev_output_length |
int | 0 | 上次处理时的输出长度 |
spec_token_ids |
list[int] | [] | 当前投机draft token ids |
force_index |
list[int] | [] | 需要强制end token的位置索引 |
start_thinking |
int | -1 | 思考开始序列在output中的位置 |
end_thinking |
int | -1 | 思考结束序列在output中的位置 |
in_spec_mode |
bool | False | 是否在投机解码模式 |
bonus_token_forced |
bool | False | bonus token是否已被强制 |
continue_thinking |
bool | False | 是否从prompt继续思考 |
O.2 状态转换矩阵
| in_think=F | in_think=T | in_end=T
-----------|------------|-------------|----------
触发条件 | 检测到 | budget超限 | 完成end序列
| think_start| 或spec超限 | 或被拒绝回退
-----------|------------|-------------|----------
下一状态 | in_think=T | in_end=T | in_think=T
| think_count=0| end_count=0| in_end=F
| | force_index | end_count=0
-----------|------------|-------------|----------
force_index| [] | [pos] | [] or [0]
O.3 force_index 计算公式
# 当 in_think → in_end 转换时:
remaining_budget = thinking_token_budget - think_count
spec_len = len(spec_token_ids)
if remaining_budget <= 0:
# 已超预算 → 从第一个位置强制
force_index = [0]
elif 0 < remaining_budget < spec_len:
# 部分超预算 → 从第remaining_budget个spec位置强制
force_index = [remaining_budget]
else:
# 还在预算内但spec+1超 → 在bonus token位置强制
force_index = [spec_len]
图示:
假设: budget=10, think_count=8, spec_len=3
remaining = 10 - 8 = 2
force_index = [2]
draft位置: [0] [1] [2] [bonus]
✓ ✓ ✗force ✗force
正常 正常 强制end 强制end
位置0-1: 在预算内(8+2=10 ≤ budget=10)
位置2: 超预算(8+3=11 > budget=10)→ 强制
bonus: 更超预算 → 强制
附录P AdapterLogitsProcessor 适配器深度分析
P.1 V0 vs V1 LogitsProcessor接口对比
| 特性 | V0 (RequestLogitsProcessor) | V1 (LogitsProcessor) |
|---|---|---|
| 粒度 | per-request | per-batch |
| 签名 | __call__(input_ids, logits) |
apply(logits: [B,V]) |
| 状态管理 | 无(每次调用独立) | update_state(BatchUpdate) |
| 批量处理 | 逐请求循环 | 全批次一次 |
| argmax分类 | 无 | is_argmax_invariant() |
| 注册方式 | sampling_params.logits_processors | entry_points / FQCN |
P.2 适配器模式工作流
1. 用户定义V0处理器:
class MyProcessor:
def __call__(self, input_ids, logits):
# 逐token处理逻辑
return modified_logits
2. AdapterLogitsProcessor包装:
- new_req_logits_processor(params) → 返回MyProcessor实例
- _new_state(params, prompt, output) → partial(MyProcessor(), prompt, output)
- self.req_info[req_idx] = partial_fn
3. apply()时:
for req_idx, partial_fn in self.req_info.items():
# partial_fn(logits_row) 等价于 MyProcessor(prompt_ids, output_ids, logits_row)
result = partial_fn(logits[req_idx])
logits[req_idx] = result # 如果结果不是原地修改
P.3 inspect.signature 参数检测
# 3参数签名: (input_ids, logits, ...) → 需要prompt_ids
class ThreeParamProcessor:
def __call__(self, input_ids, logits, **kwargs):
# input_ids = prompt token ids
# logits = 当前logits
pass
# 2参数签名: (logits, ...) → 只需要output_ids
class TwoParamProcessor:
def __call__(self, logits, **kwargs):
# logits = 当前logits
pass
# 检测逻辑:
sig = inspect.signature(req_lp)
if len(sig.parameters) == 3:
args = [prompt_ids, output_ids] # 预填充两个参数
else:
args = [output_ids] # 预填充一个参数
partial_fn = partial(req_lp, *args)
# 调用: partial_fn(logits_row) → req_lp(prompt_ids, output_ids, logits_row)
P.4 为什么output_ids是引用而非拷贝
# 在BatchUpdate.added中:
AddedRequest = (index, params, prompt_tok_ids, output_tok_ids)
# output_tok_ids是请求的运行中输出列表的引用
# 在_new_state中:
args = [output_ids]
partial_fn = partial(req_lp, *args)
# partial_fn持有output_ids的引用
# 当模型生成新token时:
output_ids.append(new_token) # 请求的输出列表增长
# partial_fn自动看到最新的output_ids(因为是引用)
# 无需每步重新创建partial_fn
这个设计确保了per-request处理器始终操作最新的输出历史,而不需要每步都重建状态。
附录Q build_logitsprocs() 全路径决策树
Q.1 插件加载流程
_load_custom_logitsprocs(custom)
├── _load_logitsprocs_plugins()
│ ├── entry_points(group="vllm.logits_processors")
│ ├── 每个entrypoint.load()
│ │ ├── guard_cuda_initialization() # 防止CUDA过早初始化
│ │ └── 验证 issubclass(cls, LogitsProcessor)
│ └── 返回 [CustomPlugin1, CustomPlugin2, ...]
│
└── _load_logitsprocs_by_fqcns(custom)
├── 遍历custom列表
├── type对象 → 直接验证
└── 字符串FQCN → 动态加载
├── split(":") → (module_path, qualname)
├── importlib.import_module(module_path)
├── getattr链 → 获取类对象
└── 验证 issubclass(obj, LogitsProcessor)
Q.2 lru_cache缓存策略
cached_load_custom_logitsprocs = lru_cache(_load_custom_logitsprocs)
- 缓存key:
custom_logitsprocs的tuple版本 - 首次调用: 加载所有插件和自定义处理器
- 后续调用: 直接返回缓存结果
- 好处: 避免每次请求都重新加载(importlib开销大)
- 限制: 如果运行时安装新插件,需要重启进程才能生效
附录R2 MinPLogitsProcessor 完整状态更新推演
R2.1 场景: 批次从3个请求变为4个
初始状态:
min_p_cpu = [0.1, 0.0, 0.05] # 3个请求
min_p_count = 2 (请求0和2有min_p > 0)
min_p = min_p_device[:3] # [0.1, 0.0, 0.05]
BatchUpdate:
removed = [1] # 请求1完成
added = [(1, params_with_min_p_0.2, prompt1, output1), # 新请求在位置1
(3, params_no_min_p, prompt3, output3)] # 新请求在位置3
moved = [(2, 3, UNIDIRECTIONAL)] # 请求2移动到位置3(但3已被新增占据)
处理added:
index=1, params.min_p=0.2:
min_p_before = min_p_cpu[1] = 0.0
0.0 != 0.2 → needs_update = True
min_p_cpu[1] = 0.2
min_p_count: 0→0.2 (激活) → count += 1 = 3
index=3, params.min_p=0.0:
min_p_before = min_p_cpu[3] = 0.0
0.0 == 0.0 → needs_update不变
处理removed:
index=1: (但1已被新请求覆盖,此时min_p_cpu[1]=0.2)
min_p_cpu[1] = 0.2 ≠ 0 → min_p_cpu[1] = 0, count -= 1 = 2
# ⚠️ 这里似乎有问题:刚add的又被remove了
# 实际上:removed在added之前处理
# 但在process_dict_updates中,added先于removed处理
# 所以新请求1的min_p=0.2被设置后,removed[1]又清除了它
# 这是vLLM V1持久批次设计的一个已知边界情况
# 在实际调度中,removed和added不会同时包含同一索引
# 因为:先移除旧请求,然后新请求填入空位
处理moved:
(2, 3, UNIDIRECTIONAL):
min_p_a = min_p_cpu[2] = 0.05
min_p_b = min_p_cpu[3] = 0.0
0.05 != 0.0 → needs_update = True
min_p_cpu[3] = 0.05 # 请求2的min_p移到位置3
# UNIDIRECTIONAL: 源位置清零
min_p_cpu[2] = 0 # 但min_p_a=0.05 > 0
# 实际代码: if min_p_a: min_p_cpu[adx] = 0
# 位置3之前有值0.0,被覆盖
min_p_b = 0.0 → min_p_count -= 1 = 1
# ⚠️ 但位置3现在有0.05(从位置2移来),应该count不变
# 代码逻辑可能有微小bug,但实际影响很小
最终状态:
min_p_cpu = [0.1, 0.0, 0.0, 0.05]
min_p_count = 1 (只有位置0有min_p > 0)
# 注意:位置3的0.05没有被计数
# 这是因为moved处理中的计数逻辑有缺陷
# 但由于后续会重建GPU张量,实际效果只是多了一次不必要的传输
附录S2 LogitBiasLogitsProcessor 索引构建详解
S2.1 从字典到张量的转换
输入: biases = {
0: {100: 0.5, 200: -0.3}, # 请求0: 2个偏置
2: {50: 1.0}, # 请求2: 1个偏置
}
转换为3个平行数组:
reqs = [0, 0, 2] # 请求索引
tok_ids = [100, 200, 50] # token索引
biases = [0.5, -0.3, 1.0] # 偏置值
构建GPU张量:
logits_slice = (
tensor([0, 0, 2], int32, device=cuda), # 请求索引
tensor([100, 200, 50], int32, device=cuda), # token索引
)
bias_tensor = tensor([0.5, -0.3, 1.0], float32, device=cuda)
应用:
logits[logits_slice] += bias_tensor
等价于:
logits[0, 100] += 0.5
logits[0, 200] += -0.3
logits[2, 50] += 1.0
但在一次kernel调用中完成
S2.2 _device_tensor 的优化
def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
return torch.tensor(
data, device="cpu", dtype=dtype,
pin_memory=self.pin_memory
).to(device=self.device, non_blocking=True)
为什么先创建CPU tensor再传输:
torch.tensor(data, device="cuda")在CUDA上下文中创建,可能触发CUDA初始化- 先创建CPU tensor更安全(CUDA可能未初始化)
pin_memory=True+non_blocking=True实现真正的异步传输- CPU tensor创建后被立即传输,临时CPU内存很快释放
pin_memory的工作原理:
普通内存: CPU → GPU 需要: 先copy到pinned buffer → DMA传输
Pinned内存: CPU → GPU 只需: DMA传输(跳过copy步骤)
性能提升: 减少一次内存拷贝,特别当数据量大时
附录T2 ThinkingBudget _apply_forcing_to_logits 逐行注释
FUNCTION _apply_forcing_to_logits(logits, predict_bonus_token, spec_token_ids_for_layout)
# ===== 清空预分配张量 =====
self.mask[:] = False # 所有位置标记为"不强制"
cumulative_total = 0 # 累积token偏移
self.cu_num_tokens.clear() # 清空请求→偏移映射
# ===== 计算布局 =====
n_layout = len(spec_token_ids_for_layout) # spec配置的请求数
IF self._state:
n_layout = max(n_layout, max(self._state.keys()) + 1)
# 如果_state中有更多请求(无spec但需追踪),扩展布局范围
FOR index IN range(n_layout):
self.cu_num_tokens[index] = cumulative_total
# 记录每个请求的起始偏移
spec_tokens = spec_token_ids_for_layout[index] if index < len(...) else []
IF self.in_spec_mode:
cumulative_total += len(spec_tokens) IF NOT predict_bonus_token ELSE 1
# 投机模式: bonus调用时只算1个位置,target调用时算所有draft位置
ELSE:
cumulative_total += 1
# 常规模式: 每请求1个位置
# ===== 设置强制位置 =====
FOR seq_idx IN sorted(self._state.keys()):
IF seq_idx NOT IN self.cu_num_tokens:
continue # 该请求不在当前布局中
state = self._state[seq_idx]
IF NOT state.get("in_end", False):
continue # 不在强制结束状态
# ===== Bonus调用的特殊处理 =====
IF predict_bonus_token:
IF state.get("force_index") AND state["force_index"][0] < len(state["spec_token_ids"]):
continue # force_index指向draft位置,不在bonus调用中处理
ELSE:
state["force_index"] = [0] # force_index指向bonus → 在bonus位置强制
# ===== end_count > 0 的处理 =====
IF state.get("end_count", 0) > 0:
state["bonus_token_forced"] = False
# 如果已经强制了一些end token,重置bonus标记
IF state AND NOT state["bonus_token_forced"]:
force_index = state.get("force_index", [])
IF len(force_index) == 0:
continue # 没有需要强制的位置
end_count = state.get("end_count", 0)
FOR force_idx IN force_index:
IF end_count < len(self.think_end_token_ids):
# 计算绝对行索引
mask_idx = self.cu_num_tokens[seq_idx] + force_idx
IF mask_idx < len(self.mask) AND mask_idx < logits.shape[0]:
self.mask[mask_idx] = True
self.force_token_ids[mask_idx] = self.think_end_token_ids[end_count]
# 标记该位置需要强制输出end token序列的第end_count个token
IF predict_bonus_token:
IF state["end_count"] > 0:
state["bonus_token_forced"] = False
state["force_index"] = []
# 多token end序列的中间步骤
ELSE:
state["bonus_token_forced"] = True
# 标记bonus token已被强制
# ===== 应用强制 =====
has_active_thinking = any(s.get("in_end") for s in self._state.values())
IF has_active_thinking:
active_indices = self.mask.nonzero(as_tuple=False).view(-1)
IF len(active_indices) > 0:
force_tokens = self.force_token_ids[active_indices]
logits[active_indices, force_tokens] = 1e9
# 设置极大logit → softmax后概率接近1 → 确保被采样
RETURN logits
END FUNCTION
附录U2 BatchUpdateBuilder 与 GPUInputBatch 的协作关系
U2.1 调用时序
GPUInputBatch._update_states() {
// 每步调度时调用
// 1. 收集移除的请求索引
for completed_request in completed_requests:
builder.removed_append(request_index)
// 2. 填充新请求到空位
for new_request in waiting_queue:
empty_slot = builder.pop_removed() // 获取最低索引的空位
if empty_slot is not None:
// 将新请求放入空位
builder.added.append((empty_slot, params, prompt_ids, output_ids))
// 3. 处理移动(重排批次以保持紧凑)
// ...
// 4. 生成BatchUpdate
batch_update = builder.get_and_reset(batch_size)
// 5. 通知logits处理器
for processor in logitsprocs.all:
processor.update_state(batch_update)
}
U2.2 为何removed必须降序排列
假设: _removed = [7, 3, 5]
如果降序排列: [7, 5, 3]
pop_removed() → 3 (最小,最后弹出)
新请求填入位置3 ✓
pop_removed() → 5
新请求填入位置5 ✓
pop_removed() → 7
新请求填入位置7 ✓
如果乱序: [3, 7, 5]
需要排序才能找到最小值
每次pop都需要O(n)搜索
降序 + pop() = O(1) 获取最小值
这是典型的"最小堆"优化:降序列表 + pop()等价于最小堆的extract-min
U2.3 BatchUpdate的frozen设计
@dataclass(frozen=True)
class BatchUpdate:
batch_size: int
removed: Sequence[RemovedRequest]
added: Sequence[AddedRequest]
moved: Sequence[MovedRequest]
frozen=True的意义:
- 不可变性: 一旦创建就不能修改
- 安全性: 多个处理器读取同一个BatchUpdate,不会互相影响
- 可哈希: 可以用作dict的key或放入set(虽然目前没用到)
- 调试友好: 不存在"部分更新"的中间状态
与Builder的关系:
- Builder是可变的(收集变更)
- BatchUpdate是不可变的(变更的快照)
- get_and_reset()将可变→不可变的原子转换
更多推荐


所有评论(0)