Paged Attention(分页注意力)

!!! warning
本文为历史文档,基于 vLLM 原始论文
不再描述当前 vLLM 的代码实现。

当前 vLLM 使用自有实现的多头 query attention 内核(csrc/attention/attention_kernels.cu)。
该内核专为 vLLM 的分页 KV 缓存设计,key 和 value 缓存分别存储在独立块中(注意该“块”概念与 GPU 线程块不同,本文称 vLLM 分页注意力的块为“块”,GPU 线程块为“线程块”)。

为实现高性能,该内核依赖特殊设计的内存布局与访问方式,尤其在线程从全局内存读数据到共享内存时。本文旨在逐步高层解释内核实现,帮助理解 vLLM 多头 query attention 内核。读完本文,你将更容易理解源码细节。

注意本文可能未涵盖全部细节(如索引计算或点乘实现),但理解高层逻辑后阅读源码会更容易。

输入

内核函数接收一组参数,每个线程根据分工处理数据。三大重要参数为输入指针 qk_cachev_cache,分别指向全局内存中的 query、key、value 数据。输出指针 out 指向全局内存,用于写入计算结果。四个指针实际指向多维数组,但每线程只访问分配到的数据。其他运行参数略。

template<typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, int PARTITION_SIZE = 0>
__device__ void paged_attention_kernel(
    ... // 其他参数
    const scalar_t* __restrict__ out,       // [num_seqs, num_heads, max_num_partitions, head_size]
    const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
    const scalar_t* __restrict__ k_cache,   // [num_blocks, num_kv_heads, head_size/x, block_size, x]
    const scalar_t* __restrict__ v_cache,   // [num_blocks, num_kv_heads, head_size, block_size]
    ... // 其他参数
)

函数模板参数为编译期确定。scalar_t为数据类型(如FP16),HEAD_SIZE为每头元素数,BLOCK_SIZE为每块token数,NUM_THREADS为每线程块线程数,PARTITION_SIZE为张量并行GPU数(假设为0,禁用并行)。

这些参数用于准备,如计算当前头索引、块索引等,本文先略,直接看主流程,理解后细节会更清晰。

概念

在进入计算流程前,介绍一些后文用到的概念。如遇生僻术语可回查本节。

  • Sequence(序列):表示一个客户端请求。如 q 形状为 [num_seqs, num_heads, head_size],即有 num_seqs 个 query 序列。此内核为单query注意力,每序列仅一个query token,所以 num_seqs 等于批次处理的token数。
  • Context(上下文):由序列生成的token组成。例如 ["What", "is", "your"] 为上下文,输入query为 "name",模型可能生成 "?"
  • Vec(向量块):一次批量读取和计算的元素组。query和key的vec大小(VEC_SIZE)设计为每线程组可一次处理16字节数据。value的vec大小(V_VEC_SIZE)设计为每线程可一次处理16字节。例如FP16时(2字节),THREAD_GROUP_SIZE为2,则VEC_SIZE=4,V_VEC_SIZE=8。
  • Thread group(线程组):一小组线程(THREAD_GROUP_SIZE),一次处理一个query和key token,每线程只处理部分token数据。线程组处理的总元素数记作x。如组含2线程,head size为8,则线程0处理索引0,2,4,6,线程1处理索引1,3,5,7。
  • Block(块):vLLM的key/value缓存按块分割。每块存储一头的固定token数(BLOCK_SIZE),可能只含部分上下文token。如块大小16、头大小128,则一块存128*16=2048元素。
  • Warp(束):一组32线程(WARP_SIZE),在SM上同时执行。本内核中,每束一次处理一个query与一块key token(可能多块,多轮迭代)。如4束6块,每束分配0/4、1/5、2、3块。
  • Thread block(线程块):一组线程(NUM_THREADS),可共享内存。每块含多个束(NUM_WARPS),本内核每块处理一个query与整个上下文key token。
  • Grid(网格):线程块集合,定义分布形状。本内核形状为 (num_heads, num_seqs, max_num_partitions),每块处理一个头、序列、分区。

Query

本节介绍query数据的内存布局及线程读取方式。每线程组读取一个query token,每线程只处理部分数据。同一束内各线程组读取同一query token,但会与不同key token相乘。

const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
![外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传](https://img-home.csdnimg.cn/images/20230724024159.png?origin_url=..%2Fassets%2Fdesign%2Fpaged_attention%2Fquery.png&pos_id=img-EpSVE4PI-1759887769714){ align="center" alt="query" width="70%" }

每线程定义自己的 q_ptr,指向分配的query token数据。如VEC_SIZE=4,HEAD_SIZE=128,则q_ptr指向128元素,可分为32个vec。

![外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传](https://img-home.csdnimg.cn/images/20230724024159.png?origin_url=..%2Fassets%2Fdesign%2Fpaged_attention%2Fq_vecs.png&pos_id=img-olULMYne-1759887769715){ align="center" alt="q_vecs" width="70%" }
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];

接下来将q_ptr指向的全局数据读到共享内存的q_vecs。每个vec分配到不同行,如THREAD_GROUP_SIZE=2,线程0管0行vec,线程1管1行vec。这样邻近线程访问邻近内存,实现内存合并提升性能。

Key

与Query类似,本节介绍key的内存布局及分配。每线程组一次只处理一个query token,但可跨多轮迭代处理多个key token。每束会多轮处理多个key块,保证整个上下文token都被计算。"处理"即query与key做点乘。

const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
                    + kv_head_idx * kv_head_stride
                    + physical_block_offset * x;

q_ptr不同,k_ptr每轮迭代指向不同key token数据。上述代码按分配的块、头、token定位key数据。

![外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传](https://img-home.csdnimg.cn/images/20230724024159.png?origin_url=..%2Fassets%2Fdesign%2Fpaged_attention%2Fkey.png&pos_id=img-wSLCYo4W-1759887769716){ align="center" alt="key" width="70%" }

上图假定BLOCK_SIZE=16, HEAD_SIZE=128, x=8, THREAD_GROUP_SIZE=2, 共4束。每个矩形代表一个key token所有元素由一个线程组处理。左半为束0的16块key,右半为其他束或迭代的key。每矩形含32个vec(128元素),由2线程组分工。

![外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传](https://img-home.csdnimg.cn/images/20230724024159.png?origin_url=..%2Fassets%2Fdesign%2Fpaged_attention%2Fk_vecs.png&pos_id=img-O9JAEfbT-1759887769716){ align="center" alt="k_vecs" width="70%" }
K_vec k_vecs[NUM_VECS_PER_THREAD]

接下来从k_ptr读取key token数据到寄存器内存k_vecs。用寄存器因只被本线程一次访问,而q_vecs被多线程多次访问。每个k_vecs含多组向量供后续计算。vec分配保证束内邻近线程读邻近内存,提升合并性能。如线程0读vec0,线程1读vec1,下轮线程0读vec2,线程1读vec3。

如仍疑惑整体流程,可继续阅读下节 QK,会更清晰。

QK

如下伪码,进入总循环前,先读取一个token的query数据存为q_vecs。外循环遍历不同k_ptr,内循环准备k_vecs,最后将q_vecs与各k_vecs做点乘。

q_vecs = ...
for ... {
    k_ptr = ...
    for ... {
        k_vecs[i] = ...
    }
    ...
    float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
}

如前所述,每线程一次只处理部分query和key token数据,但Qk_dot<>::dot会跨线程组归约,返回的qk是整个query与key的全量点乘。

举例,HEAD_SIZE=128, THREAD_GROUP_SIZE=2,则每线程k_vecs含64元素,但qk是128元素query与128元素key的点乘。细节可查Qk_dot<>::dot源码,本文不展开。

Softmax

接下来需对所有qk计算归一化 softmax。即每个 x x xqk。需归约得到qk_max m ( x ) m(x) m(x))和各qkexp_sum ℓ ( x ) \ell(x) (x))。归约在整个线程块完成,涵盖query与所有上下文key token的结果。

m ( x ) : = max ⁡ i x i f ( x ) : = [ e x 1 − m ( x ) … e x B − m ( x ) ] ℓ ( x ) : = ∑ i f ( x ) i softmax ⁡ ( x ) : = f ( x ) ℓ ( x ) \begin{gather*} m(x):=\max _i \quad x_i \\ \quad f(x):=\left[\begin{array}{lll}e^{x_1-m(x)} & \ldots & e^{x_B-m(x)}\end{array}\right]\\ \quad \ell(x):=\sum_i f(x)_i \\ \quad \operatorname{softmax}(x):=\frac{f(x)}{\ell(x)} \end{gather*} m(x):=imaxxif(x):=[ex1m(x)exBm(x)](x):=if(x)isoftmax(x):=(x)f(x)

qk_maxlogits

拿到qk后可用它设临时logits结果(最终存归一化softmax结果)。同时比较并收集所有qkqk_max

if (thread_group_offset == 0) {
    const bool mask = token_idx >= context_len;
    logits[token_idx - start_token_idx] = mask ? 0.f : qk;
    qk_max = mask ? qk_max : fmaxf(qk_max, qk);
}

注意logits在共享内存,每线程组设置自己分配的context token字段,总体大小等于context token数。

for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
    qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}

if (lane == 0) {
    red_smem[warp_idx] = qk_max;
}

然后跨束归约qk_max。主要思想是束内线程通信,得到最终最大qk

for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
    qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
qk_max = VLLM_SHFL_SYNC(qk_max, 0);

最后比较所有束的qk_max,全线程块归约,结果广播给每线程。

exp_sum

qk_max类似,需对整个线程块归约求和。

for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
    float val = __expf(logits[i] - qk_max);
    logits[i] = val;
    exp_sum += val;
}
...
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);

先对每线程组求exp和,同时将logitsqk变为exp(qk - qk_max)。注意这里的qk_max已经是全线程块最大值。然后像qk_max一样归约exp_sum

const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
    logits[i] *= inv_sum;
}

最后,有了归约后的qk_maxexp_sum,即可得到归一化后的softmax结果logits。该变量用于后续与value数据做点乘,现在保存的是所有分配context token的qk归一化结果。

Value

![外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传](https://img-home.csdnimg.cn/images/20230724024159.png?origin_url=..%2Fassets%2Fdesign%2Fpaged_attention%2Fvalue.png&pos_id=img-bXfIt5hG-1759887769716){ align="center" alt="value" width="70%" }
![外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传](https://img-home.csdnimg.cn/images/20230724024159.png?origin_url=..%2Fassets%2Fdesign%2Fpaged_attention%2Flogits_vec.png&pos_id=img-epa1ni1Z-1759887769716){ align="center" alt="logits_vec" width="50%" }
![外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传](https://img-home.csdnimg.cn/images/20230724024159.png?origin_url=..%2Fassets%2Fdesign%2Fpaged_attention%2Fv_vec.png&pos_id=img-YfJNpeX4-1759887769716){ align="center" alt="v_vec" width="70%" }

接下来需读取 value 数据,与logits做点乘。与query/key不同,value数据无线程组概念。如下图,value token的内存布局与key不同,同列元素对应同一个value token。一个value块有HEAD_SIZE行、BLOCK_SIZE列,分成多个v_vecs

每线程每次读取V_VEC_SIZE个token的元素。即一个线程一次从不同行同一列读取多个v_vec,多轮内循环实现。每个v_vec需与同样大小的logits_vec点乘。多轮内循环后,每束处理一个value块;多轮外循环处理整个上下文的value token。

float accs[NUM_ROWS_PER_THREAD];
for ... { // 不同块的迭代
    logits_vec = ...
    for ... { // 不同行的迭代
        v_vec = ...
        ...
        accs[i] += dot(logits_vec, v_vec);
    }
}

如伪码所示,外循环(如k_ptrlogits_vec遍历不同块,每次读V_VEC_SIZElogits元素。内循环每线程读同列的v_vec做点乘。每次内循环,线程读同token不同头位置元素。结果累加到accs。每个accs条目对应分配到的头位置。

BLOCK_SIZE=16, V_VEC_SIZE=8,线程每次取8个value元素,分别是同头不同token。HEAD_SIZE=128, WARP_SIZE=32,则每次内循环束需读328=256个元素,总共12816/256=8次内循环处理一块value。每线程accs有8个条目,对应8头位置、累加所有分配token的结果。

LV

然后对每个束内部的accs做归约,每线程累加分配到的头位置所有token的结果。

for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
    float acc = accs[i];
    for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
        acc += VLLM_SHFL_XOR_SYNC(acc, mask);
    }
    accs[i] = acc;
}

然后跨束归约accs,每线程拿到分配头位置所有上下文token的累加。注意每线程的accs只存部分头位置全token的累加,但最终所有输出已算出,只是分布在不同线程寄存器里。

??? code

```cpp
float* out_smem = reinterpret_cast<float*>(shared_mem);
for (int i = NUM_WARPS; i > 1; i /= 2) {
    // 上半束写入共享内存
    ...
    float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
        ...
        dst[row_idx] = accs[i];
    }

    // 下半束更新输出
    const float* src = &out_smem[warp_idx * HEAD_SIZE];
    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
        ...
        accs[i] += src[row_idx];
    }

    // 写入 accs.
}
```

输出

最后将本地寄存器的结果写回全局输出内存。

scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
                + head_idx * max_num_partitions * HEAD_SIZE
                + partition_idx * HEAD_SIZE;

首先定义out_ptr,指向分配序列和头的起始地址。

for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
    const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
    if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
        from_float(*(out_ptr + row_idx), accs[i]);
    }
}

最后遍历分配到的头位置,将累加结果写入out_ptr

引用

@inproceedings{kwon2023efficient,
  title={Efficient Memory Management for Large Language Model Serving with PagedAttention},
  author={Woosuk Kwon and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica},
  booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles},
  year={2023}
}

原文地址:https://docs.vllm.ai/en/latest/design/paged_attention.html
当前更新时间:2025-10-08

Logo

更多推荐