diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index d73e05562951..9a1057e8c051 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -752,6 +752,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self._use_flashinfer = attn_backend == _Backend.FLASHINFER self._use_pallas = attn_backend == _Backend.PALLAS self.kv_cache_layout = get_kv_cache_layout() + self.host_buffer_kv_cache_layout = self.kv_cache_layout logger.debug("Detected attention backend %s", self.backend_name) logger.debug("Detected kv cache layout %s", self.kv_cache_layout) @@ -864,6 +865,20 @@ def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> Non for layer_name, kv_cache in kv_caches.items(): kv_shape = kv_cache.shape kv_dtype = kv_cache.dtype + if ( + self.kv_cache_layout == "NHD" + and self.vllm_config.kv_transfer_config is not None + and self.vllm_config.kv_transfer_config.enable_permute_local_kv + ): + logger.info_once( + "'enable_permute_local_kv' flag is enabled while " + "device KV Layout is NHD. Init host buffer with" + " HND to better support Decode/Prefill TP_ratio > 1." + ) + # Since NHD will not support Decode/Prefill TP_ratio > 1, + # we can leverage host_buffer for permute + self.host_buffer_kv_cache_layout = "HND" + kv_shape = tuple(kv_shape[i] for i in [0, 1, 3, 2, 4]) xfer_buffers[layer_name] = torch.empty( kv_shape, dtype=kv_dtype, device="cpu" ) @@ -1099,7 +1114,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): num_blocks=self.num_blocks, block_lens=self.block_len_per_layer, attn_backend_name=self.backend_name, - kv_cache_layout=self.kv_cache_layout, + kv_cache_layout=self.kv_cache_layout + if not self.use_host_buffer + else self.host_buffer_kv_cache_layout, ) ready_event = threading.Event() self._nixl_handshake_listener_t = threading.Thread( @@ -1255,7 +1272,12 @@ def _validate_remote_agent_handshake( assert not self._use_pallas or tp_ratio == 1, ( "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." ) - if not self.use_mla and nixl_agent_meta.kv_cache_layout != self.kv_cache_layout: + kv_cache_layout = ( + self.kv_cache_layout + if not self.use_host_buffer + else self.host_buffer_kv_cache_layout + ) + if not self.use_mla and nixl_agent_meta.kv_cache_layout != kv_cache_layout: if ( self.kv_transfer_config.enable_permute_local_kv and nixl_agent_meta.kv_cache_layout == "HND" @@ -1281,9 +1303,6 @@ def _validate_remote_agent_handshake( ) remote_block_size = remote_block_len // (self.slot_size_per_layer[0]) else: - if tp_ratio > 1 and self.device_type == "xpu": - # XPU uses NHD, hence it does not support splitting on H - raise ValueError("Heterogeneous TP is not supported on XPU") # When MLA is not used, this is a list of the same block length for block_len in nixl_agent_meta.block_lens: assert block_len == remote_block_len, ( diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 5799f97b8038..44bb73acbb6e 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -160,6 +160,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # check and update parallel config parallel_config = vllm_config.parallel_config parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker" + if vllm_config.kv_transfer_config is not None: + vllm_config.kv_transfer_config.enable_permute_local_kv = True if parallel_config.distributed_executor_backend is None: if parallel_config.world_size > 1: @@ -261,6 +263,10 @@ def insert_blocks_to_device( ) -> None: """Copy blocks from src_cache to dst_cache on XPU.""" _src_cache = src_cache[:, src_block_indices] + if _src_cache.shape[2:] != dst_cache.shape[2:]: + # To support TP_ratio, HOST KV might be initiated with HND + # while XPU device KV is with NHD + _src_cache = _src_cache.permute(0, 1, 3, 2, 4) dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device) @classmethod @@ -273,4 +279,8 @@ def swap_out_blocks_to_host( ) -> None: """Copy blocks from XPU to host (CPU).""" _src_cache = src_cache[:, src_block_indices] + if _src_cache.shape[2:] != dst_cache.shape[2:]: + # XPU device KV is with NHD while HOST KV + # might be initiated with HND for TP_ratio support + _src_cache = _src_cache.permute(0, 1, 3, 2, 4) dst_cache[:, dst_block_indices] = _src_cache.cpu()