Skip to content

[Feature] Support PD Disaggregation #1329

@ZiyiTsang

Description

@ZiyiTsang

This feature will maintain backward compatibility with the current APIs in areal/api/.

Background

Rollout takes up most of the training time, while PD separation can better manage rollback resources and is an important means of optimizing performance in production environments

AReaL currently uses homogeneous inference workers (each worker handles both prefill and decode). Prefill-Decode (PD) disaggregation splits prefill and decode onto different GPU pools, transferring KV cache via RDMA/NVLink, which significantly improves throughput for RL post-training workloads (short prompts, long decode sequences).

SGLang natively supports PD disaggregation. This feature can be added in a relatively non-intrusive way by leveraging SGLang's native APIs (ServerArgs, sgl-router REST API if used) without modifying the engine source code. However, integrating PD into AReaL still requires changes to the router/gateway layer, as the current custom Python router does not support PD-aware routing or typed worker registration.

Potential Solution

1. Config Layer (areal/api/cli_args.py)

Add disaggregation_mode, disaggregation_transfer_backend, disaggregation_bootstrap_port, and prefill_round_robin_balance fields to SGLangConfig, transparently passing them to SGLang's native ServerArgs.

To Implement:

  • disaggregation_mode: Literal["null", "prefill", "decode"] = "null"
  • disaggregation_transfer_backend: str | None = None
  • disaggregation_bootstrap_port: int = 8998
  • InferenceEngineConfig.pd_disaggregation: bool = False
  • Validation: requires sglang backend, v2 controller, PP=1, DP=2

2. Resource Allocation Syntax (areal/api/alloc_mode.py)

The current DSL + operator means "training-inference GPU pool separation". Extend it to support multiple inference groups within the same model (e.g., prefill + decode).

Introduce InferenceGroupConfig to encapsulate worker_type + parallel + overrides. All CLI paths converge to a unified data structure.

3. Launcher (areal/infra/launcher/sglang_server.py)

Change from "single-group homogeneous launch" to "multi-group heterogeneous launch", where prefill and decode each get their own independent GPU bundles.

Use "port cursor relay" for port allocation: prefill gets an extra bootstrap_port, decode does not, avoiding port conflicts across groups.

4. Controller (areal/experimental/inference_service/controller/controller.py)

Change _async_fork_inf_servers from launching a single server type to iterating over multiple groups: prefill with bootstrap_port, decode with prefill_round_robin_balance.

Implemented:

  • Parallel fork of prefill (group 0) and decode (group 1)
  • Bootstrap port tracked per prefill group
  • Data proxies launched with backend_addrs override
  • _inf_addrs populated for downstream compatibility

And Support arbitrary number of servers per group (prefill DP > 1, decode DP > 1).

5. Router/Gateway

Files modified:

  • areal/experimental/inference_service/router/state.py
  • areal/experimental/inference_service/router/app.py
  • areal/experimental/inference_service/gateway/streaming.py
  • areal/experimental/inference_service/gateway/app.py
  • areal/experimental/inference_service/gateway/__main__.py
  • areal/experimental/inference_service/gateway/config.py

Router:

  • worker_type in registration ("regular" | "prefill" | "decode")
  • bootstrap_port for prefill workers
  • _type_workers bucketing in WorkerRegistry
  • /route_pd endpoint: picks one prefill + one decode, generates bootstrap_room

Gateway:

  • pd_disaggregation CLI flag and config
  • query_router_pd(): get PDPair
  • pd_dual_dispatch(): concurrent dispatch with bootstrap injection
  • Reject streaming in PD mode (400)

6. Weight Synchronization

Build a flat NCCL group with the training rank 0 as the broadcast source and all inference engines (prefill + decode) as receivers.

Compute cumulative rank_offset from engine_gpu_counts to naturally support heterogeneous TP (prefill TP != decode TP).

The training engine is completely unaware — it only sees "engine list + GPU counts", with no knowledge of downstream roles.

Additional Information

Reference: ServerGroupConfig(worker_type) pattern, port cursor relay, _compute_server_args three-way branching, flat NCCL group + cumulative rank_offset.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions