Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
self._use_flashinfer = attn_backend == _Backend.FLASHINFER
self._use_pallas = attn_backend == _Backend.PALLAS
self.kv_cache_layout = get_kv_cache_layout()
self.host_buffer_kv_cache_layout = self.kv_cache_layout
logger.debug("Detected attention backend %s", self.backend_name)
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)

Expand Down Expand Up @@ -864,6 +865,20 @@ def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> Non
for layer_name, kv_cache in kv_caches.items():
kv_shape = kv_cache.shape
kv_dtype = kv_cache.dtype
if (
self.kv_cache_layout == "NHD"
and self.vllm_config.kv_transfer_config is not None
and self.vllm_config.kv_transfer_config.enable_permute_local_kv
):
Comment on lines +868 to +872
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.vllm_config.kv_transfer_config is not None this should always be true, is mypy complaining?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this is added to make mypy happy

logger.info_once(
"'enable_permute_local_kv' flag is enabled while "
"device KV Layout is NHD. Init host buffer with"
" HND to better support Decode/Prefill TP_ratio > 1."
)
# Since NHD will not support Decode/Prefill TP_ratio > 1,
# we can leverage host_buffer for permute
self.host_buffer_kv_cache_layout = "HND"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we set it to NHD when homogeneous TP, tp ratio=1?

Copy link
Contributor Author

@xuechendi xuechendi Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, if you do want to keep using NHD for decode/prefill TP_Ratio == 1, you can set "enable_permute_local_kv = False" for KV_transfer_config, so it will not do permute.
However, permute + memcpy or memcpy won't have much perf diff, so I think we can leave enable_permute_local_kv = True for XPU until HND layout gets enabled.

kv_shape = tuple(kv_shape[i] for i in [0, 1, 3, 2, 4])
xfer_buffers[layer_name] = torch.empty(
kv_shape, dtype=kv_dtype, device="cpu"
)
Expand Down Expand Up @@ -1099,7 +1114,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
num_blocks=self.num_blocks,
block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name,
kv_cache_layout=self.kv_cache_layout,
kv_cache_layout=self.kv_cache_layout
if not self.use_host_buffer
else self.host_buffer_kv_cache_layout,
)
ready_event = threading.Event()
self._nixl_handshake_listener_t = threading.Thread(
Expand Down Expand Up @@ -1255,7 +1272,12 @@ def _validate_remote_agent_handshake(
assert not self._use_pallas or tp_ratio == 1, (
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
)
if not self.use_mla and nixl_agent_meta.kv_cache_layout != self.kv_cache_layout:
kv_cache_layout = (
self.kv_cache_layout
if not self.use_host_buffer
else self.host_buffer_kv_cache_layout
)
if not self.use_mla and nixl_agent_meta.kv_cache_layout != kv_cache_layout:
if (
self.kv_transfer_config.enable_permute_local_kv
and nixl_agent_meta.kv_cache_layout == "HND"
Expand All @@ -1281,9 +1303,6 @@ def _validate_remote_agent_handshake(
)
remote_block_size = remote_block_len // (self.slot_size_per_layer[0])
else:
if tp_ratio > 1 and self.device_type == "xpu":
# XPU uses NHD, hence it does not support splitting on H
raise ValueError("Heterogeneous TP is not supported on XPU")
# When MLA is not used, this is a list of the same block length
for block_len in nixl_agent_meta.block_lens:
assert block_len == remote_block_len, (
Expand Down
10 changes: 10 additions & 0 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
# check and update parallel config
parallel_config = vllm_config.parallel_config
parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
if vllm_config.kv_transfer_config is not None:
vllm_config.kv_transfer_config.enable_permute_local_kv = True

if parallel_config.distributed_executor_backend is None:
if parallel_config.world_size > 1:
Expand Down Expand Up @@ -261,6 +263,10 @@ def insert_blocks_to_device(
) -> None:
"""Copy blocks from src_cache to dst_cache on XPU."""
_src_cache = src_cache[:, src_block_indices]
if _src_cache.shape[2:] != dst_cache.shape[2:]:
# To support TP_ratio, HOST KV might be initiated with HND
# while XPU device KV is with NHD
_src_cache = _src_cache.permute(0, 1, 3, 2, 4)
dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)

@classmethod
Expand All @@ -273,4 +279,8 @@ def swap_out_blocks_to_host(
) -> None:
"""Copy blocks from XPU to host (CPU)."""
_src_cache = src_cache[:, src_block_indices]
if _src_cache.shape[2:] != dst_cache.shape[2:]:
# XPU device KV is with NHD while HOST KV
# might be initiated with HND for TP_ratio support
_src_cache = _src_cache.permute(0, 1, 3, 2, 4)
dst_cache[:, dst_block_indices] = _src_cache.cpu()