Skip to content

Fix group offloading for quanto-quantized models and the use_stream path for quantized tensor subclasses#14038

Open
Sunt-ing wants to merge 1 commit into
huggingface:mainfrom
Sunt-ing:0
Open

Fix group offloading for quanto-quantized models and the use_stream path for quantized tensor subclasses#14038
Sunt-ing wants to merge 1 commit into
huggingface:mainfrom
Sunt-ing:0

Conversation

@Sunt-ing

Copy link
Copy Markdown

What does this PR do?

Fixes #12610
Fixes #13281

Group offloading moves a group's parameters between CPU and the accelerator by reassigning param.data:

param.data = source_tensor.to(device)

This is correct for plain tensors but wrong for tensor subclasses (quantized weights), whose real payload lives in internal sub-tensors (quanto WeightQBytesTensor: _data/_scale; torchao AffineQuantizedTensor: qdata/scale/...). Reassigning .data only swaps the outer wrapper and leaves the inner tensors on the source device, so the next matmul fails with mat2 is on cpu, different from cuda:0.

#13276 fixed this for torchao by swapping the whole subclass via torch.utils.swap_tensors and restoring inner attributes one by one. Two gaps remained:

Changes (src/diffusers/hooks/group_offloading.py)

  • Add _is_quanto_tensor plus quanto helpers, and handle quanto next to the existing torchao branch in _transfer_tensor_to_device (onload), _offload_to_memory (restore / offload), and the record_stream path. Inner tensor names come from the standard subclass protocol __tensor_flatten__(); quanto onload uses torch.utils.swap_tensors instead of .data =.
  • In _to_cpu and _pinned_memory_tensors, skip pin_memory() / is_pinned() for quanto and torchao subclasses.
  • Plain tensors and the torchao non-stream path are untouched (zero behavior change).

Tests

Added test_group_offloading to the quanto and torchao quantization suites. Each loads a quantized tiny Flux transformer, offloads it across leaf_level / block_level and non-stream / use_stream, and asserts the output matches the non-offloaded quantized baseline.

  • tests/quantization/quanto/test_quanto.py (int8 and float8): both fail on main with the device mismatch, pass here.
  • tests/quantization/torchao/test_torchao.py::TorchAoTest::test_group_offloading: the use_stream=True cases fail on main with the aten.is_pinned error, pass here.
Reproduction and before/after

Environment: NVIDIA RTX 4090, torch==2.8.0+cu128, diffusers @ 2d0110f, optimum-quanto==0.2.7, torchao==0.17.0.

Minimal standalone repro for #12610 (quanto):

import torch
from diffusers import UNet2DConditionModel
from diffusers.hooks import apply_group_offloading
from optimum.quanto import quantize, freeze, qint8

m = UNet2DConditionModel.from_pretrained(
    "hf-internal-testing/tiny-stable-diffusion-pipe", subfolder="unet"
).to(torch.float32).eval()
quantize(m, weights=qint8); freeze(m)
apply_group_offloading(
    m, onload_device=torch.device("cuda"), offload_device=torch.device("cpu"),
    offload_type="leaf_level",
)
x = torch.randn(2, m.config.in_channels, m.config.sample_size, m.config.sample_size, device="cuda")
t = torch.tensor([10, 10], device="cuda")
e = torch.randn(2, 4, m.config.cross_attention_dim, device="cuda")
with torch.no_grad():
    m(x, t, e)  # main: RuntimeError: mat2 is on cpu, different from cuda:0

Running the new tests (RUN_NIGHTLY=1 RUN_SLOW=1):

# on main (fix reverted, tests kept)
quanto  FluxTransformerInt8WeightsTest::test_group_offloading    FAILED  (mat2 is on cpu, different from cuda:0)
quanto  FluxTransformerFloat8WeightsTest::test_group_offloading  FAILED  (mat2 is on cpu, different from cuda:0)
torchao TorchAoTest::test_group_offloading                       FAILED  (NotImplementedError: ... aten.is_pinned)

# with this PR
quanto  FluxTransformerInt8WeightsTest::test_group_offloading    PASSED
quanto  FluxTransformerFloat8WeightsTest::test_group_offloading  PASSED
torchao TorchAoTest::test_group_offloading                       PASSED

Across leaf_level / block_level × non-stream / use_stream / record_stream, the offloaded output is bit-identical (max abs diff = 0.0) to the fully-on-accelerator quantized baseline. A non-quantized group-offload equivalence sweep stays at 0.0 (plain-tensor path unchanged).

Relationship to other work

Who can review?

cc @sayakpaul

Before submitting

@github-actions github-actions Bot added fixes-issue size/M PR with diff < 200 LOC tests hooks and removed size/M PR with diff < 200 LOC labels Jun 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

1 participant