Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 67 additions & 33 deletions tests/models/testing_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,13 @@
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size

from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant, logging
from diffusers.utils.testing_utils import require_accelerator, require_torch_multi_accelerator

from ...testing_utils import assert_tensors_close, torch_device
from ...testing_utils import (
assert_tensors_close,
require_accelerator,
require_torch_multi_accelerator,
torch_device,
)


def named_persistent_module_tensors(
Expand Down Expand Up @@ -258,7 +262,39 @@ def get_dummy_inputs(self) -> Dict[str, Any]:
raise NotImplementedError("Subclasses must implement `get_dummy_inputs()`.")


class ModelTesterMixin:
class BaseModelOutputMixin:
"""Provides the class-scoped `base_model_output` fixture shared across tester mixins.

Kept separate from `BaseModelTesterConfig` — which only declares the testing contract and performs no
computation — so any mixin that needs the cached reference output (`ModelTesterMixin`, the memory
offload mixins, ...) can inherit it without duplicating the build-and-forward.
"""

@pytest.fixture(scope="class")
def base_model_output(self):
"""Class-scoped reference forward output, built once and reused across the class.

Building the model and running its forward pass is fully deterministic (`torch.manual_seed(0)`
plus the deterministic `get_dummy_inputs` contract), so the reference ("base") output is
identical for every test in the class. The save/load, parallelism, and memory-offload tests
compare a reloaded/offloaded model against this output; computing it a single time here — instead
of rebuilding the model and re-running the forward in each test — removes that redundant work and
speeds up the suite.

The hardware-gated tests that consume this fixture use `pytest.mark.skipif` (via the `require_*`
decorators), which pytest evaluates before fixture setup, so skipping on a machine without the
required accelerators never triggers this forward.

Tests that still need a live model (e.g. to save or offload it) build their own with the same
seed, so the reloaded model's weights match this cached output.
"""
torch.manual_seed(0)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should BaseModelOutputMixin also expose the seed it uses to create the base_model_output fixture? I think this would allow tests to more easily match its behavior.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about it but then this change should also be propagated to the generator property, e.g.

@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)

So, I think this should be done in a separate PR.

model = self.model_class(**self.get_init_dict()).eval().to(torch_device)
with torch.no_grad():
return model(**self.get_dummy_inputs(), return_dict=False)[0]


class ModelTesterMixin(BaseModelOutputMixin):
"""
Base mixin class for model testing with common test methods.

Expand All @@ -279,7 +315,7 @@ class TestMyModel(MyModelTestConfig, ModelTesterMixin):
"""

@torch.no_grad()
def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5):
def test_from_save_pretrained(self, base_model_output, tmp_path, atol=5e-5, rtol=5e-5):
torch.manual_seed(0)
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
Expand All @@ -296,13 +332,15 @@ def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5):
f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
)

image = model(**self.get_dummy_inputs(), return_dict=False)[0]
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]

assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
assert_tensors_close(
base_model_output, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes."
)

@torch.no_grad()
def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0):
def test_from_save_pretrained_variant(self, base_model_output, tmp_path, atol=5e-5, rtol=0):
torch.manual_seed(0)
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
Expand All @@ -317,10 +355,11 @@ def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0):

new_model.to(torch_device)

image = model(**self.get_dummy_inputs(), return_dict=False)[0]
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]

assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
assert_tensors_close(
base_model_output, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes."
)

@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
def test_from_save_pretrained_dtype(self, tmp_path, dtype):
Expand Down Expand Up @@ -360,13 +399,8 @@ def test_determinism(self, atol=1e-5, rtol=0):
)

@torch.no_grad()
def test_output(self, expected_output_shape=None):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()

inputs_dict = self.get_dummy_inputs()
output = model(**inputs_dict, return_dict=False)[0]
def test_output(self, base_model_output, expected_output_shape=None):
output = base_model_output

assert output is not None, "Model output is None"
assert output[0].shape == expected_output_shape or self.output_shape, (
Expand Down Expand Up @@ -509,14 +543,12 @@ def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype, atol=1e-4,

@require_accelerator
@torch.no_grad()
def test_sharded_checkpoints(self, tmp_path, atol=1e-5, rtol=0):
def test_sharded_checkpoints(self, base_model_output, tmp_path, atol=1e-5, rtol=0):
torch.manual_seed(0)
config = self.get_init_dict()
model = self.model_class(**config).eval()
model = model.to(torch_device)

base_output = model(**self.get_dummy_inputs(), return_dict=False)[0]

model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small

Expand All @@ -537,19 +569,17 @@ def test_sharded_checkpoints(self, tmp_path, atol=1e-5, rtol=0):
new_output = new_model(**self.get_dummy_inputs(), return_dict=False)[0]

assert_tensors_close(
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after sharded save/load"
base_model_output, new_output, atol=atol, rtol=rtol, msg="Output should match after sharded save/load"
)

@require_accelerator
@torch.no_grad()
def test_sharded_checkpoints_with_variant(self, tmp_path, atol=1e-5, rtol=0):
def test_sharded_checkpoints_with_variant(self, base_model_output, tmp_path, atol=1e-5, rtol=0):
torch.manual_seed(0)
config = self.get_init_dict()
model = self.model_class(**config).eval()
model = model.to(torch_device)

base_output = model(**self.get_dummy_inputs(), return_dict=False)[0]

model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
variant = "fp16"
Expand All @@ -575,20 +605,22 @@ def test_sharded_checkpoints_with_variant(self, tmp_path, atol=1e-5, rtol=0):
new_output = new_model(**self.get_dummy_inputs(), return_dict=False)[0]

assert_tensors_close(
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after variant sharded save/load"
base_model_output,
new_output,
atol=atol,
rtol=rtol,
msg="Output should match after variant sharded save/load",
)

@torch.no_grad()
def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rtol=0):
def test_sharded_checkpoints_with_parallel_loading(self, base_model_output, tmp_path, atol=1e-5, rtol=0):
from diffusers.utils import constants

torch.manual_seed(0)
config = self.get_init_dict()
model = self.model_class(**config).eval()
model = model.to(torch_device)

base_output = model(**self.get_dummy_inputs(), return_dict=False)[0]

model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small

Expand Down Expand Up @@ -624,7 +656,11 @@ def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rt
output_parallel = model_parallel(**self.get_dummy_inputs(), return_dict=False)[0]

assert_tensors_close(
base_output, output_parallel, atol=atol, rtol=rtol, msg="Output should match with parallel loading"
base_model_output,
output_parallel,
atol=atol,
rtol=rtol,
msg="Output should match with parallel loading",
)

finally:
Expand All @@ -635,19 +671,17 @@ def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rt

@require_torch_multi_accelerator
@torch.no_grad()
def test_model_parallelism(self, tmp_path, atol=1e-5, rtol=0):
def test_model_parallelism(self, base_model_output, tmp_path, atol=1e-5, rtol=0):
if self.model_class._no_split_modules is None:
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")

torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()

model = model.to(torch_device)

torch.manual_seed(0)
base_output = model(**inputs_dict, return_dict=False)[0]

model_size = compute_module_sizes(model)[""]
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]

Expand All @@ -665,5 +699,5 @@ def test_model_parallelism(self, tmp_path, atol=1e-5, rtol=0):
new_output = new_model(**inputs_dict, return_dict=False)[0]

assert_tensors_close(
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match with model parallelism"
base_model_output, new_output, atol=atol, rtol=rtol, msg="Output should match with model parallelism"
)
38 changes: 14 additions & 24 deletions tests/models/testing_utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
require_accelerator,
torch_device,
)
from .common import cast_inputs_to_dtype, check_device_map_is_respected
from .common import BaseModelOutputMixin, cast_inputs_to_dtype, check_device_map_is_respected


def require_offload_support(func):
Expand Down Expand Up @@ -69,7 +69,7 @@ def wrapper(self, *args, **kwargs):


@is_cpu_offload
class CPUOffloadTesterMixin:
class CPUOffloadTesterMixin(BaseModelOutputMixin):
"""
Mixin class for testing CPU offloading functionality.

Expand All @@ -94,16 +94,14 @@ def model_split_percents(self) -> list[float]:

@require_offload_support
@torch.no_grad()
def test_cpu_offload(self, tmp_path, atol=1e-5, rtol=0):
def test_cpu_offload(self, base_model_output, tmp_path, atol=1e-5, rtol=0):
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()

model = model.to(torch_device)

torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
Expand All @@ -120,21 +118,19 @@ def test_cpu_offload(self, tmp_path, atol=1e-5, rtol=0):
new_output = new_model(**inputs_dict)

assert_tensors_close(
base_output[0], new_output[0], atol=atol, rtol=rtol, msg="Output should match with CPU offloading"
base_model_output, new_output[0], atol=atol, rtol=rtol, msg="Output should match with CPU offloading"
)

@require_offload_support
@torch.no_grad()
def test_disk_offload_without_safetensors(self, tmp_path, atol=1e-5, rtol=0):
def test_disk_offload_without_safetensors(self, base_model_output, tmp_path, atol=1e-5, rtol=0):
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()

model = model.to(torch_device)

torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_sizes(model)[""]
max_size = int(self.model_split_percents[0] * model_size)
# Force disk offload by setting very small CPU memory
Expand All @@ -154,21 +150,19 @@ def test_disk_offload_without_safetensors(self, tmp_path, atol=1e-5, rtol=0):
new_output = new_model(**inputs_dict)

assert_tensors_close(
base_output[0], new_output[0], atol=atol, rtol=rtol, msg="Output should match with disk offloading"
base_model_output, new_output[0], atol=atol, rtol=rtol, msg="Output should match with disk offloading"
)

@require_offload_support
@torch.no_grad()
def test_disk_offload_with_safetensors(self, tmp_path, atol=1e-5, rtol=0):
def test_disk_offload_with_safetensors(self, base_model_output, tmp_path, atol=1e-5, rtol=0):
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()

model = model.to(torch_device)

torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_sizes(model)[""]
model.cpu().save_pretrained(str(tmp_path))

Expand All @@ -183,7 +177,7 @@ def test_disk_offload_with_safetensors(self, tmp_path, atol=1e-5, rtol=0):
new_output = new_model(**inputs_dict)

assert_tensors_close(
base_output[0],
base_model_output,
new_output[0],
atol=atol,
rtol=rtol,
Expand All @@ -192,7 +186,7 @@ def test_disk_offload_with_safetensors(self, tmp_path, atol=1e-5, rtol=0):


@is_group_offload
class GroupOffloadTesterMixin:
class GroupOffloadTesterMixin(BaseModelOutputMixin):
"""
Mixin class for testing group offloading functionality.

Expand All @@ -209,10 +203,9 @@ class GroupOffloadTesterMixin:

@require_group_offload_support
@pytest.mark.parametrize("record_stream", [False, True])
def test_group_offloading(self, record_stream, atol=1e-5, rtol=0):
def test_group_offloading(self, base_model_output, record_stream, atol=1e-5, rtol=0):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
torch.manual_seed(0)

@torch.no_grad()
def run_forward(model):
Expand All @@ -224,10 +217,7 @@ def run_forward(model):
model.eval()
return model(**inputs_dict)[0]

model = self.model_class(**init_dict)

model.to(torch_device)
output_without_group_offloading = run_forward(model)
output_without_group_offloading = base_model_output

torch.manual_seed(0)
model = self.model_class(**init_dict)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:


class TestHunyuanDiT(HunyuanDiTTesterConfig, ModelTesterMixin):
def test_output(self):
def test_output(self, base_model_output):
batch_size = self.get_dummy_inputs()[self.main_input_name].shape[0]
super().test_output(expected_output_shape=(batch_size,) + self.output_shape)
super().test_output(base_model_output, expected_output_shape=(batch_size,) + self.output_shape)


class TestHunyuanDiTTraining(HunyuanDiTTesterConfig, TrainingTesterMixin):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@ def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:


class TestHunyuanVideoI2VTransformer(HunyuanVideoI2VTransformerTesterConfig, ModelTesterMixin):
def test_output(self):
super().test_output(expected_output_shape=(1, *self.output_shape))
def test_output(self, base_model_output):
super().test_output(base_model_output, expected_output_shape=(1, *self.output_shape))


# ======================== HunyuanVideo Token Replace Image-to-Video ========================
Expand Down Expand Up @@ -299,5 +299,5 @@ def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:


class TestHunyuanVideoTokenReplaceTransformer(HunyuanVideoTokenReplaceTransformerTesterConfig, ModelTesterMixin):
def test_output(self):
super().test_output(expected_output_shape=(1, *self.output_shape))
def test_output(self, base_model_output):
super().test_output(base_model_output, expected_output_shape=(1, *self.output_shape))
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,11 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
class TestWanAnimateTransformer3D(WanAnimateTransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Wan Animate Transformer 3D."""

def test_output(self):
def test_output(self, base_model_output):
# Override test_output because the transformer output is expected to have less channels
# than the main transformer input.
expected_output_shape = (1, 4, 21, 16, 16)
super().test_output(expected_output_shape=expected_output_shape)
super().test_output(base_model_output, expected_output_shape=expected_output_shape)

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
Expand Down
Loading