Skip to content

Commit e133d6d

Browse files
authored
[BugFix] fix graph partition signature (#27139)
Signed-off-by: Boyuan Feng <[email protected]>
1 parent a1946c9 commit e133d6d

File tree

1 file changed

+151
-0
lines changed

1 file changed

+151
-0
lines changed

vllm/env_override.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,156 @@ def get_output_names(graph_outputs) -> list[str]:
9090
assert len(planning_states) == 0
9191

9292

93+
# ===================================================
94+
# torch 2.9 Inductor get_graph_partition_signature monkeypatch
95+
# ===================================================
96+
# This change monkeypatches get_graph_partition_signature in pytorch 2.9.0 to
97+
# fix inductor partition + attention-nvfp4 quant fusion, tested in
98+
# `tests/compile/test_fusions_e2e.py::test_attn_quant`.
99+
# For more context, see https://github.com/pytorch/pytorch/pull/165815.
100+
101+
102+
def get_graph_partition_signature_patched(
103+
self, partitions, skip_cudagraphs: list[bool]
104+
):
105+
"""
106+
Gets signature for each graph partition, including input nodes, output nodes, and
107+
whether deallocating an input within graph partition.
108+
"""
109+
from torch._inductor import dependencies
110+
from torch._inductor.ir import GraphPartitionSignature, MutationOutput, NoneLayout
111+
from torch._inductor.virtualized import V
112+
from torch.utils._ordered_set import OrderedSet
113+
114+
signatures = []
115+
116+
unmet_output_names = OrderedSet(V.graph.get_output_names())
117+
name_to_node = self.get_name_to_nodes()
118+
119+
def is_none_layout(buf_name: str) -> bool:
120+
"""
121+
Checks if buf_name is NoneLayout. Buffers with NoneLayout is not allocated
122+
so graph partition should not take it as inputs or outputs.
123+
"""
124+
buf = self.name_to_buf.get(buf_name, None)
125+
126+
if buf is None:
127+
return False
128+
129+
if isinstance(buf.node.layout, NoneLayout):
130+
if isinstance(buf.node, MutationOutput) and (
131+
real_name := self.mutation_real_name.get(buf_name, None)
132+
):
133+
return is_none_layout(real_name)
134+
135+
return True
136+
137+
return False
138+
139+
for partition, skip_cudagraph in zip(
140+
reversed(partitions), reversed(skip_cudagraphs)
141+
):
142+
output_names: OrderedSet[str] = OrderedSet()
143+
144+
for node in partition:
145+
output_names.update(node.outputs_by_name.keys())
146+
147+
returned_output_names = output_names.intersection(unmet_output_names)
148+
149+
# all reads/writes are partition inputs except those generated
150+
# within the partition and tensor constants
151+
read_writes = dependencies.ReadWrites.merge_list(
152+
[node.read_writes for node in partition]
153+
)
154+
155+
# WeakDep is fake dependency on unused buffer. It should not appear
156+
# in partition_input_names for inputs that are actually read or written.
157+
partition_input_names = (
158+
OrderedSet(
159+
[
160+
x.name
161+
for x in read_writes.reads | read_writes.writes
162+
if not is_none_layout(x.name)
163+
]
164+
)
165+
- output_names
166+
)
167+
168+
partition_input_names = OrderedSet(
169+
self.mutation_real_name.get(name, name) for name in partition_input_names
170+
)
171+
172+
buffer_names_to_free: OrderedSet[str] = OrderedSet()
173+
for node in partition:
174+
buffer_names_to_free.update(node.last_usage)
175+
176+
# buffer_names_to_free may contain buffers allocated in previous
177+
# graph partitions. These buffers should also be a partition
178+
# input.
179+
extra_input_names = [
180+
name
181+
for name in (buffer_names_to_free - output_names)
182+
if name in name_to_node
183+
]
184+
partition_input_names.update(extra_input_names)
185+
186+
input_nodes = {
187+
name: name_to_node[name]
188+
for name in partition_input_names
189+
if name in name_to_node
190+
}
191+
input_deallocation = {
192+
name: name in buffer_names_to_free
193+
for name in partition_input_names
194+
if name in name_to_node
195+
}
196+
197+
# if an input tensor is not freed in the partition function, it should
198+
# also be returned as an output. This brings benefits to cudagraph
199+
# since the returned output tensor is a cudagraph managed tensor with
200+
# a static tensor address.
201+
extra_output_names = [
202+
name
203+
for name in partition_input_names
204+
if name in name_to_node and name not in buffer_names_to_free
205+
]
206+
207+
returned_output_names.update(extra_output_names)
208+
209+
returned_output_names = OrderedSet(
210+
self.mutation_real_name.get(name, name) for name in returned_output_names
211+
)
212+
213+
output_nodes = [
214+
name_to_node[name]
215+
for name in returned_output_names
216+
if not is_none_layout(name)
217+
]
218+
219+
constant_names = [
220+
name for name in partition_input_names if name in V.graph.constants
221+
]
222+
223+
symbol_inputs = self.get_graph_partition_symbol_inputs(partition, input_nodes)
224+
225+
partition_signature = GraphPartitionSignature(
226+
symbol_inputs,
227+
input_nodes,
228+
output_nodes,
229+
input_deallocation,
230+
skip_cudagraph,
231+
constant_names,
232+
)
233+
234+
signatures.append(partition_signature)
235+
236+
unmet_output_names = partition_input_names.union(
237+
unmet_output_names - returned_output_names
238+
)
239+
240+
return signatures[::-1]
241+
242+
93243
# ========================================
94244
# torch 2.9 Inductor Scheduler monkeypatch
95245
# ========================================
@@ -196,6 +346,7 @@ def _update_scheduler_patched(self) -> None:
196346
from torch._inductor.scheduler import Scheduler
197347

198348
Scheduler.should_partition = should_partition_patched
349+
Scheduler.get_graph_partition_signature = get_graph_partition_signature_patched
199350

200351
with config.patch("triton.store_cubin", False):
201352
self.scheduler = Scheduler(self.operations)

0 commit comments

Comments
 (0)