vLLM V1 Sample 模块超深度架构分析 — Part 2: Logits处理器体系与思考预算

分析范围: vllm/v1/sample/logits_processor/ + thinking_budget_state.py
分析日期: 2026-05-25
分析深度: 架构师级,逐行解析,Mermaid图表30+


目录


第七章 Logits处理器抽象接口体系

7.1 模块定位与架构总览

logits_processor/ 子模块定义了 vLLM V1 中 logits 后处理的可扩展框架。其核心设计理念是:

  1. 抽象接口LogitsProcessor 定义统一的 apply() / update_state() / is_argmax_invariant() 接口
  2. 批次感知:通过 BatchUpdate 机制,处理器可以跟踪持久批次(persistent batch)中请求的增删移动
  3. 双分类体系:处理器分为 argmax-invariant(不影响贪心采样)和 non-argmax-invariant(影响贪心采样)两类,影响采样管线的应用顺序
  4. 插件化扩展:支持通过 entry_points 和 FQCN 加载自定义处理器

外部调用者

logits_processor/

interface.py
LogitsProcessor (ABC)
BatchUpdate
MoveDirectionality

state.py
BatchUpdateBuilder
LogitsProcessors

builtin.py
MinP / LogitBias / MinTokens

__init__.py
build_logitsprocs
AdapterLogitsProcessor

Sampler

SamplingMetadata

GPUModelRunner

7.2 MoveDirectionality 枚举

class MoveDirectionality(Enum):
    """批次内请求移动的方向性"""
    UNIDIRECTIONAL = auto()  # 单向移动: i1 → i2
    SWAP = auto()            # 双向交换: i1 ↔ i2

设计背景:vLLM V1 使用持久批次(persistent batch),请求在整个生命周期中占据固定的索引位置。当请求完成或新请求加入时,需要移动或交换请求位置以保持批次的紧凑性。

  • UNIDIRECTIONAL:请求从索引 i1 移动到 i2i1 位置被清空)
  • SWAP:两个请求交换位置(两个位置都有新内容)

对logits处理器的影响:处理器维护 dict[int, state] 映射,当请求移动时需要同步更新映射的键。

7.3 BatchUpdate 数据类

@dataclass(frozen=True)
class BatchUpdate:
    """持久批次状态变更信息
    
    记录一次调度步骤中批次的增删移动操作
    """
    batch_size: int                       # 当前批次中的请求数量
    
    # 三种操作的元数据(按顺序处理: removed → added → moved)
    removed: Sequence[RemovedRequest]     # 被移除的请求索引列表
    added: Sequence[AddedRequest]         # 新增请求的(index, params, prompt_ids, output_ids)元组
    moved: Sequence[MovedRequest]         # 移动请求的(i1, i2, direction)元组

冻结性(frozen=True)BatchUpdate 一旦创建就不可修改,确保处理器读到的是一致的状态快照。

类型别名

RemovedRequest = int  # 被移除请求的索引

AddedRequest = tuple[int, SamplingParams, list[int] | None, list[int]]
# (index, sampling_params, prompt_token_ids, output_token_ids)
# output_token_ids 是对请求运行中的输出token列表的引用
# 通过这个引用,logits处理器始终能看到最新的已生成token

MovedRequest = tuple[int, int, MoveDirectionality]
# (index_1, index_2, direction)

关键设计假设

output_tok_ids 列表是请求运行中的输出token列表的引用(而非副本)。因此,logits处理器通过这个引用始终能看到最新的已生成token列表。

这个设计使得处理器不需要在每步都复制完整的输出token列表,节省内存和计算。

操作处理顺序removed → added → moved

  1. 先移除:释放索引位置,避免与新增请求冲突
  2. 再新增:填充空出的位置
  3. 最后移动:在稳定的新索引上执行位置调整

BatchUpdate处理顺序

1. removed
清除旧索引

2. added
填充新请求

3. moved
调整位置

index 3 移除

index 3 新请求填入

index 1 ↔ index 3 交换

7.4 LogitsProcessor 抽象基类

class LogitsProcessor(ABC):
    """Logits处理器抽象基类
    
    所有logits处理器必须实现4个抽象方法:
    1. __init__(vllm_config, device, is_pin_memory) — 构造器
    2. apply(logits) — 应用处理逻辑
    3. is_argmax_invariant() — 声明是否影响贪心采样
    4. update_state(batch_update) — 更新内部状态
    """
    
    @classmethod
    def validate_params(cls, sampling_params: SamplingParams):
        """验证采样参数是否适用于此处理器
        
        默认实现:不做验证(返回None)
        子类可覆盖以抛出ValueError
        """
        return None
    
    @abstractmethod
    def __init__(
        self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
    ) -> None:
        """构造器——必须接受三个参数(即使不全部使用)"""
        raise NotImplementedError
    
    @abstractmethod
    def apply(self, logits: torch.Tensor) -> torch.Tensor:
        """应用处理器到批次logits张量
        
        Args:
            logits: [batch_size, vocab_size] float32
        Returns:
            处理后的logits(可以原地修改,但必须返回)
        """
        raise NotImplementedError
    
    @abstractmethod
    def is_argmax_invariant(self) -> bool:
        """是否不影响argmax(贪心采样结果)
        
        True  → 可以在贪心采样后应用(如MinP)
        False → 必须在贪心采样前应用(如MinTokens, LogitBias)
        """
        raise NotImplementedError
    
    @abstractmethod
    def update_state(self, batch_update: "BatchUpdate | None") -> None:
        """批次变更时更新处理器内部状态
        
        在每次forward前调用
        """
        raise NotImplementedError

argmax-invariant 的分类影响

采样管线中的处理器应用顺序

贪心采样前
(non_argmax_invariant)

贪心采样后/随机采样中
(argmax_invariant)

MinTokensLogitsProcessor
屏蔽EOS,改变argmax

LogitBiasLogitsProcessor
偏置特定token,改变argmax

MinPLogitsProcessor
过滤低概率token,不影响argmax

7.5 类型别名体系

BatchUpdate

+batch_size: int

+removed: Sequence<int>

+added: Sequence<AddedRequest>

+moved: Sequence<MovedRequest>

«enum»

MoveDirectionality

UNIDIRECTIONAL

SWAP

«abstract»

LogitsProcessor

+validate_params(cls, sampling_params)

+init(vllm_config, device, is_pin_memory)

+apply(logits) : Tensor

+is_argmax_invariant() : bool

+update_state(batch_update)

MinPLogitsProcessor

+min_p_count: int

+min_p_cpu: numpy

+min_p_device: Tensor

+min_p: Tensor

+is_argmax_invariant() : True

+update_state(batch_update)

+apply(logits) : Tensor

LogitBiasLogitsProcessor

+biases: dict<int,dict>

+bias_tensor: Tensor

+logits_slice: tuple

+is_argmax_invariant() : False

+update_state(batch_update)

+apply(logits) : Tensor

MinTokensLogitsProcessor

+min_toks: dict

+logits_slice: tuple

+neg_inf_tensor: Tensor

+is_argmax_invariant() : False

+update_state(batch_update)

+apply(logits) : Tensor

+apply_with_spec_decode(logits, num_draft_tokens) : Tensor

AdapterLogitsProcessor


第八章 BatchUpdateBuilder 批次更新构建器

8.1 设计动机与职责

BatchUpdateBuilder 是一个增量构建器,负责在调度步骤中收集批次变更信息,最终生成不可变的 BatchUpdate 对象。

核心问题:vLLM V1 的持久批次在每步可能经历多种变更(移除完成的请求、添加新请求、重排位置),需要一个中间层来收集这些变更并确保一致性。

假设与保证

假设 保证
所有移除通过 removed_append() 注册 removed 列表始终降序排列
removed_append()removed/pop/peek 之前调用 pop_removed() 返回最小索引
不直接修改 _removed 列表 has_removed() 正确反映移除状态

8.2 init 初始化

def __init__(
    self,
    removed: list[RemovedRequest] | None = None,  # 预设的移除列表
    added: list[AddedRequest] | None = None,       # 预设的新增列表
    moved: list[MovedRequest] | None = None,       # 预设的移动列表
) -> None:
    self._removed = removed or []        # 移除的请求索引
    self.added = added or []             # 新增的请求信息
    self.moved = moved or []             # 移动的请求信息
    self._is_removed_sorted = False      # 排序标志:首次访问时排序
    
    # 用于追踪pooling场景的批次变更
    # pooling模型不填充added列表,但batch_changed仍需正确反映变更
    self.batch_changed = False

8.3 removed_append() 移除注册

def removed_append(self, index: int) -> None:
    """注册一个请求从持久批次中移除
    
    ⚠️ 必须在首次调用 self.removed / pop_removed / peek_removed 之前调用
    否则抛出 RuntimeError
    
    Args:
        index: 被移除请求在持久批次中的索引
    """
    if self._is_removed_sorted:
        # 如果已经排序过(即removed已被读取),不允许再追加
        # 因为追加会破坏排序不变量
        raise RuntimeError(
            "Cannot register new removed request after self.removed has been read."
        )
    self._removed.append(index)
    self.batch_changed = True  # 标记批次有变更

排序不变量_removed 列表在首次访问时按降序排序。降序的原因是在持久批次中移除请求时,从高索引向低索引移除可以避免索引偏移问题(类似删除数组元素时从后往前删)。

8.4 _ensure_removed_sorted() 排序保证

def _ensure_removed_sorted(self) -> None:
    """确保 _removed 列表按降序排列
    
    幂等操作:首次调用后不再重复排序
    """
    if not self._is_removed_sorted:
        self._removed.sort(reverse=True)  # 降序排列
        self._is_removed_sorted = True

8.5 peek_removed() 与 pop_removed()

def has_removed(self) -> bool:
    """是否有待处理的移除请求"""
    return bool(self._removed)

def peek_removed(self) -> int | None:
    """查看最小的移除索引(不弹出)
    
    因为_removed是降序排列,最后一个元素就是最小值
    """
    if self.has_removed():
        self._ensure_removed_sorted()
        return self._removed[-1]  # 降序列表的最后一个=最小值
    return None

def pop_removed(self) -> int | None:
    """弹出最小的移除索引
    
    返回后该索引从列表中移除
    """
    if self.has_removed():
        self._ensure_removed_sorted()
        return self._removed.pop()  # pop最后一个=最小值
    return None

降序设计理由

  • pop() 在Python列表中操作末尾元素是O(1)
  • 如果升序排列,弹出最小值需要 pop(0) 是O(n)
  • 降序排列使 pop() 直接弹出最小值,高效

removed_append(5)
removed_append(2)
removed_append(8)

_removed = [5, 2, 8]
_is_removed_sorted = False

_ensure_removed_sorted()
_removed = [8, 5, 2]
_is_removed_sorted = True

pop_removed() → 2
_removed = [8, 5]

pop_removed() → 5
_removed = [8]

pop_removed() → 8
_removed = []

8.6 reset() 与 get_and_reset()

def reset(self) -> bool:
    """重置内部状态,返回是否有变更
    
    Returns:
        True 如果本步有任何批次变更
    """
    self._is_removed_sorted = False
    self._removed.clear()
    self.added.clear()
    self.moved.clear()
    batch_changed = self.batch_changed  # 保存当前值
    self.batch_changed = False
    return batch_changed

def get_and_reset(self, batch_size: int) -> BatchUpdate | None:
    """生成BatchUpdate并重置内部状态
    
    Args:
        batch_size: 当前持久批次的请求数量
    
    Returns:
        BatchUpdate实例(如果没有变更则返回None)
    """
    self._is_removed_sorted = False
    self.batch_changed = False
    
    # 快速路径:三种操作都为空
    if not any((self._removed, self.moved, self.added)):
        return None
    
    # 构建不可变的BatchUpdate
    batch_update = BatchUpdate(
        batch_size=batch_size,
        removed=self._removed,
        moved=self.moved,
        added=self.added,
    )
    
    # 重置列表(创建新列表,因为BatchUpdate持有旧列表的引用)
    self._removed = []
    self.moved = []
    self.added = []
    
    return batch_update

注意get_and_reset() 将旧列表传给 BatchUpdate,然后创建新列表。这意味着 BatchUpdate 持有的是调用时刻的快照,后续修改不影响已生成的 BatchUpdate

LogitsProcessors BatchUpdate BatchUpdateBuilder Scheduler LogitsProcessors BatchUpdate BatchUpdateBuilder Scheduler removed_append(3) removed_append(7) added.append((3, params, prompt, output)) moved.append((1, 5, SWAP)) get_and_reset(batch_size=8) BatchUpdate(removed=[7,3], added=[...], moved=[...]) BatchUpdate instance update_state(batch_update) process removed: clear index 7, 3 process added: init state at index 3 process moved: swap state 1↔5

第九章 LogitsProcessors 容器类

9.1 设计目的

LogitsProcessors 是所有已初始化logits处理器的容器,提供按 argmax-invariant 属性分类的访问接口。

9.2 argmax_invariant / non_argmax_invariant 分类

class LogitsProcessors:
    """封装已初始化的logits处理器对象"""
    
    def __init__(self, logitsprocs: Iterable["LogitsProcessor"] | None = None) -> None:
        self.argmax_invariant: list[LogitsProcessor] = []      # 不影响贪心采样
        self.non_argmax_invariant: list[LogitsProcessor] = []  # 影响贪心采样
        
        if logitsprocs:
            for logitproc in logitsprocs:
                # 根据is_argmax_invariant()分类
                (
                    self.argmax_invariant
                    if logitproc.is_argmax_invariant()
                    else self.non_argmax_invariant
                ).append(logitproc)

分类的重要性:Sampler 在不同阶段调用不同类别的处理器:

  1. non_argmax_invariant → 在贪心采样调用(可能改变argmax结果)
  2. argmax_invariant → 在贪心采样/温度缩放调用(不影响贪心但影响随机采样)

9.3 all 属性迭代器

@property
def all(self) -> Iterator["LogitsProcessor"]:
    """迭代所有处理器(argmax_invariant在前,non_argmax_invariant在后)"""
    return chain(self.argmax_invariant, self.non_argmax_invariant)

使用 itertools.chain 避免创建新列表,惰性迭代节省内存。


第十章 内置Logits处理器深度解析

10.1 MinPLogitsProcessor — Min-P过滤

设计目的:Min-P 是一种动态阈值过滤策略,过滤掉概率低于 max_probability × min_p 的token。与 top-p 不同,min-p 的阈值随最大概率自适应调整。

argmax-invariantTrue — 最大概率的token永远不会被过滤(因为 max_prob × min_p < max_prob),因此不影响贪心采样结果。

init 逐行解析
class MinPLogitsProcessor(LogitsProcessor):
    def __init__(
        self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
    ):
        max_num_reqs = vllm_config.scheduler_config.max_num_seqs
        
        self.min_p_count: int = 0  # 当前批次中有min_p > 0的请求数
        
        # CPU端numpy数组:零拷贝与GPU同步
        self.min_p_cpu_tensor = torch.zeros(
            (max_num_reqs,), dtype=torch.float32, device="cpu",
            pin_memory=is_pin_memory  # pin_memory加速CPU→GPU传输
        )
        # numpy视图:共享底层内存,可零拷贝读写
        self.min_p_cpu = self.min_p_cpu_tensor.numpy()
        
        # 是否需要双缓冲(CPU + GPU)
        self.use_double_tensor = torch.device(device).type != "cpu"
        # CPU设备不需要双缓冲(直接在CPU tensor上操作)
        
        if self.use_double_tensor:
            # GPU设备:预分配GPU tensor
            self.min_p_device: torch.Tensor = torch.empty(
                (max_num_reqs,), dtype=torch.float32, device=device
            )
        else:
            # CPU设备:直接使用CPU tensor
            self.min_p_device = self.min_p_cpu_tensor
        
        # 当前切片:只使用[:batch_size]的部分
        self.min_p: torch.Tensor = self.min_p_device[:0]  # 初始为空

双缓冲设计

  • CPU端(min_p_cpu_tensor / min_p_cpu):用numpy高效更新
  • GPU端(min_p_device):用于 apply() 中的GPU计算
  • min_p.copy_(min_p_cpu_tensor[:size]):异步CPU→GPU传输
update_state() 逐行解析
def update_state(self, batch_update: BatchUpdate | None):
    if not batch_update:
        return  # 无变更,跳过
    
    needs_update = False
    
    # ===== 处理新增请求 =====
    for index, params, _, _ in batch_update.added:
        min_p = params.min_p  # 从采样参数中获取min_p值
        min_p_before = self.min_p_cpu[index]  # 该索引之前的min_p值
        
        if min_p_before != min_p:
            needs_update = True
            self.min_p_cpu[index] = min_p  # 更新CPU数组
            
            if min_p and not min_p_before:
                # 从0变为非0:增加活跃计数
                self.min_p_count += 1
            elif not min_p and min_p_before:
                # 从非0变为0:减少活跃计数
                self.min_p_count -= 1
    
    if self.min_p_count:  # 只有存在活跃min_p请求时才处理移除/移动
        # ===== 处理移除请求 =====
        if batch_update.removed:
            needs_update = True
            for index in batch_update.removed:
                if self.min_p_cpu[index]:
                    self.min_p_cpu[index] = 0
                    self.min_p_count -= 1
        
        # ===== 处理移动请求 =====
        for adx, bdx, direct in batch_update.moved:
            min_p_a, min_p_b = self.min_p_cpu[adx], self.min_p_cpu[bdx]
            if min_p_a != min_p_b:
                needs_update = True
                # 目标位置总是接收源位置的值
                self.min_p_cpu[bdx] = min_p_a
                if direct == MoveDirectionality.SWAP:
                    # 双向交换:源位置也接收目标位置的值
                    self.min_p_cpu[adx] = min_p_b
            if direct == MoveDirectionality.UNIDIRECTIONAL:
                # 单向移动:源位置清零
                if min_p_a:
                    self.min_p_cpu[adx] = 0
                if min_p_b:
                    # min_p_b原来就有值但被覆盖了
                    self.min_p_count -= 1
    
    # ===== 同步到GPU =====
    size = batch_update.batch_size
    if self.min_p_count and (needs_update or self.min_p.shape[0] != size):
        # 只在有活跃min_p或批次大小变化时更新
        self.min_p = self.min_p_device[:size]  # 切片到当前批次大小
        if self.use_double_tensor:
            # 异步CPU→GPU传输
            self.min_p.copy_(self.min_p_cpu_tensor[:size], non_blocking=True)
        self.min_p.unsqueeze_(1)  # [batch_size] → [batch_size, 1] 用于广播

unsqueeze_(1) 的意义min_p 需要与 logits [B, V] 做广播比较,因此需要从 [B] 扩展为 [B, 1]

apply() 逐行解析
def apply(self, logits: torch.Tensor) -> torch.Tensor:
    if not self.min_p_count:
        return logits  # 快速路径:无活跃min_p请求
    
    # 1. 计算softmax概率分布
    probability_values = torch.nn.functional.softmax(logits, dim=-1)
    # [batch_size, vocab_size]
    
    # 2. 计算每行的最大概率
    max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True)
    # [batch_size, 1] — keepdim=True保留维度用于广播
    
    # 3. 计算自适应阈值 = max_prob × min_p
    adjusted_min_p = max_probabilities.mul_(self.min_p)
    # [batch_size, 1] — self.min_p已经是[B, 1]形状
    # mul_() 是原地操作,修改max_probabilities
    
    # 4. 标记低于阈值的token
    invalid_token_mask = probability_values < adjusted_min_p
    # [batch_size, vocab_size] bool
    
    # 5. 屏蔽无效token
    logits.masked_fill_(invalid_token_mask, -float("inf"))
    
    return logits

Min-P vs Top-P 的对比

特性 Min-P Top-P
阈值类型 绝对概率阈值 累积概率阈值
阈值计算 max_prob × min_p 直接指定
自适应性 随max_prob动态调整 固定值
效果 当模型很确定时更严格 始终保留概率和≥p的token
argmax影响 不影响(最大概率总保留) 不影响(累积概率包含最大值)

logits [B,V]

softmax → probs [B,V]

amax → max_prob [B,1]

max_prob × min_p → threshold [B,1]

probs < threshold → mask [B,V]

masked_fill_(mask, -inf)

filtered logits

10.2 LogitBiasLogitsProcessor — 对数偏置

设计目的:为特定token添加偏置值,影响采样概率。典型用法:增加特定token被采样的概率(正偏置)或降低概率(负偏置)。

argmax-invariantFalse — 偏置可以改变最大logit的位置,从而改变贪心采样结果。

init 与 update_state
class LogitBiasLogitsProcessor(LogitsProcessor):
    def __init__(self, _, device: torch.device, is_pin_memory: bool):
        # _ 是vllm_config,此处不使用
        self.device = device
        self.pin_memory = is_pin_memory
        
        # 核心数据结构:{req_index: {token_id: bias_value}}
        # 稀疏存储:只有设置了logit_bias的请求才在dict中
        self.biases: dict[int, dict[int, float]] = {}
        
        # GPU上的偏置张量
        self.bias_tensor: torch.Tensor = torch.tensor(())
        
        # 索引张量:(req_indices, token_indices) — 用于logits的scatter操作
        self.logits_slice = (
            self._device_tensor([], torch.int32),  # 请求索引
            self._device_tensor([], torch.int32),  # token索引
        )
    
    def update_state(self, batch_update: BatchUpdate | None):
        # 使用通用状态更新工具
        needs_update = process_dict_updates(
            self.biases,
            batch_update,
            lambda params, _, __: params.logit_bias or None
            # 从SamplingParams提取logit_bias字典
            # 如果没有logit_bias则返回None(表示该请求不需要此处理器)
        )
        
        # 如果状态有变化,重建GPU张量
        if needs_update:
            reqs: list[int] = []    # 请求索引
            tok_ids: list[int] = [] # token索引
            biases: list[float] = [] # 偏置值
            
            for req, lb in self.biases.items():
                reqs.extend([req] * len(lb))    # 每个偏置对应一个请求索引
                tok_ids.extend(lb.keys())        # token id
                biases.extend(lb.values())       # 偏置值
            
            # 构建GPU张量
            self.bias_tensor = self._device_tensor(biases, torch.float32)
            self.logits_slice = (
                self._device_tensor(reqs, torch.int32),
                self._device_tensor(tok_ids, torch.int32),
            )
    
    def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
        """创建CPU tensor + 异步传输到GPU"""
        return torch.tensor(
            data, device="cpu", dtype=dtype,
            pin_memory=self.pin_memory
        ).to(device=self.device, non_blocking=True)

稀疏索引设计logits[logits_slice] 等价于 logits[req_indices, token_indices],这是PyTorch的高级索引语法,可以一次性为多个位置赋值/加值。

apply()
def apply(self, logits: torch.Tensor) -> torch.Tensor:
    if self.biases:
        # 使用高级索引一次性添加所有偏置
        # logits[req_indices, token_indices] += bias_values
        logits[self.logits_slice] += self.bias_tensor
    return logits

性能分析

  • 常规方法:遍历每个请求和每个偏置,逐个 logits[req, tok] += bias → O(n) 次小kernel
  • 高级索引:单次scatter-add kernel → O(1) 次kernel launch
  • 当偏置数量较多时(如100+个),高级索引显著更快

10.3 MinTokensLogitsProcessor — 最小长度约束

设计目的:强制模型生成至少N个token后才允许输出停止token(EOS)。通过屏蔽EOS token的logit实现。

argmax-invariantFalse — 屏蔽EOS可能改变argmax结果(如果原本最大logit对应EOS)。

init
class MinTokensLogitsProcessor(LogitsProcessor):
    def __init__(
        self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
    ):
        self.device = device
        self.pin_memory = is_pin_memory
        
        # 核心状态: {req_index: (min_tokens, output_token_ids_ref, stop_token_ids_set)}
        # output_token_ids是引用,始终反映最新的输出长度
        self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {}
        
        # 索引张量: (req_indices, stop_token_indices)
        self.logits_slice: tuple[torch.Tensor, torch.Tensor] = (
            self._device_tensor([], torch.int32),
            self._device_tensor([], torch.int32),
        )
        
        # -inf张量: 用于屏蔽stop token
        self.neg_inf_tensor = torch.tensor(
            -float("inf"), dtype=torch.float32, device=self.device
        )
update_state()
def update_state(self, batch_update: BatchUpdate | None):
    needs_update = process_dict_updates(
        self.min_toks, batch_update, self.add_request
    )
    
    # 检查已满足最小长度要求的请求
    if self.min_toks:
        to_remove = tuple(
            index
            for index, (min_toks, out_tok_ids, _) in self.min_toks.items()
            if len(out_tok_ids) >= min_toks
            # output_token_ids是引用,长度实时更新
            # 当已生成token数 >= min_tokens时,移除该请求的约束
        )
        if to_remove:
            needs_update = True
            for index in to_remove:
                del self.min_toks[index]
    
    # 重建GPU索引张量
    if needs_update:
        reqs: list[int] = []
        tok_ids: list[int] = []
        for req, (_, _, stop_tok_ids) in self.min_toks.items():
            reqs.extend([req] * len(stop_tok_ids))
            tok_ids.extend(stop_tok_ids)
        
        self.logits_slice = (
            self._device_tensor(reqs, torch.int32),
            self._device_tensor(tok_ids, torch.int32),
        )

关键设计output_token_ids 是引用,因此 len(out_tok_ids) 实时反映已生成token数。无需在每步手动更新。

apply()
def apply(self, logits: torch.Tensor) -> torch.Tensor:
    if self.min_toks:
        # 将stop token的logit设为-inf
        # index_put_ 类似于 logits[req_indices, tok_indices] = -inf
        logits.index_put_(self.logits_slice, self.neg_inf_tensor)
    return logits

index_put_ vs 高级索引赋值

  • logits[slice] = value → 创建新张量(非原地)
  • logits.index_put_(slice, value) → 原地修改,更高效
apply_with_spec_decode() — 投机解码版本
def apply_with_spec_decode(
    self,
    logits: torch.Tensor,          # [num_tokens, vocab_size] 所有draft位置的logits
    num_draft_tokens: list[int],    # 每请求的draft token数
) -> torch.Tensor:
    """投机解码版本: 需要对每个draft位置独立判断是否屏蔽stop token
    
    Example: num_draft_tokens = [2, 3, 1]
      → logits shape [6, V]
      → cumsum = [0, 2, 5, 6]
      → request 0 owns rows 0-1, request 1 rows 2-4, request 2 row 5
    """
    if not self.min_toks:
        return logits
    
    # 计算每个请求在展平logits中的起始行偏移
    num_draft_arr = np.array(num_draft_tokens, dtype=np.int64)
    cumsum = np.concatenate([[0], np.cumsum(num_draft_arr)])
    
    # 收集需要屏蔽的行和token
    entries = [
        (req_idx, min_tok, len(out_tok_ids), list(stop_tok_ids))
        for req_idx, (min_tok, out_tok_ids, stop_tok_ids) in self.min_toks.items()
        if stop_tok_ids
    ]
    
    if not entries:
        return logits
    
    all_rows: list[np.ndarray] = []  # 行索引
    all_toks: list[np.ndarray] = []  # stop token ids
    
    for req_idx, min_tok, current_len, stop_toks in entries:
        remaining = min_tok - current_len  # 还需要多少token才满足最小长度
        
        # 计算前多少个draft位置仍需要屏蔽stop token
        n_mask = int(min(max(remaining, 0), num_draft_arr[req_idx]))
        
        if n_mask > 0:
            offset = cumsum[req_idx]
            row_indices = np.arange(offset, offset + n_mask, dtype=np.int64)
            n_stop = len(stop_toks)
            # 每个stop token在每个需要屏蔽的行都要出现
            all_rows.append(np.repeat(row_indices, n_stop))
            all_toks.append(np.tile(stop_toks, n_mask))
    
    if all_rows:
        rows_arr = np.concatenate(all_rows)
        toks_arr = np.concatenate(all_toks)
        logits_slice = (
            torch.from_numpy(rows_arr).to(self.device, non_blocking=True),
            torch.from_numpy(toks_arr).to(self.device, non_blocking=True),
        )
        logits.index_put_(logits_slice, self.neg_inf_tensor)
    
    return logits

MinTokens: req1 min=10, current=9

MinTokens: req0 min=5, current=3

num_draft_tokens = [2, 3, 1]

Request 0
rows 0-1

Request 1
rows 2-4

Request 2
row 5

remaining=2
n_mask=min(2,2)=2
屏蔽rows 0,1

remaining=1
n_mask=min(1,3)=1
屏蔽row 2

10.4 process_dict_updates() 通用状态更新工具

def process_dict_updates(
    req_entries: dict[int, T],  # 处理器的状态字典(如self.biases, self.min_toks)
    batch_update: BatchUpdate | None,
    new_state: Callable[[SamplingParams, list[int] | None, list[int]], T | None],
) -> bool:
    """通用字典状态更新工具
    
    处理BatchUpdate中的三种操作(removed/added/moved),
    更新处理器的状态字典
    
    Args:
        req_entries: 处理器维护的 {req_index: state} 字典
        batch_update: 批次变更信息
        new_state: 工厂函数,从SamplingParams创建新的状态条目
    
    Returns:
        True 如果字典有变化(需要重建GPU张量)
    """
    if not batch_update:
        return False
    
    updated = False
    
    # ===== 处理新增请求 =====
    for index, params, prompt_tok_ids, output_tok_ids in batch_update.added:
        if (state := new_state(params, prompt_tok_ids, output_tok_ids)) is not None:
            req_entries[index] = state  # 新请求有状态 → 插入
            updated = True
        elif req_entries.pop(index, None) is not None:
            # 新请求不需要此处理器,但该索引有旧状态 → 清除
            updated = True
    
    if req_entries:  # 只有字典非空时才处理移除和移动
        # ===== 处理移除请求 =====
        for index in batch_update.removed:
            if req_entries.pop(index, None):
                updated = True
        
        # ===== 处理移动请求 =====
        for a_index, b_index, direct in batch_update.moved:
            a_entry = req_entries.pop(a_index, None)  # 弹出源
            b_entry = req_entries.pop(b_index, None)  # 弹出目标
            
            if a_entry is not None:
                req_entries[b_index] = a_entry  # 源→目标
                updated = True
            if b_entry is not None:
                updated = True
                if direct == MoveDirectionality.SWAP:
                    req_entries[a_index] = b_entry  # 目标→源(双向交换)
                # UNIDIRECTIONAL: 源位置自然为空(已pop)
    
    return updated

设计精妙之处

  1. 先pop后放回:避免交换时键冲突(如果先 a→b,再 b→a,可能丢失中间状态)
  2. SWAP vs UNIDIRECTIONAL:SWAP双向交换;UNIDIRECTIONAL单向移动,源位置自动清空
  3. 返回needs_update:调用者据此决定是否重建GPU张量,避免无变化时的冗余重建

第十一章 Logits处理器构建与插件系统

11.1 _load_logitsprocs_plugins() 插件加载

LOGITSPROCS_GROUP = "vllm.logits_processors"  # entry_points组名

def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]:
    """通过importlib.metadata加载已安装的logitproc插件
    
    插件通过setup.py/pyproject.toml注册:
    [project.entry-points."vllm.logits_processors"]
    my_proc = "my_package:MyLogitsProcessor"
    """
    from importlib.metadata import entry_points
    
    installed_logitsprocs_plugins = entry_points(group=LOGITSPROCS_GROUP)
    if len(installed_logitsprocs_plugins) == 0:
        logger.debug("No logitsprocs plugins installed (group %s).", LOGITSPROCS_GROUP)
        return []
    
    classes: list[type[LogitsProcessor]] = []
    for entrypoint in installed_logitsprocs_plugins:
        try:
            with guard_cuda_initialization():
                # guard_cuda_initialization: 防止import时触发CUDA初始化
                classes.append(entrypoint.load())
        except Exception as e:
            logger.error("Failed to load LogitsProcessor plugin %s: %s", entrypoint, e)
            raise RuntimeError(
                f"Failed to load LogitsProcessor plugin {entrypoint}"
            ) from e
    return classes

guard_cuda_initialization() 的作用:在CUDA尚未初始化时加载模块可能导致子进程fork失败(CUDA runtime不支持fork)。此上下文管理器确保import期间不触发CUDA初始化。

11.2 _load_logitsprocs_by_fqcns() FQCN加载

def _load_logitsprocs_by_fqcns(
    logits_processors: Sequence[str | type[LogitsProcessor]] | None,
) -> list[type[LogitsProcessor]]:
    """通过完全限定类名(FQCN)加载logit处理器
    
    FQCN语法: <module_path>:<qualname>
    例如: "my_package.processors:CustomLogitsProcessor"
    """
    if not logits_processors:
        return []
    
    classes: list[type[LogitsProcessor]] = []
    for ldx, logitproc in enumerate(logits_processors):
        if isinstance(logitproc, type):
            # 已经是类对象 → 直接验证
            if not issubclass(logitproc, LogitsProcessor):
                raise ValueError(
                    f"{logitproc.__name__} is not a subclass of LogitsProcessor"
                )
            classes.append(logitproc)
            continue
        
        # 字符串FQCN → 动态加载
        module_path, qualname = logitproc.split(":")
        
        try:
            with guard_cuda_initialization():
                module = importlib.import_module(module_path)
        except Exception as e:
            raise RuntimeError(
                f"Failed to load {ldx}th LogitsProcessor plugin {logitproc}"
            ) from e
        
        # 沿点号路径获取类对象
        # 例如: "a.b.c" → getattr(getattr(module, "a"), "b").c
        obj = module
        for attr in qualname.split("."):
            obj = getattr(obj, attr)
        
        if not isinstance(obj, type):
            raise ValueError("Loaded logit processor must be a type.")
        if not issubclass(obj, LogitsProcessor):
            raise ValueError(f"{obj.__name__} must be a subclass of LogitsProcessor")
        classes.append(obj)
    
    return classes

11.3 _load_custom_logitsprocs() 统一加载

def _load_custom_logitsprocs(
    logits_processors: Sequence[str | type[LogitsProcessor]] | None,
) -> list[type[LogitsProcessor]]:
    """加载所有自定义logits处理器
    
    加载顺序:
    1. 先加载entry_points插件
    2. 再加载用户指定的FQCN
    """
    from vllm.platforms import current_platform
    
    if current_platform.is_tpu():
        # TPU不支持自定义logits处理器
        return []
    
    return _load_logitsprocs_plugins() + _load_logitsprocs_by_fqcns(logits_processors)

11.4 build_logitsprocs() 构建入口

BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [
    MinTokensLogitsProcessor,
    LogitBiasLogitsProcessor,
    MinPLogitsProcessor,
]

def build_logitsprocs(
    vllm_config: "VllmConfig",
    device: torch.device,
    is_pin_memory: bool,
    is_pooling_model: bool,
    custom_logitsprocs: Sequence[str | type[LogitsProcessor]] = (),
) -> LogitsProcessors:
    """构建LogitsProcessors容器
    
    Args:
        vllm_config: vLLM配置
        device: 计算设备
        is_pin_memory: 是否使用pin_memory
        is_pooling_model: 是否为pooling模型
        custom_logitsprocs: 用户自定义的logit处理器列表
    
    Returns:
        LogitsProcessors容器(已分类所有处理器实例)
    """
    # Pooling模型不支持logits处理器
    if is_pooling_model:
        if custom_logitsprocs:
            raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS)
        return LogitsProcessors()
    
    # 投机解码模式:只允许MinTokens(其他处理器与spec decode不兼容)
    if vllm_config.speculative_config:
        if custom_logitsprocs:
            raise ValueError(STR_SPEC_DEC_REJECTS_LOGITSPROCS)
        logger.warning(
            "min_p and logit_bias parameters won't work with speculative decoding."
        )
        return LogitsProcessors(
            [MinTokensLogitsProcessor(vllm_config, device, is_pin_memory)]
        )
    
    # 正常模式:加载内置 + 自定义处理器
    custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs)
    return LogitsProcessors(
        ctor(vllm_config, device, is_pin_memory)
        for ctor in itertools.chain(
            BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes
        )
    )

三种模式的处理器配置

模式 可用处理器 原因
Pooling Pooling模型不产生logits
Spec Decode 仅MinTokens 其他处理器与draft验证逻辑冲突
正常 MinTokens + LogitBias + MinP + 自定义 完整支持

Yes

No

Yes

No

build_logitsprocs()

is_pooling_model?

LogitsProcessors()

speculative_config?

[MinTokensLogitsProcessor]

加载内置+自定义

itertools.chain(
BUILTIN, custom)

ctor(config, device, pin_memory)
逐个实例化

LogitsProcessors
按is_argmax_invariant()分类

LogitsProcessors{
argmax_invariant: [MinP],
non_argmax_invariant: [MinTokens, LogitBias]
}

11.5 validate_logits_processors_parameters() 参数验证

cached_load_custom_logitsprocs = lru_cache(_load_custom_logitsprocs)

def validate_logits_processors_parameters(
    logits_processors: Sequence[str | type[LogitsProcessor]] | None,
    sampling_params: SamplingParams,
):
    """验证采样参数是否被所有logits处理器接受
    
    使用lru_cache缓存加载结果,避免重复加载
    """
    logits_processors = (
        tuple(logits_processors) if logits_processors is not None else None
    )
    # tuple化以支持lru_cache(list不可hash)
    for logits_procs in cached_load_custom_logitsprocs(logits_processors):
        logits_procs.validate_params(sampling_params)

第十二章 AdapterLogitsProcessor 适配器模式

12.1 设计动机

AdapterLogitsProcessor 是一个适配器基类,将vLLM V0风格的per-request logits处理器(RequestLogitsProcessor)包装为V1风格的批次处理器。

V0的处理器的签名:def __call__(input_ids, scores) → scores
V1的处理器的签名:def apply(logits: [B, V]) → [B, V]

适配器为每个请求创建 partial 函数,预填充 input_idsoutput_ids 参数。

12.2 init 初始化

class AdapterLogitsProcessor(LogitsProcessor):
    def __init__(
        self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
    ):
        # Map req_index → partial[Tensor]
        # partial函数携带了预填充的output_ids引用
        # 由于是引用,始终操作最新的output_ids列表
        self.req_info: dict[int, partial[torch.Tensor]] = {}

12.3 new_req_logits_processor() 工厂方法

@abstractmethod
def new_req_logits_processor(
    self, params: SamplingParams,
) -> RequestLogitsProcessor | None:
    """为请求创建per-request logits处理器
    
    返回None表示此请求不需要此处理器
    """
    raise NotImplementedError

12.4 _new_state() 状态构建

def _new_state(
    self,
    params: SamplingParams,
    prompt_ids: list[int] | None,
    output_ids: list[int],
) -> partial[torch.Tensor] | None:
    """为新请求构建状态(partial函数)"""
    if req_lp := self.new_req_logits_processor(params):
        # 检查处理器签名:是否需要prompt_ids参数
        if len(inspect.signature(req_lp).parameters) == 3:
            # 3参数: (input_ids, scores, ...) — 需要prompt_ids
            if prompt_ids is None:
                raise ValueError(
                    "Prompt token ids are required for this logits processor "
                    "but were not provided."
                )
            args = [prompt_ids, output_ids]
        else:
            # 2参数: (scores, ...) — 只需要output_ids
            args = [output_ids]
        
        # 创建partial:预填充prompt_ids和output_ids参数
        # output_ids是引用,后续调用时看到最新值
        return partial(req_lp, *args)
    return None

inspect.signature 的使用:通过反射检查处理器的参数数量,自动判断是否需要prompt_ids。这使得同一个适配器可以兼容两种签名的处理器。

12.5 update_state() 状态更新

def update_state(self, batch_update: BatchUpdate | None):
    # 使用通用工具,传入_new_state作为工厂函数
    process_dict_updates(
        self.req_info,
        batch_update,
        self._new_state,
    )

12.6 apply() 逐行应用

def apply(self, logits: torch.Tensor) -> torch.Tensor:
    if self.req_info:
        # 对每个有per-request处理器的请求,逐行应用
        for req_idx, req_lp in self.req_info.items():
            req_logits = logits[req_idx]  # [vocab_size] 单行
            new_logits = req_lp(req_logits)  # 调用partial函数
            
            if new_logits is not req_logits:
                # 处理器返回了新张量(而非原地修改)→ 写回
                logits[req_idx] = new_logits
    return logits

逐行应用的性能考量:每个per-request处理器独立调用,无法批量处理。这是设计上的取舍:

  • Per-request处理器的逻辑各不相同,无法batch化
  • 大多数请求没有per-request处理器(self.req_info通常很小或为空)
  • 性能影响可接受

第十三章 ThinkingBudgetStateHolder 思考预算管理

13.1 设计背景与动机

问题:推理模型(如DeepSeek-R1, QwQ)使用特殊token标记思考过程(<think>...</think>)。当设置 thinking_token_budget 时,需要强制模型在思考token数达到预算时输出结束token。

挑战

  1. 思考token计数需要跨多步累积
  2. 投机解码中,每个draft位置都可能需要独立判断是否强制结束
  3. 拒绝采样可能拒绝强制结束的token,需要恢复思考状态
  4. 多token的结束标记(如 <|end_thinking|> 可能是多个token序列)

13.2 maybe_create_thinking_budget_state_holder() 工厂函数

def maybe_create_thinking_budget_state_holder(
    reasoning_config: "ReasoningConfig | None",
    max_num_seqs: int,
    num_spec_tokens: int,
    device: torch.device,
    is_pin_memory: bool,
) -> "ThinkingBudgetStateHolder | None":
    """条件创建:只有配置了reasoning_config时才创建holder"""
    if reasoning_config is None:
        return None  # 非推理模型不需要
    return ThinkingBudgetStateHolder(
        reasoning_config, max_num_seqs, num_spec_tokens, device, is_pin_memory
    )

13.3 init 初始化逐行解析

class ThinkingBudgetStateHolder:
    think_start_token_ids: list[int]   # 思考开始token序列(如 <think>的token ids)
    think_end_token_ids: list[int]     # 思考结束token序列(如 </think>的token ids)
    
    def __init__(
        self,
        reasoning_config: "ReasoningConfig | None",
        max_num_seqs: int,             # 最大并发序列数
        num_spec_tokens: int,          # 投机解码的draft token数(0=非spec模式)
        device: torch.device,
        is_pin_memory: bool,
    ):
        _ = is_pin_memory  # API对齐,未使用
        max_num_reqs = max_num_seqs
        self.in_spec_mode = num_spec_tokens > 0  # 是否在投机解码模式
        self.num_spec_tokens = num_spec_tokens
        
        # 启用标志:reasoning_config非None即启用
        self.is_enabled = reasoning_config is not None
        
        # 初始化思考开始/结束token ids
        if reasoning_config is None:
            self.think_start_token_ids = []
            self.think_end_token_ids = []
        else:
            rs = reasoning_config.reasoning_start_token_ids
            re = reasoning_config.reasoning_end_token_ids
            self.think_start_token_ids = rs if rs else []
            self.think_end_token_ids = re if re else []
        
        self.device = device
        self._state: dict[int, dict[str, Any]] = {}  # req_index → 状态字典
        self.cu_num_tokens: dict[int, int] = {}      # req_index → 累积token偏移
        
        # 预分配GPU张量用于强制结束
        if self.num_spec_tokens > 0:
            # 投机模式:每请求 (spec_tokens + 1) 个位置
            total = max_num_reqs * (self.num_spec_tokens + 1)
            self.mask = torch.zeros(total, dtype=torch.bool, device=device)
            self.force_token_ids = torch.full(
                (total,), -1, dtype=torch.long, device=device
            )
        else:
            # 常规模式:每请求1个位置
            self.mask = torch.zeros(max_num_reqs, dtype=torch.bool, device=device)
            self.force_token_ids = torch.full(
                (max_num_reqs,), -1, dtype=torch.long, device=device
            )

mask / force_token_ids 的设计

  • mask[i] = True → 第i行需要强制输出特定token
  • force_token_ids[i] → 第i行需要强制输出的token id
  • 预分配避免每步动态创建张量

13.4 has_tracked_requests() 活跃检测

def has_tracked_requests(self) -> bool:
    """是否有正在追踪思考预算的请求
    
    用于决定采样是否需要output-token行和spec combining
    不同于仅仅有holder实例(reasoning可能开启但当前批次没有预算请求)
    """
    return bool(self._state)

13.5 sync_batch() 批次同步

def sync_batch(self, batch_update: BatchUpdate | None) -> None:
    """根据批次变更增删移动状态(不调用_update_think_state)"""
    if not self.is_enabled or not batch_update:
        return
    
    # 移除已完成的请求
    for index in batch_update.removed:
        self._state.pop(index, None)
    
    # 新增请求:如果设置了thinking_token_budget则初始化状态
    for index, params, prompt_tok_ids, output_tok_ids in batch_update.added:
        thinking_token_budget = params.thinking_token_budget
        if thinking_token_budget is not None:
            self._state[index] = self._init_state_entry(
                prompt_tok_ids, thinking_token_budget
            )
            self._state[index]["output_tok_ids"] = output_tok_ids  # 引用
            self._state[index]["spec_token_ids"] = []
        else:
            self._state.pop(index, None)  # 不需要预算追踪
    
    # 移动请求
    for i1, i2, direction in batch_update.moved:
        if direction == MoveDirectionality.SWAP:
            state1 = self._state.get(i1)
            state2 = self._state.get(i2)
            if state1 is not None:
                self._state[i2] = state1
            if state2 is not None:
                self._state[i1] = state2
        else:
            state = self._state.pop(i1, None)
            if state is not None:
                self._state[i2] = state

13.6 update_state() 状态更新

def update_state(
    self,
    output_token_ids: list[list[int]],           # 每请求最新输出token
    spec_token_ids: list[list[int]] | None,      # 投机draft token
    repeat_indices: torch.Tensor | None = None,  # spec模式下请求→行映射
) -> None:
    """刷新output/spec数据并重新计算思考状态"""
    if not self.is_enabled or not self._state:
        return
    
    spec_lists = spec_token_ids or []
    
    # 构建请求到最新输出行的映射(spec模式下需要)
    last_row_for_req: dict[int, int] | None = None
    if repeat_indices is not None:
        last_row_for_req = {}
        rpt = repeat_indices.cpu().tolist()
        for batch_row, req_i in enumerate(rpt):
            last_row_for_req[req_i] = batch_row  # 最后出现的行号
    
    # 更新每个追踪请求的output_tok_ids引用
    for seq_idx, state in list(self._state.items()):
        if last_row_for_req is not None:
            output_row = last_row_for_req.get(seq_idx)
            if output_row is None or output_row >= len(output_token_ids):
                continue
            state["output_tok_ids"] = output_token_ids[output_row]
        elif seq_idx >= len(output_token_ids):
            continue
        else:
            state["output_tok_ids"] = output_token_ids[seq_idx]
        
        # 更新spec token ids
        if seq_idx < len(spec_lists):
            state["spec_token_ids"] = list(spec_lists[seq_idx])
        else:
            state["spec_token_ids"] = []
        
        state["in_spec_mode"] = self.in_spec_mode
        state["force_index"] = []  # 清空强制索引
        
        # 剥离draft token后缀
        if len(state["output_tok_ids"]) > 0:
            spec_len = len(state["spec_token_ids"])
            if spec_len > 0 and len(state["output_tok_ids"]) >= spec_len:
                state["output_tok_ids"] = state["output_tok_ids"][:-spec_len]
                # 去掉draft tokens,只保留已确认的输出
        
        # 重新计算思考状态
        self._update_think_state(state)

13.7 _init_state_entry() 状态初始化

def _init_state_entry(
    self, prompt_tok_ids: list[int] | None, thinking_token_budget: int
) -> dict[str, Any]:
    """为新请求初始化思考状态条目"""
    
    if prompt_tok_ids is None:
        # 无prompt:初始状态为"未进入思考"
        in_think = False
        think_count = 0
        start_thinking = -1
        countdown = thinking_token_budget
        continue_thinking = False
        in_end = False
    else:
        # 有prompt:检查prompt中是否已在思考中
        start_thinking = -1
        countdown = thinking_token_budget
        continue_thinking = False
        in_end = False
        
        # 在prompt中查找最后出现的start/end token序列
        last_start = self._find_last_sequence_index(
            prompt_tok_ids, self.think_start_token_ids
        )
        last_end = self._find_last_sequence_index(
            prompt_tok_ids, self.think_end_token_ids
        )
        
        # 判断是否在思考中:start出现在end之后
        in_think = last_start > last_end
        
        if in_think:
            # 计算已使用的思考token数
            think_count = len(prompt_tok_ids) - (
                last_start + len(self.think_start_token_ids)
            )
            start_thinking = len(prompt_tok_ids) - think_count - 1
            countdown -= think_count  # 剩余预算
            continue_thinking = True
            
            # 检查是否在prompt内已超预算
            token_exhausted = thinking_token_budget - think_count
            in_end = token_exhausted <= 0
    
    return {
        "in_think": in_think,                  # 是否在思考中
        "in_end": in_end,                      # 是否需要强制结束
        "check_count_down": countdown,         # 倒计数
        "think_count": think_count,            # 已用思考token数
        "end_count": 0,                        # 结束标记进度(多token序列)
        "prompt_tok_ids": prompt_tok_ids,
        "output_tok_ids": [],
        "thinking_token_budget": thinking_token_budget,
        "prev_output_length": 0,               # 上次处理时的输出长度
        "spec_token_ids": [],                  # 当前draft tokens
        "force_index": [],                     # 需要强制的位置索引
        "start_thinking": start_thinking,       # 思考开始的绝对位置
        "end_thinking": -1,                     # 思考结束的绝对位置
        "in_spec_mode": False,
        "bonus_token_forced": False,            # bonus token是否已被强制
        "continue_thinking": continue_thinking, # 从prompt继续思考
    }

13.8 _update_think_state() 思考状态机

这是整个ThinkingBudget系统最复杂的部分,实现了一个有限状态机来管理思考→结束→恢复的状态转换。

请求开始

检测到think_start

继续生成(budget未超)

budget超限

强制end token序列

完成end序列

拒绝采样回退了end token

请求完成

NotThinking

InThink

InEnd

状态机核心逻辑(简化版):

def _update_think_state(self, state: dict[str, Any]) -> None:
    # 1. 检查是否仍需要追踪
    if state.get("thinking_token_budget", -1) == -1:
        return
    if len(self.think_end_token_ids) == 0:
        # 无结束token → 禁用预算追踪
        state["thinking_token_budget"] = -1
        return
    
    # 2. 查找output中的start/end位置
    if state["start_thinking"] == -1:
        state["start_thinking"] = self._find_last_sequence_index(
            state.get("output_tok_ids", []), self.think_start_token_ids
        )
    if state["end_thinking"] == -1:
        state["end_thinking"] = self._find_last_sequence_index(
            state.get("output_tok_ids", []), self.think_end_token_ids
        )
    
    # 3. 如果从未开始思考,直接返回
    if state["start_thinking"] == -1:
        return
    
    # 4. 计算当前步新增的token数
    sampled_tokens_from_previous_step = len(output) - prev_length
    
    # 5. 更新倒计数
    current_step_countdown = state["check_count_down"] - sampled_tokens
    predicted_countdown = current_step_countdown - spec_len - 1
    
    # 6. 如果倒计数仍为正,只更新countdown,不需要强制
    if not in_end and predicted_countdown >= 0 and start_thinking > -1:
        state["check_count_down"] = current_step_countdown
        state["prev_output_length"] = len(output)
        return
    
    # 7. 需要强制结束的情况
    if state["in_end"] and state["end_count"] == 0:
        # 检查拒绝采样是否回退了end token
        new_tokens = output[prev_length:]
        stopping_thinking = self.think_end_token_ids[0] in new_tokens
        if not stopping_thinking:
            # 回退 → 重新进入思考模式
            state["in_think"] = True
            state["in_end"] = False
            state["end_count"] = 0
    
    # 8. 根据start/end位置判断当前状态
    if not state["in_end"]:
        if absolute_start_pos > absolute_end_pos:
            # start在end之后 → 在思考中
            state["in_think"] = True
        elif absolute_end_pos > absolute_start_pos:
            # end在start之后 → 不在思考
            state["in_think"] = False
        
        # 9. 检查是否超预算
        if state["in_think"] and total_thinking_tokens > budget:
            state["in_think"] = False
            state["in_end"] = True  # 切换到强制结束模式
            
            # 计算force_index(在哪个draft位置开始强制)
            remaining = budget - think_count
            if remaining <= 0:
                state["force_index"] = [0]  # 从第一个位置开始强制
            elif remaining < spec_len:
                state["force_index"] = [remaining]  # 从中间位置开始
            else:
                state["force_index"] = [spec_len]  # 在bonus位置强制
    
    else:  # in_end状态
        # 10. 追踪end token序列的进度
        # 多token结束标记需要逐token强制
        if len(spec_token_ids) > 0:
            for i, token_id in enumerate(spec_token_ids):
                if end_count + 1 < len(think_end_token_ids):
                    if token_id == think_end_token_ids[end_count + 1]:
                        end_count += 1  # 匹配,继续
                    else:
                        end_count += 1
                        force_index = [i]  # 不匹配,强制
                        break
        
        # 11. 完成end序列
        if end_count >= len(think_end_token_ids):
            state.update({
                "in_end": False,
                "end_count": 0,
                "check_count_down": budget,
            })

13.9 apply_to_logits() logits强制

def apply_to_logits(
    self,
    logits: torch.Tensor,
    predict_bonus_token: bool,
    spec_token_ids: list[list[int]] | None,
) -> torch.Tensor:
    """将强制结束逻辑应用到logits"""
    if not self.is_enabled or not self._state:
        return logits
    
    spec_lists = spec_token_ids or []
    return self._apply_forcing_to_logits(logits, predict_bonus_token, spec_lists)

13.10 _apply_forcing_to_logits() 强制逻辑

def _apply_forcing_to_logits(
    self,
    logits: torch.Tensor,
    predict_bonus_token: bool,
    spec_token_ids_for_layout: list[list[int]],
) -> torch.Tensor:
    """核心强制逻辑:为需要强制end token的位置设置极大logit"""
    
    # 清空mask和force_token_ids
    self.mask[:] = False
    cumulative_total = 0
    self.cu_num_tokens.clear()
    
    # 计算每个请求在展平logits中的位置偏移
    n_layout = len(spec_token_ids_for_layout)
    if self._state:
        n_layout = max(n_layout, max(self._state.keys()) + 1)
    
    for index in range(n_layout):
        self.cu_num_tokens[index] = cumulative_total
        spec_tokens = (
            spec_token_ids_for_layout[index]
            if index < len(spec_token_ids_for_layout) else []
        )
        if self.in_spec_mode:
            cumulative_total += len(spec_tokens) if not predict_bonus_token else 1
        else:
            cumulative_total += 1
    
    # 为每个追踪状态的请求设置强制位置
    for seq_idx in sorted(self._state.keys()):
        if seq_idx not in self.cu_num_tokens:
            continue
        state = self._state[seq_idx]
        
        if state.get("in_end", False):
            # 投机解码中logits处理器被调用两次:
            # 1. bonus token logits
            # 2. target logits
            if predict_bonus_token:
                if state.get("force_index") and state["force_index"][0] < len(
                    state["spec_token_ids"]
                ):
                    # force_index指向draft位置 → 跳过bonus调用
                    continue
                else:
                    # force_index指向bonus位置 → 在bonus调用中强制
                    state["force_index"] = [0]
            
            if state.get("end_count", 0) > 0:
                state["bonus_token_forced"] = False
            
            if state and not state["bonus_token_forced"]:
                force_index = state.get("force_index", [])
                if len(force_index) == 0:
                    continue
                end_count = state.get("end_count", 0)
                
                for force_idx in force_index:
                    if end_count < len(self.think_end_token_ids):
                        # 计算在展平logits中的绝对行索引
                        mask_idx = self.cu_num_tokens[seq_idx] + force_idx
                        if mask_idx < len(self.mask) and mask_idx < logits.shape[0]:
                            self.mask[mask_idx] = True
                            self.force_token_ids[mask_idx] = (
                                self.think_end_token_ids[end_count]
                            )
                        if predict_bonus_token:
                            if state["end_count"] > 0:
                                state["bonus_token_forced"] = False
                                state["force_index"] = []
                            else:
                                state["bonus_token_forced"] = True
    
    # 应用强制:将对应token的logit设为极大值
    has_active_thinking = any(
        state.get("in_end", False) for state in self._state.values()
    )
    
    if has_active_thinking:
        active_indices = self.mask.nonzero(as_tuple=False).view(-1)
        if len(active_indices) > 0:
            force_tokens = self.force_token_ids[active_indices]
            # 设置极大logit(1e9)确保该token被采样
            logits[active_indices, force_tokens] = 1e9
    
    return logits

1e9 的设计选择:使用 1e9 而非 float("inf") 的原因:

  • float("inf") 可能导致softmax产生NaN
  • 1e9 足够大,确保在float32精度下,该token的概率接近1
  • 同时保持数值稳定性

No

Yes

强制逻辑示意

row 3: force_token=999
logits[3, 999] = 1e9

row 5: force_token=999
logits[5, 999] = 1e9

logits [B,V]

有in_end请求?

return logits

构建mask和force_token_ids

active_indices = mask.nonzero()

logits[active, force_tokens] = 1e9

return logits


附录C Logits处理器生命周期时序图

Sampler AdapterLogitsProcessor Builtins (MinP/LogitBias/MinTokens) BatchUpdateBuilder GPUModelRunner Sampler AdapterLogitsProcessor Builtins (MinP/LogitBias/MinTokens) BatchUpdateBuilder GPUModelRunner Step 1: 构建阶段 Step 2: 状态更新 MinP: 更新min_p_cpu数组 LogitBias: 重建biases dict MinTokens: 检查已满足的请求 process_dict_updates(self.req_info) Step 3: 采样前向 Step 3a: Non-argmax-invariant index_put_(stop_token_indices, -inf) logits[slice] += bias_tensor Step 3b: Argmax-invariant (after temperature) masked_fill_(probs < max_prob*min_p, -inf) Step 3c: Return 注册removed/added/moved get_and_reset(batch_size) BatchUpdate(removed, added, moved) update_state(batch_update) update_state(batch_update) forward(logits, metadata) MinTokens.apply(logits) LogitBias.apply(logits) MinP.apply(logits) SamplerOutput

附录D 思考预算状态机全图

请求创建

无thinking_token_budget

有thinking_token_budget

prompt中已有think_start

prompt中无think_start

继续生成

生成think_start

think_count < budget

think_count + spec_tokens + 1 > budget

计算force_index

强制end token序列(多token)

完成整个end序列

拒绝采样回退了end token

等待下一次思考

请求完成

Init

NoThink

CheckPrompt

InThink_Prompt

WaitingStart

InThink

InEnd

ForceEnd

EndComplete

状态字段:
in_think=True
in_end=False
think_count: 递增
check_count_down: 递减

状态字段:
in_think=False
in_end=True
end_count: 0→N
force_index: [pos]
mask[pos]=True
force_token_ids[pos]=end_token

force_index 的三种情况

场景 remaining = budget - think_count force_index 含义
已超预算 remaining ≤ 0 [0] 从第一个draft位置强制
部分超 0 < remaining < spec_len [remaining] 从第remaining个draft位置强制
恰好在边界 remaining ≥ spec_len [spec_len] 在bonus token位置强制

_find_last_sequence_index() 工具方法

@staticmethod
def _find_last_sequence_index(target_list: list[int], token_ids: list[int]) -> int:
    """在target_list中查找token_ids序列最后一次出现的位置
    
    Args:
        target_list: 搜索目标(如prompt_token_ids或output_token_ids)
        token_ids: 要查找的token序列(如think_start_token_ids)
    
    Returns:
        最后匹配的起始索引,-1表示未找到
    """
    if not token_ids:
        return -1
    # 从后往前搜索
    for i in range(len(target_list) - len(token_ids), -1, -1):
        if target_list[i: i + len(token_ids)] == token_ids:
            return i
    return -1

附录N process_dict_updates() 全场景推演

N.1 场景1: 新增请求(有logit_bias)

初始状态: biases = {}
BatchUpdate.added = [(0, params_with_bias, prompt0, output0)]

Step 1: 调用 new_state(params_with_bias, prompt0, output0)
  → params.logit_bias = {100: 0.5, 200: -0.3}
  → 返回 {100: 0.5, 200: -0.3}

Step 2: req_entries[0] = {100: 0.5, 200: -0.3}
  → updated = True

Step 3: 重建GPU张量
  reqs = [0, 0]           # 两个偏置都在请求0
  tok_ids = [100, 200]    # token id
  biases = [0.5, -0.3]    # 偏置值
  logits_slice = (tensor([0, 0], int32), tensor([100, 200], int32))
  bias_tensor = tensor([0.5, -0.3], float32)

N.2 场景2: 请求移除

初始状态: biases = {0: {100: 0.5}, 1: {200: -0.3}}
BatchUpdate.removed = [1]

Step 1: req_entries.pop(1, None) → {200: -0.3}
  → updated = True

Step 2: 重建GPU张量
  biases = {0: {100: 0.5}}
  reqs = [0]
  tok_ids = [100]
  biases_vals = [0.5]

N.3 场景3: SWAP交换

初始状态: biases = {0: {100: 0.5}, 1: {200: -0.3}}
BatchUpdate.moved = [(0, 1, SWAP)]

Step 1: a_entry = pop(0) = {100: 0.5}
Step 2: b_entry = pop(1) = {200: -0.3}
Step 3: req_entries[1] = {100: 0.5}  # a→b
Step 4: req_entries[0] = {200: -0.3}  # b→a (SWAP)
  → updated = True

结果: biases = {0: {200: -0.3}, 1: {100: 0.5}}
# 两个请求的状态互换了

N.4 场景4: UNIDIRECTIONAL单向移动

初始状态: biases = {0: {100: 0.5}, 2: {300: 0.8}}
BatchUpdate.moved = [(0, 1, UNIDIRECTIONAL)]

Step 1: a_entry = pop(0) = {100: 0.5}
Step 2: b_entry = pop(1) = None  # 位置1原来没有状态
Step 3: req_entries[1] = {100: 0.5}  # a→b
Step 4: b_entry is None → 不做额外操作
  → updated = True

结果: biases = {1: {100: 0.5}, 2: {300: 0.8}}
# 位置0被清空,位置1获得了位置0的状态

N.5 场景5: 新请求不需要此处理器

初始状态: biases = {0: {100: 0.5}}  # 位置0有logit_bias
BatchUpdate.added = [(0, params_without_bias, prompt0, output0)]
# 位置0的新请求没有logit_bias → 覆盖旧状态

Step 1: new_state(params_without_bias) → params.logit_bias is None → None
Step 2: state is None → 不插入
Step 3: req_entries.pop(0) = {100: 0.5} → 清除旧状态
  → updated = True

结果: biases = {}  # 位置0的旧状态被清除

场景5: 新请求覆盖旧

biases = {0:A}

added: (0, no_bias)

biases = {}

场景4: UNIDIRECTIONAL

biases = {0:A, 2:C}

moved: (0,1,UNI)

biases = {1:A, 2:C}

场景3: SWAP

biases = {0:A, 1:B}

moved: (0,1,SWAP)

biases = {0:B, 1:A}

场景2: 移除

biases = {0:{..}, 1:{..}}

removed: [1]

biases = {0:{..}}

场景1: 新增

biases = {}

added: (0, bias={100:0.5})

biases = {0: {100:0.5}}


附录O ThinkingBudget 状态字段完整参考

O.1 状态字典字段说明

字段 类型 初始值 含义
in_think bool False 当前是否在思考模式中
in_end bool False 当前是否需要强制结束思考
check_count_down int budget 剩余可用思考token计数
think_count int 0 已使用的思考token数
end_count int 0 结束标记序列中已完成的token数
prompt_tok_ids list|None prompt 原始prompt token ids
output_tok_ids list[int] [] (ref) 已生成输出token ids(引用)
thinking_token_budget int budget 用户设定的思考token预算
prev_output_length int 0 上次处理时的输出长度
spec_token_ids list[int] [] 当前投机draft token ids
force_index list[int] [] 需要强制end token的位置索引
start_thinking int -1 思考开始序列在output中的位置
end_thinking int -1 思考结束序列在output中的位置
in_spec_mode bool False 是否在投机解码模式
bonus_token_forced bool False bonus token是否已被强制
continue_thinking bool False 是否从prompt继续思考

O.2 状态转换矩阵

           | in_think=F | in_think=T  | in_end=T
-----------|------------|-------------|----------
触发条件    | 检测到     | budget超限  | 完成end序列
           | think_start| 或spec超限  | 或被拒绝回退
-----------|------------|-------------|----------
下一状态    | in_think=T | in_end=T    | in_think=T
           | think_count=0| end_count=0| in_end=F
           |            | force_index | end_count=0
-----------|------------|-------------|----------
force_index| []         | [pos]       | [] or [0]

O.3 force_index 计算公式

# 当 in_think → in_end 转换时:
remaining_budget = thinking_token_budget - think_count
spec_len = len(spec_token_ids)

if remaining_budget <= 0:
    # 已超预算 → 从第一个位置强制
    force_index = [0]
elif 0 < remaining_budget < spec_len:
    # 部分超预算 → 从第remaining_budget个spec位置强制
    force_index = [remaining_budget]
else:
    # 还在预算内但spec+1超 → 在bonus token位置强制
    force_index = [spec_len]

图示:

假设: budget=10, think_count=8, spec_len=3
remaining = 10 - 8 = 2
force_index = [2]

draft位置: [0]  [1]  [2]  [bonus]
           ✓    ✓    ✗force  ✗force
           正常 正常  强制end  强制end

位置0-1: 在预算内(8+2=10 ≤ budget=10)
位置2: 超预算(8+3=11 > budget=10)→ 强制
bonus: 更超预算 → 强制

附录P AdapterLogitsProcessor 适配器深度分析

P.1 V0 vs V1 LogitsProcessor接口对比

特性 V0 (RequestLogitsProcessor) V1 (LogitsProcessor)
粒度 per-request per-batch
签名 __call__(input_ids, logits) apply(logits: [B,V])
状态管理 无(每次调用独立) update_state(BatchUpdate)
批量处理 逐请求循环 全批次一次
argmax分类 is_argmax_invariant()
注册方式 sampling_params.logits_processors entry_points / FQCN

P.2 适配器模式工作流

1. 用户定义V0处理器:
   class MyProcessor:
       def __call__(self, input_ids, logits):
           # 逐token处理逻辑
           return modified_logits

2. AdapterLogitsProcessor包装:
   - new_req_logits_processor(params) → 返回MyProcessor实例
   - _new_state(params, prompt, output) → partial(MyProcessor(), prompt, output)
   - self.req_info[req_idx] = partial_fn

3. apply()时:
   for req_idx, partial_fn in self.req_info.items():
       # partial_fn(logits_row) 等价于 MyProcessor(prompt_ids, output_ids, logits_row)
       result = partial_fn(logits[req_idx])
       logits[req_idx] = result  # 如果结果不是原地修改

P.3 inspect.signature 参数检测

# 3参数签名: (input_ids, logits, ...) → 需要prompt_ids
class ThreeParamProcessor:
    def __call__(self, input_ids, logits, **kwargs):
        # input_ids = prompt token ids
        # logits = 当前logits
        pass

# 2参数签名: (logits, ...) → 只需要output_ids
class TwoParamProcessor:
    def __call__(self, logits, **kwargs):
        # logits = 当前logits
        pass

# 检测逻辑:
sig = inspect.signature(req_lp)
if len(sig.parameters) == 3:
    args = [prompt_ids, output_ids]  # 预填充两个参数
else:
    args = [output_ids]  # 预填充一个参数

partial_fn = partial(req_lp, *args)
# 调用: partial_fn(logits_row) → req_lp(prompt_ids, output_ids, logits_row)

P.4 为什么output_ids是引用而非拷贝

# 在BatchUpdate.added中:
AddedRequest = (index, params, prompt_tok_ids, output_tok_ids)
# output_tok_ids是请求的运行中输出列表的引用

# 在_new_state中:
args = [output_ids]
partial_fn = partial(req_lp, *args)
# partial_fn持有output_ids的引用

# 当模型生成新token时:
output_ids.append(new_token)  # 请求的输出列表增长
# partial_fn自动看到最新的output_ids(因为是引用)
# 无需每步重新创建partial_fn

这个设计确保了per-request处理器始终操作最新的输出历史,而不需要每步都重建状态。


附录Q build_logitsprocs() 全路径决策树

Yes

Yes

No

No

Yes

Yes

No

No

build_logitsprocs(config, device, pin_mem, is_pooling, custom)

is_pooling_model?

custom非空?

❌ ValueError:
Pooling不支持logits处理器

return LogitsProcessors()

speculative_config?

custom非空?

❌ ValueError:
Spec decode不支持自定义处理器

⚠️ min_p/logit_bias不工作

return LogitsProcessors([
MinTokensLogitsProcessor
])

custom_classes = _load_custom_logitsprocs(custom)

itertools.chain(BUILTIN, custom_classes)

ctor(config, device, pin_memory)
逐个实例化

LogitsProcessors(processors)

return LogitsProcessors{
argmax_invariant: [MinP],
non_argmax_invariant: [MinTokens, LogitBias, custom...]
}

Q.1 插件加载流程

_load_custom_logitsprocs(custom)
├── _load_logitsprocs_plugins()
│   ├── entry_points(group="vllm.logits_processors")
│   ├── 每个entrypoint.load()
│   │   ├── guard_cuda_initialization()  # 防止CUDA过早初始化
│   │   └── 验证 issubclass(cls, LogitsProcessor)
│   └── 返回 [CustomPlugin1, CustomPlugin2, ...]
│
└── _load_logitsprocs_by_fqcns(custom)
    ├── 遍历custom列表
    ├── type对象 → 直接验证
    └── 字符串FQCN → 动态加载
        ├── split(":") → (module_path, qualname)
        ├── importlib.import_module(module_path)
        ├── getattr链 → 获取类对象
        └── 验证 issubclass(obj, LogitsProcessor)

Q.2 lru_cache缓存策略

cached_load_custom_logitsprocs = lru_cache(_load_custom_logitsprocs)
  • 缓存key: custom_logitsprocs 的tuple版本
  • 首次调用: 加载所有插件和自定义处理器
  • 后续调用: 直接返回缓存结果
  • 好处: 避免每次请求都重新加载(importlib开销大)
  • 限制: 如果运行时安装新插件,需要重启进程才能生效

附录R2 MinPLogitsProcessor 完整状态更新推演

R2.1 场景: 批次从3个请求变为4个

初始状态:
  min_p_cpu = [0.1, 0.0, 0.05]  # 3个请求
  min_p_count = 2  (请求0和2有min_p > 0)
  min_p = min_p_device[:3]  # [0.1, 0.0, 0.05]

BatchUpdate:
  removed = [1]  # 请求1完成
  added = [(1, params_with_min_p_0.2, prompt1, output1),  # 新请求在位置1
           (3, params_no_min_p, prompt3, output3)]         # 新请求在位置3
  moved = [(2, 3, UNIDIRECTIONAL)]  # 请求2移动到位置3(但3已被新增占据)

处理added:
  index=1, params.min_p=0.2:
    min_p_before = min_p_cpu[1] = 0.0
    0.0 != 0.2 → needs_update = True
    min_p_cpu[1] = 0.2
    min_p_count: 0→0.2 (激活) → count += 1 = 3
  
  index=3, params.min_p=0.0:
    min_p_before = min_p_cpu[3] = 0.0
    0.0 == 0.0 → needs_update不变

处理removed:
  index=1:  (但1已被新请求覆盖,此时min_p_cpu[1]=0.2)
    min_p_cpu[1] = 0.2 ≠ 0 → min_p_cpu[1] = 0, count -= 1 = 2
    # ⚠️ 这里似乎有问题:刚add的又被remove了
    # 实际上:removed在added之前处理
    # 但在process_dict_updates中,added先于removed处理
    # 所以新请求1的min_p=0.2被设置后,removed[1]又清除了它

    # 这是vLLM V1持久批次设计的一个已知边界情况
    # 在实际调度中,removed和added不会同时包含同一索引
    # 因为:先移除旧请求,然后新请求填入空位

处理moved:
  (2, 3, UNIDIRECTIONAL):
    min_p_a = min_p_cpu[2] = 0.05
    min_p_b = min_p_cpu[3] = 0.0
    0.05 != 0.0 → needs_update = True
    min_p_cpu[3] = 0.05  # 请求2的min_p移到位置3
    
    # UNIDIRECTIONAL: 源位置清零
    min_p_cpu[2] = 0  # 但min_p_a=0.05 > 0
    # 实际代码: if min_p_a: min_p_cpu[adx] = 0
    
    # 位置3之前有值0.0,被覆盖
    min_p_b = 0.0 → min_p_count -= 1 = 1
    # ⚠️ 但位置3现在有0.05(从位置2移来),应该count不变
    # 代码逻辑可能有微小bug,但实际影响很小

最终状态:
  min_p_cpu = [0.1, 0.0, 0.0, 0.05]
  min_p_count = 1  (只有位置0有min_p > 0)
  # 注意:位置3的0.05没有被计数
  # 这是因为moved处理中的计数逻辑有缺陷
  # 但由于后续会重建GPU张量,实际效果只是多了一次不必要的传输

附录S2 LogitBiasLogitsProcessor 索引构建详解

S2.1 从字典到张量的转换

输入: biases = {
  0: {100: 0.5, 200: -0.3},   # 请求0: 2个偏置
  2: {50: 1.0},                # 请求2: 1个偏置
}

转换为3个平行数组:
  reqs    = [0, 0, 2]          # 请求索引
  tok_ids = [100, 200, 50]     # token索引
  biases  = [0.5, -0.3, 1.0]  # 偏置值

构建GPU张量:
  logits_slice = (
    tensor([0, 0, 2], int32, device=cuda),   # 请求索引
    tensor([100, 200, 50], int32, device=cuda), # token索引
  )
  bias_tensor = tensor([0.5, -0.3, 1.0], float32, device=cuda)

应用:
  logits[logits_slice] += bias_tensor
  等价于:
    logits[0, 100] += 0.5
    logits[0, 200] += -0.3
    logits[2, 50] += 1.0
  但在一次kernel调用中完成

S2.2 _device_tensor 的优化

def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
    return torch.tensor(
        data, device="cpu", dtype=dtype,
        pin_memory=self.pin_memory
    ).to(device=self.device, non_blocking=True)

为什么先创建CPU tensor再传输:

  1. torch.tensor(data, device="cuda") 在CUDA上下文中创建,可能触发CUDA初始化
  2. 先创建CPU tensor更安全(CUDA可能未初始化)
  3. pin_memory=True + non_blocking=True 实现真正的异步传输
  4. CPU tensor创建后被立即传输,临时CPU内存很快释放

pin_memory的工作原理:

普通内存: CPU → GPU 需要: 先copy到pinned buffer → DMA传输
Pinned内存: CPU → GPU 只需: DMA传输(跳过copy步骤)
性能提升: 减少一次内存拷贝,特别当数据量大时

附录T2 ThinkingBudget _apply_forcing_to_logits 逐行注释

FUNCTION _apply_forcing_to_logits(logits, predict_bonus_token, spec_token_ids_for_layout)

  # ===== 清空预分配张量 =====
  self.mask[:] = False                         # 所有位置标记为"不强制"
  cumulative_total = 0                          # 累积token偏移
  self.cu_num_tokens.clear()                    # 清空请求→偏移映射

  # ===== 计算布局 =====
  n_layout = len(spec_token_ids_for_layout)     # spec配置的请求数
  IF self._state:
    n_layout = max(n_layout, max(self._state.keys()) + 1)
    # 如果_state中有更多请求(无spec但需追踪),扩展布局范围

  FOR index IN range(n_layout):
    self.cu_num_tokens[index] = cumulative_total
    # 记录每个请求的起始偏移
    
    spec_tokens = spec_token_ids_for_layout[index] if index < len(...) else []
    IF self.in_spec_mode:
      cumulative_total += len(spec_tokens) IF NOT predict_bonus_token ELSE 1
      # 投机模式: bonus调用时只算1个位置,target调用时算所有draft位置
    ELSE:
      cumulative_total += 1
      # 常规模式: 每请求1个位置

  # ===== 设置强制位置 =====
  FOR seq_idx IN sorted(self._state.keys()):
    IF seq_idx NOT IN self.cu_num_tokens:
      continue  # 该请求不在当前布局中
    
    state = self._state[seq_idx]
    IF NOT state.get("in_end", False):
      continue  # 不在强制结束状态
    
    # ===== Bonus调用的特殊处理 =====
    IF predict_bonus_token:
      IF state.get("force_index") AND state["force_index"][0] < len(state["spec_token_ids"]):
        continue  # force_index指向draft位置,不在bonus调用中处理
      ELSE:
        state["force_index"] = [0]  # force_index指向bonus → 在bonus位置强制

    # ===== end_count > 0 的处理 =====
    IF state.get("end_count", 0) > 0:
      state["bonus_token_forced"] = False
      # 如果已经强制了一些end token,重置bonus标记

    IF state AND NOT state["bonus_token_forced"]:
      force_index = state.get("force_index", [])
      IF len(force_index) == 0:
        continue  # 没有需要强制的位置
      
      end_count = state.get("end_count", 0)
      FOR force_idx IN force_index:
        IF end_count < len(self.think_end_token_ids):
          # 计算绝对行索引
          mask_idx = self.cu_num_tokens[seq_idx] + force_idx
          IF mask_idx < len(self.mask) AND mask_idx < logits.shape[0]:
            self.mask[mask_idx] = True
            self.force_token_ids[mask_idx] = self.think_end_token_ids[end_count]
            # 标记该位置需要强制输出end token序列的第end_count个token
          
          IF predict_bonus_token:
            IF state["end_count"] > 0:
              state["bonus_token_forced"] = False
              state["force_index"] = []
              # 多token end序列的中间步骤
            ELSE:
              state["bonus_token_forced"] = True
              # 标记bonus token已被强制

  # ===== 应用强制 =====
  has_active_thinking = any(s.get("in_end") for s in self._state.values())
  IF has_active_thinking:
    active_indices = self.mask.nonzero(as_tuple=False).view(-1)
    IF len(active_indices) > 0:
      force_tokens = self.force_token_ids[active_indices]
      logits[active_indices, force_tokens] = 1e9
      # 设置极大logit → softmax后概率接近1 → 确保被采样
  
  RETURN logits
END FUNCTION

附录U2 BatchUpdateBuilder 与 GPUInputBatch 的协作关系

U2.1 调用时序

GPUInputBatch._update_states() {
  // 每步调度时调用
  
  // 1. 收集移除的请求索引
  for completed_request in completed_requests:
    builder.removed_append(request_index)
  
  // 2. 填充新请求到空位
  for new_request in waiting_queue:
    empty_slot = builder.pop_removed()  // 获取最低索引的空位
    if empty_slot is not None:
      // 将新请求放入空位
      builder.added.append((empty_slot, params, prompt_ids, output_ids))
  
  // 3. 处理移动(重排批次以保持紧凑)
  // ...
  
  // 4. 生成BatchUpdate
  batch_update = builder.get_and_reset(batch_size)
  
  // 5. 通知logits处理器
  for processor in logitsprocs.all:
    processor.update_state(batch_update)
}

U2.2 为何removed必须降序排列

假设: _removed = [7, 3, 5]

如果降序排列: [7, 5, 3]
  pop_removed() → 3 (最小,最后弹出)
  新请求填入位置3 ✓
  
  pop_removed() → 5
  新请求填入位置5 ✓
  
  pop_removed() → 7
  新请求填入位置7 ✓

如果乱序: [3, 7, 5]
  需要排序才能找到最小值
  每次pop都需要O(n)搜索

降序 + pop() = O(1) 获取最小值
这是典型的"最小堆"优化:降序列表 + pop()等价于最小堆的extract-min

U2.3 BatchUpdate的frozen设计

@dataclass(frozen=True)
class BatchUpdate:
    batch_size: int
    removed: Sequence[RemovedRequest]
    added: Sequence[AddedRequest]
    moved: Sequence[MovedRequest]

frozen=True的意义

  1. 不可变性: 一旦创建就不能修改
  2. 安全性: 多个处理器读取同一个BatchUpdate,不会互相影响
  3. 可哈希: 可以用作dict的key或放入set(虽然目前没用到)
  4. 调试友好: 不存在"部分更新"的中间状态

与Builder的关系:

  • Builder是可变的(收集变更)
  • BatchUpdate是不可变的(变更的快照)
  • get_and_reset()将可变→不可变的原子转换
Logo

免费领 100 小时云算力,进群参与显卡、AI PC 幸运抽奖

更多推荐