Skip to content

Commit e54fc90

Browse files
committed
fix cuda tests for models.
1 parent 178c4cb commit e54fc90

8 files changed

Lines changed: 47 additions & 11 deletions

File tree

src/diffusers/models/downsampling.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,15 +227,15 @@ def _downsample_2d(
227227
stride_value = [factor, factor]
228228
upfirdn_input = upfirdn2d_native(
229229
hidden_states,
230-
torch.tensor(kernel, device=hidden_states.device),
230+
kernel.to(device=hidden_states.device, dtype=hidden_states.dtype),
231231
pad=((pad_value + 1) // 2, pad_value // 2),
232232
)
233233
output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
234234
else:
235235
pad_value = kernel.shape[0] - factor
236236
output = upfirdn2d_native(
237237
hidden_states,
238-
torch.tensor(kernel, device=hidden_states.device),
238+
kernel.to(device=hidden_states.device, dtype=hidden_states.dtype),
239239
down=factor,
240240
pad=((pad_value + 1) // 2, pad_value // 2),
241241
)
@@ -392,7 +392,7 @@ def downsample_2d(
392392
pad_value = kernel.shape[0] - factor
393393
output = upfirdn2d_native(
394394
hidden_states,
395-
kernel.to(device=hidden_states.device),
395+
kernel.to(device=hidden_states.device, dtype=hidden_states.dtype),
396396
down=factor,
397397
pad=((pad_value + 1) // 2, pad_value // 2),
398398
)

src/diffusers/models/upsampling.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -300,14 +300,14 @@ def _upsample_2d(
300300

301301
output = upfirdn2d_native(
302302
inverse_conv,
303-
torch.tensor(kernel, device=inverse_conv.device),
303+
kernel.to(device=inverse_conv.device, dtype=inverse_conv.dtype),
304304
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
305305
)
306306
else:
307307
pad_value = kernel.shape[0] - factor
308308
output = upfirdn2d_native(
309309
hidden_states,
310-
torch.tensor(kernel, device=hidden_states.device),
310+
kernel.to(device=hidden_states.device, dtype=hidden_states.dtype),
311311
up=factor,
312312
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
313313
)
@@ -508,7 +508,7 @@ def upsample_2d(
508508
pad_value = kernel.shape[0] - factor
509509
output = upfirdn2d_native(
510510
hidden_states,
511-
kernel.to(device=hidden_states.device),
511+
kernel.to(device=hidden_states.device, dtype=hidden_states.dtype),
512512
up=factor,
513513
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
514514
)

tests/models/autoencoders/test_models_autoencoder_tiny.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,12 @@ def get_dummy_inputs(self) -> dict:
7676

7777

7878
class TestAutoencoderTiny(AutoencoderTinyTesterConfig, ModelTesterMixin):
79-
pass
79+
@pytest.mark.skip(
80+
"`forward` round-trips the latents through a uint8 byte tensor (`.byte()` / `/ 255.0`), which upcasts to "
81+
"float32 regardless of the model dtype, so full fp16/bf16 forward inference is not possible."
82+
)
83+
def test_from_save_pretrained_dtype_inference(self):
84+
pass
8085

8186

8287
class TestAutoencoderTinyTraining(AutoencoderTinyTesterConfig, TrainingTesterMixin):

tests/models/autoencoders/test_models_consistency_decoder_vae.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import gc
1717

1818
import numpy as np
19+
import pytest
1920
import torch
2021

2122
from diffusers import ConsistencyDecoderVAE, StableDiffusionPipeline
@@ -86,7 +87,13 @@ def get_dummy_inputs(self) -> dict:
8687

8788

8889
class TestConsistencyDecoderVAE(ConsistencyDecoderVAETesterConfig, ModelTesterMixin):
89-
pass
90+
@pytest.mark.skip(
91+
"`forward` decodes through an iterative, RNG-driven consistency-decoding loop whose output is not "
92+
"reproducible across two model instances and amplifies fp16/bf16 nondeterminism, so a low-precision "
93+
"output-equivalence check is not meaningful."
94+
)
95+
def test_from_save_pretrained_dtype_inference(self):
96+
pass
9097

9198

9299
class TestConsistencyDecoderVAETraining(ConsistencyDecoderVAETesterConfig, TrainingTesterMixin):

tests/models/controlnets/test_models_controlnet_cosmos.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,10 @@ def test_training(self):
283283
def test_training_with_ema(self):
284284
super().test_training_with_ema()
285285

286+
@pytest.mark.skip("ControlNet outputs list of control blocks, not single tensor for MSE loss.")
287+
def test_mixed_precision_training(self):
288+
super().test_mixed_precision_training()
289+
286290
@pytest.mark.skip("ControlNet output doesn't have .sample attribute.")
287291
def test_gradient_checkpointing_equivalence(self):
288292
super().test_gradient_checkpointing_equivalence()

tests/models/testing_utils/common.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,9 @@ def cast_inputs_to_dtype(inputs, current_dtype, target_dtype):
135135
return inputs.to(target_dtype) if inputs.dtype == current_dtype else inputs
136136
if isinstance(inputs, dict):
137137
return {k: cast_inputs_to_dtype(v, current_dtype, target_dtype) for k, v in inputs.items()}
138-
if isinstance(inputs, list):
139-
return [cast_inputs_to_dtype(v, current_dtype, target_dtype) for v in inputs]
138+
if isinstance(inputs, (list, tuple)):
139+
# Preserve the container type so models that branch on it (e.g. `isinstance(..., tuple)`) still see a tuple.
140+
return type(inputs)(cast_inputs_to_dtype(v, current_dtype, target_dtype) for v in inputs)
140141

141142
return inputs
142143

@@ -479,7 +480,11 @@ def test_keep_in_fp32_modules(self, tmp_path):
479480
)
480481
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
481482
@torch.no_grad()
482-
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype, atol=1e-4, rtol=0):
483+
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
484+
# Low-precision inference is inherently lossy, and models that keep some modules in fp32 diverge further from
485+
# the fully-cast reference. Tolerances reflect the dtype's precision rather than a tight fp32-style threshold.
486+
atol = 3e-2 if dtype == torch.bfloat16 else 1e-2
487+
rtol = 0
483488
model = self.model_class(**self.get_init_dict())
484489
model.to(torch_device)
485490
fp32_modules = model._keep_in_fp32_modules or []

tests/models/transformers/test_models_transformer_z_image.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,10 @@ def test_training(self):
250250
def test_training_with_ema(self):
251251
pass
252252

253+
@pytest.mark.skip("Model output `sample` is a list of tensors; mixed-precision training computes MSE loss on it.")
254+
def test_mixed_precision_training(self):
255+
pass
256+
253257
@pytest.mark.skip("Test is not supported for handling main inputs that are lists.")
254258
def test_gradient_checkpointing_equivalence(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip=None):
255259
pass

tests/testing_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,17 @@ def assert_tensors_close(
165165
if not is_torch_available():
166166
raise ValueError("PyTorch needs to be installed to use this function.")
167167

168+
# Some models (e.g. Z-Image, Cosmos ControlNet) return a list/tuple of tensors as their output. Compare these
169+
# element-wise so the same helper works regardless of whether the output is a single tensor or a sequence.
170+
if isinstance(actual, (list, tuple)) or isinstance(expected, (list, tuple)):
171+
if not (isinstance(actual, (list, tuple)) and isinstance(expected, (list, tuple))):
172+
raise AssertionError(f"{msg} Type mismatch: actual {type(actual)} vs expected {type(expected)}")
173+
if len(actual) != len(expected):
174+
raise AssertionError(f"{msg} Length mismatch: actual {len(actual)} vs expected {len(expected)}")
175+
for i, (a, e) in enumerate(zip(actual, expected)):
176+
assert_tensors_close(a, e, atol=atol, rtol=rtol, msg=f"{msg} [element {i}]")
177+
return
178+
168179
if actual.shape != expected.shape:
169180
raise AssertionError(f"{msg} Shape mismatch: actual {actual.shape} vs expected {expected.shape}")
170181

0 commit comments

Comments
 (0)