Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 66 additions & 38 deletions thunder/transforms/quantization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Sequence
from thunder.core.trace_interpreter import TraceSubstitutionProcessor

import thunder
from thunder.core.transform_common import Transform
Expand Down Expand Up @@ -217,49 +218,76 @@ def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilo
if psym.shape != csym.shape or psym.dtype != csym.dtype
}

new_computation_trace = trace_with_replaced_proxy_metadata(computation_trace, computation_proxy_map)
bound_symbols = new_computation_trace.bound_symbols
new_computation_trace.bound_symbols = []

new_computation_trace.args = (*new_computation_trace.args, *new_compute_inputs)
new_computation_trace.names.update(i.name for i in new_compute_inputs)
new_computation_trace._siginfo.args = [(a.name, None) for a in new_computation_trace.args]
# Add new compute inputs to the trace args before processing
computation_trace.args = (*computation_trace.args, *new_compute_inputs)
computation_trace.names.update(i.name for i in new_compute_inputs)
computation_trace._siginfo.args = [(a.name, None) for a in computation_trace.args]

with tracectx(new_computation_trace):
# Add unpack_trivial bindings for new inputs in the correct position
with tracectx(computation_trace):
new_bindings = [
thunder.core.prims.unpack_trivial.bind(i, output=i, name=i.name) for i in new_compute_inputs
]

for idx, bsym in enumerate(bound_symbols):
if bsym.sym != prims.unpack_trivial:
break
new_computation_trace.bound_symbols.append(bsym.from_bsym())
new_computation_trace.bound_symbols += new_bindings

for bsym in bound_symbols[idx:]:
if bsym.sym == thunder.torch.linear and bsym.args[1].name in quantized_proxies:
assert len(bsym.args) == 3 # torch.linear(input, weight, bias)
n = quantized_proxies[bsym.args[1].name]
qs = self.quant_states[n]
# signature of the new symbol:
# bnb_matmul_nf4(x, qweight, bias, absmax, quant_map, blocksize, dtype, shape)
new_args = (
*bsym.args[:3],
additional_proxies[f"{n}.absmax"],
additional_proxies[f"{n}.code"],
qs["blocksize"],
qs["dtype"],
qs["shape"],
)
mm_bsym = bsym.from_bsym(
sym=bnb_matmul_nf4,
subsymbols=[],
args=new_args,
)

new_computation_trace.bound_symbols.append(mm_bsym)
else:
new_computation_trace.bound_symbols.append(bsym.from_bsym())
# Insert the new bindings after the existing unpack_trivial bindings to maintain arg order
# Find the last unpack_trivial binding and insert after it
insert_idx = len(computation_trace.bound_symbols)
for i, bsym in enumerate(computation_trace.bound_symbols):
if bsym.sym.id == prims.PrimIDs.UNPACK_TRIVIAL:
insert_idx = i + 1

computation_trace.bound_symbols[insert_idx:insert_idx] = new_bindings

# Now update metadata for the complete trace
new_computation_trace = trace_with_replaced_proxy_metadata(computation_trace, computation_proxy_map)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need this btw or could we absorb it into the QuantizationProcessor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be better to keep them separate because the TraceSubstitutionProcessor is only concerned with symbol-by-symbol replacement, i tried to experiment with it and ran into more errors.


class QuantizationProcessor(TraceSubstitutionProcessor):
def __init__(self, trace, quantized_proxies, additional_proxies, quant_states, new_compute_inputs):
super().__init__(trace)
self.quantized_proxies = quantized_proxies
self.additional_proxies = additional_proxies
self.quant_states = quant_states
self.new_compute_inputs = new_compute_inputs

def process_bsym(self, bsym):
if bsym.sym == thunder.torch.linear and bsym.args[1].name in self.quantized_proxies:
assert len(bsym.args) == 3 # torch.linear(input, weight, bias)
n = self.quantized_proxies[bsym.args[1].name]
qs = self.quant_states[n]
# signature of the new symbol:
# bnb_matmul_nf4(x, qweight, bias, absmax, quant_map, blocksize, dtype, shape)
new_args = (
*bsym.args[:3],
self.additional_proxies[f"{n}.absmax"],
self.additional_proxies[f"{n}.code"],
qs["blocksize"],
qs["dtype"],
qs["shape"],
)
mm_bsym = bsym.from_bsym(
sym=bnb_matmul_nf4,
subsymbols=[],
args=new_args,
)
self.add_processed_bsyms([mm_bsym])
self.set_result(bsym.output)
elif bsym.sym == prims.python_return:
assert len(bsym.args) == 1 and isinstance(bsym.args[0], dict)
new_return_dict = bsym.args[0].copy()
new_return_dict["flat_args"] = list(self.new_trace.args) # we know that the args are flat
self.add_processed_bsyms([bsym.from_bsym(args=(new_return_dict,))])
self.set_result(bsym.output)
else:
# Keep the original symbol
self.add_processed_bsyms([bsym.from_bsym()])
self.set_result(bsym.output)

# Process the trace using the QuantizationProcessor
processor = QuantizationProcessor(
new_computation_trace, quantized_proxies, additional_proxies, self.quant_states, new_compute_inputs
)

# Now process the trace
new_computation_trace, _ = processor()
new_computation_trace.set_provenance(thunder.core.trace.TraceProvenance("quant pass"))
return prologue_trace, new_computation_trace, epilogue_trace
Loading