Fix group offloading for quanto-quantized models and the use_stream path for quantized tensor subclasses#14038
Open
Sunt-ing wants to merge 1 commit into
Open
Fix group offloading for quanto-quantized models and the use_stream path for quantized tensor subclasses#14038Sunt-ing wants to merge 1 commit into
Sunt-ing wants to merge 1 commit into
Conversation
…ath for quantized tensor subclasses
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Fixes #12610
Fixes #13281
Group offloading moves a group's parameters between CPU and the accelerator by reassigning
param.data:This is correct for plain tensors but wrong for tensor subclasses (quantized weights), whose real payload lives in internal sub-tensors (quanto
WeightQBytesTensor:_data/_scale; torchaoAffineQuantizedTensor:qdata/scale/...). Reassigning.dataonly swaps the outer wrapper and leaves the inner tensors on the source device, so the next matmul fails withmat2 is on cpu, different from cuda:0.#13276 fixed this for torchao by swapping the whole subclass via
torch.utils.swap_tensorsand restoring inner attributes one by one. Two gaps remained:enable_group_offloadhits the wrapper-only.data =path and crashes with a device mismatch on the first forward, for bothleaf_levelandblock_level.use_stream=True,_to_cpu/_pinned_memory_tensorscallpin_memory()/is_pinned(), which neither subclass supports: quanto silently loses the subclass identity, and torchao raisesNotImplementedError: ... aten.is_pinned. So torchao +use_stream=Truecrashes even though its non-stream path was already fixed.Changes (
src/diffusers/hooks/group_offloading.py)_is_quanto_tensorplus quanto helpers, and handle quanto next to the existing torchao branch in_transfer_tensor_to_device(onload),_offload_to_memory(restore / offload), and therecord_streampath. Inner tensor names come from the standard subclass protocol__tensor_flatten__(); quanto onload usestorch.utils.swap_tensorsinstead of.data =._to_cpuand_pinned_memory_tensors, skippin_memory()/is_pinned()for quanto and torchao subclasses.Tests
Added
test_group_offloadingto the quanto and torchao quantization suites. Each loads a quantized tiny Flux transformer, offloads it acrossleaf_level/block_leveland non-stream /use_stream, and asserts the output matches the non-offloaded quantized baseline.tests/quantization/quanto/test_quanto.py(int8 and float8): both fail onmainwith the device mismatch, pass here.tests/quantization/torchao/test_torchao.py::TorchAoTest::test_group_offloading: theuse_stream=Truecases fail onmainwith theaten.is_pinnederror, pass here.Reproduction and before/after
Environment: NVIDIA RTX 4090,
torch==2.8.0+cu128,diffusers@2d0110f,optimum-quanto==0.2.7,torchao==0.17.0.Minimal standalone repro for #12610 (quanto):
Running the new tests (
RUN_NIGHTLY=1 RUN_SLOW=1):Across
leaf_level/block_level× non-stream /use_stream/record_stream, the offloaded output is bit-identical (max abs diff = 0.0) to the fully-on-accelerator quantized baseline. A non-quantized group-offload equivalence sweep stays at0.0(plain-tensor path unchanged).Relationship to other work
mx/nvfp4tensors.Int8WeightOnlyConfigAffineQuantizedTensorstill raisesaten.is_pinnedontorchao==0.17.0, so the streamed path is still broken for the common int8 case. Skipping pinning on the diffusers side fixes it regardless of the torchao version, and is also required for quanto, whose subclass tensors do not implement torch pinning at all._to_cpu,_pinned_memory_tensors,_swap_torchao_tensor) to add disk offload. They are orthogonal in intent (disk vs the memory device-mismatch / stream-pin crash here) but touch the same region, so this PR will need a rebase around whichever lands first.Who can review?
cc @sayakpaul
Before submitting
.ai/review-rules.md?