1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- import unittest
16+ import pytest
17+ import torch
1718
1819from diffusers import AutoencoderKLKVAEVideo
20+ from diffusers .utils .torch_utils import randn_tensor
1921
20- from ...testing_utils import enable_full_determinism , floats_tensor , torch_device
21- from ..test_modeling_common import ModelTesterMixin
22- from .testing_utils import AutoencoderTesterMixin
22+ from ...testing_utils import enable_full_determinism , torch_device
23+ from ..testing_utils import BaseModelTesterConfig , MemoryTesterMixin , ModelTesterMixin , TrainingTesterMixin
24+ from .testing_utils import NewAutoencoderTesterMixin
2325
2426
2527enable_full_determinism ()
2628
2729
28- class AutoencoderKLKVAEVideoTests (ModelTesterMixin , AutoencoderTesterMixin , unittest .TestCase ):
29- model_class = AutoencoderKLKVAEVideo
30- main_input_name = "sample"
31- base_precision = 1e-2
30+ def _run_nondeterministic (fn ):
31+ # reflection_pad3d_backward_out_cuda has no deterministic CUDA implementation;
32+ # temporarily relax the requirement for tests that do backward passes.
33+ torch .use_deterministic_algorithms (False )
34+ try :
35+ fn ()
36+ finally :
37+ torch .use_deterministic_algorithms (True )
3238
33- def get_autoencoder_kl_kvae_video_config (self ):
39+
40+ class AutoencoderKLKVAEVideoTesterConfig (BaseModelTesterConfig ):
41+ @property
42+ def model_class (self ):
43+ return AutoencoderKLKVAEVideo
44+
45+ @property
46+ def main_input_name (self ) -> str :
47+ return "sample"
48+
49+ @property
50+ def output_shape (self ) -> tuple :
51+ return (3 , 3 , 16 , 16 )
52+
53+ @property
54+ def generator (self ):
55+ return torch .Generator ("cpu" ).manual_seed (0 )
56+
57+ def get_init_dict (self ) -> dict :
3458 return {
3559 "ch" : 32 ,
3660 "ch_mult" : (1 , 2 ),
@@ -41,78 +65,53 @@ def get_autoencoder_kl_kvae_video_config(self):
4165 "temporal_compress_times" : 2 ,
4266 }
4367
44- @property
45- def dummy_input (self ):
68+ def get_dummy_inputs (self ) -> dict :
4669 batch_size = 2
4770 num_frames = 3 # satisfies (T-1) % temporal_compress_times == 0 with temporal_compress_times=2
4871 num_channels = 3
4972 sizes = (16 , 16 )
50-
51- video = floats_tensor (( batch_size , num_channels , num_frames ) + sizes ). to ( torch_device )
52-
73+ video = randn_tensor (
74+ ( batch_size , num_channels , num_frames , * sizes ), generator = self . generator , device = torch_device
75+ )
5376 return {"sample" : video }
5477
55- @property
56- def input_shape (self ):
57- return (3 , 3 , 16 , 16 )
58-
59- @property
60- def output_shape (self ):
61- return (3 , 3 , 16 , 16 )
62-
63- def prepare_init_args_and_inputs_for_common (self ):
64- init_dict = self .get_autoencoder_kl_kvae_video_config ()
65- inputs_dict = self .dummy_input
66- return init_dict , inputs_dict
67-
68- def test_gradient_checkpointing_is_applied (self ):
69- expected_set = {
70- "KVAECachedEncoder3D" ,
71- "KVAECachedDecoder3D" ,
72- }
73- super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
74-
75- @unittest .skip ("Unsupported test." )
76- def test_outputs_equivalence (self ):
77- pass
7878
79- @unittest .skip (
79+ class TestAutoencoderKLKVAEVideo (AutoencoderKLKVAEVideoTesterConfig , ModelTesterMixin ):
80+ @pytest .mark .skip (
8081 "Multi-GPU inference is not supported due to the stateful cache_dict passing through the forward pass."
8182 )
8283 def test_model_parallelism (self ):
83- pass
84+ super (). test_model_parallelism ()
8485
85- @unittest .skip (
86- "Multi-GPU inference is not supported due to the stateful cache_dict passing through the forward pass."
87- )
88- def test_sharded_checkpoints_device_map (self ):
89- pass
9086
91- def _run_nondeterministic (self , fn ):
92- # reflection_pad3d_backward_out_cuda has no deterministic CUDA implementation;
93- # temporarily relax the requirement for training tests that do backward passes.
94- import torch
87+ class TestAutoencoderKLKVAEVideoTraining (AutoencoderKLKVAEVideoTesterConfig , TrainingTesterMixin ):
88+ """Training tests for AutoencoderKLKVAEVideo."""
9589
96- torch .use_deterministic_algorithms (False )
97- try :
98- fn ()
99- finally :
100- torch .use_deterministic_algorithms (True )
90+ def test_gradient_checkpointing_is_applied (self ):
91+ expected_set = {"KVAECachedEncoder3D" , "KVAECachedDecoder3D" }
92+ super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
10193
10294 def test_training (self ):
103- self . _run_nondeterministic (super ().test_training )
95+ _run_nondeterministic (super ().test_training )
10496
105- def test_ema_training (self ):
106- self . _run_nondeterministic (super ().test_ema_training )
97+ def test_training_with_ema (self ):
98+ _run_nondeterministic (super ().test_training_with_ema )
10799
108- @unittest .skip (
100+ @pytest . mark .skip (
109101 "Gradient checkpointing recomputes the forward pass, but the model uses a stateful cache_dict "
110102 "that is mutated during the first forward. On recomputation the cache is already populated, "
111- "causing a different execution path and numerically different gradients. "
112- "GC still reduces peak memory usage; gradient correctness in the presence of GC is a known limitation."
103+ "causing a different execution path and numerically different gradients."
113104 )
114- def test_effective_gradient_checkpointing (self ):
115- pass
105+ def test_gradient_checkpointing_equivalence (self ):
106+ super (). test_gradient_checkpointing_equivalence ()
116107
117108 def test_layerwise_casting_training (self ):
118- self ._run_nondeterministic (super ().test_layerwise_casting_training )
109+ _run_nondeterministic (super ().test_layerwise_casting_training )
110+
111+
112+ class TestAutoencoderKLKVAEVideoMemory (AutoencoderKLKVAEVideoTesterConfig , MemoryTesterMixin ):
113+ """Memory optimization tests for AutoencoderKLKVAEVideo."""
114+
115+
116+ class TestAutoencoderKLKVAEVideoSlicingTiling (AutoencoderKLKVAEVideoTesterConfig , NewAutoencoderTesterMixin ):
117+ """Slicing and tiling tests for AutoencoderKLKVAEVideo."""
0 commit comments