Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions docs/source/user_guide/feature_guide/external_dp.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# External DP

For larger scale deployments especially, it can make sense to handle the orchestration and load balancing of data parallel ranks externally.

In this case, it's more convenient to treat each DP rank like a separate vLLM deployment, with its own endpoint, and have an external router balance HTTP requests between them, making use of appropriate real-time telemetry from each server for routing decisions.

## Getting Start

The functionality of [external DP](https://docs.vllm.ai/en/latest/serving/data_parallel_deployment/?h=external#external-load-balancing) is already natively supported by vLLM. In vllm-ascend we provide two enhanced functionalities:

1. A launch script which helps to launch multi vllm instances in one command.
2. A request-length-aware load balance proxy for external dp.

This tutorial will introduce the usage of them.

### Prerequisites:
- Python 3.10+
- Install dependencies needed by load-balance proxy server:

```
pip install fastapi httpx uvicorn
```

## Starting Exeternal DP Servers
First you need to have at least two vLLM servers running in data parallel. These can be mock servers or actual vLLM servers. Note that this proxy also works with only one vLLM server running, but will fall back to direct request forwarding which is meaningless.

You can start external vLLM dp servers one-by-one manually or using the launch script in `examples/external_online_dp`. For scenarios of large dp size across multi nodes, we recommend using our launch script for convenience.

### Manually Launch

```
# This example shows how to manually launch a vLLM service with DP size 2 in one node.
vllm serve --host 0.0.0.0 --port 8100 --data-parallel-size 2 --data-parallel-rank 0 ... # vLLM DP0
vllm serve --host 0.0.0.0 --port 8101 --data-parallel-size 2 --data-parallel-rank 1 ... # vLLM DP1
```

### Use Launch Script
Firstly, you need to modify the `examples/external_online_dp/run_dp_template.sh` according to your vLLM configuration. Then you can use `examples/external_online_dp/launch_online_dp.py` to launch multiple vLLM instances in one command each node. It will internally call `examples/external_online_dp/run_dp_template.sh` for each DP rank with proper DP-related parameters.

An example of running external DP in one single node:

```
cd examples/external_online_dp
# running DP4 TP4 in a node with 16 NPUs
python launch_online_dp.py --dp-size 4 --tp-size 4 --dp-size-local 4 --dp-rank-start 0 --dp-address x.x.x.x --dp-rpc-port 12342
```

An example of running external DP in two nodes:

```
cd examples/external_online_dp
# running DP4 TP4 in two nodes with 8 NPUs each
# Node 0 holds DP0 DP1 and node 1 holds DP2 DP3
# Here x.x.x.x:12342 is served as the common data parallel RPC address

# On node 0:
python launch_online_dp.py --dp-size 4 --tp-size 4 --dp-size-local 2 --dp-rank-start 0 --dp-address x.x.x.x --dp-rpc-port 12342

# On node 1:
python launch_online_dp.py --dp-size 4 --tp-size 4 --dp-size-local 2 --dp-rank-start 2 --dp-address x.x.x.x --dp-rpc-port 12342
```

## Starting Load-balance Proxy Server
After all vLLM DP instances are launched, you can now launch the load-balance proxy server which serves as entrypoint for coming requests and load balance them between vLLM DP instances.

The proxy server has following features:
- Load balances requests to multiple vLLM servers based on request length.
- Supports OpenAI-compatible `/v1/completions` and `/v1/chat/completions` endpoints.
- Streams responses from backend servers to clients.

To run the proxy server, you need to specify the host and port for each vLLM DP Instance:

```
# For example, we have already started two DP instances in single node:
# python launch_online_dp.py --dp-size 2 --tp-size 8 --dp-size-local 2 --dp-rank-start 0 --dp-address x.x.x.x --dp-rpc-port 12342
# By default, launch_online_dp.py will launch vLLM instances from starting port 9000,
# so the vLLM ports for DP0 and DP1 are 9000 and 9001 separately.
# Then you can start the load-balance proxy server by:
cd examples/external_online_dp
python dp_load_balance_proxy_server.py \
--host 0.0.0.0 --port 8000 \
--dp-hosts 127.0.0.1 127.0.0.1 \
--dp-ports 9000 9001 \
```

After this, you can directly send requests to the proxy server and run DP with external load-balance.
1 change: 1 addition & 0 deletions docs/source/user_guide/feature_guide/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ eplb_swift_balancer
netloader
dynamic_batch
kv_pool_mooncake
external_dp
:::
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# - Streams responses from backend servers to clients.
#
# Prerequisites:
# - Python 3.8+
# - Python 3.10+
# - Install dependencies:
# pip install fastapi httpx uvicorn vllm
#
Expand Down
18 changes: 17 additions & 1 deletion examples/external_online_dp/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Here is an example guiding how to use `launch_online_dp.py` to launch external dp server in vllm. User can easily launch external dp server following the steps below:
Here is an example guiding how to use `launch_online_dp.py` to launch external dp vllm servers. User can easily launch external dp servers following the steps below:

### Modify parameters in `run_dp_template.sh`
`run_dp_template.sh` is an template script used to launch each dp vllm instance separately. It will be called by `launch_online_dp.py` in multi threads and most of its configurations are set by `launch_online_dp.py`. Parameters you need to set manually include:
Expand Down Expand Up @@ -36,3 +36,19 @@ python launch_online_dp.py --dp-size 4 --tp-size 4 --dp-size-local 2 --dp-rank-s
python launch_online_dp.py --dp-size 4 --tp-size 4 --dp-size-local 2 --dp-rank-start 2 --dp-address x.x.x.x --dp-rpc-port 12342
```

### (Optional) Run `dp_load_balance_proxy_server.py` to load balance requests between external dp servers
External dp server means that you need to handle load balance between multiple dp instances out of vllm by implementing your custom proxy server. Here we provide an example of request-length-aware dp load-balance proxy server for you. The arguments of `dp_load_balance_proxy_server.py` include:

1. `--port`: port of proxy server, default 8000
2. `--host`: host address of proxy server, default localhost
3. `--dp-hosts`: host addresses of external dp servers
4. `--dp-ports`: ports of external dp servers, the number of dp ports should be the same as dp hosts.
5. `--max-retries`: Max number of retries for HTTP requests, default 3

For example, if you have two external dp servers running in x.x.x.a:10001 and x.x.x.b:10002, then you can start the proxy server by:

```(python)
python dp_load_balance_proxy_server.py --host x.x.x.c --port 8000 --dp-hosts x.x.x.a x.x.x.b --dp-ports 10001 10002
```

which will then serve as the entrypoint for inference requests at x.x.x.c:8000, and load balance coming requests between these two external dp servers according to request length.
Loading
Loading