SGLang PD分离流程细节
SGLang PD分离流程细节
SGLang PD分离设计文档
https://docs.google.com/document/d/1rQXJwKd5b9b1aOzLh98mnyMhBMhlxXA5ATZTHoQrwvc/edit?tab=t.0
[Roadmap] Prefill and Decoding Disaggregation · Issue #4655 · sgl-project/sglang · GitHub
为什么要PD分离

从做了PD分离的马后炮角度来看:
P是算力密集型,需要高算力,对内存容量需求相对没有那么高,也不需要很多请求组batch计算。
D是访存密集型,相比于计算性能,更显著需要高内存带宽和内容容量。需要很多请求组成一个大的batch计算提升throughput。
如果P/D没有针对性优化,只是简单做一个PD分离,同样多的GPU,其实P/D不能产生更大的throughput,反而可能降低throguhput,但是PD分离的好处是可以灵活调控P/D的比例,产生显著不同的TTFT/ITL特性。
P/D分离能够产生更大throughput,必须要在P/D相对于非PD实施独特的优化手段。其中最为至关重要的就是DP attention。此外还有deepep等P/D针对性优化。
非PD对DP attention并不是很兼容。用了DP可能性能反而变差。
P/D对DP attention需求不一样,对于MLA模型,DP attention可以一定程度提升P的throughput,但是并不是那么显著,有个百分之30%量级。但是用了DP attention特别是DP=TP的时候,单个GPU算一个请求,算力比较低,TTFT数值相比于不用DP是数倍增长了,往往得不偿失。
但是DP attention对D的throughput是成倍提升的,基本上用多少DP就提升多少倍。而DP对ITL有一定牺牲,却不像TTFT那样成倍降低,而只是可能降低个百分之30%量级。原理可能是因为D是访存密集型,需要显著更大的batch size达到计算密集型。而对于MOE模型,这个batch size要达到异常至高,才能使得每个专家获得足够token输入达到throughput饱和。而TP并行不可能达到一个高的batch size,而DP attention每个GPU都计算一个大的batch,整个节点batch size相比TP直接x8,throughput从而显著提升。
那么P的TP + D的DP组合能够同时带来好的TTFT,ITL,throughput,这一点是非PD基本上无法实现的。业界的实践参考蚂蚁工作:
Together with SGLang: Best Practices for Serving DeepSeek-R1 on H20-96G
SGLang PD分离请求流程


美中不足:当前不支持像nvidia dynamo最初版本支持的那种允许一些场景的请求只在decode节点处理而不依赖于prefill,当前设计必须同时经历prefill处理然后传输kv cache到decoding节点解码。
Mooncake通信库处理逻辑

NIXL通信库处理逻辑

BootstrapServer
BootstrapServer通过调用start_disagg_service启动,在prefill的TokenizerManager创建时调用。
def start_disagg_service(
server_args: ServerArgs,
):
# Start kv boostrap server on prefill
disagg_mode = DisaggregationMode(server_args.disaggregation_mode)
transfer_backend = TransferBackend(server_args.disaggregation_transfer_backend)
if disagg_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class: Type[BaseKVBootstrapServer] = get_kv_class(
transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
bootstrap_server: BaseKVBootstrapServer = kv_bootstrap_server_class(
host=server_args.host,
port=server_args.disaggregation_bootstrap_port,
)


utils
poll_and_all_reduce:对同一个atten DP内的TP worker的kv状态进行一个reduce_min的同步。
kv cache传输状态迁移:

数据传输准备
class KVArgs:
engine_rank: int
kv_data_ptrs: List[int]
kv_data_lens: List[int]
kv_item_lens: List[int]
aux_data_ptrs: List[int]
aux_data_lens: List[int]
aux_item_lens: List[int]
state_data_ptrs: List[int]
state_data_lens: List[int]
state_item_lens: List[int]
state_type: str # "none", "mamba", "swa"
# for mamba state different tp slice transfer
state_dim_per_tensor: List[int] # dimension to slice for each state tensor
ib_device: str
ib_traffic_class: str
gpu_id: int
kv_head_num: int
total_kv_head_num: int
page_size: int
# for pp prefill
pp_rank: int
prefill_start_layer: int
# for system dp
system_dp_rank: int
三个数据部分:
kv_data_ptrs: List[int], kv_data_lens: List[int], kv_item_lens: List[int]:
主模型和speculative draft model的kv cache。
aux_data_ptrs: List[int], aux_data_lens: List[int], aux_item_lens: List[int]:
transfer the metadata of first output token to decode
class MetadataBuffers:
def get_buf_infos(self):
ptrs = [
self.output_ids.data_ptr(),
self.cached_tokens.data_ptr(),
self.output_token_logprobs_val.data_ptr(),
self.output_token_logprobs_idx.data_ptr(),
self.output_top_logprobs_val.data_ptr(),
self.output_top_logprobs_idx.data_ptr(),
self.output_topk_p.data_ptr(),
self.output_topk_index.data_ptr(),
self.output_hidden_states.data_ptr(),
self.bootstrap_room.data_ptr(),
]
state_data_ptrs: List[int], state_data_lens: List[int], state_item_lens: List[int]:
swa, mamba, DeepSeek V3.2模型的DSA以及mtp模型的DSA额外的kv cache信息。
def setup_state_kv_args(
kv_args: KVArgs,
token_to_kv_pool,
draft_token_to_kv_pool=None,
) -> None:
"""Populate ``kv_args`` state-buffer fields from the given pool.
Shared by prefill and decode bootstrap paths so the state_type dispatch
lives in one place.
"""
from sglang.srt.mem_cache.memory_pool import HybridLinearKVPool, NSATokenToKVPool
from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool
state_data_ptrs, state_data_lens, state_item_lens = (
token_to_kv_pool.get_state_buf_infos()
)
kv_args.state_data_ptrs = state_data_ptrs
kv_args.state_data_lens = state_data_lens
kv_args.state_item_lens = state_item_lens
if isinstance(token_to_kv_pool, SWAKVPool):
kv_args.state_type = "swa"
elif isinstance(token_to_kv_pool, HybridLinearKVPool):
kv_args.state_type = "mamba"
# Get state dimension info for cross-TP slice transfer
if hasattr(token_to_kv_pool, "get_state_dim_per_tensor"):
kv_args.state_dim_per_tensor = token_to_kv_pool.get_state_dim_per_tensor()
elif isinstance(token_to_kv_pool, NSATokenToKVPool):
kv_args.state_type = "nsa"
if draft_token_to_kv_pool is not None and isinstance(
draft_token_to_kv_pool, NSATokenToKVPool
):
(
draft_state_data_ptrs,
draft_state_data_lens,
draft_state_item_lens,
) = draft_token_to_kv_pool.get_state_buf_infos()
kv_args.state_data_ptrs += draft_state_data_ptrs
kv_args.state_data_lens += draft_state_data_lens
kv_args.state_item_lens += draft_state_item_lens
else:
kv_args.state_type = "none"
初始化传输的数据信息
class PrefillBootstrapQueue:
def _init_kv_manager(self) -> CommonKVManager:
kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
kv_args = kv_args_class()
kv_args.engine_rank = self.tp_rank
kv_args.pp_rank = self.pp_rank
kv_args.system_dp_rank = self.scheduler.dp_rank
kv_args.prefill_start_layer = self.token_to_kv_pool.start_layer
kv_data_ptrs, kv_data_lens, kv_item_lens = (
self.token_to_kv_pool.get_contiguous_buf_infos()
)
if self.draft_token_to_kv_pool is not None:
# We should also transfer draft model kv cache. The indices are
# always shared with a target model.
draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
self.draft_token_to_kv_pool.get_contiguous_buf_infos()
)
kv_data_ptrs += draft_kv_data_ptrs
kv_data_lens += draft_kv_data_lens
kv_item_lens += draft_kv_item_lens
kv_args.kv_data_ptrs = kv_data_ptrs
kv_args.kv_data_lens = kv_data_lens
kv_args.kv_item_lens = kv_item_lens
if not self.is_mla_backend:
kv_args.kv_head_num = self.token_to_kv_pool.head_num
kv_args.total_kv_head_num = (
self.scheduler.model_config.get_total_num_kv_heads()
)
kv_args.page_size = self.token_to_kv_pool.page_size
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
self.metadata_buffers.get_buf_infos()
)
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
kv_args.gpu_id = self.scheduler.gpu_id
setup_state_kv_args(kv_args, self.token_to_kv_pool, self.draft_token_to_kv_pool)
mooncake注册
class MooncakeKVManager(CommonKVManager):
def register_buffer_to_engine(self):
# Batch register KV data buffers
if self.kv_args.kv_data_ptrs and self.kv_args.kv_data_lens:
self.engine.batch_register(
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
)
# Batch register auxiliary data buffers
if self.kv_args.aux_data_ptrs and self.kv_args.aux_data_lens:
self.engine.batch_register(
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
)
# Batch register state/extra pool data buffers
if self.kv_args.state_data_ptrs and self.kv_args.state_data_lens:
self.engine.batch_register(
self.kv_args.state_data_ptrs, self.kv_args.state_data_lens
)
更多推荐


所有评论(0)