Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 51 additions & 6 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import safetensors.torch
import torch

from ..utils import get_logger, is_accelerate_available, is_torchao_available
from ..utils import get_logger, is_accelerate_available, is_optimum_quanto_available, is_torchao_available
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from .hooks import HookRegistry, ModelHook

Expand Down Expand Up @@ -83,6 +83,31 @@ def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None:
getattr(param, attr_name).record_stream(stream)


def _is_quanto_tensor(tensor: torch.Tensor) -> bool:
if not is_optimum_quanto_available():
return False
from optimum.quanto import QTensor

return isinstance(tensor, QTensor)


def _get_quanto_inner_tensor_names(tensor: torch.Tensor) -> list[str]:
"""Get names of all internal tensor data attributes from a quanto QTensor (e.g. `_data`, `_scale`)."""
return list(tensor.__tensor_flatten__()[0])


def _restore_quanto_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
"""Restore internal tensor data of a quanto QTensor from `source` without mutating `source`."""
for attr_name in _get_quanto_inner_tensor_names(source):
setattr(param, attr_name, getattr(source, attr_name))


def _record_stream_quanto_tensor(param: torch.Tensor, stream) -> None:
"""Record stream for all internal tensors of a quanto QTensor."""
for attr_name in _get_quanto_inner_tensor_names(param):
getattr(param, attr_name).record_stream(stream)


# fmt: off
_GROUP_OFFLOADING = "group_offloading"
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
Expand Down Expand Up @@ -174,10 +199,14 @@ def __init__(

@staticmethod
def _to_cpu(tensor, low_cpu_mem_usage):
# For TorchAO tensors, `.data` returns an incomplete wrapper without internal attributes
# (e.g. `.qdata`, `.scale`), so we must call `.cpu()` on the tensor directly.
t = tensor.cpu() if _is_torchao_tensor(tensor) else tensor.data.cpu()
return t if low_cpu_mem_usage else t.pin_memory()
# For tensor subclasses (TorchAO / quanto), `.data` returns an incomplete wrapper without internal
# attributes (e.g. `.qdata`/`.scale`, `._data`/`._scale`), so we must call `.cpu()` on the tensor directly.
t = tensor.cpu() if (_is_torchao_tensor(tensor) or _is_quanto_tensor(tensor)) else tensor.data.cpu()
# Subclass tensors (quanto / torchao) don't support `pin_memory()`/`is_pinned()` (quanto loses the
# subclass, torchao raises on the unimplemented op), so skip pinning for them.
if low_cpu_mem_usage or _is_quanto_tensor(tensor) or _is_torchao_tensor(tensor):
return t
return t.pin_memory()

def _init_cpu_param_dict(self):
cpu_param_dict = {}
Expand All @@ -202,7 +231,9 @@ def _init_cpu_param_dict(self):
def _pinned_memory_tensors(self):
try:
pinned_dict = {
param: tensor.pin_memory() if not tensor.is_pinned() else tensor
param: tensor
if (_is_quanto_tensor(tensor) or _is_torchao_tensor(tensor) or tensor.is_pinned())
else tensor.pin_memory()
for param, tensor in self.cpu_param_dict.items()
}
yield pinned_dict
Expand All @@ -213,11 +244,15 @@ def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if _is_torchao_tensor(tensor):
_swap_torchao_tensor(tensor, moved)
elif _is_quanto_tensor(tensor):
torch.utils.swap_tensors(tensor, moved)
else:
tensor.data = moved
if self.record_stream:
if _is_torchao_tensor(tensor):
_record_stream_torchao_tensor(tensor, default_stream)
elif _is_quanto_tensor(tensor):
_record_stream_quanto_tensor(tensor, default_stream)
else:
tensor.data.record_stream(default_stream)

Expand Down Expand Up @@ -320,16 +355,22 @@ def _offload_to_memory(self):
for param in group_module.parameters():
if _is_torchao_tensor(param):
_restore_torchao_tensor(param, self.cpu_param_dict[param])
elif _is_quanto_tensor(param):
_restore_quanto_tensor(param, self.cpu_param_dict[param])
else:
param.data = self.cpu_param_dict[param]
for param in self.parameters:
if _is_torchao_tensor(param):
_restore_torchao_tensor(param, self.cpu_param_dict[param])
elif _is_quanto_tensor(param):
_restore_quanto_tensor(param, self.cpu_param_dict[param])
else:
param.data = self.cpu_param_dict[param]
for buffer in self.buffers:
if _is_torchao_tensor(buffer):
_restore_torchao_tensor(buffer, self.cpu_param_dict[buffer])
elif _is_quanto_tensor(buffer):
_restore_quanto_tensor(buffer, self.cpu_param_dict[buffer])
else:
buffer.data = self.cpu_param_dict[buffer]
else:
Expand All @@ -339,12 +380,16 @@ def _offload_to_memory(self):
if _is_torchao_tensor(param):
moved = param.to(self.offload_device, non_blocking=False)
_swap_torchao_tensor(param, moved)
elif _is_quanto_tensor(param):
torch.utils.swap_tensors(param, param.to(self.offload_device, non_blocking=False))
else:
param.data = param.data.to(self.offload_device, non_blocking=False)
for buffer in self.buffers:
if _is_torchao_tensor(buffer):
moved = buffer.to(self.offload_device, non_blocking=False)
_swap_torchao_tensor(buffer, moved)
elif _is_quanto_tensor(buffer):
torch.utils.swap_tensors(buffer, buffer.to(self.offload_device, non_blocking=False))
else:
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)

Expand Down
24 changes: 24 additions & 0 deletions tests/quantization/quanto/test_quanto.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,30 @@ def test_model_cpu_offload(self):
pipe.enable_model_cpu_offload(device=torch_device)
_ = pipe("a cat holding a sign that says hello", num_inference_steps=2)

def test_group_offloading(self):
inputs = self.get_dummy_inputs()
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()).to(torch_device)
with torch.no_grad():
output_without_offloading = model(**inputs).sample
model.to("cpu")
del model
backend_empty_cache(torch_device)
gc.collect()

for offload_kwargs in (
{"offload_type": "leaf_level"},
{"offload_type": "leaf_level", "use_stream": True},
{"offload_type": "block_level", "num_blocks_per_group": 1, "use_stream": True},
):
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
model.enable_group_offload(torch_device, **offload_kwargs)
with torch.no_grad():
output = model(**inputs).sample
assert torch.allclose(output_without_offloading, output, atol=1e-3, rtol=1e-3)
del model
backend_empty_cache(torch_device)
gc.collect()

def test_training(self):
quantization_config = QuantoConfig(**self.get_dummy_init_kwargs())
quantized_model = self.model_cls.from_pretrained(
Expand Down
27 changes: 27 additions & 0 deletions tests/quantization/torchao/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,33 @@ def test_sequential_cpu_offload(self):
inputs = self.get_dummy_inputs(torch_device)
_ = pipe(**inputs)

def test_group_offloading(self):
r"""
A test that checks if inference runs as expected when group offloading is enabled, including the
`use_stream` path that pins tensors, which the quantized subclass tensors do not support.
"""
inputs = self.get_dummy_tensor_inputs(torch_device)
transformer = self.get_dummy_components(TorchAoConfig(Int8WeightOnlyConfig()))["transformer"].to(torch_device)
with torch.no_grad():
output_without_offloading = transformer(**inputs)[0]
del transformer
backend_empty_cache(torch_device)
gc.collect()

for offload_kwargs in (
{"offload_type": "leaf_level"},
{"offload_type": "leaf_level", "use_stream": True},
{"offload_type": "block_level", "num_blocks_per_group": 1, "use_stream": True},
):
transformer = self.get_dummy_components(TorchAoConfig(Int8WeightOnlyConfig()))["transformer"]
transformer.enable_group_offload(torch_device, **offload_kwargs)
with torch.no_grad():
output = transformer(**inputs)[0]
assert torch.allclose(output_without_offloading, output, atol=1e-3, rtol=1e-3)
del transformer
backend_empty_cache(torch_device)
gc.collect()

@require_torchao_version_greater_or_equal("0.15.0")
def test_aobase_config(self):
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
Expand Down
Loading