Skip to content

Commit 4c77dcd

Browse files
refactor autoencoder tests (vq, kvae_video, oobleck, consistency_decoder, tiny, vidtok) (#13849)
* refactor vq tests * refactor autoencoder_kl_kvae_video tests * refactor autoencoder_oobleck tests * refactor consistency_decoder_vae tests * refactor autoencoder_tiny tests * refactor autoencoder_vidtok tests * remove unused base_precision and test_outputs_equivalence skips --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 334ef1a commit 4c77dcd

6 files changed

Lines changed: 335 additions & 365 deletions

tests/models/autoencoders/test_models_autoencoder_kl_kvae_video.py

Lines changed: 61 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,48 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
16+
import pytest
17+
import torch
1718

1819
from diffusers import AutoencoderKLKVAEVideo
20+
from diffusers.utils.torch_utils import randn_tensor
1921

20-
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
21-
from ..test_modeling_common import ModelTesterMixin
22-
from .testing_utils import AutoencoderTesterMixin
22+
from ...testing_utils import enable_full_determinism, torch_device
23+
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
24+
from .testing_utils import NewAutoencoderTesterMixin
2325

2426

2527
enable_full_determinism()
2628

2729

28-
class AutoencoderKLKVAEVideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
29-
model_class = AutoencoderKLKVAEVideo
30-
main_input_name = "sample"
31-
base_precision = 1e-2
30+
def _run_nondeterministic(fn):
31+
# reflection_pad3d_backward_out_cuda has no deterministic CUDA implementation;
32+
# temporarily relax the requirement for tests that do backward passes.
33+
torch.use_deterministic_algorithms(False)
34+
try:
35+
fn()
36+
finally:
37+
torch.use_deterministic_algorithms(True)
3238

33-
def get_autoencoder_kl_kvae_video_config(self):
39+
40+
class AutoencoderKLKVAEVideoTesterConfig(BaseModelTesterConfig):
41+
@property
42+
def model_class(self):
43+
return AutoencoderKLKVAEVideo
44+
45+
@property
46+
def main_input_name(self) -> str:
47+
return "sample"
48+
49+
@property
50+
def output_shape(self) -> tuple:
51+
return (3, 3, 16, 16)
52+
53+
@property
54+
def generator(self):
55+
return torch.Generator("cpu").manual_seed(0)
56+
57+
def get_init_dict(self) -> dict:
3458
return {
3559
"ch": 32,
3660
"ch_mult": (1, 2),
@@ -41,78 +65,53 @@ def get_autoencoder_kl_kvae_video_config(self):
4165
"temporal_compress_times": 2,
4266
}
4367

44-
@property
45-
def dummy_input(self):
68+
def get_dummy_inputs(self) -> dict:
4669
batch_size = 2
4770
num_frames = 3 # satisfies (T-1) % temporal_compress_times == 0 with temporal_compress_times=2
4871
num_channels = 3
4972
sizes = (16, 16)
50-
51-
video = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
52-
73+
video = randn_tensor(
74+
(batch_size, num_channels, num_frames, *sizes), generator=self.generator, device=torch_device
75+
)
5376
return {"sample": video}
5477

55-
@property
56-
def input_shape(self):
57-
return (3, 3, 16, 16)
58-
59-
@property
60-
def output_shape(self):
61-
return (3, 3, 16, 16)
62-
63-
def prepare_init_args_and_inputs_for_common(self):
64-
init_dict = self.get_autoencoder_kl_kvae_video_config()
65-
inputs_dict = self.dummy_input
66-
return init_dict, inputs_dict
67-
68-
def test_gradient_checkpointing_is_applied(self):
69-
expected_set = {
70-
"KVAECachedEncoder3D",
71-
"KVAECachedDecoder3D",
72-
}
73-
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
74-
75-
@unittest.skip("Unsupported test.")
76-
def test_outputs_equivalence(self):
77-
pass
7878

79-
@unittest.skip(
79+
class TestAutoencoderKLKVAEVideo(AutoencoderKLKVAEVideoTesterConfig, ModelTesterMixin):
80+
@pytest.mark.skip(
8081
"Multi-GPU inference is not supported due to the stateful cache_dict passing through the forward pass."
8182
)
8283
def test_model_parallelism(self):
83-
pass
84+
super().test_model_parallelism()
8485

85-
@unittest.skip(
86-
"Multi-GPU inference is not supported due to the stateful cache_dict passing through the forward pass."
87-
)
88-
def test_sharded_checkpoints_device_map(self):
89-
pass
9086

91-
def _run_nondeterministic(self, fn):
92-
# reflection_pad3d_backward_out_cuda has no deterministic CUDA implementation;
93-
# temporarily relax the requirement for training tests that do backward passes.
94-
import torch
87+
class TestAutoencoderKLKVAEVideoTraining(AutoencoderKLKVAEVideoTesterConfig, TrainingTesterMixin):
88+
"""Training tests for AutoencoderKLKVAEVideo."""
9589

96-
torch.use_deterministic_algorithms(False)
97-
try:
98-
fn()
99-
finally:
100-
torch.use_deterministic_algorithms(True)
90+
def test_gradient_checkpointing_is_applied(self):
91+
expected_set = {"KVAECachedEncoder3D", "KVAECachedDecoder3D"}
92+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
10193

10294
def test_training(self):
103-
self._run_nondeterministic(super().test_training)
95+
_run_nondeterministic(super().test_training)
10496

105-
def test_ema_training(self):
106-
self._run_nondeterministic(super().test_ema_training)
97+
def test_training_with_ema(self):
98+
_run_nondeterministic(super().test_training_with_ema)
10799

108-
@unittest.skip(
100+
@pytest.mark.skip(
109101
"Gradient checkpointing recomputes the forward pass, but the model uses a stateful cache_dict "
110102
"that is mutated during the first forward. On recomputation the cache is already populated, "
111-
"causing a different execution path and numerically different gradients. "
112-
"GC still reduces peak memory usage; gradient correctness in the presence of GC is a known limitation."
103+
"causing a different execution path and numerically different gradients."
113104
)
114-
def test_effective_gradient_checkpointing(self):
115-
pass
105+
def test_gradient_checkpointing_equivalence(self):
106+
super().test_gradient_checkpointing_equivalence()
116107

117108
def test_layerwise_casting_training(self):
118-
self._run_nondeterministic(super().test_layerwise_casting_training)
109+
_run_nondeterministic(super().test_layerwise_casting_training)
110+
111+
112+
class TestAutoencoderKLKVAEVideoMemory(AutoencoderKLKVAEVideoTesterConfig, MemoryTesterMixin):
113+
"""Memory optimization tests for AutoencoderKLKVAEVideo."""
114+
115+
116+
class TestAutoencoderKLKVAEVideoSlicingTiling(AutoencoderKLKVAEVideoTesterConfig, NewAutoencoderTesterMixin):
117+
"""Slicing and tiling tests for AutoencoderKLKVAEVideo."""

0 commit comments

Comments
 (0)