@@ -90,6 +90,156 @@ def get_output_names(graph_outputs) -> list[str]:
9090 assert len (planning_states ) == 0
9191
9292
93+ # ===================================================
94+ # torch 2.9 Inductor get_graph_partition_signature monkeypatch
95+ # ===================================================
96+ # This change monkeypatches get_graph_partition_signature in pytorch 2.9.0 to
97+ # fix inductor partition + attention-nvfp4 quant fusion, tested in
98+ # `tests/compile/test_fusions_e2e.py::test_attn_quant`.
99+ # For more context, see https://github.com/pytorch/pytorch/pull/165815.
100+
101+
102+ def get_graph_partition_signature_patched (
103+ self , partitions , skip_cudagraphs : list [bool ]
104+ ):
105+ """
106+ Gets signature for each graph partition, including input nodes, output nodes, and
107+ whether deallocating an input within graph partition.
108+ """
109+ from torch ._inductor import dependencies
110+ from torch ._inductor .ir import GraphPartitionSignature , MutationOutput , NoneLayout
111+ from torch ._inductor .virtualized import V
112+ from torch .utils ._ordered_set import OrderedSet
113+
114+ signatures = []
115+
116+ unmet_output_names = OrderedSet (V .graph .get_output_names ())
117+ name_to_node = self .get_name_to_nodes ()
118+
119+ def is_none_layout (buf_name : str ) -> bool :
120+ """
121+ Checks if buf_name is NoneLayout. Buffers with NoneLayout is not allocated
122+ so graph partition should not take it as inputs or outputs.
123+ """
124+ buf = self .name_to_buf .get (buf_name , None )
125+
126+ if buf is None :
127+ return False
128+
129+ if isinstance (buf .node .layout , NoneLayout ):
130+ if isinstance (buf .node , MutationOutput ) and (
131+ real_name := self .mutation_real_name .get (buf_name , None )
132+ ):
133+ return is_none_layout (real_name )
134+
135+ return True
136+
137+ return False
138+
139+ for partition , skip_cudagraph in zip (
140+ reversed (partitions ), reversed (skip_cudagraphs )
141+ ):
142+ output_names : OrderedSet [str ] = OrderedSet ()
143+
144+ for node in partition :
145+ output_names .update (node .outputs_by_name .keys ())
146+
147+ returned_output_names = output_names .intersection (unmet_output_names )
148+
149+ # all reads/writes are partition inputs except those generated
150+ # within the partition and tensor constants
151+ read_writes = dependencies .ReadWrites .merge_list (
152+ [node .read_writes for node in partition ]
153+ )
154+
155+ # WeakDep is fake dependency on unused buffer. It should not appear
156+ # in partition_input_names for inputs that are actually read or written.
157+ partition_input_names = (
158+ OrderedSet (
159+ [
160+ x .name
161+ for x in read_writes .reads | read_writes .writes
162+ if not is_none_layout (x .name )
163+ ]
164+ )
165+ - output_names
166+ )
167+
168+ partition_input_names = OrderedSet (
169+ self .mutation_real_name .get (name , name ) for name in partition_input_names
170+ )
171+
172+ buffer_names_to_free : OrderedSet [str ] = OrderedSet ()
173+ for node in partition :
174+ buffer_names_to_free .update (node .last_usage )
175+
176+ # buffer_names_to_free may contain buffers allocated in previous
177+ # graph partitions. These buffers should also be a partition
178+ # input.
179+ extra_input_names = [
180+ name
181+ for name in (buffer_names_to_free - output_names )
182+ if name in name_to_node
183+ ]
184+ partition_input_names .update (extra_input_names )
185+
186+ input_nodes = {
187+ name : name_to_node [name ]
188+ for name in partition_input_names
189+ if name in name_to_node
190+ }
191+ input_deallocation = {
192+ name : name in buffer_names_to_free
193+ for name in partition_input_names
194+ if name in name_to_node
195+ }
196+
197+ # if an input tensor is not freed in the partition function, it should
198+ # also be returned as an output. This brings benefits to cudagraph
199+ # since the returned output tensor is a cudagraph managed tensor with
200+ # a static tensor address.
201+ extra_output_names = [
202+ name
203+ for name in partition_input_names
204+ if name in name_to_node and name not in buffer_names_to_free
205+ ]
206+
207+ returned_output_names .update (extra_output_names )
208+
209+ returned_output_names = OrderedSet (
210+ self .mutation_real_name .get (name , name ) for name in returned_output_names
211+ )
212+
213+ output_nodes = [
214+ name_to_node [name ]
215+ for name in returned_output_names
216+ if not is_none_layout (name )
217+ ]
218+
219+ constant_names = [
220+ name for name in partition_input_names if name in V .graph .constants
221+ ]
222+
223+ symbol_inputs = self .get_graph_partition_symbol_inputs (partition , input_nodes )
224+
225+ partition_signature = GraphPartitionSignature (
226+ symbol_inputs ,
227+ input_nodes ,
228+ output_nodes ,
229+ input_deallocation ,
230+ skip_cudagraph ,
231+ constant_names ,
232+ )
233+
234+ signatures .append (partition_signature )
235+
236+ unmet_output_names = partition_input_names .union (
237+ unmet_output_names - returned_output_names
238+ )
239+
240+ return signatures [::- 1 ]
241+
242+
93243# ========================================
94244# torch 2.9 Inductor Scheduler monkeypatch
95245# ========================================
@@ -196,6 +346,7 @@ def _update_scheduler_patched(self) -> None:
196346 from torch ._inductor .scheduler import Scheduler
197347
198348 Scheduler .should_partition = should_partition_patched
349+ Scheduler .get_graph_partition_signature = get_graph_partition_signature_patched
199350
200351 with config .patch ("triton.store_cubin" , False ):
201352 self .scheduler = Scheduler (self .operations )
0 commit comments