Skip to content

Commit 19408a4

Browse files
committed
Refactor and improve ACL graph capture/replay DP test
Restructures the multi-card ACL graph test for improved clarity, robustness, and accuracy. Key improvements include: - Replaces fragile `sys.settrace` and manual patching with a clean, reusable spy installer using `unittest.mock.patch`. - Introduces more precise metrics by tracking `NPUModelRunner.execute_model` and `_dummy_run` calls directly. - Rewrites assertions to be more accurate and provides clear explanations for the expected counts of graph captures, replays, model executions, and dummy runs. - Simplifies the overall test structure by separating the worker logic into a dedicated function. - Removes a long, unnecessary sleep at the end of the test. - Expands test coverage by adding a larger `max_tokens` parameter. Signed-off-by: Yizhou Liu <[email protected]>
1 parent ae09ac2 commit 19408a4

File tree

1 file changed

+198
-158
lines changed

1 file changed

+198
-158
lines changed
Lines changed: 198 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -1,177 +1,217 @@
1-
#
2-
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3-
# Copyright 2023 The vLLM team.
4-
#
5-
# Licensed under the Apache License, Version 2.0 (the "License");
6-
# you may not use this file except in compliance with the License.
7-
# You may obtain a copy of the License at
8-
#
9-
# http://www.apache.org/licenses/LICENSE-2.0
10-
#
11-
# Unless required by applicable law or agreed to in writing, software
12-
# distributed under the License is distributed on an "AS IS" BASIS,
13-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14-
# See the License for the specific language governing permissions and
15-
# limitations under the License.
16-
#
17-
181
import contextlib
192
import gc
203
import math
214
import multiprocessing
225
import os
23-
import sys
246
from time import sleep
7+
from typing import Any
258
from unittest.mock import patch
269

2710
import pytest
2811
import torch
29-
from vllm import LLM, SamplingParams
30-
from vllm.distributed.parallel_state import ( # noqa E402
31-
destroy_distributed_environment, destroy_model_parallel)
3212

33-
MODELS = ["Qwen/Qwen3-0.6B", "vllm-ascend/DeepSeek-V2-Lite-W8A8"]
13+
from vllm_ascend.utils import vllm_version_is
14+
15+
if vllm_version_is("0.11.0"):
16+
from vllm.utils import get_open_port
17+
else:
18+
from vllm.utils.network_utils import get_open_port
19+
20+
MODELS = [
21+
"Qwen/Qwen3-0.6B",
22+
"vllm-ascend/DeepSeek-V2-Lite-W8A8",
23+
]
24+
25+
26+
def _install_spies(counters: dict[str, Any]) -> contextlib.ExitStack:
27+
"""Installs thread-safe spies on NPU methods to track invocation counts."""
28+
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
29+
30+
def make_spy(cls, method_name, counter):
31+
original = getattr(cls, method_name)
32+
33+
def spy(self, *args, **kwargs):
34+
with counter.get_lock():
35+
counter.value += 1
36+
return original(self, *args, **kwargs)
37+
38+
return spy
39+
40+
stack = contextlib.ExitStack()
41+
hooks = [
42+
(torch.npu.NPUGraph, "replay", counters["replay"]),
43+
(torch.npu.NPUGraph, "__init__", counters["capture"]),
44+
(NPUModelRunner, "execute_model", counters["exec_model"]),
45+
(NPUModelRunner, "_dummy_run", counters["dummy_run"]),
46+
]
47+
48+
for cls, method, counter in hooks:
49+
stack.enter_context(
50+
patch.object(cls, method, make_spy(cls, method, counter)))
51+
52+
return stack
53+
54+
55+
def _run_worker_process(
56+
rank: int,
57+
local_rank: int,
58+
world_size: int,
59+
master_ip: str,
60+
master_port: int,
61+
counters: dict[str, Any],
62+
model_path: str,
63+
max_tokens: int,
64+
):
65+
"""Main entry point for the worker process."""
66+
os.environ.update({
67+
"VLLM_DP_RANK": str(rank),
68+
"VLLM_DP_RANK_LOCAL": str(local_rank),
69+
"VLLM_DP_SIZE": str(world_size),
70+
"VLLM_DP_MASTER_IP": master_ip,
71+
"VLLM_DP_MASTER_PORT": str(master_port),
72+
})
73+
74+
# Import vLLM only after environment setup
75+
from vllm import LLM, SamplingParams
76+
from vllm.distributed.parallel_state import (
77+
destroy_distributed_environment, destroy_model_parallel)
78+
79+
# Apply hooks and run inference
80+
with _install_spies(counters):
81+
prompts = [
82+
"Hello, my name is",
83+
"The president of the United States is",
84+
"The capital of France is",
85+
"The future of AI is",
86+
]
87+
88+
# Simple data sharding
89+
chunk_size = len(prompts) // world_size
90+
start_idx = rank * chunk_size
91+
end_idx = start_idx + chunk_size if rank < world_size - 1 else len(
92+
prompts)
93+
local_prompts = prompts[start_idx:end_idx]
94+
95+
llm = LLM(
96+
model=model_path,
97+
quantization="ascend" if "W8A8" in model_path else None,
98+
enable_expert_parallel=True if "DeepSeek" in model_path else False,
99+
trust_remote_code=True,
100+
)
101+
102+
# Expose model config to the main test process
103+
counters["hidden_layers"].value = (
104+
llm.llm_engine.model_config.hf_config.num_hidden_layers)
105+
106+
llm.generate(local_prompts,
107+
SamplingParams(max_tokens=max_tokens, temperature=0.0))
108+
109+
# Explicit cleanup is mandatory in multi-process vLLM tests
110+
sleep(5)
111+
del llm
112+
113+
destroy_model_parallel()
114+
destroy_distributed_environment()
115+
116+
with contextlib.suppress(AssertionError):
117+
torch.distributed.destroy_process_group()
118+
119+
gc.collect()
120+
torch.npu.empty_cache()
121+
torch.npu.reset_peak_memory_stats()
34122

35123

36124
@pytest.mark.parametrize("model", MODELS)
37-
@pytest.mark.parametrize("max_tokens", [4])
125+
@pytest.mark.parametrize("max_tokens", [4, 36])
38126
@patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1"})
39-
def test_aclgraph_capture_replay_dp2(
40-
model: str,
41-
max_tokens: int,
42-
) -> None:
43-
# HCCL_OP_EXPANSION_MODE determines how max_num_batch_sizes is computed.
44-
if 'VLLM_WORKER_MULTIPROC_METHOD' in os.environ:
45-
del os.environ["VLLM_WORKER_MULTIPROC_METHOD"]
46-
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
47-
del os.environ['HCCL_OP_EXPANSION_MODE']
127+
@patch.dict(os.environ, {"VLLM_WORKER_MULTIPROC_METHOD": "fork"})
128+
def test_aclgraph_capture_replay_dp(model: str, max_tokens: int) -> None:
129+
# Shared counters for cross-process assertion
130+
counters = {
131+
"replay": multiprocessing.Value("i", 0),
132+
"capture": multiprocessing.Value("i", 0),
133+
"exec_model": multiprocessing.Value("i", 0),
134+
"dummy_run": multiprocessing.Value("i", 0),
135+
"hidden_layers": multiprocessing.Value("i", -1),
136+
}
137+
48138
dp_size = 2
49-
tp_size = 1
50-
replay_counter = multiprocessing.Value("i", 0)
51-
capture_counter = multiprocessing.Value("i", 0)
52-
num_hidden_layers_shared = multiprocessing.Value("i", -1)
53-
num_execute_model_shared = multiprocessing.Value("i", 0)
54-
dp_master_ip = "127.0.0.1"
55-
dp_master_port = 11011
56-
57-
def dp_rank_main(global_dp_rank: int, local_dp_rank: int):
58-
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
59-
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
60-
os.environ["VLLM_DP_SIZE"] = str(dp_size)
61-
os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip
62-
os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
63-
64-
original_replay = torch.npu.NPUGraph.replay
65-
66-
def replay_wrapper(self):
67-
with replay_counter.get_lock():
68-
replay_counter.value += 1
69-
return original_replay(self)
70-
71-
original_init = torch.npu.NPUGraph.__init__
72-
73-
def init_wrapper(self, *args, **kwargs):
74-
with capture_counter.get_lock():
75-
capture_counter.value += 1
76-
return original_init(self, *args, **kwargs)
77-
78-
with patch.object(torch.npu.NPUGraph, "replay", replay_wrapper), \
79-
patch.object(torch.npu.NPUGraph, "__init__", init_wrapper):
80-
prompts = [
81-
"Hello, my name is", "The president of the United States is",
82-
"The capital of France is", "The future of AI is"
83-
]
84-
chunk_size = len(prompts) // dp_size
85-
start = global_dp_rank * chunk_size
86-
end = start + chunk_size if global_dp_rank < dp_size - 1 else len(
87-
prompts)
88-
my_prompts = prompts[start:end]
89-
sampling_params = SamplingParams(max_tokens=max_tokens,
90-
temperature=0.0)
91-
92-
def trace_calls(frame, event, arg):
93-
if event == 'call':
94-
code = frame.f_code
95-
func_name = code.co_name
96-
file_name = code.co_filename
97-
if func_name == 'dispatch' and 'cudagraph_dispatcher.py' in file_name:
98-
with num_execute_model_shared.get_lock():
99-
num_execute_model_shared.value += 1
100-
return trace_calls
101-
102-
sys.settrace(trace_calls)
103-
if model == "vllm-ascend/DeepSeek-V2-Lite-W8A8":
104-
llm = LLM(
105-
model=model,
106-
quantization="ascend",
107-
tensor_parallel_size=tp_size,
108-
trust_remote_code=True,
109-
)
110-
else:
111-
llm = LLM(
112-
model=model,
113-
tensor_parallel_size=tp_size,
114-
trust_remote_code=True,
115-
)
116-
num_hidden_layers_shared.value = llm.llm_engine.model_config.hf_config.num_hidden_layers
117-
_ = llm.generate(my_prompts, sampling_params)
118-
sys.settrace(None)
119-
120-
# Give engines time to pause their processing loops before exiting.
121-
sleep(5)
122-
del llm
123-
cleanup_env_and_memory()
124-
125-
processes = []
126-
for local_dp_rank in range(dp_size):
127-
global_dp_rank = local_dp_rank
128-
p = multiprocessing.Process(target=dp_rank_main,
129-
args=(global_dp_rank, local_dp_rank))
139+
workers = []
140+
port = get_open_port()
141+
142+
# Launch workers
143+
for rank in range(dp_size):
144+
p = multiprocessing.Process(
145+
target=_run_worker_process,
146+
args=(rank, rank, dp_size, "127.0.0.1", port, counters, model,
147+
max_tokens),
148+
)
130149
p.start()
131-
processes.append(p)
150+
workers.append(p)
132151

133-
for p in processes:
134-
p.join(timeout=900)
152+
# Supervision loop
153+
for p in workers:
154+
p.join(timeout=60)
135155
if p.exitcode != 0:
136-
if p.exitcode is None:
137-
p.kill()
138-
raise RuntimeError(f"Process {p.pid} timed out")
139-
else:
140-
raise RuntimeError(
141-
f"Process failed with exit code {p.exitcode}")
142-
143-
actual_capture = capture_counter.value
144-
actual_replay = replay_counter.value
145-
num_hidden_layers = num_hidden_layers_shared.value
146-
num_execute_model = num_execute_model_shared.value
147-
148-
num_acl_graphs = num_hidden_layers + 1
149-
num_comm_groups = sum(size > 1 for size in [
150-
dp_size,
151-
tp_size,
152-
])
153-
max_num_batch_sizes = math.floor(
154-
(1800 - num_comm_groups * 40) / num_acl_graphs /
155-
(1 + num_comm_groups * 2))
156-
expected_total_capture = max_num_batch_sizes * num_acl_graphs * dp_size
157-
assert actual_capture == expected_total_capture, (
158-
f"capture count mismatch. Expected: {expected_total_capture}, Got: {actual_capture}"
159-
)
160-
161-
num_inference_steps = max_tokens + 1 # first token + max_tokens
162-
expected_total_replay = num_acl_graphs * num_inference_steps * dp_size + num_execute_model * num_acl_graphs
163-
assert actual_replay == expected_total_replay, (
164-
f"Replay count mismatch. Expected: {expected_total_replay}, Got: {actual_replay}"
165-
)
166-
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = 'spawn'
167-
sleep(600)
168-
169-
170-
def cleanup_env_and_memory():
171-
destroy_model_parallel()
172-
destroy_distributed_environment()
173-
with contextlib.suppress(AssertionError):
174-
torch.distributed.destroy_process_group()
175-
gc.collect()
176-
torch.npu.empty_cache()
177-
torch.npu.reset_peak_memory_stats()
156+
for k in workers:
157+
if k.is_alive():
158+
k.kill()
159+
raise RuntimeError(
160+
f"Worker {p.pid} failed with exit code {p.exitcode}")
161+
162+
actual_capture = counters["capture"].value
163+
actual_replay = counters["replay"].value
164+
num_execute_model = counters["exec_model"].value
165+
num_dummy_run = counters["dummy_run"].value
166+
num_layers = counters["hidden_layers"].value
167+
168+
num_acl_graphs = num_layers + 1
169+
num_comm_groups = sum(1 for s in [dp_size, 1]
170+
if s > 1) # dp_size=2, tp_size=1
171+
172+
# Metric 1: Graph Capture (ACL Graph Construction)
173+
# Ref: vllm_ascend.utils.update_aclgraph_sizes
174+
max_batch_sizes = math.floor((1800 - num_comm_groups * 40) /
175+
num_acl_graphs / (1 + num_comm_groups * 2))
176+
177+
expected_capture = max_batch_sizes * num_acl_graphs * dp_size
178+
assert (
179+
actual_capture == expected_capture
180+
), f"Capture count mismatch. Expected: {expected_capture}, Got: {actual_capture}"
181+
182+
# Metric 2: Model Execution (NPUModelRunner.execute_model)
183+
# vLLM Step Breakdown:
184+
# 1. First step (prefill, 1 prompt)
185+
# 2. Generation steps (max_tokens)
186+
# 3. Final step (likely EOS/idle step), no replay here
187+
total_steps = max_tokens + 1
188+
expected_exec_model = (1 + max_tokens + 1) * dp_size
189+
190+
assert (
191+
num_execute_model == expected_exec_model
192+
), f"Model execution count mismatch. Expected: {expected_exec_model}, Got: {num_execute_model}"
193+
194+
# Metric 3: Dummy Runs (Warmup & Alignment)
195+
# vLLM synchronizes globally every 32 steps.
196+
# Ref: vllm.v1.engine.core.DPEngineCoreProc._has_global_unfinished_reqs
197+
aligned_steps = (total_steps + 31) // 32 * 32
198+
199+
# Part A: Warmup runs (Profile run + 2 runs per captured graph)
200+
warmup_runs = 1 + (2 * max_batch_sizes)
201+
202+
# Part B: Alignment padding (Empty runs to hit the 32-step boundary)
203+
padding_runs = aligned_steps - total_steps
204+
205+
expected_dummy_run = (warmup_runs + padding_runs) * dp_size
206+
207+
assert (
208+
num_dummy_run == expected_dummy_run
209+
), f"Dummy run count mismatch. Expected: {expected_dummy_run}, Got: {num_dummy_run}"
210+
211+
# Metric 4: Graph Replay (Inference Execution)
212+
# Replays happen for every aligned step across all graphs.
213+
expected_replay = num_acl_graphs * aligned_steps * dp_size
214+
215+
assert (
216+
actual_replay == expected_replay
217+
), f"Replay count mismatch. Expected: {expected_replay}, Got: {actual_replay}"

0 commit comments

Comments
 (0)