From ae09ac2dd6ef1b90b0a9b0773f5d101e39853b76 Mon Sep 17 00:00:00 2001 From: lilinsiman Date: Thu, 30 Oct 2025 09:50:10 +0800 Subject: [PATCH 1/2] add new test case for aclgraph capture and replay Signed-off-by: lilinsiman --- .github/workflows/_e2e_test.yaml | 1 + .../multicard/test_aclgraph_capture_replay.py | 177 ++++++++++++++++++ 2 files changed, 178 insertions(+) create mode 100644 tests/e2e/multicard/test_aclgraph_capture_replay.py diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index be5b43e6373..20779f0442c 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -179,6 +179,7 @@ jobs: VLLM_USE_MODELSCOPE: True if: ${{ inputs.type == 'full' }} run: | + pytest -sv tests/e2e/multicard/test_aclgraph_capture_replay.py pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py pytest -sv tests/e2e/multicard/test_full_graph_mode.py pytest -sv tests/e2e/multicard/test_data_parallel.py diff --git a/tests/e2e/multicard/test_aclgraph_capture_replay.py b/tests/e2e/multicard/test_aclgraph_capture_replay.py new file mode 100644 index 00000000000..bcd27170311 --- /dev/null +++ b/tests/e2e/multicard/test_aclgraph_capture_replay.py @@ -0,0 +1,177 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import contextlib +import gc +import math +import multiprocessing +import os +import sys +from time import sleep +from unittest.mock import patch + +import pytest +import torch +from vllm import LLM, SamplingParams +from vllm.distributed.parallel_state import ( # noqa E402 + destroy_distributed_environment, destroy_model_parallel) + +MODELS = ["Qwen/Qwen3-0.6B", "vllm-ascend/DeepSeek-V2-Lite-W8A8"] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("max_tokens", [4]) +@patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1"}) +def test_aclgraph_capture_replay_dp2( + model: str, + max_tokens: int, +) -> None: + # HCCL_OP_EXPANSION_MODE determines how max_num_batch_sizes is computed. + if 'VLLM_WORKER_MULTIPROC_METHOD' in os.environ: + del os.environ["VLLM_WORKER_MULTIPROC_METHOD"] + if 'HCCL_OP_EXPANSION_MODE' in os.environ: + del os.environ['HCCL_OP_EXPANSION_MODE'] + dp_size = 2 + tp_size = 1 + replay_counter = multiprocessing.Value("i", 0) + capture_counter = multiprocessing.Value("i", 0) + num_hidden_layers_shared = multiprocessing.Value("i", -1) + num_execute_model_shared = multiprocessing.Value("i", 0) + dp_master_ip = "127.0.0.1" + dp_master_port = 11011 + + def dp_rank_main(global_dp_rank: int, local_dp_rank: int): + os.environ["VLLM_DP_RANK"] = str(global_dp_rank) + os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank) + os.environ["VLLM_DP_SIZE"] = str(dp_size) + os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip + os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port) + + original_replay = torch.npu.NPUGraph.replay + + def replay_wrapper(self): + with replay_counter.get_lock(): + replay_counter.value += 1 + return original_replay(self) + + original_init = torch.npu.NPUGraph.__init__ + + def init_wrapper(self, *args, **kwargs): + with capture_counter.get_lock(): + capture_counter.value += 1 + return original_init(self, *args, **kwargs) + + with patch.object(torch.npu.NPUGraph, "replay", replay_wrapper), \ + patch.object(torch.npu.NPUGraph, "__init__", init_wrapper): + prompts = [ + "Hello, my name is", "The president of the United States is", + "The capital of France is", "The future of AI is" + ] + chunk_size = len(prompts) // dp_size + start = global_dp_rank * chunk_size + end = start + chunk_size if global_dp_rank < dp_size - 1 else len( + prompts) + my_prompts = prompts[start:end] + sampling_params = SamplingParams(max_tokens=max_tokens, + temperature=0.0) + + def trace_calls(frame, event, arg): + if event == 'call': + code = frame.f_code + func_name = code.co_name + file_name = code.co_filename + if func_name == 'dispatch' and 'cudagraph_dispatcher.py' in file_name: + with num_execute_model_shared.get_lock(): + num_execute_model_shared.value += 1 + return trace_calls + + sys.settrace(trace_calls) + if model == "vllm-ascend/DeepSeek-V2-Lite-W8A8": + llm = LLM( + model=model, + quantization="ascend", + tensor_parallel_size=tp_size, + trust_remote_code=True, + ) + else: + llm = LLM( + model=model, + tensor_parallel_size=tp_size, + trust_remote_code=True, + ) + num_hidden_layers_shared.value = llm.llm_engine.model_config.hf_config.num_hidden_layers + _ = llm.generate(my_prompts, sampling_params) + sys.settrace(None) + + # Give engines time to pause their processing loops before exiting. + sleep(5) + del llm + cleanup_env_and_memory() + + processes = [] + for local_dp_rank in range(dp_size): + global_dp_rank = local_dp_rank + p = multiprocessing.Process(target=dp_rank_main, + args=(global_dp_rank, local_dp_rank)) + p.start() + processes.append(p) + + for p in processes: + p.join(timeout=900) + if p.exitcode != 0: + if p.exitcode is None: + p.kill() + raise RuntimeError(f"Process {p.pid} timed out") + else: + raise RuntimeError( + f"Process failed with exit code {p.exitcode}") + + actual_capture = capture_counter.value + actual_replay = replay_counter.value + num_hidden_layers = num_hidden_layers_shared.value + num_execute_model = num_execute_model_shared.value + + num_acl_graphs = num_hidden_layers + 1 + num_comm_groups = sum(size > 1 for size in [ + dp_size, + tp_size, + ]) + max_num_batch_sizes = math.floor( + (1800 - num_comm_groups * 40) / num_acl_graphs / + (1 + num_comm_groups * 2)) + expected_total_capture = max_num_batch_sizes * num_acl_graphs * dp_size + assert actual_capture == expected_total_capture, ( + f"capture count mismatch. Expected: {expected_total_capture}, Got: {actual_capture}" + ) + + num_inference_steps = max_tokens + 1 # first token + max_tokens + expected_total_replay = num_acl_graphs * num_inference_steps * dp_size + num_execute_model * num_acl_graphs + assert actual_replay == expected_total_replay, ( + f"Replay count mismatch. Expected: {expected_total_replay}, Got: {actual_replay}" + ) + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = 'spawn' + sleep(600) + + +def cleanup_env_and_memory(): + destroy_model_parallel() + destroy_distributed_environment() + with contextlib.suppress(AssertionError): + torch.distributed.destroy_process_group() + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() From 0061f87b03cf58f3b79d0962d9a8d38bacf9ee7d Mon Sep 17 00:00:00 2001 From: Yizhou Liu Date: Tue, 18 Nov 2025 20:31:50 +0800 Subject: [PATCH 2/2] Refactor and improve ACL graph capture/replay DP test Restructures the multi-card ACL graph test for improved clarity, robustness, and accuracy. Key improvements include: - Replaces fragile `sys.settrace` and manual patching with a clean, reusable spy installer using `unittest.mock.patch`. - Introduces more precise metrics by tracking `NPUModelRunner.execute_model` and `_dummy_run` calls directly. - Rewrites assertions to be more accurate and provides clear explanations for the expected counts of graph captures, replays, model executions, and dummy runs. - Simplifies the overall test structure by separating the worker logic into a dedicated function. - Removes a long, unnecessary sleep at the end of the test. - Expands test coverage by adding a larger `max_tokens` parameter. Signed-off-by: Yizhou Liu --- .../multicard/test_aclgraph_capture_replay.py | 338 +++++++++++------- 1 file changed, 199 insertions(+), 139 deletions(-) diff --git a/tests/e2e/multicard/test_aclgraph_capture_replay.py b/tests/e2e/multicard/test_aclgraph_capture_replay.py index bcd27170311..f4dd496551b 100644 --- a/tests/e2e/multicard/test_aclgraph_capture_replay.py +++ b/tests/e2e/multicard/test_aclgraph_capture_replay.py @@ -1,4 +1,3 @@ -# # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # Copyright 2023 The vLLM team. # @@ -13,165 +12,226 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# import contextlib import gc import math import multiprocessing import os -import sys -from time import sleep +from typing import Any from unittest.mock import patch import pytest import torch -from vllm import LLM, SamplingParams -from vllm.distributed.parallel_state import ( # noqa E402 - destroy_distributed_environment, destroy_model_parallel) -MODELS = ["Qwen/Qwen3-0.6B", "vllm-ascend/DeepSeek-V2-Lite-W8A8"] +from vllm_ascend.utils import vllm_version_is + +if vllm_version_is("0.11.0"): + from vllm.utils import get_open_port +else: + from vllm.utils.network_utils import get_open_port + +MODELS = [ + "Qwen/Qwen3-0.6B", + "vllm-ascend/DeepSeek-V2-Lite-W8A8", +] + + +def _install_spies(counters: dict[str, Any]) -> contextlib.ExitStack: + """Installs thread-safe spies on NPU methods to track invocation counts.""" + from vllm_ascend.worker.model_runner_v1 import NPUModelRunner + + def make_spy(cls, method_name, counter): + original = getattr(cls, method_name) + def spy(self, *args, **kwargs): + with counter.get_lock(): + counter.value += 1 + return original(self, *args, **kwargs) + return spy + + stack = contextlib.ExitStack() + hooks = [ + (torch.npu.NPUGraph, "replay", counters["replay"]), + (torch.npu.NPUGraph, "__init__", counters["capture"]), + (NPUModelRunner, "execute_model", counters["exec_model"]), + (NPUModelRunner, "_dummy_run", counters["dummy_run"]), + ] + + for cls, method, counter in hooks: + stack.enter_context( + patch.object(cls, method, make_spy(cls, method, counter))) + + return stack + + +def _run_worker_process( + rank: int, + local_rank: int, + world_size: int, + master_ip: str, + master_port: int, + counters: dict[str, Any], + model_path: str, + max_tokens: int, +): + """Main entry point for the worker process.""" + os.environ.update({ + "VLLM_DP_RANK": str(rank), + "VLLM_DP_RANK_LOCAL": str(local_rank), + "VLLM_DP_SIZE": str(world_size), + "VLLM_DP_MASTER_IP": master_ip, + "VLLM_DP_MASTER_PORT": str(master_port), + }) + + # Import vLLM only after environment setup + from vllm import LLM, SamplingParams + from vllm.distributed.parallel_state import ( + destroy_distributed_environment, destroy_model_parallel) + + # Apply hooks and run inference + with _install_spies(counters): + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + # Simple data sharding + chunk_size = len(prompts) // world_size + start_idx = rank * chunk_size + end_idx = start_idx + chunk_size if rank < world_size - 1 else len( + prompts) + local_prompts = prompts[start_idx:end_idx] + + llm = LLM( + model=model_path, + quantization="ascend" if "W8A8" in model_path else None, + # enable_expert_parallel=True if "DeepSeek" in model_path else False, + trust_remote_code=True, + ) + + # Expose model config to the main test process + counters["hidden_layers"].value = ( + llm.llm_engine.model_config.hf_config.num_hidden_layers) + + llm.generate(local_prompts, + SamplingParams(max_tokens=max_tokens, temperature=0.0)) + + # Explicit cleanup is mandatory in multi-process vLLM tests + del llm + + destroy_model_parallel() + destroy_distributed_environment() + + with contextlib.suppress(AssertionError): + torch.distributed.destroy_process_group() + + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() + + +# @patch.dict(os.environ, clear=["HCCL_OP_EXPANSION_MODE","VLLM_WORKER_MULTIPROC_METHOD"]) @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("max_tokens", [4]) +@pytest.mark.parametrize("max_tokens", [4, 36]) @patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1"}) def test_aclgraph_capture_replay_dp2( model: str, max_tokens: int, + monkeypatch: pytest.MonkeyPatch, ) -> None: - # HCCL_OP_EXPANSION_MODE determines how max_num_batch_sizes is computed. - if 'VLLM_WORKER_MULTIPROC_METHOD' in os.environ: - del os.environ["VLLM_WORKER_MULTIPROC_METHOD"] - if 'HCCL_OP_EXPANSION_MODE' in os.environ: - del os.environ['HCCL_OP_EXPANSION_MODE'] + # Counter doesn't work in default "spawn" mode + monkeypatch.delenv("VLLM_WORKER_MULTIPROC_METHOD", raising=False) + + # Shared counters for cross-process assertion + counters = { + "replay": multiprocessing.Value("i", 0), + "capture": multiprocessing.Value("i", 0), + "exec_model": multiprocessing.Value("i", 0), + "dummy_run": multiprocessing.Value("i", 0), + "hidden_layers": multiprocessing.Value("i", -1), + } + dp_size = 2 - tp_size = 1 - replay_counter = multiprocessing.Value("i", 0) - capture_counter = multiprocessing.Value("i", 0) - num_hidden_layers_shared = multiprocessing.Value("i", -1) - num_execute_model_shared = multiprocessing.Value("i", 0) - dp_master_ip = "127.0.0.1" - dp_master_port = 11011 - - def dp_rank_main(global_dp_rank: int, local_dp_rank: int): - os.environ["VLLM_DP_RANK"] = str(global_dp_rank) - os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank) - os.environ["VLLM_DP_SIZE"] = str(dp_size) - os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip - os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port) - - original_replay = torch.npu.NPUGraph.replay - - def replay_wrapper(self): - with replay_counter.get_lock(): - replay_counter.value += 1 - return original_replay(self) - - original_init = torch.npu.NPUGraph.__init__ - - def init_wrapper(self, *args, **kwargs): - with capture_counter.get_lock(): - capture_counter.value += 1 - return original_init(self, *args, **kwargs) - - with patch.object(torch.npu.NPUGraph, "replay", replay_wrapper), \ - patch.object(torch.npu.NPUGraph, "__init__", init_wrapper): - prompts = [ - "Hello, my name is", "The president of the United States is", - "The capital of France is", "The future of AI is" - ] - chunk_size = len(prompts) // dp_size - start = global_dp_rank * chunk_size - end = start + chunk_size if global_dp_rank < dp_size - 1 else len( - prompts) - my_prompts = prompts[start:end] - sampling_params = SamplingParams(max_tokens=max_tokens, - temperature=0.0) - - def trace_calls(frame, event, arg): - if event == 'call': - code = frame.f_code - func_name = code.co_name - file_name = code.co_filename - if func_name == 'dispatch' and 'cudagraph_dispatcher.py' in file_name: - with num_execute_model_shared.get_lock(): - num_execute_model_shared.value += 1 - return trace_calls - - sys.settrace(trace_calls) - if model == "vllm-ascend/DeepSeek-V2-Lite-W8A8": - llm = LLM( - model=model, - quantization="ascend", - tensor_parallel_size=tp_size, - trust_remote_code=True, - ) - else: - llm = LLM( - model=model, - tensor_parallel_size=tp_size, - trust_remote_code=True, - ) - num_hidden_layers_shared.value = llm.llm_engine.model_config.hf_config.num_hidden_layers - _ = llm.generate(my_prompts, sampling_params) - sys.settrace(None) - - # Give engines time to pause their processing loops before exiting. - sleep(5) - del llm - cleanup_env_and_memory() - - processes = [] - for local_dp_rank in range(dp_size): - global_dp_rank = local_dp_rank - p = multiprocessing.Process(target=dp_rank_main, - args=(global_dp_rank, local_dp_rank)) + port = get_open_port() + + # Launch workers + workers = [] + for rank in range(dp_size): + p = multiprocessing.Process( + target=_run_worker_process, + args=(rank, rank, dp_size, "127.0.0.1", port, counters, model, + max_tokens), + ) p.start() - processes.append(p) + workers.append(p) - for p in processes: + # Supervision loop + for p in workers: p.join(timeout=900) if p.exitcode != 0: - if p.exitcode is None: - p.kill() - raise RuntimeError(f"Process {p.pid} timed out") - else: - raise RuntimeError( - f"Process failed with exit code {p.exitcode}") - - actual_capture = capture_counter.value - actual_replay = replay_counter.value - num_hidden_layers = num_hidden_layers_shared.value - num_execute_model = num_execute_model_shared.value - - num_acl_graphs = num_hidden_layers + 1 - num_comm_groups = sum(size > 1 for size in [ - dp_size, - tp_size, - ]) - max_num_batch_sizes = math.floor( - (1800 - num_comm_groups * 40) / num_acl_graphs / - (1 + num_comm_groups * 2)) - expected_total_capture = max_num_batch_sizes * num_acl_graphs * dp_size - assert actual_capture == expected_total_capture, ( - f"capture count mismatch. Expected: {expected_total_capture}, Got: {actual_capture}" - ) - - num_inference_steps = max_tokens + 1 # first token + max_tokens - expected_total_replay = num_acl_graphs * num_inference_steps * dp_size + num_execute_model * num_acl_graphs - assert actual_replay == expected_total_replay, ( - f"Replay count mismatch. Expected: {expected_total_replay}, Got: {actual_replay}" - ) - os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = 'spawn' - sleep(600) - - -def cleanup_env_and_memory(): - destroy_model_parallel() - destroy_distributed_environment() - with contextlib.suppress(AssertionError): - torch.distributed.destroy_process_group() - gc.collect() - torch.npu.empty_cache() - torch.npu.reset_peak_memory_stats() + for k in workers: + if k.is_alive(): + k.kill() + raise RuntimeError( + f"Worker {p.pid} failed with exit code {p.exitcode}") + + actual_capture = counters["capture"].value + actual_replay = counters["replay"].value + num_execute_model = counters["exec_model"].value + num_dummy_run = counters["dummy_run"].value + num_layers = counters["hidden_layers"].value + + num_acl_graphs = num_layers + 1 + num_comm_groups = sum(1 for s in [dp_size, 1] + if s > 1) # dp_size=2, tp_size=1 + + # Metric 1: Graph Capture (ACL Graph Construction) + # Ref: vllm_ascend.utils.update_aclgraph_sizes + max_batch_sizes = math.floor((1800 - num_comm_groups * 40) / + num_acl_graphs / (1 + num_comm_groups * 2)) + + expected_capture = max_batch_sizes * num_acl_graphs * dp_size + assert ( + actual_capture == expected_capture + ), f"Capture count mismatch. Expected: {expected_capture}, Got: {actual_capture}" + + # Metric 2: Model Execution (NPUModelRunner.execute_model) + # vLLM Step Breakdown: + # 1. First step (prefill, 1 prompt) + # 2. Generation steps (max_tokens) + # 3. Final step (likely EOS/idle step), no replay here + total_steps = max_tokens + 1 # this includes the 1 and 2 above + expected_exec_model = (total_steps + 1) * dp_size + + assert ( + num_execute_model == expected_exec_model + ), f"Model execution count mismatch. Expected: {expected_exec_model}, Got: {num_execute_model}" + + # Metric 3: Dummy Runs (Warmup & Alignment) + # vLLM synchronizes globally every 32 steps. + # Ref: vllm.v1.engine.core.DPEngineCoreProc._has_global_unfinished_reqs + aligned_steps = (total_steps + 31) // 32 * 32 + + # Part A: Warmup runs (Profile run + 2 runs per captured graph) + warmup_runs = 1 + (2 * max_batch_sizes) + + # Part B: Alignment padding (Empty runs to hit the 32-step boundary) + padding_runs = aligned_steps - total_steps + + expected_dummy_run = (warmup_runs + padding_runs) * dp_size + + assert ( + num_dummy_run == expected_dummy_run + ), f"Dummy run count mismatch. Expected: {expected_dummy_run}, Got: {num_dummy_run}" + + # Metric 4: Graph Replay (Inference Execution) + # Replays happen for every aligned step across all graphs. + expected_replay = num_acl_graphs * aligned_steps * dp_size + + assert ( + actual_replay == expected_replay + ), f"Replay count mismatch. Expected: {expected_replay}, Got: {actual_replay}"