-
-
Notifications
You must be signed in to change notification settings - Fork 10.7k
[NIXL] use Host buffer to support TP_ratio > 1 for XPU #27140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -752,6 +752,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): | |
self._use_flashinfer = attn_backend == _Backend.FLASHINFER | ||
self._use_pallas = attn_backend == _Backend.PALLAS | ||
self.kv_cache_layout = get_kv_cache_layout() | ||
self.host_buffer_kv_cache_layout = self.kv_cache_layout | ||
logger.debug("Detected attention backend %s", self.backend_name) | ||
logger.debug("Detected kv cache layout %s", self.kv_cache_layout) | ||
|
||
|
@@ -864,6 +865,20 @@ def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> Non | |
for layer_name, kv_cache in kv_caches.items(): | ||
kv_shape = kv_cache.shape | ||
kv_dtype = kv_cache.dtype | ||
if ( | ||
self.kv_cache_layout == "NHD" | ||
and self.vllm_config.kv_transfer_config is not None | ||
and self.vllm_config.kv_transfer_config.enable_permute_local_kv | ||
): | ||
logger.info_once( | ||
"'enable_permute_local_kv' flag is enabled while " | ||
"device KV Layout is NHD. Init host buffer with" | ||
" HND to better support Decode/Prefill TP_ratio > 1." | ||
) | ||
# Since NHD will not support Decode/Prefill TP_ratio > 1, | ||
# we can leverage host_buffer for permute | ||
self.host_buffer_kv_cache_layout = "HND" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we set it to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, if you do want to keep using NHD for decode/prefill TP_Ratio == 1, you can set "enable_permute_local_kv = False" for KV_transfer_config, so it will not do permute. |
||
kv_shape = tuple(kv_shape[i] for i in [0, 1, 3, 2, 4]) | ||
xfer_buffers[layer_name] = torch.empty( | ||
kv_shape, dtype=kv_dtype, device="cpu" | ||
) | ||
|
@@ -1099,7 +1114,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): | |
num_blocks=self.num_blocks, | ||
block_lens=self.block_len_per_layer, | ||
attn_backend_name=self.backend_name, | ||
kv_cache_layout=self.kv_cache_layout, | ||
kv_cache_layout=self.kv_cache_layout | ||
if not self.use_host_buffer | ||
else self.host_buffer_kv_cache_layout, | ||
) | ||
ready_event = threading.Event() | ||
self._nixl_handshake_listener_t = threading.Thread( | ||
|
@@ -1255,7 +1272,12 @@ def _validate_remote_agent_handshake( | |
assert not self._use_pallas or tp_ratio == 1, ( | ||
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet." | ||
) | ||
if not self.use_mla and nixl_agent_meta.kv_cache_layout != self.kv_cache_layout: | ||
kv_cache_layout = ( | ||
self.kv_cache_layout | ||
if not self.use_host_buffer | ||
else self.host_buffer_kv_cache_layout | ||
) | ||
if not self.use_mla and nixl_agent_meta.kv_cache_layout != kv_cache_layout: | ||
if ( | ||
self.kv_transfer_config.enable_permute_local_kv | ||
and nixl_agent_meta.kv_cache_layout == "HND" | ||
|
@@ -1281,9 +1303,6 @@ def _validate_remote_agent_handshake( | |
) | ||
remote_block_size = remote_block_len // (self.slot_size_per_layer[0]) | ||
else: | ||
if tp_ratio > 1 and self.device_type == "xpu": | ||
# XPU uses NHD, hence it does not support splitting on H | ||
raise ValueError("Heterogeneous TP is not supported on XPU") | ||
# When MLA is not used, this is a list of the same block length | ||
for block_len in nixl_agent_meta.block_lens: | ||
assert block_len == remote_block_len, ( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.vllm_config.kv_transfer_config is not None
this should always be true, is mypy complaining?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, this is added to make mypy happy