Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions vllm/env_override.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,156 @@ def get_output_names(graph_outputs) -> list[str]:
assert len(planning_states) == 0


# ===================================================
# torch 2.9 Inductor get_graph_partition_signature monkeypatch
# ===================================================
# This change monkeypatches get_graph_partition_signature in pytorch 2.9.0 to
# fix inductor partition + attention-nvfp4 quant fusion, tested in
# `tests/compile/test_fusions_e2e.py::test_attn_quant`.
# For more context, see https://github.com/pytorch/pytorch/pull/165815.


def get_graph_partition_signature_patched(
self, partitions, skip_cudagraphs: list[bool]
):
"""
Gets signature for each graph partition, including input nodes, output nodes, and
whether deallocating an input within graph partition.
"""
from torch._inductor import dependencies
from torch._inductor.ir import GraphPartitionSignature, MutationOutput, NoneLayout
from torch._inductor.virtualized import V
from torch.utils._ordered_set import OrderedSet

signatures = []

unmet_output_names = OrderedSet(V.graph.get_output_names())
name_to_node = self.get_name_to_nodes()

def is_none_layout(buf_name: str) -> bool:
"""
Checks if buf_name is NoneLayout. Buffers with NoneLayout is not allocated
so graph partition should not take it as inputs or outputs.
"""
buf = self.name_to_buf.get(buf_name, None)

if buf is None:
return False

if isinstance(buf.node.layout, NoneLayout):
if isinstance(buf.node, MutationOutput) and (
real_name := self.mutation_real_name.get(buf_name, None)
):
return is_none_layout(real_name)

return True

return False

for partition, skip_cudagraph in zip(
reversed(partitions), reversed(skip_cudagraphs)
):
output_names: OrderedSet[str] = OrderedSet()

for node in partition:
output_names.update(node.outputs_by_name.keys())

returned_output_names = output_names.intersection(unmet_output_names)

# all reads/writes are partition inputs except those generated
# within the partition and tensor constants
read_writes = dependencies.ReadWrites.merge_list(
[node.read_writes for node in partition]
)

# WeakDep is fake dependency on unused buffer. It should not appear
# in partition_input_names for inputs that are actually read or written.
partition_input_names = (
OrderedSet(
[
x.name
for x in read_writes.reads | read_writes.writes
if not is_none_layout(x.name)
]
)
- output_names
)

partition_input_names = OrderedSet(
self.mutation_real_name.get(name, name) for name in partition_input_names
)

buffer_names_to_free: OrderedSet[str] = OrderedSet()
for node in partition:
buffer_names_to_free.update(node.last_usage)

# buffer_names_to_free may contain buffers allocated in previous
# graph partitions. These buffers should also be a partition
# input.
extra_input_names = [
name
for name in (buffer_names_to_free - output_names)
if name in name_to_node
]
partition_input_names.update(extra_input_names)

input_nodes = {
name: name_to_node[name]
for name in partition_input_names
if name in name_to_node
}
input_deallocation = {
name: name in buffer_names_to_free
for name in partition_input_names
if name in name_to_node
}

# if an input tensor is not freed in the partition function, it should
# also be returned as an output. This brings benefits to cudagraph
# since the returned output tensor is a cudagraph managed tensor with
# a static tensor address.
extra_output_names = [
name
for name in partition_input_names
if name in name_to_node and name not in buffer_names_to_free
]

returned_output_names.update(extra_output_names)

returned_output_names = OrderedSet(
self.mutation_real_name.get(name, name) for name in returned_output_names
)

output_nodes = [
name_to_node[name]
for name in returned_output_names
if not is_none_layout(name)
]

constant_names = [
name for name in partition_input_names if name in V.graph.constants
]

symbol_inputs = self.get_graph_partition_symbol_inputs(partition, input_nodes)

partition_signature = GraphPartitionSignature(
symbol_inputs,
input_nodes,
output_nodes,
input_deallocation,
skip_cudagraph,
constant_names,
)

signatures.append(partition_signature)

unmet_output_names = partition_input_names.union(
unmet_output_names - returned_output_names
)

return signatures[::-1]


# ========================================
# torch 2.9 Inductor Scheduler monkeypatch
# ========================================
Expand Down Expand Up @@ -196,6 +346,7 @@ def _update_scheduler_patched(self) -> None:
from torch._inductor.scheduler import Scheduler

Scheduler.should_partition = should_partition_patched
Scheduler.get_graph_partition_signature = get_graph_partition_signature_patched

with config.patch("triton.store_cubin", False):
self.scheduler = Scheduler(self.operations)
Expand Down