LLM 推理优化实战:从 KV Cache 到连续批处理,降低推理成本的工程路径

cover

一、推理成本:大模型落地的最大瓶颈

大模型应用落地的核心矛盾是推理成本与响应速度。以一个 70B 参数模型为例,单次推理的 GPU 显存占用约 140GB(FP16),需要 2 张 A100-80G。如果 QPS 需求为 10,则需要 4 6 张 A100 才能满足延迟要求。按云 GPU 价格计算,月成本在 35 万美元。

更具体的痛点:一个智能客服系统,高峰期 QPS 达 50,平均输入 token 数 800、输出 token 数 300。使用 GPT-4 级别模型,单次推理延迟 3~5 秒,月 API 费用超过 10 万美元。这个成本在大多数业务场景下无法持续。

推理优化的目标很明确:在保证输出质量的前提下,降低单次推理的 GPU 时间。下面从底层机制到工程实践,逐层拆解优化手段。

二、推理瓶颈定位:显存带宽才是天花板

LLM 推理分为两个阶段:Prefill(预填充)和 Decode(解码)。Prefill 阶段并行处理所有输入 token,计算密集;Decode 阶段逐 token 生成输出,每步都需要读取全部 KV Cache,显存带宽密集。

graph LR
    subgraph 推理两阶段
        A[Prefill 阶段<br/>处理输入 tokens] -->|计算密集| B[生成 KV Cache]
        B --> C[Decode 阶段<br/>逐 token 生成]
        C -->|显存带宽密集| D[每步读取 KV Cache]
        D --> C
    end

    subgraph 瓶颈分析
        E[Prefill: 算力瓶颈<br/>GPU 利用率高] --> F[优化方向: 算子融合/Flash Attention]
        G[Decode: 带宽瓶颈<br/>GPU 利用率低] --> H[优化方向: KV Cache 压缩/连续批处理]
    end

    style A fill:#e1f5fe
    style C fill:#fff3e0
    style E fill:#e8f5e9
    style G fill:#ffebee

关键数据:Decode 阶段每生成一个 token,需要从显存读取整个 KV Cache(与序列长度成正比)。对于 70B 模型,序列长度 4096 时,KV Cache 约 5GB。A100 的显存带宽为 2TB/s,单次读取耗时约 2.5ms。如果算力是瓶颈,GPU 利用率会很高;但实际 Decode 阶段 GPU 利用率通常只有 10%~30%,说明显存带宽才是真正的瓶颈。

因此,推理优化的核心策略是:减少不必要的显存读取提高 GPU 利用率

三、工程实践:四层优化策略与代码实现

3.1 KV Cache 优化:PagedAttention

传统实现中,每个请求预分配最大序列长度的 KV Cache,造成大量显存浪费。vLLM 的 PagedAttention 借鉴操作系统虚拟内存的分页机制,将 KV Cache 划分为固定大小的 block(如 16 tokens),按需分配。

# PagedAttention 的核心思想(伪代码)
class PagedKVCache:
    """分页 KV Cache 管理器,减少显存碎片"""

    def __init__(self, block_size: int = 16, num_blocks: int = 1024):
        self.block_size = block_size
        # 预分配所有 block 的显存池
        self.kv_pool = self._allocate_kv_pool(num_blocks)
        # 空闲 block 链表
        self.free_blocks = list(range(num_blocks))
        # 每个序列占用的 block 映射
        self.seq_blocks: dict[str, list[int]] = {}

    def _allocate_kv_pool(self, num_blocks: int):
        """预分配显存池,避免运行时动态分配"""
        # 实际实现中,这里分配 GPU 显存
        return [None] * num_blocks

    def allocate(self, seq_id: str, num_tokens: int) -> list[int]:
        """为序列分配 KV Cache block"""
        num_needed = (num_tokens + self.block_size - 1) // self.block_size
        if len(self.free_blocks) < num_needed:
            raise RuntimeError(
                f"显存不足: 需要 {num_needed} blocks, "
                f"可用 {len(self.free_blocks)} blocks"
            )
        blocks = []
        for _ in range(num_needed):
            block_id = self.free_blocks.pop()
            blocks.append(block_id)
        self.seq_blocks[seq_id] = blocks
        return blocks

    def free(self, seq_id: str):
        """序列完成后释放 block"""
        if seq_id in self.seq_blocks:
            for block_id in self.seq_blocks[seq_id]:
                self.free_blocks.append(block_id)
            del self.seq_blocks[seq_id]

    def get_utilization(self) -> float:
        """显存利用率"""
        total = len(self.free_blocks) + sum(
            len(b) for b in self.seq_blocks.values()
        )
        used = sum(len(b) for b in self.seq_blocks.values())
        return used / total if total > 0 else 0.0

PagedAttention 的收益:显存利用率从传统方案的 20% 40% 提升到 90%+,同等显存下可并发处理的请求数增加 24 倍。

3.2 连续批处理(Continuous Batching)

传统批处理是静态的:等所有请求的 Prefill 完成后才开始 Decode,一个请求生成完毕也要等整批完成。连续批处理在迭代级别动态调度:每个迭代步,已完成生成的请求被移出,新请求加入,GPU 始终满载。

import time
from dataclasses import dataclass
from typing import Optional

@dataclass
class Request:
    """推理请求"""
    req_id: str
    input_tokens: list[int]
    max_output_tokens: int
    generated_tokens: list[int] = None
    kv_blocks: list[int] = None

    def __post_init__(self):
        if self.generated_tokens is None:
            self.generated_tokens = []

    @property
    def is_finished(self) -> bool:
        return len(self.generated_tokens) >= self.max_output_tokens


class ContinuousBatcher:
    """连续批处理调度器"""

    def __init__(self, max_batch_size: int = 32):
        self.max_batch_size = max_batch_size
        self.running: list[Request] = []
        self.waiting: list[Request] = []

    def add_request(self, req: Request):
        """添加新请求到等待队列"""
        self.waiting.append(req)

    def schedule(self) -> list[Request]:
        """调度一个批次的请求"""
        # 移除已完成的请求
        self.running = [r for r in self.running if not r.is_finished]

        # 从等待队列补充请求,直到批次满
        while (
            len(self.running) < self.max_batch_size
            and self.waiting
        ):
            req = self.waiting.pop(0)
            self.running.append(req)

        return self.running

    def step(self):
        """执行一步 Decode(模拟)"""
        batch = self.schedule()
        if not batch:
            return

        # 模拟一步推理:每个请求生成一个 token
        for req in batch:
            # 实际实现中,这里调用模型推理
            req.generated_tokens.append(0)  # placeholder

    def get_stats(self) -> dict:
        return {
            "running": len(self.running),
            "waiting": len(self.waiting),
            "batch_utilization": (
                f"{len(self.running) / self.max_batch_size:.1%}"
            ),
        }

连续批处理的收益:GPU 利用率从静态批处理的 30% 50% 提升到 80%+,吞吐量提升 23 倍。

3.3 量化:用精度换速度

量化是最直接的降本手段。INT8 量化将模型权重从 FP16 压缩到 INT8,显存减半,推理速度提升 1.5~2 倍(得益于 INT8 Tensor Core 的高吞吐)。INT4 量化进一步压缩,但精度损失需要评估。

量化方案 显存占用 推理速度 精度损失
FP16(基线) 100% 1x 0
INT8(W8A8) 50% 1.5~2x < 1%
INT4(GPTQ) 25% 2~3x 1%~3%
INT4(AWQ) 25% 2~3x < 2%

生产建议:对精度敏感场景(如数学推理、代码生成)用 INT8;对精度不敏感场景(如对话、摘要)可用 INT4。量化前必须在业务数据集上评估精度损失。

3.4 Speculative Decoding:用小模型猜大模型

Speculative Decoding 的思路:用一个 7B 小模型快速生成 K 个候选 token,大模型一次前向传播验证这 K 个 token。如果全部正确,相当于一次推理生成 K 个 token;如果有错误,从第一个错误位置重新生成。

def speculative_decode_step(
    draft_model,   # 小模型(快速)
    target_model,  # 大模型(准确)
    input_ids: list[int],
    num_speculate: int = 5,
) -> list[int]:
    """单步 Speculative Decoding"""
    # 1. 小模型快速生成 K 个候选 token
    draft_tokens = []
    current_ids = input_ids[:]
    for _ in range(num_speculate):
        next_token = draft_model.generate_one(current_ids)
        draft_tokens.append(next_token)
        current_ids.append(next_token)

    # 2. 大模型一次前向传播,验证所有候选 token
    #    大模型同时输出每个位置的概率分布
    target_probs = target_model.forward(current_ids)

    # 3. 从左到右验证,找到第一个不匹配的位置
    accepted = []
    for i, draft_token in enumerate(draft_tokens):
        # 大模型在该位置的概率最高的 token
        target_token = argmax(target_probs[len(input_ids) + i])
        if draft_token == target_token:
            accepted.append(draft_token)
        else:
            # 用大模型的 token 替代
            accepted.append(target_token)
            break
    else:
        # 所有候选都正确,额外采样一个 token
        last_token = sample(target_probs[-1])
        accepted.append(last_token)

    return accepted

Speculative Decoding 的收益:在候选接受率 80% 时(对话场景常见),推理速度提升 2~3 倍。代价是需要额外部署一个小模型,增加显存占用和工程复杂度。

四、优化策略的权衡与适用边界

4.1 优化策略的 ROI 排序

按投入产出比排序:量化 > 连续批处理 > KV Cache 优化 > Speculative Decoding。量化改动最小(一行配置),收益最确定。连续批处理需要推理框架支持(vLLM/TGI 已内置)。KV Cache 优化同样由框架提供。Speculative Decoding 工程复杂度最高,适合对延迟极度敏感的场景。

4.2 精度与速度的平衡

量化引入精度损失,必须在业务数据集上评估。评估指标不是通用的 perplexity,而是业务相关的准确率。例如,代码生成场景用 pass@k,对话场景用人工评估或 LLM-as-Judge。

4.3 框架选型

框架 连续批处理 PagedAttention 量化支持 Speculative Decoding
vLLM 支持 支持 GPTQ/AWQ 支持
TGI 支持 支持 GPTQ/AWQ/bitsandbytes 支持
TensorRT-LLM 支持 支持 INT8/INT4 支持
llama.cpp 不支持 不支持 GGUF 全系列 不支持

4.4 禁用场景

  • Speculative Decoding 在候选接受率低于 50% 时反而更慢(小模型与目标模型差异大时)。
  • INT4 量化在数学推理任务上精度损失可能超过 5%,需谨慎评估。
  • 连续批处理在单请求场景下无收益,反而增加调度开销。

五、总结

LLM 推理优化的核心是减少显存带宽瓶颈和提高 GPU 利用率。PagedAttention 减少显存碎片,连续批处理提高 GPU 利用率,量化用精度换速度,Speculative Decoding 用小模型加速大模型。按 ROI 排序:量化 > 连续批处理 > KV Cache 优化 > Speculative Decoding。所有优化手段都必须在业务数据集上验证精度损失,通用指标(如 perplexity)不能替代业务评估。框架选型上,vLLM 和 TGI 是当前最成熟的开源方案。

Logo

免费领 200 小时云算力,进群参与显卡、AI PC 幸运抽奖

更多推荐