From 8ab88ee1df0b3eb0fee7afe78d35073b8bbf87ec Mon Sep 17 00:00:00 2001 From: Ting Sun Date: Mon, 22 Jun 2026 04:25:55 +0800 Subject: [PATCH] Fix group offloading for quanto-quantized models and the use_stream path for quantized tensor subclasses --- src/diffusers/hooks/group_offloading.py | 57 +++++++++++++++++++--- tests/quantization/quanto/test_quanto.py | 24 +++++++++ tests/quantization/torchao/test_torchao.py | 27 ++++++++++ 3 files changed, 102 insertions(+), 6 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index f3d1f3389bb7..77d0b241c5f0 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -22,7 +22,7 @@ import safetensors.torch import torch -from ..utils import get_logger, is_accelerate_available, is_torchao_available +from ..utils import get_logger, is_accelerate_available, is_optimum_quanto_available, is_torchao_available from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS from .hooks import HookRegistry, ModelHook @@ -83,6 +83,31 @@ def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None: getattr(param, attr_name).record_stream(stream) +def _is_quanto_tensor(tensor: torch.Tensor) -> bool: + if not is_optimum_quanto_available(): + return False + from optimum.quanto import QTensor + + return isinstance(tensor, QTensor) + + +def _get_quanto_inner_tensor_names(tensor: torch.Tensor) -> list[str]: + """Get names of all internal tensor data attributes from a quanto QTensor (e.g. `_data`, `_scale`).""" + return list(tensor.__tensor_flatten__()[0]) + + +def _restore_quanto_tensor(param: torch.Tensor, source: torch.Tensor) -> None: + """Restore internal tensor data of a quanto QTensor from `source` without mutating `source`.""" + for attr_name in _get_quanto_inner_tensor_names(source): + setattr(param, attr_name, getattr(source, attr_name)) + + +def _record_stream_quanto_tensor(param: torch.Tensor, stream) -> None: + """Record stream for all internal tensors of a quanto QTensor.""" + for attr_name in _get_quanto_inner_tensor_names(param): + getattr(param, attr_name).record_stream(stream) + + # fmt: off _GROUP_OFFLOADING = "group_offloading" _LAYER_EXECUTION_TRACKER = "layer_execution_tracker" @@ -174,10 +199,14 @@ def __init__( @staticmethod def _to_cpu(tensor, low_cpu_mem_usage): - # For TorchAO tensors, `.data` returns an incomplete wrapper without internal attributes - # (e.g. `.qdata`, `.scale`), so we must call `.cpu()` on the tensor directly. - t = tensor.cpu() if _is_torchao_tensor(tensor) else tensor.data.cpu() - return t if low_cpu_mem_usage else t.pin_memory() + # For tensor subclasses (TorchAO / quanto), `.data` returns an incomplete wrapper without internal + # attributes (e.g. `.qdata`/`.scale`, `._data`/`._scale`), so we must call `.cpu()` on the tensor directly. + t = tensor.cpu() if (_is_torchao_tensor(tensor) or _is_quanto_tensor(tensor)) else tensor.data.cpu() + # Subclass tensors (quanto / torchao) don't support `pin_memory()`/`is_pinned()` (quanto loses the + # subclass, torchao raises on the unimplemented op), so skip pinning for them. + if low_cpu_mem_usage or _is_quanto_tensor(tensor) or _is_torchao_tensor(tensor): + return t + return t.pin_memory() def _init_cpu_param_dict(self): cpu_param_dict = {} @@ -202,7 +231,9 @@ def _init_cpu_param_dict(self): def _pinned_memory_tensors(self): try: pinned_dict = { - param: tensor.pin_memory() if not tensor.is_pinned() else tensor + param: tensor + if (_is_quanto_tensor(tensor) or _is_torchao_tensor(tensor) or tensor.is_pinned()) + else tensor.pin_memory() for param, tensor in self.cpu_param_dict.items() } yield pinned_dict @@ -213,11 +244,15 @@ def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream): moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) if _is_torchao_tensor(tensor): _swap_torchao_tensor(tensor, moved) + elif _is_quanto_tensor(tensor): + torch.utils.swap_tensors(tensor, moved) else: tensor.data = moved if self.record_stream: if _is_torchao_tensor(tensor): _record_stream_torchao_tensor(tensor, default_stream) + elif _is_quanto_tensor(tensor): + _record_stream_quanto_tensor(tensor, default_stream) else: tensor.data.record_stream(default_stream) @@ -320,16 +355,22 @@ def _offload_to_memory(self): for param in group_module.parameters(): if _is_torchao_tensor(param): _restore_torchao_tensor(param, self.cpu_param_dict[param]) + elif _is_quanto_tensor(param): + _restore_quanto_tensor(param, self.cpu_param_dict[param]) else: param.data = self.cpu_param_dict[param] for param in self.parameters: if _is_torchao_tensor(param): _restore_torchao_tensor(param, self.cpu_param_dict[param]) + elif _is_quanto_tensor(param): + _restore_quanto_tensor(param, self.cpu_param_dict[param]) else: param.data = self.cpu_param_dict[param] for buffer in self.buffers: if _is_torchao_tensor(buffer): _restore_torchao_tensor(buffer, self.cpu_param_dict[buffer]) + elif _is_quanto_tensor(buffer): + _restore_quanto_tensor(buffer, self.cpu_param_dict[buffer]) else: buffer.data = self.cpu_param_dict[buffer] else: @@ -339,12 +380,16 @@ def _offload_to_memory(self): if _is_torchao_tensor(param): moved = param.to(self.offload_device, non_blocking=False) _swap_torchao_tensor(param, moved) + elif _is_quanto_tensor(param): + torch.utils.swap_tensors(param, param.to(self.offload_device, non_blocking=False)) else: param.data = param.data.to(self.offload_device, non_blocking=False) for buffer in self.buffers: if _is_torchao_tensor(buffer): moved = buffer.to(self.offload_device, non_blocking=False) _swap_torchao_tensor(buffer, moved) + elif _is_quanto_tensor(buffer): + torch.utils.swap_tensors(buffer, buffer.to(self.offload_device, non_blocking=False)) else: buffer.data = buffer.data.to(self.offload_device, non_blocking=False) diff --git a/tests/quantization/quanto/test_quanto.py b/tests/quantization/quanto/test_quanto.py index e3463f136f94..0d203a83bb44 100644 --- a/tests/quantization/quanto/test_quanto.py +++ b/tests/quantization/quanto/test_quanto.py @@ -273,6 +273,30 @@ def test_model_cpu_offload(self): pipe.enable_model_cpu_offload(device=torch_device) _ = pipe("a cat holding a sign that says hello", num_inference_steps=2) + def test_group_offloading(self): + inputs = self.get_dummy_inputs() + model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()).to(torch_device) + with torch.no_grad(): + output_without_offloading = model(**inputs).sample + model.to("cpu") + del model + backend_empty_cache(torch_device) + gc.collect() + + for offload_kwargs in ( + {"offload_type": "leaf_level"}, + {"offload_type": "leaf_level", "use_stream": True}, + {"offload_type": "block_level", "num_blocks_per_group": 1, "use_stream": True}, + ): + model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) + model.enable_group_offload(torch_device, **offload_kwargs) + with torch.no_grad(): + output = model(**inputs).sample + assert torch.allclose(output_without_offloading, output, atol=1e-3, rtol=1e-3) + del model + backend_empty_cache(torch_device) + gc.collect() + def test_training(self): quantization_config = QuantoConfig(**self.get_dummy_init_kwargs()) quantized_model = self.model_cls.from_pretrained( diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 8a811cfc1c73..f61e1c05a788 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -522,6 +522,33 @@ def test_sequential_cpu_offload(self): inputs = self.get_dummy_inputs(torch_device) _ = pipe(**inputs) + def test_group_offloading(self): + r""" + A test that checks if inference runs as expected when group offloading is enabled, including the + `use_stream` path that pins tensors, which the quantized subclass tensors do not support. + """ + inputs = self.get_dummy_tensor_inputs(torch_device) + transformer = self.get_dummy_components(TorchAoConfig(Int8WeightOnlyConfig()))["transformer"].to(torch_device) + with torch.no_grad(): + output_without_offloading = transformer(**inputs)[0] + del transformer + backend_empty_cache(torch_device) + gc.collect() + + for offload_kwargs in ( + {"offload_type": "leaf_level"}, + {"offload_type": "leaf_level", "use_stream": True}, + {"offload_type": "block_level", "num_blocks_per_group": 1, "use_stream": True}, + ): + transformer = self.get_dummy_components(TorchAoConfig(Int8WeightOnlyConfig()))["transformer"] + transformer.enable_group_offload(torch_device, **offload_kwargs) + with torch.no_grad(): + output = transformer(**inputs)[0] + assert torch.allclose(output_without_offloading, output, atol=1e-3, rtol=1e-3) + del transformer + backend_empty_cache(torch_device) + gc.collect() + @require_torchao_version_greater_or_equal("0.15.0") def test_aobase_config(self): quantization_config = TorchAoConfig(Int8WeightOnlyConfig())