Skip to content

Commit 0925d44

Browse files
authored
[PD Disaggregation] support different tp_size for prefill and decode (#5296)
* up * up * up * fix
1 parent 54119cf commit 0925d44

File tree

13 files changed

+584
-36
lines changed

13 files changed

+584
-36
lines changed

fastdeploy/cache_manager/cache_messager.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@ def parse_args():
5555
default="mixed",
5656
help="splitwise role, can be decode, prefill or mixed",
5757
)
58-
parser.add_argument("--rank", type=int, default=0, help="current rank")
58+
parser.add_argument("--rank", type=int, default=0, help="local tp rank id")
5959
parser.add_argument("--device_id", type=int, default=0, help="device id")
6060
parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
6161
parser.add_argument("--key_cache_shape", type=str, default="", help="key cache shape")
6262
parser.add_argument("--value_cache_shape", type=str, default="", help="value cache shape")
6363
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
64-
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
64+
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel, i.e. tp_size, tp_num")
6565
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
6666
parser.add_argument(
6767
"--protocol",
@@ -208,6 +208,8 @@ def __init__(
208208
max_block_num,
209209
block_bytes,
210210
rdma_port,
211+
nranks,
212+
rank,
211213
)
212214

213215
self.gpu_id = gpu_id
@@ -507,6 +509,8 @@ def __init__(
507509
max_block_num,
508510
block_bytes,
509511
rdma_port,
512+
nranks,
513+
rank,
510514
)
511515

512516
self.gpu_id = gpu_id
@@ -595,6 +599,7 @@ def prefill_layerwise_send_cache_thread(self):
595599
block_id_end = prefilled_token_num // self.block_size # [block_id_start, block_id_end)
596600
block_start_end_list.append((block_id_start, block_id_end))
597601
current_prefilled_token_num_list.append(prefilled_token_num)
602+
598603
while True: # from layer0 to last layer
599604
sended_layer_idx = self.idx_cache_task_dict[batch_engine_signals[0][0]]["sended_layer_id"]
600605
start_layer_idx = sended_layer_idx + 1
@@ -633,13 +638,27 @@ def prefill_layerwise_send_cache_thread(self):
633638
current_transfer_protocol = task["transfer_protocol"]
634639
if task["transfer_protocol"] == "rdma":
635640
target_ip = task["ip"]
636-
target_id = int(task["rdma_ports"][self.rank])
641+
# Default decode_tp_size to prefill tp_size (self.nranks) if not specified
642+
decode_tp_size = task.get("decode_tp_size", self.nranks)
643+
if len(task["rdma_ports"]) == self.nranks:
644+
target_id = int(task["rdma_ports"][self.rank])
645+
elif len(task["rdma_ports"]) == 1:
646+
target_id = task["rdma_ports"][0]
647+
else:
648+
task["status"] = "the tp_size of prefill and decode is mismatch"
649+
continue
650+
637651
if "error" in task["status"]:
638652
continue
639653

640654
# TODO: use is connected to check if the connection is still alive
641-
logger.debug(f"rdma, start connect decode, {target_ip}:{target_id}")
642-
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
655+
logger.debug(
656+
f"rdma, start connect decode, {target_ip}:{target_id}, "
657+
f"prefill_tp_size:{self.nranks}, decode_tp_size:{decode_tp_size}"
658+
)
659+
status = self.messager[current_transfer_protocol].connect(
660+
target_ip, target_id, decode_tp_size
661+
)
643662
if status:
644663
logger.info(f"connect to {target_ip}:{target_id} success")
645664
else:
@@ -762,12 +781,22 @@ def _handle_connect_task(self):
762781
self.engine_worker_queue.connect_task_barrier.wait()
763782
logger.info(f"_handle_connect_task recv task: {task}")
764783
task_id = task["task_id"]
765-
ip, rdma_port = task["ip"], task["rdma_ports"][self.rank]
766-
status = self.messager["rdma"].connect(ip, rdma_port)
767-
if not status:
784+
ip = task["ip"]
785+
# Default decode_tp_size to self.nranks (number of ranks) if not specified in the task.
786+
decode_tp_size = task.get("decode_tp_size", self.nranks)
787+
rdma_ports = task["rdma_ports"]
788+
rdma_ports_len = len(rdma_ports)
789+
if not (rdma_ports_len == 1 or rdma_ports_len == self.nranks):
790+
# TODO: support other cases
791+
logger.error(f"rdma_ports length should be 1 or equal to mp_num, but got {rdma_ports_len}")
768792
response = {"task_id": task_id, "success": False}
769793
else:
770-
response = {"task_id": task_id, "success": True}
794+
port = rdma_ports[0] if rdma_ports_len == 1 else rdma_ports[self.rank]
795+
status = self.messager["rdma"].connect(ip, port, decode_tp_size)
796+
if not status:
797+
response = {"task_id": task_id, "success": False}
798+
else:
799+
response = {"task_id": task_id, "success": True}
771800
self.engine_worker_queue.connect_task_response_barrier.wait()
772801
self.engine_worker_queue.put_connect_rdma_task_response(response)
773802
except Exception as e:

fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/kvcache_connection.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ struct Connection {
142142
int wc_target_count;
143143

144144
// Configuration
145+
int decode_tp_size;
145146
int layer_number;
146147
int block_number;
147148
int block_byte_size;

fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/kvcache_rdma.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,15 @@ class RDMACommunicator {
2424
std::vector<int64_t> local_key_cache,
2525
std::vector<int64_t> local_value_cache,
2626
int block_number,
27-
int block_bytes);
27+
int block_bytes,
28+
int prefill_tp_size,
29+
int prefill_tp_idx);
2830
~RDMACommunicator();
2931

3032
// Connection management
31-
int connect(const std::string& dst_ip, const std::string& dst_port);
33+
int connect(const std::string& dst_ip,
34+
const std::string& dst_port,
35+
int dest_tp_size);
3236
bool is_connected(const std::string& dst_ip, const std::string& dst_port);
3337

3438
// Core functionality
@@ -120,6 +124,8 @@ class RDMACommunicator {
120124
int block_number; // Number of blocks
121125
int block_size_byte; // Size of each block in bytes
122126
int layer_number; // Number of layers
127+
int prefill_tp_size; // tensor parallelism size for prefill
128+
int prefill_tp_idx; // tensor parallelism index for prefill
123129

124130
std::vector<std::vector<void*>>
125131
local_cache_key_ptr_per_layer; // Per-layer key pointers

fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/kvcache_rdma.cpp

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
* @param local_key_cache Vector of local key cache pointers
4242
* @param local_value_cache Vector of local value cache pointers
4343
* @param block_number Number of blocks in cache
44-
* @param block_bytes Size of each block in bytes
44+
* @param block_bytes Bytes of each block in each tp rank
4545
*
4646
* @throws std::runtime_error If initialization fails
4747
*/
@@ -51,14 +51,18 @@ RDMACommunicator::RDMACommunicator(std::string& role,
5151
std::vector<int64_t> local_key_cache,
5252
std::vector<int64_t> local_value_cache,
5353
int block_number,
54-
int block_bytes)
54+
int block_bytes,
55+
int prefill_tp_size,
56+
int prefill_tp_idx)
5557
: splitwise_role(role),
5658
gpu_idx(gpu_idx),
5759
port(port),
5860
local_cache_key_ptr_layer_head_(std::move(local_key_cache)),
5961
local_cache_value_ptr_layer_head_(std::move(local_value_cache)),
6062
block_number(block_number),
6163
block_size_byte(block_bytes),
64+
prefill_tp_size(prefill_tp_size),
65+
prefill_tp_idx(prefill_tp_idx),
6266
RDMACommunicator_status(0),
6367
rdma_event_channel_epoll_fd(-1) {
6468
try {
@@ -480,11 +484,14 @@ std::string RDMACommunicator::fetch_local_ip() {
480484
*
481485
* @param dst_ip Destination IP address
482486
* @param dst_port Destination port
487+
* @param dest_tp_size Default 0: assumes dest has same tp_size as source;
488+
* otherwise specifies decode tp_size
483489
* @return ConnStatus::kConnected ConnStatus::kError;
484490
*/
485491

486492
int RDMACommunicator::connect(const std::string& dst_ip,
487-
const std::string& dst_port) {
493+
const std::string& dst_port,
494+
int dest_tp_size = 0) {
488495
std::string url = dst_ip + ":" + dst_port;
489496

490497
// Initialize IB devices if not already done
@@ -515,6 +522,10 @@ int RDMACommunicator::connect(const std::string& dst_ip,
515522
ctx->conn.layer_number = layer_number;
516523
ctx->conn.block_number = block_number;
517524
ctx->conn.block_byte_size = block_size_byte;
525+
if (dest_tp_size > 0)
526+
ctx->conn.decode_tp_size = dest_tp_size;
527+
else
528+
ctx->conn.decode_tp_size = prefill_tp_size;
518529

519530
// Get port information for the connection
520531
if (get_port_info(ctx->context, ib_dev->port, &ctx->portinfo)) {
@@ -537,9 +548,6 @@ int RDMACommunicator::connect(const std::string& dst_ip,
537548
ERR("Couldn't getexchange port infodestinations");
538549
return static_cast<int>(ConnStatus::kError);
539550
} else {
540-
std::lock_guard<std::mutex> lock(mutex_);
541-
ctx->conn.connected = 1;
542-
conn_map[url] = ctx;
543551
client_exchange_mr(ctx);
544552
}
545553

@@ -589,6 +597,10 @@ int RDMACommunicator::connect(const std::string& dst_ip,
589597
}
590598
}
591599

600+
std::lock_guard<std::mutex> lock(mutex_);
601+
ctx->conn.connected = 1;
602+
conn_map[url] = ctx;
603+
592604
WARN("connect end ....");
593605
return static_cast<int>(ConnStatus::kConnected);
594606
}
@@ -649,6 +661,7 @@ int RDMACommunicator::client_listener() {
649661

650662
bool RDMACommunicator::is_connected(const std::string& dst_ip,
651663
const std::string& dst_port) {
664+
std::lock_guard<std::mutex> lock(mutex_);
652665
std::string url = dst_ip + ":" + dst_port;
653666
return conn_map.find(url) != conn_map.end();
654667
}
@@ -889,17 +902,25 @@ int RDMACommunicator::write_cache(const std::string& ip,
889902
uint32_t cache_value_rkey =
890903
ctx->conn.write_cache_value_remote_rkey_list[layer_idx];
891904
uint32_t crc_cache_key_rkey, crc_cache_value_rkey;
905+
bool pd_tp_size_is_same = prefill_tp_size == ctx->conn.decode_tp_size;
906+
uint64_t offset_in_block =
907+
pd_tp_size_is_same ? 0 : block_size_byte * prefill_tp_idx;
908+
uint64_t total_block_size_byte =
909+
pd_tp_size_is_same ? block_size_byte : block_size_byte * prefill_tp_size;
892910

893911
for (size_t block_index = 0; block_index < block_num; ++block_index) {
894912
char* char_ptr = static_cast<char*>(
895913
ctx->conn.write_cache_key_remote_ptr_list[layer_idx]);
896-
cache_key_remote_addr[block_index] =
897-
(uint64_t(char_ptr + remote_block_ids[block_index] * block_size_byte));
914+
cache_key_remote_addr[block_index] = (uint64_t(
915+
char_ptr + remote_block_ids[block_index] * total_block_size_byte +
916+
offset_in_block));
898917
char_ptr = static_cast<char*>(
899918
ctx->conn.write_cache_value_remote_ptr_list[layer_idx]);
900-
cache_value_remote_addr[block_index] =
901-
(uint64_t(char_ptr + remote_block_ids[block_index] * block_size_byte));
919+
cache_value_remote_addr[block_index] = (uint64_t(
920+
char_ptr + remote_block_ids[block_index] * total_block_size_byte +
921+
offset_in_block));
902922
}
923+
903924
ctx->conn.wc_target_count = 0;
904925
for (int i = 0; i < 2; ++i) {
905926
bool is_key = (i == 0);

fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/pybind.cpp

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,39 @@ PYBIND11_MODULE(rdma_comm, m) {
1414
std::vector<int64_t>,
1515
std::vector<int64_t>,
1616
int,
17-
int>())
18-
.def("connect", &RDMACommunicator::connect)
19-
.def("is_connected", &RDMACommunicator::is_connected)
20-
.def("write_cache", &RDMACommunicator::write_cache);
17+
int,
18+
int,
19+
int>(),
20+
py::arg("splitwise_role"),
21+
py::arg("gpu_idx"),
22+
py::arg("port"),
23+
py::arg("key_cache_ptrs"),
24+
py::arg("value_cache_ptrs"),
25+
py::arg("block_number"),
26+
py::arg("block_bytes"),
27+
py::arg("prefill_tp_size") = 1,
28+
py::arg("prefill_tp_idx") = 0)
29+
.def("connect",
30+
&RDMACommunicator::connect,
31+
py::arg("dst_ip"),
32+
py::arg("dst_port"),
33+
py::arg("dst_tp_size") =
34+
0, // Default 0: assumes dest has same tp_size as source;
35+
// otherwise specifies decode tp_size
36+
py::call_guard<py::gil_scoped_release>())
37+
.def("is_connected",
38+
&RDMACommunicator::is_connected,
39+
py::arg("dst_ip"),
40+
py::arg("dst_port"),
41+
py::call_guard<py::gil_scoped_release>())
42+
.def("write_cache",
43+
&RDMACommunicator::write_cache,
44+
py::arg("dst_ip"),
45+
py::arg("dst_port"),
46+
py::arg("local_block_ids"),
47+
py::arg("remote_block_ids"),
48+
py::arg("layer_idx"),
49+
py::call_guard<py::gil_scoped_release>());
2150

2251
#ifdef VERSION_INFO
2352
m.attr("__version__") = VERSION_INFO;

fastdeploy/cache_manager/transfer_factory/rdma_cache_transfer.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def __init__(
3434
max_block_num,
3535
block_bytes,
3636
rdma_port,
37+
prefill_tp_size,
38+
prefill_tp_idx,
3739
):
3840
try:
3941
import rdma_comm
@@ -51,12 +53,16 @@ def __init__(
5153
cache_v_ptr_list,
5254
max_block_num,
5355
block_bytes,
56+
prefill_tp_size,
57+
prefill_tp_idx,
5458
)
5559
self.splitwise_role = splitwise_role
5660
self.connected_rdma = set()
57-
logger.info(f"init rdma messager {gpu_id} {rdma_port}")
61+
logger.info(
62+
f"init rdma messager {gpu_id} {rdma_port}, prefill_tp_size: {prefill_tp_size}, prefill_tp_idx: {prefill_tp_idx}"
63+
)
5864

59-
def connect(self, ip, port):
65+
def connect(self, ip, port, tp_size):
6066
"""
6167
Connect to remote gpu and write cache.
6268
"""
@@ -65,7 +71,7 @@ def connect(self, ip, port):
6571
if ret:
6672
return True
6773

68-
ret = self.messager.connect(ip, str(port))
74+
ret = self.messager.connect(ip, str(port), tp_size)
6975
logger.info(f"connect to remote rdma address {ip}:{port} status is {ret}")
7076
return ret == 0
7177

fastdeploy/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1888,6 +1888,7 @@ def init_cache_info(self):
18881888
logger.info(f"disaggregate_info: {self.disaggregate_info}")
18891889

18901890
if self.router_config:
1891+
# the information for registering this server to router
18911892
self.register_info = {
18921893
"role": self.scheduler_config.splitwise_role,
18931894
"host_ip": self.host_ip,
@@ -1897,6 +1898,7 @@ def init_cache_info(self):
18971898
"engine_worker_queue_port": engine_worker_queue_port,
18981899
"device_ids": self.local_device_ids,
18991900
"transfer_protocol": self.cache_config.cache_transfer_protocol.split(","),
1901+
"tp_size": self.parallel_config.tensor_parallel_size,
19001902
}
19011903
logger.info(f"register_info: {self.register_info}")
19021904

fastdeploy/envs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@
150150
"FD_ENABLE_PDL": lambda: int(os.getenv("FD_ENABLE_PDL", "1")),
151151
# "Number of tokens in the group for Mixture of Experts (MoE) computation processing on HPU"
152152
"FD_HPU_CHUNK_SIZE": lambda: int(os.getenv("FD_HPU_CHUNK_SIZE", "64")),
153+
"FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS": lambda: int(os.getenv("FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS", "30")),
153154
}
154155

155156

fastdeploy/router/router.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def __init__(self, args):
9595
async def register_instance(self, instance_info_dict: dict):
9696
"""Register an instance asynchronously"""
9797
try:
98-
inst_info = InstanceInfo(**instance_info_dict)
98+
inst_info = InstanceInfo.from_dict(instance_info_dict)
9999
except Exception as e:
100100
logger.error(f"register instance failed: {e}")
101101
raise
@@ -173,11 +173,17 @@ async def handle_splitwise_request(self, request_data: dict, endpoint_name: str)
173173
logger.debug(f"Received request: {request_data}")
174174
prefill_server, decode_server = await self.select_pd()
175175

176+
if prefill_server.tp_size != decode_server.tp_size and decode_server.tp_size != 1:
177+
raise HTTPException(
178+
status_code=400,
179+
detail="The tp_size of prefill and decode should be equal or the tp_size of decode is 1",
180+
)
181+
176182
# TODO: unify the disaggregate_info in server and remove redundancy params
177183
is_same_node = prefill_server.host_ip == decode_server.host_ip
178-
use_ipc = (
179-
is_same_node and "ipc" in prefill_server.transfer_protocol and "ipc" in decode_server.transfer_protocol
180-
)
184+
is_support_ipc = "ipc" in prefill_server.transfer_protocol and "ipc" in decode_server.transfer_protocol
185+
is_same_tp_size = prefill_server.tp_size == decode_server.tp_size
186+
use_ipc = is_same_node and is_support_ipc and is_same_tp_size
181187

182188
cache_info = {}
183189
if use_ipc:

0 commit comments

Comments
 (0)