|
1 | | -# |
2 | | -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. |
3 | | -# Copyright 2023 The vLLM team. |
4 | | -# |
5 | | -# Licensed under the Apache License, Version 2.0 (the "License"); |
6 | | -# you may not use this file except in compliance with the License. |
7 | | -# You may obtain a copy of the License at |
8 | | -# |
9 | | -# http://www.apache.org/licenses/LICENSE-2.0 |
10 | | -# |
11 | | -# Unless required by applicable law or agreed to in writing, software |
12 | | -# distributed under the License is distributed on an "AS IS" BASIS, |
13 | | -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | | -# See the License for the specific language governing permissions and |
15 | | -# limitations under the License. |
16 | | -# |
17 | | - |
18 | 1 | import contextlib |
19 | 2 | import gc |
20 | 3 | import math |
21 | 4 | import multiprocessing |
22 | 5 | import os |
23 | | -import sys |
24 | 6 | from time import sleep |
| 7 | +from typing import Any |
25 | 8 | from unittest.mock import patch |
26 | 9 |
|
27 | 10 | import pytest |
28 | 11 | import torch |
29 | | -from vllm import LLM, SamplingParams |
30 | | -from vllm.distributed.parallel_state import ( # noqa E402 |
31 | | - destroy_distributed_environment, destroy_model_parallel) |
32 | 12 |
|
33 | | -MODELS = ["Qwen/Qwen3-0.6B", "vllm-ascend/DeepSeek-V2-Lite-W8A8"] |
| 13 | +from vllm_ascend.utils import vllm_version_is |
| 14 | + |
| 15 | +if vllm_version_is("0.11.0"): |
| 16 | + from vllm.utils import get_open_port |
| 17 | +else: |
| 18 | + from vllm.utils.network_utils import get_open_port |
| 19 | + |
| 20 | +MODELS = [ |
| 21 | + "Qwen/Qwen3-0.6B", |
| 22 | + "vllm-ascend/DeepSeek-V2-Lite-W8A8", |
| 23 | +] |
| 24 | + |
| 25 | + |
| 26 | +def _install_spies(counters: dict[str, Any]) -> contextlib.ExitStack: |
| 27 | + """Installs thread-safe spies on NPU methods to track invocation counts.""" |
| 28 | + from vllm_ascend.worker.model_runner_v1 import NPUModelRunner |
| 29 | + |
| 30 | + def make_spy(cls, method_name, counter): |
| 31 | + original = getattr(cls, method_name) |
| 32 | + |
| 33 | + def spy(self, *args, **kwargs): |
| 34 | + with counter.get_lock(): |
| 35 | + counter.value += 1 |
| 36 | + return original(self, *args, **kwargs) |
| 37 | + |
| 38 | + return spy |
| 39 | + |
| 40 | + stack = contextlib.ExitStack() |
| 41 | + hooks = [ |
| 42 | + (torch.npu.NPUGraph, "replay", counters["replay"]), |
| 43 | + (torch.npu.NPUGraph, "__init__", counters["capture"]), |
| 44 | + (NPUModelRunner, "execute_model", counters["exec_model"]), |
| 45 | + (NPUModelRunner, "_dummy_run", counters["dummy_run"]), |
| 46 | + ] |
| 47 | + |
| 48 | + for cls, method, counter in hooks: |
| 49 | + stack.enter_context( |
| 50 | + patch.object(cls, method, make_spy(cls, method, counter))) |
| 51 | + |
| 52 | + return stack |
| 53 | + |
| 54 | + |
| 55 | +def _run_worker_process( |
| 56 | + rank: int, |
| 57 | + local_rank: int, |
| 58 | + world_size: int, |
| 59 | + master_ip: str, |
| 60 | + master_port: int, |
| 61 | + counters: dict[str, Any], |
| 62 | + model_path: str, |
| 63 | + max_tokens: int, |
| 64 | +): |
| 65 | + """Main entry point for the worker process.""" |
| 66 | + os.environ.update({ |
| 67 | + "VLLM_DP_RANK": str(rank), |
| 68 | + "VLLM_DP_RANK_LOCAL": str(local_rank), |
| 69 | + "VLLM_DP_SIZE": str(world_size), |
| 70 | + "VLLM_DP_MASTER_IP": master_ip, |
| 71 | + "VLLM_DP_MASTER_PORT": str(master_port), |
| 72 | + }) |
| 73 | + |
| 74 | + # Import vLLM only after environment setup |
| 75 | + from vllm import LLM, SamplingParams |
| 76 | + from vllm.distributed.parallel_state import ( |
| 77 | + destroy_distributed_environment, destroy_model_parallel) |
| 78 | + |
| 79 | + # Apply hooks and run inference |
| 80 | + with _install_spies(counters): |
| 81 | + prompts = [ |
| 82 | + "Hello, my name is", |
| 83 | + "The president of the United States is", |
| 84 | + "The capital of France is", |
| 85 | + "The future of AI is", |
| 86 | + ] |
| 87 | + |
| 88 | + # Simple data sharding |
| 89 | + chunk_size = len(prompts) // world_size |
| 90 | + start_idx = rank * chunk_size |
| 91 | + end_idx = start_idx + chunk_size if rank < world_size - 1 else len( |
| 92 | + prompts) |
| 93 | + local_prompts = prompts[start_idx:end_idx] |
| 94 | + |
| 95 | + llm = LLM( |
| 96 | + model=model_path, |
| 97 | + quantization="ascend" if "W8A8" in model_path else None, |
| 98 | + enable_expert_parallel=True if "DeepSeek" in model_path else False, |
| 99 | + trust_remote_code=True, |
| 100 | + ) |
| 101 | + |
| 102 | + # Expose model config to the main test process |
| 103 | + counters["hidden_layers"].value = ( |
| 104 | + llm.llm_engine.model_config.hf_config.num_hidden_layers) |
| 105 | + |
| 106 | + llm.generate(local_prompts, |
| 107 | + SamplingParams(max_tokens=max_tokens, temperature=0.0)) |
| 108 | + |
| 109 | + # Explicit cleanup is mandatory in multi-process vLLM tests |
| 110 | + sleep(5) |
| 111 | + del llm |
| 112 | + |
| 113 | + destroy_model_parallel() |
| 114 | + destroy_distributed_environment() |
| 115 | + |
| 116 | + with contextlib.suppress(AssertionError): |
| 117 | + torch.distributed.destroy_process_group() |
| 118 | + |
| 119 | + gc.collect() |
| 120 | + torch.npu.empty_cache() |
| 121 | + torch.npu.reset_peak_memory_stats() |
34 | 122 |
|
35 | 123 |
|
36 | 124 | @pytest.mark.parametrize("model", MODELS) |
37 | | -@pytest.mark.parametrize("max_tokens", [4]) |
| 125 | +@pytest.mark.parametrize("max_tokens", [4, 36]) |
38 | 126 | @patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1"}) |
39 | | -def test_aclgraph_capture_replay_dp2( |
40 | | - model: str, |
41 | | - max_tokens: int, |
42 | | -) -> None: |
43 | | - # HCCL_OP_EXPANSION_MODE determines how max_num_batch_sizes is computed. |
44 | | - if 'VLLM_WORKER_MULTIPROC_METHOD' in os.environ: |
45 | | - del os.environ["VLLM_WORKER_MULTIPROC_METHOD"] |
46 | | - if 'HCCL_OP_EXPANSION_MODE' in os.environ: |
47 | | - del os.environ['HCCL_OP_EXPANSION_MODE'] |
| 127 | +@patch.dict(os.environ, {"VLLM_WORKER_MULTIPROC_METHOD": "fork"}) |
| 128 | +def test_aclgraph_capture_replay_dp(model: str, max_tokens: int) -> None: |
| 129 | + # Shared counters for cross-process assertion |
| 130 | + counters = { |
| 131 | + "replay": multiprocessing.Value("i", 0), |
| 132 | + "capture": multiprocessing.Value("i", 0), |
| 133 | + "exec_model": multiprocessing.Value("i", 0), |
| 134 | + "dummy_run": multiprocessing.Value("i", 0), |
| 135 | + "hidden_layers": multiprocessing.Value("i", -1), |
| 136 | + } |
| 137 | + |
48 | 138 | dp_size = 2 |
49 | | - tp_size = 1 |
50 | | - replay_counter = multiprocessing.Value("i", 0) |
51 | | - capture_counter = multiprocessing.Value("i", 0) |
52 | | - num_hidden_layers_shared = multiprocessing.Value("i", -1) |
53 | | - num_execute_model_shared = multiprocessing.Value("i", 0) |
54 | | - dp_master_ip = "127.0.0.1" |
55 | | - dp_master_port = 11011 |
56 | | - |
57 | | - def dp_rank_main(global_dp_rank: int, local_dp_rank: int): |
58 | | - os.environ["VLLM_DP_RANK"] = str(global_dp_rank) |
59 | | - os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank) |
60 | | - os.environ["VLLM_DP_SIZE"] = str(dp_size) |
61 | | - os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip |
62 | | - os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port) |
63 | | - |
64 | | - original_replay = torch.npu.NPUGraph.replay |
65 | | - |
66 | | - def replay_wrapper(self): |
67 | | - with replay_counter.get_lock(): |
68 | | - replay_counter.value += 1 |
69 | | - return original_replay(self) |
70 | | - |
71 | | - original_init = torch.npu.NPUGraph.__init__ |
72 | | - |
73 | | - def init_wrapper(self, *args, **kwargs): |
74 | | - with capture_counter.get_lock(): |
75 | | - capture_counter.value += 1 |
76 | | - return original_init(self, *args, **kwargs) |
77 | | - |
78 | | - with patch.object(torch.npu.NPUGraph, "replay", replay_wrapper), \ |
79 | | - patch.object(torch.npu.NPUGraph, "__init__", init_wrapper): |
80 | | - prompts = [ |
81 | | - "Hello, my name is", "The president of the United States is", |
82 | | - "The capital of France is", "The future of AI is" |
83 | | - ] |
84 | | - chunk_size = len(prompts) // dp_size |
85 | | - start = global_dp_rank * chunk_size |
86 | | - end = start + chunk_size if global_dp_rank < dp_size - 1 else len( |
87 | | - prompts) |
88 | | - my_prompts = prompts[start:end] |
89 | | - sampling_params = SamplingParams(max_tokens=max_tokens, |
90 | | - temperature=0.0) |
91 | | - |
92 | | - def trace_calls(frame, event, arg): |
93 | | - if event == 'call': |
94 | | - code = frame.f_code |
95 | | - func_name = code.co_name |
96 | | - file_name = code.co_filename |
97 | | - if func_name == 'dispatch' and 'cudagraph_dispatcher.py' in file_name: |
98 | | - with num_execute_model_shared.get_lock(): |
99 | | - num_execute_model_shared.value += 1 |
100 | | - return trace_calls |
101 | | - |
102 | | - sys.settrace(trace_calls) |
103 | | - if model == "vllm-ascend/DeepSeek-V2-Lite-W8A8": |
104 | | - llm = LLM( |
105 | | - model=model, |
106 | | - quantization="ascend", |
107 | | - tensor_parallel_size=tp_size, |
108 | | - trust_remote_code=True, |
109 | | - ) |
110 | | - else: |
111 | | - llm = LLM( |
112 | | - model=model, |
113 | | - tensor_parallel_size=tp_size, |
114 | | - trust_remote_code=True, |
115 | | - ) |
116 | | - num_hidden_layers_shared.value = llm.llm_engine.model_config.hf_config.num_hidden_layers |
117 | | - _ = llm.generate(my_prompts, sampling_params) |
118 | | - sys.settrace(None) |
119 | | - |
120 | | - # Give engines time to pause their processing loops before exiting. |
121 | | - sleep(5) |
122 | | - del llm |
123 | | - cleanup_env_and_memory() |
124 | | - |
125 | | - processes = [] |
126 | | - for local_dp_rank in range(dp_size): |
127 | | - global_dp_rank = local_dp_rank |
128 | | - p = multiprocessing.Process(target=dp_rank_main, |
129 | | - args=(global_dp_rank, local_dp_rank)) |
| 139 | + workers = [] |
| 140 | + port = get_open_port() |
| 141 | + |
| 142 | + # Launch workers |
| 143 | + for rank in range(dp_size): |
| 144 | + p = multiprocessing.Process( |
| 145 | + target=_run_worker_process, |
| 146 | + args=(rank, rank, dp_size, "127.0.0.1", port, counters, model, |
| 147 | + max_tokens), |
| 148 | + ) |
130 | 149 | p.start() |
131 | | - processes.append(p) |
| 150 | + workers.append(p) |
132 | 151 |
|
133 | | - for p in processes: |
134 | | - p.join(timeout=900) |
| 152 | + # Supervision loop |
| 153 | + for p in workers: |
| 154 | + p.join(timeout=60) |
135 | 155 | if p.exitcode != 0: |
136 | | - if p.exitcode is None: |
137 | | - p.kill() |
138 | | - raise RuntimeError(f"Process {p.pid} timed out") |
139 | | - else: |
140 | | - raise RuntimeError( |
141 | | - f"Process failed with exit code {p.exitcode}") |
142 | | - |
143 | | - actual_capture = capture_counter.value |
144 | | - actual_replay = replay_counter.value |
145 | | - num_hidden_layers = num_hidden_layers_shared.value |
146 | | - num_execute_model = num_execute_model_shared.value |
147 | | - |
148 | | - num_acl_graphs = num_hidden_layers + 1 |
149 | | - num_comm_groups = sum(size > 1 for size in [ |
150 | | - dp_size, |
151 | | - tp_size, |
152 | | - ]) |
153 | | - max_num_batch_sizes = math.floor( |
154 | | - (1800 - num_comm_groups * 40) / num_acl_graphs / |
155 | | - (1 + num_comm_groups * 2)) |
156 | | - expected_total_capture = max_num_batch_sizes * num_acl_graphs * dp_size |
157 | | - assert actual_capture == expected_total_capture, ( |
158 | | - f"capture count mismatch. Expected: {expected_total_capture}, Got: {actual_capture}" |
159 | | - ) |
160 | | - |
161 | | - num_inference_steps = max_tokens + 1 # first token + max_tokens |
162 | | - expected_total_replay = num_acl_graphs * num_inference_steps * dp_size + num_execute_model * num_acl_graphs |
163 | | - assert actual_replay == expected_total_replay, ( |
164 | | - f"Replay count mismatch. Expected: {expected_total_replay}, Got: {actual_replay}" |
165 | | - ) |
166 | | - os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = 'spawn' |
167 | | - sleep(600) |
168 | | - |
169 | | - |
170 | | -def cleanup_env_and_memory(): |
171 | | - destroy_model_parallel() |
172 | | - destroy_distributed_environment() |
173 | | - with contextlib.suppress(AssertionError): |
174 | | - torch.distributed.destroy_process_group() |
175 | | - gc.collect() |
176 | | - torch.npu.empty_cache() |
177 | | - torch.npu.reset_peak_memory_stats() |
| 156 | + for k in workers: |
| 157 | + if k.is_alive(): |
| 158 | + k.kill() |
| 159 | + raise RuntimeError( |
| 160 | + f"Worker {p.pid} failed with exit code {p.exitcode}") |
| 161 | + |
| 162 | + actual_capture = counters["capture"].value |
| 163 | + actual_replay = counters["replay"].value |
| 164 | + num_execute_model = counters["exec_model"].value |
| 165 | + num_dummy_run = counters["dummy_run"].value |
| 166 | + num_layers = counters["hidden_layers"].value |
| 167 | + |
| 168 | + num_acl_graphs = num_layers + 1 |
| 169 | + num_comm_groups = sum(1 for s in [dp_size, 1] |
| 170 | + if s > 1) # dp_size=2, tp_size=1 |
| 171 | + |
| 172 | + # Metric 1: Graph Capture (ACL Graph Construction) |
| 173 | + # Ref: vllm_ascend.utils.update_aclgraph_sizes |
| 174 | + max_batch_sizes = math.floor((1800 - num_comm_groups * 40) / |
| 175 | + num_acl_graphs / (1 + num_comm_groups * 2)) |
| 176 | + |
| 177 | + expected_capture = max_batch_sizes * num_acl_graphs * dp_size |
| 178 | + assert ( |
| 179 | + actual_capture == expected_capture |
| 180 | + ), f"Capture count mismatch. Expected: {expected_capture}, Got: {actual_capture}" |
| 181 | + |
| 182 | + # Metric 2: Model Execution (NPUModelRunner.execute_model) |
| 183 | + # vLLM Step Breakdown: |
| 184 | + # 1. First step (prefill, 1 prompt) |
| 185 | + # 2. Generation steps (max_tokens) |
| 186 | + # 3. Final step (likely EOS/idle step), no replay here |
| 187 | + total_steps = max_tokens + 1 |
| 188 | + expected_exec_model = (1 + max_tokens + 1) * dp_size |
| 189 | + |
| 190 | + assert ( |
| 191 | + num_execute_model == expected_exec_model |
| 192 | + ), f"Model execution count mismatch. Expected: {expected_exec_model}, Got: {num_execute_model}" |
| 193 | + |
| 194 | + # Metric 3: Dummy Runs (Warmup & Alignment) |
| 195 | + # vLLM synchronizes globally every 32 steps. |
| 196 | + # Ref: vllm.v1.engine.core.DPEngineCoreProc._has_global_unfinished_reqs |
| 197 | + aligned_steps = (total_steps + 31) // 32 * 32 |
| 198 | + |
| 199 | + # Part A: Warmup runs (Profile run + 2 runs per captured graph) |
| 200 | + warmup_runs = 1 + (2 * max_batch_sizes) |
| 201 | + |
| 202 | + # Part B: Alignment padding (Empty runs to hit the 32-step boundary) |
| 203 | + padding_runs = aligned_steps - total_steps |
| 204 | + |
| 205 | + expected_dummy_run = (warmup_runs + padding_runs) * dp_size |
| 206 | + |
| 207 | + assert ( |
| 208 | + num_dummy_run == expected_dummy_run |
| 209 | + ), f"Dummy run count mismatch. Expected: {expected_dummy_run}, Got: {num_dummy_run}" |
| 210 | + |
| 211 | + # Metric 4: Graph Replay (Inference Execution) |
| 212 | + # Replays happen for every aligned step across all graphs. |
| 213 | + expected_replay = num_acl_graphs * aligned_steps * dp_size |
| 214 | + |
| 215 | + assert ( |
| 216 | + actual_replay == expected_replay |
| 217 | + ), f"Replay count mismatch. Expected: {expected_replay}, Got: {actual_replay}" |
0 commit comments