diff --git a/tests/lora/test_lora_layers_auraflow.py b/tests/lora/test_lora_layers_auraflow.py index 91f63c4b56c4..e3bbbfb632d9 100644 --- a/tests/lora/test_lora_layers_auraflow.py +++ b/tests/lora/test_lora_layers_auraflow.py @@ -13,16 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys -import unittest +import pytest import torch from transformers import AutoTokenizer, UMT5EncoderModel -from diffusers import ( - AuraFlowPipeline, - AuraFlowTransformer2DModel, - FlowMatchEulerDiscreteScheduler, -) +from diffusers import AuraFlowPipeline, AuraFlowTransformer2DModel, FlowMatchEulerDiscreteScheduler from ..testing_utils import ( floats_tensor, @@ -40,7 +36,7 @@ @require_peft_backend -class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class TestAuraFlowLoRA(PeftLoraLoaderMixinTests): pipeline_class = AuraFlowPipeline scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {} @@ -103,34 +99,34 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - @unittest.skip("Not supported in AuraFlow.") + @pytest.mark.skip("Not supported in AuraFlow.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in AuraFlow.") + @pytest.mark.skip("Not supported in AuraFlow.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Not supported in AuraFlow.") + @pytest.mark.skip("Not supported in AuraFlow.") def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") + @pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.") def test_simple_inference_with_partial_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") + @pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.") def test_simple_inference_with_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") + @pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.") def test_simple_inference_with_text_lora_and_scale(self): pass - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") + @pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.") def test_simple_inference_with_text_lora_fused(self): pass - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") + @pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.") def test_simple_inference_with_text_lora_save_load(self): pass diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py index fa57b4c9c2f9..ad2943b6816a 100644 --- a/tests/lora/test_lora_layers_cogvideox.py +++ b/tests/lora/test_lora_layers_cogvideox.py @@ -13,10 +13,9 @@ # limitations under the License. import sys -import unittest +import pytest import torch -from parameterized import parameterized from transformers import AutoTokenizer, T5EncoderModel from diffusers import ( @@ -39,7 +38,7 @@ @require_peft_backend -class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class TestCogVideoXLoRA(PeftLoraLoaderMixinTests): pipeline_class = CogVideoXPipeline scheduler_cls = CogVideoXDPMScheduler scheduler_kwargs = {"timestep_spacing": "trailing"} @@ -119,54 +118,59 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - def test_simple_inference_with_text_lora_denoiser_fused_multi(self): - super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) + def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe): + super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3, pipe=pipe) - def test_simple_inference_with_text_denoiser_lora_unfused(self): - super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) + def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe): + super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3) - def test_lora_scale_kwargs_match_fusion(self): - super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3) + def test_lora_scale_kwargs_match_fusion(self, base_pipe_output): + super().test_lora_scale_kwargs_match_fusion( + base_pipe_output=base_pipe_output, expected_atol=9e-3, expected_rtol=9e-3 + ) - @parameterized.expand([("block_level", True), ("leaf_level", False)]) + @pytest.mark.parametrize( + "offload_type, use_stream", + [("block_level", True), ("leaf_level", False)], + ) @require_torch_accelerator - def test_group_offloading_inference_denoiser(self, offload_type, use_stream): + def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname, pipe): # TODO: We don't run the (leaf_level, True) test here that is enabled for other models. # The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338 - super()._test_group_offloading_inference_denoiser(offload_type, use_stream) + super()._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname, pipe) - @unittest.skip("Not supported in CogVideoX.") + @pytest.mark.skip("Not supported in CogVideoX.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in CogVideoX.") + @pytest.mark.skip("Not supported in CogVideoX.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Not supported in CogVideoX.") + @pytest.mark.skip("Not supported in CogVideoX.") def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") + @pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.") def test_simple_inference_with_partial_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") + @pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.") def test_simple_inference_with_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") + @pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.") def test_simple_inference_with_text_lora_and_scale(self): pass - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") + @pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.") def test_simple_inference_with_text_lora_fused(self): pass - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") + @pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.") def test_simple_inference_with_text_lora_save_load(self): pass - @unittest.skip("Not supported in CogVideoX.") + @pytest.mark.skip("Not supported in CogVideoX.") def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): pass diff --git a/tests/lora/test_lora_layers_cogview4.py b/tests/lora/test_lora_layers_cogview4.py index 30eb8fbb6367..d3902730678d 100644 --- a/tests/lora/test_lora_layers_cogview4.py +++ b/tests/lora/test_lora_layers_cogview4.py @@ -13,12 +13,9 @@ # limitations under the License. import sys -import tempfile -import unittest -import numpy as np +import pytest import torch -from parameterized import parameterized from transformers import AutoTokenizer, GlmModel from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler @@ -28,7 +25,6 @@ require_peft_backend, require_torch_accelerator, skip_mps, - torch_device, ) @@ -47,7 +43,7 @@ def from_pretrained(*args, **kwargs): @require_peft_backend @skip_mps -class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class TestCogView4LoRA(PeftLoraLoaderMixinTests): pipeline_class = CogView4Pipeline scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {} @@ -113,72 +109,50 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - def test_simple_inference_with_text_lora_denoiser_fused_multi(self): - super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) + def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe): + super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3) - def test_simple_inference_with_text_denoiser_lora_unfused(self): - super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) + def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe): + super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3) - def test_simple_inference_save_pretrained(self): - """ - Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained - """ - components, _, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - - with tempfile.TemporaryDirectory() as tmpdirname: - pipe.save_pretrained(tmpdirname) - - pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname) - pipe_from_pretrained.to(torch_device) - - images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results.", - ) - - @parameterized.expand([("block_level", True), ("leaf_level", False)]) + @pytest.mark.parametrize( + "offload_type, use_stream", + [("block_level", True), ("leaf_level", False)], + ) @require_torch_accelerator - def test_group_offloading_inference_denoiser(self, offload_type, use_stream): + def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname, pipe): # TODO: We don't run the (leaf_level, True) test here that is enabled for other models. # The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338 - super()._test_group_offloading_inference_denoiser(offload_type, use_stream) + super()._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname, pipe) - @unittest.skip("Not supported in CogView4.") + @pytest.mark.skip("Not supported in CogView4.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in CogView4.") + @pytest.mark.skip("Not supported in CogView4.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Not supported in CogView4.") + @pytest.mark.skip("Not supported in CogView4.") def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in CogView4.") + @pytest.mark.skip("Text encoder LoRA is not supported in CogView4.") def test_simple_inference_with_partial_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in CogView4.") + @pytest.mark.skip("Text encoder LoRA is not supported in CogView4.") def test_simple_inference_with_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in CogView4.") + @pytest.mark.skip("Text encoder LoRA is not supported in CogView4.") def test_simple_inference_with_text_lora_and_scale(self): pass - @unittest.skip("Text encoder LoRA is not supported in CogView4.") + @pytest.mark.skip("Text encoder LoRA is not supported in CogView4.") def test_simple_inference_with_text_lora_fused(self): pass - @unittest.skip("Text encoder LoRA is not supported in CogView4.") + @pytest.mark.skip("Text encoder LoRA is not supported in CogView4.") def test_simple_inference_with_text_lora_save_load(self): pass diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index b840d7ac72ce..3defa9ea9678 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -16,13 +16,11 @@ import gc import os import sys -import tempfile -import unittest import numpy as np +import pytest import safetensors.torch import torch -from parameterized import parameterized from PIL import Image from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel @@ -46,14 +44,12 @@ if is_peft_available(): from peft.utils import get_peft_model_state_dict - sys.path.append(".") - -from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402 +from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set @require_peft_backend -class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class TestFluxLoRA(PeftLoraLoaderMixinTests): pipeline_class = FluxPipeline scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {} @@ -115,165 +111,134 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - def test_with_alpha_in_state_dict(self): - components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + def test_with_alpha_in_state_dict(self, tmpdirname, pipe): + _, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe.transformer.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - with tempfile.TemporaryDirectory() as tmpdirname: - denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) - self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) + denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) + self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - # modify the state dict to have alpha values following - # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors - state_dict_with_alpha = safetensors.torch.load_file( - os.path.join(tmpdirname, "pytorch_lora_weights.safetensors") - ) - alpha_dict = {} - for k, v in state_dict_with_alpha.items(): - # only do for `transformer` and for the k projections -- should be enough to test. - if "transformer" in k and "to_k" in k and "lora_A" in k: - alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=())) - state_dict_with_alpha.update(alpha_dict) + # modify the state dict to have alpha values following + # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors + state_dict_with_alpha = safetensors.torch.load_file( + os.path.join(tmpdirname, "pytorch_lora_weights.safetensors") + ) + alpha_dict = {} + for k, v in state_dict_with_alpha.items(): + if "transformer" in k and "to_k" in k and ("lora_A" in k): + alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=())) + state_dict_with_alpha.update(alpha_dict) images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" pipe.unload_lora_weights() pipe.load_lora_weights(state_dict_with_alpha) images_lora_with_alpha = pipe(**inputs, generator=torch.manual_seed(0)).images - - self.assertTrue( - np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results.", + assert np.allclose(images_lora, images_lora_from_pretrained, atol=0.001, rtol=0.001), ( + "Loading from saved checkpoints should give same results." ) - self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3)) + assert not np.allclose(images_lora_with_alpha, images_lora, atol=0.001, rtol=0.001) - def test_lora_expansion_works_for_absent_keys(self): - components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + def test_lora_expansion_works_for_absent_keys(self, base_pipe_output, tmpdirname, pipe): + _, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.get_base_pipe_output() - # Modify the config to have a layer which won't be present in the second LoRA we will load. modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config) modified_denoiser_lora_config.target_modules.add("x_embedder") pipe.transformer.add_adapter(modified_denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - self.assertFalse( - np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3), - "LoRA should lead to different results.", + assert not np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001), ( + "LoRA should lead to different results." ) + denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) + self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) - with tempfile.TemporaryDirectory() as tmpdirname: - denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) - self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) - - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one") + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - # Modify the state dict to exclude "x_embedder" related LoRA params. - lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k} + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one") + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k} pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two") pipe.set_adapters(["one", "two"]) - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") - images_lora_with_absent_keys = pipe(**inputs, generator=torch.manual_seed(0)).images + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" - self.assertFalse( - np.allclose(images_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3), - "Different LoRAs should lead to different results.", + images_lora_with_absent_keys = pipe(**inputs, generator=torch.manual_seed(0)).images + assert not np.allclose(images_lora, images_lora_with_absent_keys, atol=0.001, rtol=0.001), ( + "Different LoRAs should lead to different results." ) - self.assertFalse( - np.allclose(output_no_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3), - "LoRA should lead to different results.", + assert not np.allclose(base_pipe_output, images_lora_with_absent_keys, atol=0.001, rtol=0.001), ( + "LoRA should lead to different results." ) - def test_lora_expansion_works_for_extra_keys(self): - components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + def test_lora_expansion_works_for_extra_keys(self, base_pipe_output, tmpdirname, pipe): + _, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.get_base_pipe_output() - # Modify the config to have a layer which won't be present in the first LoRA we will load. modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config) modified_denoiser_lora_config.target_modules.add("x_embedder") - pipe.transformer.add_adapter(modified_denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - self.assertFalse( - np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3), - "LoRA should lead to different results.", + assert not np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001), ( + "LoRA should lead to different results." ) - with tempfile.TemporaryDirectory() as tmpdirname: - denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) - self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) + denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) + self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - pipe.unload_lora_weights() - # Modify the state dict to exclude "x_embedder" related LoRA params. - lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k} - pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one") - - # Load state dict with `x_embedder`. - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two") + pipe.unload_lora_weights() + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k} + pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one") + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two") pipe.set_adapters(["one", "two"]) - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") - images_lora_with_extra_keys = pipe(**inputs, generator=torch.manual_seed(0)).images + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer" - self.assertFalse( - np.allclose(images_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3), - "Different LoRAs should lead to different results.", + images_lora_with_extra_keys = pipe(**inputs, generator=torch.manual_seed(0)).images + assert not np.allclose(images_lora, images_lora_with_extra_keys, atol=0.001, rtol=0.001), ( + "Different LoRAs should lead to different results." ) - self.assertFalse( - np.allclose(output_no_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3), - "LoRA should lead to different results.", + assert not np.allclose(base_pipe_output, images_lora_with_extra_keys, atol=0.001, rtol=0.001), ( + "LoRA should lead to different results." ) - @unittest.skip("Not supported in Flux.") + @pytest.mark.skip("Not supported in Flux.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in Flux.") + @pytest.mark.skip("Not supported in Flux.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Not supported in Flux.") + @pytest.mark.skip("Not supported in Flux.") def test_modify_padding_mode(self): pass - @unittest.skip("Not supported in Flux.") + @pytest.mark.skip("Not supported in Flux.") def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): pass -class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class TestFluxControlLoRA(PeftLoraLoaderMixinTests): pipeline_class = FluxControlPipeline scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {} @@ -338,12 +303,7 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - def test_with_norm_in_state_dict(self): - components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - + def test_with_norm_in_state_dict(self, pipe): _, _, inputs = self.get_dummy_inputs(with_generator=False) logger = logging.get_logger("diffusers.loaders.lora_pipeline") @@ -364,39 +324,32 @@ def test_with_norm_in_state_dict(self): pipe.load_lora_weights(norm_state_dict) lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( + assert ( "The provided state dict contains normalization layers in addition to LoRA layers" in cap_logger.out ) - self.assertTrue(len(pipe.transformer._transformer_norm_layers) > 0) + assert len(pipe.transformer._transformer_norm_layers) > 0 pipe.unload_lora_weights() lora_unload_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(pipe.transformer._transformer_norm_layers is None) - self.assertTrue(np.allclose(original_output, lora_unload_output, atol=1e-5, rtol=1e-5)) - self.assertFalse( - np.allclose(original_output, lora_load_output, atol=1e-6, rtol=1e-6), f"{norm_layer} is tested" + assert pipe.transformer._transformer_norm_layers is None + assert np.allclose(original_output, lora_unload_output, atol=1e-05, rtol=1e-05) + assert not np.allclose(original_output, lora_load_output, atol=1e-06, rtol=1e-06), ( + f"{norm_layer} is tested" ) with CaptureLogger(logger) as cap_logger: for key in list(norm_state_dict.keys()): norm_state_dict[key.replace("norm", "norm_k_something_random")] = norm_state_dict.pop(key) pipe.load_lora_weights(norm_state_dict) + assert "Unsupported keys found in state dict when trying to load normalization layers" in cap_logger.out - self.assertTrue( - "Unsupported keys found in state dict when trying to load normalization layers" in cap_logger.out - ) - - def test_lora_parameter_expanded_shapes(self): + def test_lora_parameter_expanded_shapes(self, pipe): components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] + original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(logging.DEBUG) @@ -405,24 +358,21 @@ def test_lora_parameter_expanded_shapes(self): transformer = FluxTransformer2DModel.from_config( components["transformer"].config, in_channels=num_channels_without_control ).to(torch_device) - self.assertTrue( - transformer.config.in_channels == num_channels_without_control, - f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}", + assert transformer.config.in_channels == num_channels_without_control, ( + f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}" ) original_transformer_state_dict = pipe.transformer.state_dict() x_embedder_weight = original_transformer_state_dict.pop("x_embedder.weight") incompatible_keys = transformer.load_state_dict(original_transformer_state_dict, strict=False) - self.assertTrue( - "x_embedder.weight" in incompatible_keys.missing_keys, - "Could not find x_embedder.weight in the missing keys.", + assert "x_embedder.weight" in incompatible_keys.missing_keys, ( + "Could not find x_embedder.weight in the missing keys." ) + transformer.x_embedder.weight.data.copy_(x_embedder_weight[..., :num_channels_without_control]) pipe.transformer = transformer - out_features, in_features = pipe.transformer.x_embedder.weight.shape rank = 4 - dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) lora_state_dict = { @@ -431,15 +381,13 @@ def test_lora_parameter_expanded_shapes(self): } with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(lora_state_dict, "adapter-1") - - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4)) - self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features) - self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) - self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) + assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001) + assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features + assert pipe.transformer.config.in_channels == 2 * in_features + assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module") # Testing opposite direction where the LoRA params are zero-padded. components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) @@ -454,15 +402,13 @@ def test_lora_parameter_expanded_shapes(self): } with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(lora_state_dict, "adapter-1") - - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4)) - self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features) - self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) - self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out) + assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001) + assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features + assert pipe.transformer.config.in_channels == 2 * in_features + assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out def test_normal_lora_with_expanded_lora_raises_error(self): # Test the following situation. Load a regular LoRA (such as the ones trained on Flux.1-Dev). And then @@ -494,32 +440,28 @@ def test_normal_lora_with_expanded_lora_raises_error(self): } with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(lora_state_dict, "adapter-1") - - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") - self.assertTrue(pipe.get_active_adapters() == ["adapter-1"]) - self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features) - self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) - self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) - - _, _, inputs = self.get_dummy_inputs(with_generator=False) + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + assert pipe.get_active_adapters() == ["adapter-1"] + assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features + assert pipe.transformer.config.in_channels == 2 * in_features + assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module") + (_, _, inputs) = self.get_dummy_inputs(with_generator=False) lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - normal_lora_A = torch.nn.Linear(in_features, rank, bias=False) normal_lora_B = torch.nn.Linear(rank, out_features, bias=False) lora_state_dict = { "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, } - with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(lora_state_dict, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") - self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out) - self.assertTrue(pipe.get_active_adapters() == ["adapter-2"]) + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out + assert pipe.get_active_adapters() == ["adapter-2"] lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3)) + assert not np.allclose(lora_output, lora_output_2, atol=0.001, rtol=0.001) # Test the opposite case where the first lora has the correct input features and the second lora has expanded input features. # This should raise a runtime error on input shapes being incompatible. @@ -540,32 +482,24 @@ def test_normal_lora_with_expanded_lora_raises_error(self): out_features, in_features = pipe.transformer.x_embedder.weight.shape rank = 4 - lora_state_dict = { "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, } pipe.load_lora_weights(lora_state_dict, "adapter-1") - - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") - self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features) - self.assertTrue(pipe.transformer.config.in_channels == in_features) + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" + assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features + assert pipe.transformer.config.in_channels == in_features lora_state_dict = { "transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight, "transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight, } - # We should check for input shapes being incompatible here. But because above mentioned issue is # not a supported use case, and because of the PEFT renaming, we will currently have a shape # mismatch error. - self.assertRaisesRegex( - RuntimeError, - "size mismatch for x_embedder.lora_A.adapter-2.weight", - pipe.load_lora_weights, - lora_state_dict, - "adapter-2", - ) + with pytest.raises(RuntimeError, match="size mismatch for x_embedder.lora_A.adapter-2.weight"): + pipe.load_lora_weights(lora_state_dict, "adapter-2") def test_fuse_expanded_lora_with_regular_lora(self): # This test checks if it works when a lora with expanded shapes (like control loras) but @@ -597,7 +531,7 @@ def test_fuse_expanded_lora_with_regular_lora(self): "transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight, } pipe.load_lora_weights(lora_state_dict, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" _, _, inputs = self.get_dummy_inputs(with_generator=False) lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -610,54 +544,44 @@ def test_fuse_expanded_lora_with_regular_lora(self): } pipe.load_lora_weights(lora_state_dict, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] - pipe.set_adapters(["adapter-1", "adapter-2"], [1.0, 1.0]) lora_output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3)) - self.assertFalse(np.allclose(lora_output, lora_output_3, atol=1e-3, rtol=1e-3)) - self.assertFalse(np.allclose(lora_output_2, lora_output_3, atol=1e-3, rtol=1e-3)) + assert not np.allclose(lora_output, lora_output_2, atol=0.001, rtol=0.001) + assert not np.allclose(lora_output, lora_output_3, atol=0.001, rtol=0.001) + assert not np.allclose(lora_output_2, lora_output_3, atol=0.001, rtol=0.001) pipe.fuse_lora(lora_scale=1.0, adapter_names=["adapter-1", "adapter-2"]) lora_output_4 = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(np.allclose(lora_output_3, lora_output_4, atol=1e-3, rtol=1e-3)) + assert np.allclose(lora_output_3, lora_output_4, atol=0.001, rtol=0.001) - def test_load_regular_lora(self): + def test_load_regular_lora(self, base_pipe_output, pipe): # This test checks if a regular lora (think of one trained on Flux.1 Dev for example) can be loaded # into the transformer with more input channels than Flux.1 Dev, for example. Some examples of those # transformers include Flux Fill, Flux Control, etc. - components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - out_features, in_features = pipe.transformer.x_embedder.weight.shape rank = 4 - in_features = in_features // 2 # to mimic the Flux.1-Dev LoRA. + in_features = in_features // 2 normal_lora_A = torch.nn.Linear(in_features, rank, bias=False) normal_lora_B = torch.nn.Linear(rank, out_features, bias=False) lora_state_dict = { "transformer.x_embedder.lora_A.weight": normal_lora_A.weight, "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, } - logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(logging.INFO) with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(lora_state_dict, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out) - self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2) - self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3)) + assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out + assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2 + assert not np.allclose(base_pipe_output, lora_output, atol=0.001, rtol=0.001) def test_lora_unload_with_parameter_expanded_shapes(self): components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) @@ -670,9 +594,8 @@ def test_lora_unload_with_parameter_expanded_shapes(self): transformer = FluxTransformer2DModel.from_config( components["transformer"].config, in_channels=num_channels_without_control ).to(torch_device) - self.assertTrue( - transformer.config.in_channels == num_channels_without_control, - f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}", + assert transformer.config.in_channels == num_channels_without_control, ( + f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}" ) # This should be initialized with a Flux pipeline variant that doesn't accept `control_image`. @@ -697,33 +620,31 @@ def test_lora_unload_with_parameter_expanded_shapes(self): } with CaptureLogger(logger) as cap_logger: control_pipe.load_lora_weights(lora_state_dict, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" inputs["control_image"] = control_image lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4)) - self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features) - self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) - self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) + assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001) + assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features + assert pipe.transformer.config.in_channels == 2 * in_features + assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module") control_pipe.unload_lora_weights(reset_to_overwritten_params=True) - self.assertTrue( - control_pipe.transformer.config.in_channels == num_channels_without_control, - f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}", + assert control_pipe.transformer.config.in_channels == num_channels_without_control, ( + f"Expected {num_channels_without_control} channels in the modified transformer but has control_pipe.transformer.config.in_channels={control_pipe.transformer.config.in_channels!r}" ) + loaded_pipe = FluxPipeline.from_pipe(control_pipe) - self.assertTrue( - loaded_pipe.transformer.config.in_channels == num_channels_without_control, - f"Expected {num_channels_without_control} channels in the modified transformer but has {loaded_pipe.transformer.config.in_channels=}", + assert loaded_pipe.transformer.config.in_channels == num_channels_without_control, ( + f"Expected {num_channels_without_control} channels in the modified transformer but has loaded_pipe.transformer.config.in_channels={loaded_pipe.transformer.config.in_channels!r}" ) + inputs.pop("control_image") unloaded_lora_out = loaded_pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertFalse(np.allclose(unloaded_lora_out, lora_out, rtol=1e-4, atol=1e-4)) - self.assertTrue(np.allclose(unloaded_lora_out, original_out, atol=1e-4, rtol=1e-4)) - self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features) - self.assertTrue(pipe.transformer.config.in_channels == in_features) + assert not np.allclose(unloaded_lora_out, lora_out, rtol=0.0001, atol=0.0001) + assert np.allclose(unloaded_lora_out, original_out, atol=0.0001, rtol=0.0001) + assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features + assert pipe.transformer.config.in_channels == in_features def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self): components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) @@ -731,14 +652,12 @@ def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self): logger = logging.get_logger("diffusers.loaders.lora_pipeline") logger.setLevel(logging.DEBUG) - # Change the transformer config to mimic a real use case. num_channels_without_control = 4 transformer = FluxTransformer2DModel.from_config( components["transformer"].config, in_channels=num_channels_without_control ).to(torch_device) - self.assertTrue( - transformer.config.in_channels == num_channels_without_control, - f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}", + assert transformer.config.in_channels == num_channels_without_control, ( + f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}" ) # This should be initialized with a Flux pipeline variant that doesn't accept `control_image`. @@ -763,40 +682,38 @@ def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self): } with CaptureLogger(logger) as cap_logger: control_pipe.load_lora_weights(lora_state_dict, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser" inputs["control_image"] = control_image lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4)) - self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features) - self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) - self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) + assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001) + assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features + assert pipe.transformer.config.in_channels == 2 * in_features + assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module") control_pipe.unload_lora_weights(reset_to_overwritten_params=False) - self.assertTrue( - control_pipe.transformer.config.in_channels == 2 * num_channels_without_control, - f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}", + assert control_pipe.transformer.config.in_channels == 2 * num_channels_without_control, ( + f"Expected {num_channels_without_control} channels in the modified transformer but has control_pipe.transformer.config.in_channels={control_pipe.transformer.config.in_channels!r}" ) - no_lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertFalse(np.allclose(no_lora_out, lora_out, rtol=1e-4, atol=1e-4)) - self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2) - self.assertTrue(pipe.transformer.config.in_channels == in_features * 2) + no_lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0] + assert not np.allclose(no_lora_out, lora_out, rtol=0.0001, atol=0.0001) + assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2 + assert pipe.transformer.config.in_channels == in_features * 2 - @unittest.skip("Not supported in Flux.") + @pytest.mark.skip("Not supported in Flux.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in Flux.") + @pytest.mark.skip("Not supported in Flux.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Not supported in Flux.") + @pytest.mark.skip("Not supported in Flux.") def test_modify_padding_mode(self): pass - @unittest.skip("Not supported in Flux.") + @pytest.mark.skip("Not supported in Flux.") def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): pass @@ -806,7 +723,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): @require_torch_accelerator @require_peft_backend @require_big_accelerator -class FluxLoRAIntegrationTests(unittest.TestCase): +class TestFluxLoRAIntegration: """internal note: The integration slices were obtained on audace. torch: 2.6.0.dev20241006+cu124 with CUDA 12.5. Need the same setup for the @@ -816,33 +733,27 @@ class FluxLoRAIntegrationTests(unittest.TestCase): num_inference_steps = 10 seed = 0 - def setUp(self): - super().setUp() - + @pytest.fixture(scope="function") + def pipeline(self): gc.collect() backend_empty_cache(torch_device) - - self.pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) - - def tearDown(self): - super().tearDown() - - del self.pipeline - gc.collect() - backend_empty_cache(torch_device) - - def test_flux_the_last_ben(self): - self.pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors") - self.pipeline.fuse_lora() - self.pipeline.unload_lora_weights() - # Instead of calling `enable_model_cpu_offload()`, we do a accelerator placement here because the CI - # run supports it. We have about 34GB RAM in the CI runner which kills the test when run with - # `enable_model_cpu_offload()`. We repeat this for the other tests, too. - self.pipeline = self.pipeline.to(torch_device) - + pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to( + torch_device + ) + try: + yield pipe + finally: + del pipe + gc.collect() + backend_empty_cache(torch_device) + + def test_flux_the_last_ben(self, pipeline): + pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors") + pipeline.fuse_lora() + pipeline.unload_lora_weights() + pipeline = pipeline.to(torch_device) prompt = "jon snow eating pizza with ketchup" - - out = self.pipeline( + out = pipeline( prompt, num_inference_steps=self.num_inference_steps, guidance_scale=4.0, @@ -851,71 +762,57 @@ def test_flux_the_last_ben(self): ).images out_slice = out[0, -3:, -3:, -1].flatten() expected_slice = np.array([0.1855, 0.1855, 0.1836, 0.1855, 0.1836, 0.1875, 0.1777, 0.1758, 0.2246]) - max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) + assert max_diff < 0.001 - assert max_diff < 1e-3 - - def test_flux_kohya(self): - self.pipeline.load_lora_weights("Norod78/brain-slug-flux") - self.pipeline.fuse_lora() - self.pipeline.unload_lora_weights() - self.pipeline = self.pipeline.to(torch_device) - + def test_flux_kohya(self, pipeline): + pipeline.load_lora_weights("Norod78/brain-slug-flux") + pipeline.fuse_lora() + pipeline.unload_lora_weights() + pipeline = pipeline.to(torch_device) prompt = "The cat with a brain slug earring" - out = self.pipeline( + out = pipeline( prompt, num_inference_steps=self.num_inference_steps, guidance_scale=4.5, output_type="np", generator=torch.manual_seed(self.seed), ).images - out_slice = out[0, -3:, -3:, -1].flatten() expected_slice = np.array([0.6367, 0.6367, 0.6328, 0.6367, 0.6328, 0.6289, 0.6367, 0.6328, 0.6484]) - max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) + assert max_diff < 0.001 - assert max_diff < 1e-3 - - def test_flux_kohya_with_text_encoder(self): - self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors") - self.pipeline.fuse_lora() - self.pipeline.unload_lora_weights() - self.pipeline = self.pipeline.to(torch_device) - + def test_flux_kohya_with_text_encoder(self, pipeline): + pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors") + pipeline.fuse_lora() + pipeline.unload_lora_weights() + pipeline = pipeline.to(torch_device) prompt = "optimus is cleaning the house with broomstick" - out = self.pipeline( + out = pipeline( prompt, num_inference_steps=self.num_inference_steps, guidance_scale=4.5, output_type="np", generator=torch.manual_seed(self.seed), ).images - out_slice = out[0, -3:, -3:, -1].flatten() expected_slice = np.array([0.4023, 0.4023, 0.4023, 0.3965, 0.3984, 0.3965, 0.3926, 0.3906, 0.4219]) - max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) + assert max_diff < 0.001 - assert max_diff < 1e-3 - - def test_flux_kohya_embedders_conversion(self): + def test_flux_kohya_embedders_conversion(self, pipeline): """Test that embedders load without throwing errors""" - self.pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora") - self.pipeline.unload_lora_weights() - - assert True - - def test_flux_xlabs(self): - self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors") - self.pipeline.fuse_lora() - self.pipeline.unload_lora_weights() - self.pipeline = self.pipeline.to(torch_device) - + pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora") + pipeline.unload_lora_weights() + + def test_flux_xlabs(self, pipeline): + pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors") + pipeline.fuse_lora() + pipeline.unload_lora_weights() + pipeline = pipeline.to(torch_device) prompt = "A blue jay standing on a large basket of rainbow macarons, disney style" - - out = self.pipeline( + out = pipeline( prompt, num_inference_steps=self.num_inference_steps, guidance_scale=3.5, @@ -923,23 +820,17 @@ def test_flux_xlabs(self): generator=torch.manual_seed(self.seed), ).images out_slice = out[0, -3:, -3:, -1].flatten() - expected_slice = np.array([0.3965, 0.4180, 0.4434, 0.4082, 0.4375, 0.4590, 0.4141, 0.4375, 0.4980]) - + expected_slice = np.array([0.3965, 0.418, 0.4434, 0.4082, 0.4375, 0.459, 0.4141, 0.4375, 0.498]) max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) + assert max_diff < 0.001 - assert max_diff < 1e-3 - - def test_flux_xlabs_load_lora_with_single_blocks(self): - self.pipeline.load_lora_weights( - "salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors" - ) - self.pipeline.fuse_lora() - self.pipeline.unload_lora_weights() - self.pipeline.enable_model_cpu_offload() - + def test_flux_xlabs_load_lora_with_single_blocks(self, pipeline): + pipeline.load_lora_weights("salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors") + pipeline.fuse_lora() + pipeline.unload_lora_weights() + pipeline.enable_model_cpu_offload() prompt = "a wizard mouse playing chess" - - out = self.pipeline( + out = pipeline( prompt, num_inference_steps=self.num_inference_steps, guidance_scale=3.5, @@ -951,40 +842,43 @@ def test_flux_xlabs_load_lora_with_single_blocks(self): [0.04882812, 0.04101562, 0.04882812, 0.03710938, 0.02929688, 0.02734375, 0.0234375, 0.01757812, 0.0390625] ) max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) - - assert max_diff < 1e-3 + assert max_diff < 0.001 @nightly @require_torch_accelerator @require_peft_backend @require_big_accelerator -class FluxControlLoRAIntegrationTests(unittest.TestCase): +class TestFluxControlLoRAIntegration: num_inference_steps = 10 seed = 0 prompt = "A robot made of exotic candies and chocolates of different kinds." - def setUp(self): - super().setUp() - - gc.collect() - backend_empty_cache(torch_device) - - self.pipeline = FluxControlPipeline.from_pretrained( - "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 - ).to(torch_device) - - def tearDown(self): - super().tearDown() - + @pytest.fixture(scope="function") + def pipeline(self): gc.collect() backend_empty_cache(torch_device) - - @parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"]) - def test_lora(self, lora_ckpt_id): - self.pipeline.load_lora_weights(lora_ckpt_id) - self.pipeline.fuse_lora() - self.pipeline.unload_lora_weights() + pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to( + torch_device + ) + try: + yield pipe + finally: + del pipe + gc.collect() + backend_empty_cache(torch_device) + + @pytest.mark.parametrize( + "lora_ckpt_id", + [ + "black-forest-labs/FLUX.1-Canny-dev-lora", + "black-forest-labs/FLUX.1-Depth-dev-lora", + ], + ) + def test_lora(self, pipeline, lora_ckpt_id): + pipeline.load_lora_weights(lora_ckpt_id) + pipeline.fuse_lora() + pipeline.unload_lora_weights() if "Canny" in lora_ckpt_id: control_image = load_image( @@ -995,7 +889,7 @@ def test_lora(self, lora_ckpt_id): "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png" ) - image = self.pipeline( + image = pipeline( prompt=self.prompt, control_image=control_image, height=1024, @@ -1016,12 +910,18 @@ def test_lora(self, lora_ckpt_id): assert max_diff < 1e-3 - @parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"]) - def test_lora_with_turbo(self, lora_ckpt_id): - self.pipeline.load_lora_weights(lora_ckpt_id) - self.pipeline.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors") - self.pipeline.fuse_lora() - self.pipeline.unload_lora_weights() + @pytest.mark.parametrize( + "lora_ckpt_id", + [ + "black-forest-labs/FLUX.1-Canny-dev-lora", + "black-forest-labs/FLUX.1-Depth-dev-lora", + ], + ) + def test_lora_with_turbo(self, pipeline, lora_ckpt_id): + pipeline.load_lora_weights(lora_ckpt_id) + pipeline.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors") + pipeline.fuse_lora() + pipeline.unload_lora_weights() if "Canny" in lora_ckpt_id: control_image = load_image( diff --git a/tests/lora/test_lora_layers_hunyuanvideo.py b/tests/lora/test_lora_layers_hunyuanvideo.py index cfd5d3146a91..1d3e3dbf6a38 100644 --- a/tests/lora/test_lora_layers_hunyuanvideo.py +++ b/tests/lora/test_lora_layers_hunyuanvideo.py @@ -14,9 +14,9 @@ import gc import sys -import unittest import numpy as np +import pytest import torch from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast @@ -48,7 +48,7 @@ @require_peft_backend @skip_mps -class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class TestHunyuanVideoLoRA(PeftLoraLoaderMixinTests): pipeline_class = HunyuanVideoPipeline scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {} @@ -149,46 +149,41 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - def test_simple_inference_with_text_lora_denoiser_fused_multi(self): - super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) + def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe): + super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3) - def test_simple_inference_with_text_denoiser_lora_unfused(self): - super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) + def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe): + super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3) - # TODO(aryan): Fix the following test - @unittest.skip("This test fails with an error I haven't been able to debug yet.") - def test_simple_inference_save_pretrained(self): - pass - - @unittest.skip("Not supported in HunyuanVideo.") + @pytest.mark.skip("Not supported in HunyuanVideo.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in HunyuanVideo.") + @pytest.mark.skip("Not supported in HunyuanVideo.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Not supported in HunyuanVideo.") + @pytest.mark.skip("Not supported in HunyuanVideo.") def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") + @pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.") def test_simple_inference_with_partial_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") + @pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.") def test_simple_inference_with_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") + @pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.") def test_simple_inference_with_text_lora_and_scale(self): pass - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") + @pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.") def test_simple_inference_with_text_lora_fused(self): pass - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") + @pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.") def test_simple_inference_with_text_lora_save_load(self): pass @@ -197,7 +192,7 @@ def test_simple_inference_with_text_lora_save_load(self): @require_torch_accelerator @require_peft_backend @require_big_accelerator -class HunyuanVideoLoRAIntegrationTests(unittest.TestCase): +class TestHunyuanVideoLoRAIntegration: """internal note: The integration slices were obtained on DGX. torch: 2.5.1+cu124 with CUDA 12.5. Need the same setup for the @@ -207,9 +202,8 @@ class HunyuanVideoLoRAIntegrationTests(unittest.TestCase): num_inference_steps = 10 seed = 0 - def setUp(self): - super().setUp() - + @pytest.fixture(scope="function") + def pipeline(self): gc.collect() backend_empty_cache(torch_device) @@ -217,27 +211,27 @@ def setUp(self): transformer = HunyuanVideoTransformer3DModel.from_pretrained( model_id, subfolder="transformer", torch_dtype=torch.bfloat16 ) - self.pipeline = HunyuanVideoPipeline.from_pretrained( - model_id, transformer=transformer, torch_dtype=torch.float16 - ).to(torch_device) - - def tearDown(self): - super().tearDown() - - gc.collect() - backend_empty_cache(torch_device) - - def test_original_format_cseti(self): - self.pipeline.load_lora_weights( + pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16).to( + torch_device + ) + try: + yield pipe + finally: + del pipe + gc.collect() + backend_empty_cache(torch_device) + + def test_original_format_cseti(self, pipeline): + pipeline.load_lora_weights( "Cseti/HunyuanVideo-LoRA-Arcane_Jinx-v1", weight_name="csetiarcane-nfjinx-v1-6000.safetensors" ) - self.pipeline.fuse_lora() - self.pipeline.unload_lora_weights() - self.pipeline.vae.enable_tiling() + pipeline.fuse_lora() + pipeline.unload_lora_weights() + pipeline.vae.enable_tiling() prompt = "CSETIARCANE. A cat walks on the grass, realistic" - out = self.pipeline( + out = pipeline( prompt=prompt, height=320, width=512, diff --git a/tests/lora/test_lora_layers_ltx_video.py b/tests/lora/test_lora_layers_ltx_video.py index 6ab51a5e513f..2ffc39ef2b41 100644 --- a/tests/lora/test_lora_layers_ltx_video.py +++ b/tests/lora/test_lora_layers_ltx_video.py @@ -13,8 +13,8 @@ # limitations under the License. import sys -import unittest +import pytest import torch from transformers import AutoTokenizer, T5EncoderModel @@ -34,7 +34,7 @@ @require_peft_backend -class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class TestLTXVideoLoRA(PeftLoraLoaderMixinTests): pipeline_class = LTXPipeline scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {} @@ -108,40 +108,40 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - def test_simple_inference_with_text_lora_denoiser_fused_multi(self): - super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) + def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe): + super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3) - def test_simple_inference_with_text_denoiser_lora_unfused(self): - super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) + def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe): + super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3) - @unittest.skip("Not supported in LTXVideo.") + @pytest.mark.skip("Not supported in LTXVideo.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in LTXVideo.") + @pytest.mark.skip("Not supported in LTXVideo.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Not supported in LTXVideo.") + @pytest.mark.skip("Not supported in LTXVideo.") def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") + @pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.") def test_simple_inference_with_partial_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") + @pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.") def test_simple_inference_with_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") + @pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.") def test_simple_inference_with_text_lora_and_scale(self): pass - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") + @pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.") def test_simple_inference_with_text_lora_fused(self): pass - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") + @pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.") def test_simple_inference_with_text_lora_save_load(self): pass diff --git a/tests/lora/test_lora_layers_lumina2.py b/tests/lora/test_lora_layers_lumina2.py index 0417b05b33a1..b0f6ab6039f0 100644 --- a/tests/lora/test_lora_layers_lumina2.py +++ b/tests/lora/test_lora_layers_lumina2.py @@ -13,7 +13,6 @@ # limitations under the License. import sys -import unittest import numpy as np import pytest @@ -36,7 +35,7 @@ @require_peft_backend -class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class TestLumina2LoRA(PeftLoraLoaderMixinTests): pipeline_class = Lumina2Pipeline scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {} @@ -101,35 +100,35 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - @unittest.skip("Not supported in Lumina2.") + @pytest.mark.skip("Not supported in Lumina2.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in Lumina2.") + @pytest.mark.skip("Not supported in Lumina2.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Not supported in Lumina2.") + @pytest.mark.skip("Not supported in Lumina2.") def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") + @pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.") def test_simple_inference_with_partial_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") + @pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.") def test_simple_inference_with_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") + @pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.") def test_simple_inference_with_text_lora_and_scale(self): pass - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") + @pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.") def test_simple_inference_with_text_lora_fused(self): pass - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") + @pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.") def test_simple_inference_with_text_lora_save_load(self): pass @@ -139,20 +138,17 @@ def test_simple_inference_with_text_lora_save_load(self): reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.", strict=False, ) - def test_lora_fuse_nan(self): - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + def test_lora_fuse_nan(self, pipe): + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." # corrupt one LoRA weight with `inf` values with torch.no_grad(): @@ -166,4 +162,4 @@ def test_lora_fuse_nan(self): pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) out = pipe(**inputs)[0] - self.assertTrue(np.isnan(out).all()) + assert np.isnan(out).all() diff --git a/tests/lora/test_lora_layers_mochi.py b/tests/lora/test_lora_layers_mochi.py index 7be81273db77..9b81e220b28f 100644 --- a/tests/lora/test_lora_layers_mochi.py +++ b/tests/lora/test_lora_layers_mochi.py @@ -13,8 +13,8 @@ # limitations under the License. import sys -import unittest +import pytest import torch from transformers import AutoTokenizer, T5EncoderModel @@ -34,7 +34,7 @@ @require_peft_backend @skip_mps -class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class TestMochiLoRA(PeftLoraLoaderMixinTests): pipeline_class = MochiPipeline scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {} @@ -99,44 +99,44 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - def test_simple_inference_with_text_lora_denoiser_fused_multi(self): - super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) + def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe): + super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3) - def test_simple_inference_with_text_denoiser_lora_unfused(self): - super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) + def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe): + super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3) - @unittest.skip("Not supported in Mochi.") + @pytest.mark.skip("Not supported in Mochi.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in Mochi.") + @pytest.mark.skip("Not supported in Mochi.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Not supported in Mochi.") + @pytest.mark.skip("Not supported in Mochi.") def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in Mochi.") + @pytest.mark.skip("Text encoder LoRA is not supported in Mochi.") def test_simple_inference_with_partial_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in Mochi.") + @pytest.mark.skip("Text encoder LoRA is not supported in Mochi.") def test_simple_inference_with_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in Mochi.") + @pytest.mark.skip("Text encoder LoRA is not supported in Mochi.") def test_simple_inference_with_text_lora_and_scale(self): pass - @unittest.skip("Text encoder LoRA is not supported in Mochi.") + @pytest.mark.skip("Text encoder LoRA is not supported in Mochi.") def test_simple_inference_with_text_lora_fused(self): pass - @unittest.skip("Text encoder LoRA is not supported in Mochi.") + @pytest.mark.skip("Text encoder LoRA is not supported in Mochi.") def test_simple_inference_with_text_lora_save_load(self): pass - @unittest.skip("Not supported in CogVideoX.") + @pytest.mark.skip("Not supported in CogVideoX.") def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): pass diff --git a/tests/lora/test_lora_layers_qwenimage.py b/tests/lora/test_lora_layers_qwenimage.py index 51de2f8e20e1..c24464653072 100644 --- a/tests/lora/test_lora_layers_qwenimage.py +++ b/tests/lora/test_lora_layers_qwenimage.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys -import unittest +import pytest import torch from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer @@ -34,7 +34,7 @@ @require_peft_backend -class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class TestQwenImageLoRA(PeftLoraLoaderMixinTests): pipeline_class = QwenImagePipeline scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {} @@ -96,34 +96,34 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - @unittest.skip("Not supported in Qwen Image.") + @pytest.mark.skip("Not supported in Qwen Image.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in Qwen Image.") + @pytest.mark.skip("Not supported in Qwen Image.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Not supported in Qwen Image.") + @pytest.mark.skip("Not supported in Qwen Image.") def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") + @pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.") def test_simple_inference_with_partial_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") + @pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.") def test_simple_inference_with_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") + @pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.") def test_simple_inference_with_text_lora_and_scale(self): pass - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") + @pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.") def test_simple_inference_with_text_lora_fused(self): pass - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") + @pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.") def test_simple_inference_with_text_lora_save_load(self): pass diff --git a/tests/lora/test_lora_layers_sana.py b/tests/lora/test_lora_layers_sana.py index 3cdb28de75fb..5977aeb9a53c 100644 --- a/tests/lora/test_lora_layers_sana.py +++ b/tests/lora/test_lora_layers_sana.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys -import unittest +import pytest import torch from transformers import Gemma2Model, GemmaTokenizer @@ -29,7 +29,7 @@ @require_peft_backend -class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class TestSanaLoRA(PeftLoraLoaderMixinTests): pipeline_class = SanaPipeline scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {"shift": 7.0} @@ -105,34 +105,34 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - @unittest.skip("Not supported in SANA.") + @pytest.mark.skip("Not supported in SANA.") def test_modify_padding_mode(self): pass - @unittest.skip("Not supported in SANA.") + @pytest.mark.skip("Not supported in SANA.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in SANA.") + @pytest.mark.skip("Not supported in SANA.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Text encoder LoRA is not supported in SANA.") + @pytest.mark.skip("Text encoder LoRA is not supported in SANA.") def test_simple_inference_with_partial_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in SANA.") + @pytest.mark.skip("Text encoder LoRA is not supported in SANA.") def test_simple_inference_with_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in SANA.") + @pytest.mark.skip("Text encoder LoRA is not supported in SANA.") def test_simple_inference_with_text_lora_and_scale(self): pass - @unittest.skip("Text encoder LoRA is not supported in SANA.") + @pytest.mark.skip("Text encoder LoRA is not supported in SANA.") def test_simple_inference_with_text_lora_fused(self): pass - @unittest.skip("Text encoder LoRA is not supported in SANA.") + @pytest.mark.skip("Text encoder LoRA is not supported in SANA.") def test_simple_inference_with_text_lora_save_load(self): pass diff --git a/tests/lora/test_lora_layers_sd.py b/tests/lora/test_lora_layers_sd.py index 933bf2336a59..a5e640c0b736 100644 --- a/tests/lora/test_lora_layers_sd.py +++ b/tests/lora/test_lora_layers_sd.py @@ -14,9 +14,9 @@ # limitations under the License. import gc import sys -import unittest import numpy as np +import pytest import torch import torch.nn as nn from huggingface_hub import hf_hub_download @@ -55,7 +55,7 @@ from accelerate.utils import release_memory -class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): +class TestStableDiffusionLoRA(PeftLoraLoaderMixinTests): pipeline_class = StableDiffusionPipeline scheduler_cls = DDIMScheduler scheduler_kwargs = { @@ -91,16 +91,6 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): def output_shape(self): return (1, 64, 64, 3) - def setUp(self): - super().setUp() - gc.collect() - backend_empty_cache(torch_device) - - def tearDown(self): - super().tearDown() - gc.collect() - backend_empty_cache(torch_device) - # Keeping this test here makes sense because it doesn't look any integration # (value assertions on logits). @slow @@ -114,15 +104,8 @@ def test_integration_move_lora_cpu(self): pipe.load_lora_weights(lora_id, adapter_name="adapter-2") pipe = pipe.to(torch_device) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), - "Lora not correctly set in text encoder", - ) - - self.assertTrue( - check_if_lora_correctly_set(pipe.unet), - "Lora not correctly set in unet", - ) + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet" # We will offload the first adapter in CPU and check if the offloading # has been performed correctly @@ -130,35 +113,35 @@ def test_integration_move_lora_cpu(self): for name, module in pipe.unet.named_modules(): if "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)): - self.assertTrue(module.weight.device == torch.device("cpu")) + assert module.weight.device == torch.device("cpu") elif "adapter-2" in name and not isinstance(module, (nn.Dropout, nn.Identity)): - self.assertTrue(module.weight.device != torch.device("cpu")) + assert module.weight.device != torch.device("cpu") for name, module in pipe.text_encoder.named_modules(): if "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)): - self.assertTrue(module.weight.device == torch.device("cpu")) + assert module.weight.device == torch.device("cpu") elif "adapter-2" in name and not isinstance(module, (nn.Dropout, nn.Identity)): - self.assertTrue(module.weight.device != torch.device("cpu")) + assert module.weight.device != torch.device("cpu") pipe.set_lora_device(["adapter-1"], 0) for n, m in pipe.unet.named_modules(): if "adapter-1" in n and not isinstance(m, (nn.Dropout, nn.Identity)): - self.assertTrue(m.weight.device != torch.device("cpu")) + assert m.weight.device != torch.device("cpu") for n, m in pipe.text_encoder.named_modules(): if "adapter-1" in n and not isinstance(m, (nn.Dropout, nn.Identity)): - self.assertTrue(m.weight.device != torch.device("cpu")) + assert m.weight.device != torch.device("cpu") pipe.set_lora_device(["adapter-1", "adapter-2"], torch_device) for n, m in pipe.unet.named_modules(): if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)): - self.assertTrue(m.weight.device != torch.device("cpu")) + assert m.weight.device != torch.device("cpu") for n, m in pipe.text_encoder.named_modules(): if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)): - self.assertTrue(m.weight.device != torch.device("cpu")) + assert m.weight.device != torch.device("cpu") @slow @require_torch_accelerator @@ -181,15 +164,9 @@ def test_integration_move_lora_dora_cpu(self): pipe.unet.add_adapter(unet_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), - "Lora not correctly set in text encoder", - ) + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - self.assertTrue( - check_if_lora_correctly_set(pipe.unet), - "Lora not correctly set in unet", - ) + assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet" for name, param in pipe.unet.named_parameters(): if "lora_" in name: @@ -225,17 +202,14 @@ def test_integration_set_lora_device_different_target_layers(self): pipe.unet.add_adapter(config1, adapter_name="adapter-1") pipe = pipe.to(torch_device) - self.assertTrue( - check_if_lora_correctly_set(pipe.unet), - "Lora not correctly set in unet", - ) + assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet" # sanity check that the adapters don't target the same layers, otherwise the test passes even without the fix modules_adapter_0 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-0")} modules_adapter_1 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-1")} - self.assertNotEqual(modules_adapter_0, modules_adapter_1) - self.assertTrue(modules_adapter_0 - modules_adapter_1) - self.assertTrue(modules_adapter_1 - modules_adapter_0) + assert modules_adapter_0 != modules_adapter_1 + assert modules_adapter_0 - modules_adapter_1 + assert modules_adapter_1 - modules_adapter_0 # setting both separately works pipe.set_lora_device(["adapter-0"], "cpu") @@ -243,32 +217,30 @@ def test_integration_set_lora_device_different_target_layers(self): for name, module in pipe.unet.named_modules(): if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)): - self.assertTrue(module.weight.device == torch.device("cpu")) + assert module.weight.device == torch.device("cpu") elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)): - self.assertTrue(module.weight.device == torch.device("cpu")) + assert module.weight.device == torch.device("cpu") # setting both at once also works pipe.set_lora_device(["adapter-0", "adapter-1"], torch_device) for name, module in pipe.unet.named_modules(): if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)): - self.assertTrue(module.weight.device != torch.device("cpu")) + assert module.weight.device != torch.device("cpu") elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)): - self.assertTrue(module.weight.device != torch.device("cpu")) + assert module.weight.device != torch.device("cpu") @slow @nightly @require_torch_accelerator @require_peft_backend -class LoraIntegrationTests(unittest.TestCase): - def setUp(self): - super().setUp() +class TestSDLoraIntegration: + @pytest.fixture(autouse=True) + def _gc_and_cache_cleanup(self, torch_device): gc.collect() backend_empty_cache(torch_device) - - def tearDown(self): - super().tearDown() + yield gc.collect() backend_empty_cache(torch_device) @@ -280,10 +252,7 @@ def test_integration_logits_with_scale(self): pipe.load_lora_weights(lora_id) pipe = pipe.to(torch_device) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), - "Lora not correctly set in text encoder", - ) + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" prompt = "a red sks dog" @@ -312,10 +281,7 @@ def test_integration_logits_no_scale(self): pipe.load_lora_weights(lora_id) pipe = pipe.to(torch_device) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), - "Lora not correctly set in text encoder", - ) + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" prompt = "a red sks dog" @@ -587,8 +553,8 @@ def test_unload_kohya_lora(self): ).images unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten() - self.assertFalse(np.allclose(initial_images, lora_images)) - self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3)) + assert not np.allclose(initial_images, lora_images) + assert np.allclose(initial_images, unloaded_lora_images, atol=1e-3) release_memory(pipe) @@ -625,8 +591,8 @@ def test_load_unload_load_kohya_lora(self): ).images unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten() - self.assertFalse(np.allclose(initial_images, lora_images)) - self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3)) + assert not np.allclose(initial_images, lora_images) + assert np.allclose(initial_images, unloaded_lora_images, atol=1e-3) # make sure we can load a LoRA again after unloading and they don't have # any undesired effects. @@ -637,7 +603,7 @@ def test_load_unload_load_kohya_lora(self): ).images lora_images_again = lora_images_again[0, -3:, -3:, -1].flatten() - self.assertTrue(np.allclose(lora_images, lora_images_again, atol=1e-3)) + assert np.allclose(lora_images, lora_images_again, atol=1e-3) release_memory(pipe) def test_not_empty_state_dict(self): @@ -651,7 +617,7 @@ def test_not_empty_state_dict(self): lcm_lora = load_file(cached_file) pipe.load_lora_weights(lcm_lora, adapter_name="lcm") - self.assertTrue(lcm_lora != {}) + assert lcm_lora != {} release_memory(pipe) def test_load_unload_load_state_dict(self): @@ -666,11 +632,11 @@ def test_load_unload_load_state_dict(self): previous_state_dict = lcm_lora.copy() pipe.load_lora_weights(lcm_lora, adapter_name="lcm") - self.assertDictEqual(lcm_lora, previous_state_dict) + assert lcm_lora == previous_state_dict pipe.unload_lora_weights() pipe.load_lora_weights(lcm_lora, adapter_name="lcm") - self.assertDictEqual(lcm_lora, previous_state_dict) + assert lcm_lora == previous_state_dict release_memory(pipe) diff --git a/tests/lora/test_lora_layers_sd3.py b/tests/lora/test_lora_layers_sd3.py index 228460eaad90..a44f6887f41a 100644 --- a/tests/lora/test_lora_layers_sd3.py +++ b/tests/lora/test_lora_layers_sd3.py @@ -14,9 +14,9 @@ # limitations under the License. import gc import sys -import unittest import numpy as np +import pytest import torch from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel @@ -51,7 +51,7 @@ @require_peft_backend -class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class TestSD3LoRA(PeftLoraLoaderMixinTests): pipeline_class = StableDiffusion3Pipeline scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {} @@ -113,19 +113,19 @@ def test_sd3_lora(self): lora_filename = "lora_peft_format.safetensors" pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) - @unittest.skip("Not supported in SD3.") + @pytest.mark.skip("Not supported in SD3.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in SD3.") + @pytest.mark.skip("Not supported in SD3.") def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): pass - @unittest.skip("Not supported in SD3.") + @pytest.mark.skip("Not supported in SD3.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Not supported in SD3.") + @pytest.mark.skip("Not supported in SD3.") def test_modify_padding_mode(self): pass @@ -138,17 +138,15 @@ def test_multiple_wrong_adapter_name_raises_error(self): @require_torch_accelerator @require_peft_backend @require_big_accelerator -class SD3LoraIntegrationTests(unittest.TestCase): +class TestSD3LoraIntegration: pipeline_class = StableDiffusion3Img2ImgPipeline repo_id = "stabilityai/stable-diffusion-3-medium-diffusers" - def setUp(self): - super().setUp() + @pytest.fixture(autouse=True) + def _gc_and_cache_cleanup(self, torch_device): gc.collect() backend_empty_cache(torch_device) - - def tearDown(self): - super().tearDown() + yield gc.collect() backend_empty_cache(torch_device) diff --git a/tests/lora/test_lora_layers_sdxl.py b/tests/lora/test_lora_layers_sdxl.py index ac1d65abdaa7..e1bc6e8ecb73 100644 --- a/tests/lora/test_lora_layers_sdxl.py +++ b/tests/lora/test_lora_layers_sdxl.py @@ -17,9 +17,9 @@ import importlib import sys import time -import unittest import numpy as np +import pytest import torch from packaging import version from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer @@ -59,7 +59,7 @@ from accelerate.utils import release_memory -class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): +class TestStableDiffusionXLLoRA(PeftLoraLoaderMixinTests): has_two_text_encoders = True pipeline_class = StableDiffusionXLPipeline scheduler_cls = EulerDiscreteScheduler @@ -104,21 +104,11 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): def output_shape(self): return (1, 64, 64, 3) - def setUp(self): - super().setUp() - gc.collect() - backend_empty_cache(torch_device) - - def tearDown(self): - super().tearDown() - gc.collect() - backend_empty_cache(torch_device) - @is_flaky def test_multiple_wrong_adapter_name_raises_error(self): super().test_multiple_wrong_adapter_name_raises_error() - def test_simple_inference_with_text_denoiser_lora_unfused(self): + def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe): if torch.cuda.is_available(): expected_atol = 9e-2 expected_rtol = 9e-2 @@ -127,10 +117,10 @@ def test_simple_inference_with_text_denoiser_lora_unfused(self): expected_rtol = 1e-3 super().test_simple_inference_with_text_denoiser_lora_unfused( - expected_atol=expected_atol, expected_rtol=expected_rtol + pipe=pipe, expected_atol=expected_atol, expected_rtol=expected_rtol ) - def test_simple_inference_with_text_lora_denoiser_fused_multi(self): + def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe): if torch.cuda.is_available(): expected_atol = 9e-2 expected_rtol = 9e-2 @@ -139,10 +129,10 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self): expected_rtol = 1e-3 super().test_simple_inference_with_text_lora_denoiser_fused_multi( - expected_atol=expected_atol, expected_rtol=expected_rtol + pipe=pipe, expected_atol=expected_atol, expected_rtol=expected_rtol ) - def test_lora_scale_kwargs_match_fusion(self): + def test_lora_scale_kwargs_match_fusion(self, base_pipe_output): if torch.cuda.is_available(): expected_atol = 9e-2 expected_rtol = 9e-2 @@ -150,21 +140,21 @@ def test_lora_scale_kwargs_match_fusion(self): expected_atol = 1e-3 expected_rtol = 1e-3 - super().test_lora_scale_kwargs_match_fusion(expected_atol=expected_atol, expected_rtol=expected_rtol) + super().test_lora_scale_kwargs_match_fusion( + base_pipe_output=base_pipe_output, expected_atol=expected_atol, expected_rtol=expected_rtol + ) @slow @nightly @require_torch_accelerator @require_peft_backend -class LoraSDXLIntegrationTests(unittest.TestCase): - def setUp(self): - super().setUp() +class TestLoraSDXLIntegration: + @pytest.fixture(autouse=True) + def _gc_and_cache_cleanup(self, torch_device): gc.collect() backend_empty_cache(torch_device) - - def tearDown(self): - super().tearDown() + yield gc.collect() backend_empty_cache(torch_device) @@ -383,7 +373,7 @@ def test_sdxl_1_0_lora_fusion_efficiency(self): end_time = time.time() elapsed_time_fusion = end_time - start_time - self.assertTrue(elapsed_time_fusion < elapsed_time_non_fusion) + assert elapsed_time_fusion < elapsed_time_non_fusion release_memory(pipe) @@ -439,14 +429,14 @@ def remap_key(key, sd): for key, value in text_encoder_1_sd.items(): key = remap_key(key, fused_te_state_dict) - self.assertTrue(torch.allclose(fused_te_state_dict[key], value)) + assert torch.allclose(fused_te_state_dict[key], value) for key, value in text_encoder_2_sd.items(): key = remap_key(key, fused_te_2_state_dict) - self.assertTrue(torch.allclose(fused_te_2_state_dict[key], value)) + assert torch.allclose(fused_te_2_state_dict[key], value) for key, value in unet_state_dict.items(): - self.assertTrue(torch.allclose(unet_state_dict[key], value)) + assert torch.allclose(unet_state_dict[key], value) pipe.fuse_lora() pipe.unload_lora_weights() @@ -589,7 +579,7 @@ def test_integration_logits_multi_adapter(self): pipe.load_lora_weights(lora_id, weight_name="toy_face_sdxl.safetensors", adapter_name="toy") pipe = pipe.to(torch_device) - self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet" prompt = "toy_face of a hacker with a hoodie" diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py index 5734509b410f..2dfe91d6d578 100644 --- a/tests/lora/test_lora_layers_wan.py +++ b/tests/lora/test_lora_layers_wan.py @@ -13,8 +13,8 @@ # limitations under the License. import sys -import unittest +import pytest import torch from transformers import AutoTokenizer, T5EncoderModel @@ -39,7 +39,7 @@ @require_peft_backend @skip_mps -class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class TestWanLoRA(PeftLoraLoaderMixinTests): pipeline_class = WanPipeline scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {} @@ -104,40 +104,40 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - def test_simple_inference_with_text_lora_denoiser_fused_multi(self): - super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) + def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe): + super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3) - def test_simple_inference_with_text_denoiser_lora_unfused(self): - super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) + def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe): + super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3) - @unittest.skip("Not supported in Wan.") + @pytest.mark.skip("Not supported in Wan.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in Wan.") + @pytest.mark.skip("Not supported in Wan.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Not supported in Wan.") + @pytest.mark.skip("Not supported in Wan.") def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in Wan.") + @pytest.mark.skip("Text encoder LoRA is not supported in Wan.") def test_simple_inference_with_partial_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in Wan.") + @pytest.mark.skip("Text encoder LoRA is not supported in Wan.") def test_simple_inference_with_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in Wan.") + @pytest.mark.skip("Text encoder LoRA is not supported in Wan.") def test_simple_inference_with_text_lora_and_scale(self): pass - @unittest.skip("Text encoder LoRA is not supported in Wan.") + @pytest.mark.skip("Text encoder LoRA is not supported in Wan.") def test_simple_inference_with_text_lora_fused(self): pass - @unittest.skip("Text encoder LoRA is not supported in Wan.") + @pytest.mark.skip("Text encoder LoRA is not supported in Wan.") def test_simple_inference_with_text_lora_save_load(self): pass diff --git a/tests/lora/test_lora_layers_wanvace.py b/tests/lora/test_lora_layers_wanvace.py index ab1f57bfc9da..48017120ed83 100644 --- a/tests/lora/test_lora_layers_wanvace.py +++ b/tests/lora/test_lora_layers_wanvace.py @@ -14,10 +14,9 @@ import os import sys -import tempfile -import unittest import numpy as np +import pytest import safetensors.torch import torch from PIL import Image @@ -32,7 +31,6 @@ require_peft_backend, require_peft_version_greater, skip_mps, - torch_device, ) @@ -47,7 +45,7 @@ @require_peft_backend @skip_mps @is_flaky(max_attempts=10, description="very flaky class") -class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class TestWanVACELoRA(PeftLoraLoaderMixinTests): pipeline_class = WanVACEPipeline scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {} @@ -121,56 +119,51 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - def test_simple_inference_with_text_lora_denoiser_fused_multi(self): - super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) + def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe): + super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3) - def test_simple_inference_with_text_denoiser_lora_unfused(self): - super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) + def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe): + super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3) - @unittest.skip("Not supported in Wan VACE.") + @pytest.mark.skip("Not supported in Wan VACE.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in Wan VACE.") + @pytest.mark.skip("Not supported in Wan VACE.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Not supported in Wan VACE.") + @pytest.mark.skip("Not supported in Wan VACE.") def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") + @pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.") def test_simple_inference_with_partial_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") + @pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.") def test_simple_inference_with_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") + @pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.") def test_simple_inference_with_text_lora_and_scale(self): pass - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") + @pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.") def test_simple_inference_with_text_lora_fused(self): pass - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") + @pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.") def test_simple_inference_with_text_lora_save_load(self): pass - def test_layerwise_casting_inference_denoiser(self): - super().test_layerwise_casting_inference_denoiser() - @require_peft_version_greater("0.13.2") - def test_lora_exclude_modules_wanvace(self): + def test_lora_exclude_modules_wanvace(self, base_pipe_output, tmpdirname, pipe): exclude_module_name = "vace_blocks.0.proj_out" - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components).to(torch_device) + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.get_base_pipe_output() - self.assertTrue(output_no_lora.shape == self.output_shape) + assert base_pipe_output.shape == self.output_shape # only supported for `denoiser` now denoiser_lora_config.target_modules = ["proj_out"] @@ -180,36 +173,30 @@ def test_lora_exclude_modules_wanvace(self): ) # The state dict shouldn't contain the modules to be excluded from LoRA. state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default") - self.assertTrue(not any(exclude_module_name in k for k in state_dict_from_model)) - self.assertTrue(any("proj_out" in k for k in state_dict_from_model)) + assert not any(exclude_module_name in k for k in state_dict_from_model) + assert any("proj_out" in k for k in state_dict_from_model) output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdir: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts) - pipe.unload_lora_weights() - - # Check in the loaded state dict. - loaded_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) - self.assertTrue(not any(exclude_module_name in k for k in loaded_state_dict)) - self.assertTrue(any("proj_out" in k for k in loaded_state_dict)) - - # Check in the state dict obtained after loading LoRA. - pipe.load_lora_weights(tmpdir) - state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default_0") - self.assertTrue(not any(exclude_module_name in k for k in state_dict_from_model)) - self.assertTrue(any("proj_out" in k for k in state_dict_from_model)) - - output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3), - "LoRA should change outputs.", - ) - self.assertTrue( - np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), - "Lora outputs should match.", - ) - - def test_simple_inference_with_text_denoiser_lora_and_scale(self): - super().test_simple_inference_with_text_denoiser_lora_and_scale() + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts) + pipe.unload_lora_weights() + + # Check in the loaded state dict. + loaded_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + assert not any(exclude_module_name in k for k in loaded_state_dict) + assert any("proj_out" in k for k in loaded_state_dict) + + # Check in the state dict obtained after loading LoRA. + pipe.load_lora_weights(tmpdirname) + state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default_0") + assert not any(exclude_module_name in k for k in state_dict_from_model) + assert any("proj_out" in k for k in state_dict_from_model) + + output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert not np.allclose(base_pipe_output, output_lora_exclude_modules, atol=1e-3, rtol=1e-3), ( + "LoRA should change outputs." + ) + assert np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), ( + "Lora outputs should match." + ) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 3d4344bb86a9..bfb242c74df4 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -15,19 +15,13 @@ import inspect import os import re -import tempfile -import unittest from itertools import product import numpy as np import pytest import torch -from parameterized import parameterized -from diffusers import ( - AutoencoderKL, - UNet2DConditionModel, -) +from diffusers import AutoencoderKL, UNet2DConditionModel from diffusers.utils import logging from diffusers.utils.import_utils import is_peft_available @@ -81,7 +75,7 @@ def check_module_lora_metadata(parsed_metadata: dict, lora_metadatas: dict, modu def initialize_dummy_state_dict(state_dict): - if not all(v.device.type == "meta" for _, v in state_dict.items()): + if not all((v.device.type == "meta" for _, v in state_dict.items())): raise ValueError("`state_dict` has non-meta values.") return {k: torch.randn(v.shape, device=torch_device, dtype=v.dtype) for k, v in state_dict.items()} @@ -126,12 +120,25 @@ class PeftLoraLoaderMixinTests: text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] - cached_non_lora_output = None + @property + def output_shape(self): + raise NotImplementedError + + @pytest.fixture(scope="class") + def base_pipe_output(self): + return self._compute_baseline_output() - def get_base_pipe_output(self): - if self.cached_non_lora_output is None: - self.cached_non_lora_output = self._compute_baseline_output() - return self.cached_non_lora_output + @pytest.fixture(scope="function") + def tmpdirname(self, tmp_path_factory): + return tmp_path_factory.mktemp("tmp") + + @pytest.fixture(scope="function") + def pipe(self): + components, _, _ = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + return pipe def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None): if self.unet_kwargs and self.transformer_kwargs: @@ -153,7 +160,6 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=No torch.manual_seed(0) vae = self.vae_cls(**self.vae_kwargs) - text_encoder = self.text_encoder_cls.from_pretrained( self.text_encoder_id, subfolder=self.text_encoder_subfolder ) @@ -190,7 +196,6 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=No init_lora_weights=False, use_dora=use_dora, ) - pipeline_components = { "scheduler": scheduler, "vae": vae, @@ -220,10 +225,6 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=No return pipeline_components, text_lora_config, denoiser_lora_config - @property - def output_shape(self): - raise NotImplementedError - def get_dummy_inputs(self, with_generator=True): batch_size = 1 sequence_length = 10 @@ -233,7 +234,6 @@ def get_dummy_inputs(self, with_generator=True): generator = torch.manual_seed(0) noise = floats_tensor((batch_size, num_channels) + sizes) input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) - pipeline_inputs = { "prompt": "A painting of a squirrel eating a burger", "num_inference_steps": 5, @@ -243,10 +243,30 @@ def get_dummy_inputs(self, with_generator=True): if with_generator: pipeline_inputs.update({"generator": generator}) - return noise, input_ids, pipeline_inputs + return (noise, input_ids, pipeline_inputs) + + def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"): + if text_lora_config is not None: + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name) + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + + if denoiser_lora_config is not None: + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, adapter_name=adapter_name) + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + else: + denoiser = None + + if text_lora_config is not None and self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config, adapter_name=adapter_name) + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + + return pipe, denoiser def _compute_baseline_output(self): - components, _, _ = self.get_dummy_components(self.scheduler_cls) + components, _, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) @@ -277,7 +297,7 @@ def _get_modules_to_save(self, pipe, has_denoiser=False): if ( "text_encoder" in lora_loadable_modules and hasattr(pipe, "text_encoder") - and getattr(pipe.text_encoder, "peft_config", None) is not None + and (getattr(pipe.text_encoder, "peft_config", None) is not None) ): modules_to_save["text_encoder"] = pipe.text_encoder @@ -291,309 +311,210 @@ def _get_modules_to_save(self, pipe, has_denoiser=False): if has_denoiser: if "unet" in lora_loadable_modules and hasattr(pipe, "unet"): modules_to_save["unet"] = pipe.unet - if "transformer" in lora_loadable_modules and hasattr(pipe, "transformer"): modules_to_save["transformer"] = pipe.transformer return modules_to_save - def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"): - if text_lora_config is not None: - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - ) - - if denoiser_lora_config is not None: - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config, adapter_name=adapter_name) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - else: - denoiser = None - - if text_lora_config is not None and self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder_2.add_adapter(text_lora_config, adapter_name=adapter_name) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - return pipe, denoiser - - def test_simple_inference(self): + def test_simple_inference(self, base_pipe_output): """ Tests a simple inference and makes sure it works as expected """ - output_no_lora = self.get_base_pipe_output() - assert output_no_lora.shape == self.output_shape + assert base_pipe_output.shape == self.output_shape - def test_simple_inference_with_text_lora(self): + def test_simple_inference_with_text_lora(self, base_pipe_output, pipe): """ Tests a simple inference with lora attached on the text encoder and makes sure it works as expected """ - components, text_lora_config, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + _, text_lora_config, _ = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = self.get_base_pipe_output() pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" - ) + assert not np.allclose(output_lora, base_pipe_output, atol=1e-3, rtol=1e-3), "Lora should change the output" @require_peft_version_greater("0.13.1") - def test_low_cpu_mem_usage_with_injection(self): + def test_low_cpu_mem_usage_with_injection(self, pipe): """Tests if we can inject LoRA state dict with low_cpu_mem_usage.""" - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() if "text_encoder" in self.pipeline_class._lora_loadable_modules: inject_adapter_in_model(text_lora_config, pipe.text_encoder, low_cpu_mem_usage=True) - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder.") - self.assertTrue( - "meta" in {p.device.type for p in pipe.text_encoder.parameters()}, - "The LoRA params should be on 'meta' device.", + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder." + assert "meta" in {p.device.type for p in pipe.text_encoder.parameters()}, ( + "The LoRA params should be on 'meta' device." ) te_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder)) set_peft_model_state_dict(pipe.text_encoder, te_state_dict, low_cpu_mem_usage=True) - self.assertTrue( - "meta" not in {p.device.type for p in pipe.text_encoder.parameters()}, - "No param should be on 'meta' device.", + assert "meta" not in {p.device.type for p in pipe.text_encoder.parameters()}, ( + "No param should be on 'meta' device." ) denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet inject_adapter_in_model(denoiser_lora_config, denoiser, low_cpu_mem_usage=True) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - self.assertTrue( - "meta" in {p.device.type for p in denoiser.parameters()}, "The LoRA params should be on 'meta' device." - ) + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + assert "meta" in {p.device.type for p in denoiser.parameters()}, "The LoRA params should be on 'meta' device." denoiser_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(denoiser)) set_peft_model_state_dict(denoiser, denoiser_state_dict, low_cpu_mem_usage=True) - self.assertTrue( - "meta" not in {p.device.type for p in denoiser.parameters()}, "No param should be on 'meta' device." - ) + assert "meta" not in {p.device.type for p in denoiser.parameters()}, "No param should be on 'meta' device." if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: inject_adapter_in_model(text_lora_config, pipe.text_encoder_2, low_cpu_mem_usage=True) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - self.assertTrue( - "meta" in {p.device.type for p in pipe.text_encoder_2.parameters()}, - "The LoRA params should be on 'meta' device.", + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + assert "meta" in {p.device.type for p in pipe.text_encoder_2.parameters()}, ( + "The LoRA params should be on 'meta' device." ) te2_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder_2)) set_peft_model_state_dict(pipe.text_encoder_2, te2_state_dict, low_cpu_mem_usage=True) - self.assertTrue( - "meta" not in {p.device.type for p in pipe.text_encoder_2.parameters()}, - "No param should be on 'meta' device.", + assert "meta" not in {p.device.type for p in pipe.text_encoder_2.parameters()}, ( + "No param should be on 'meta' device." ) _, _, inputs = self.get_dummy_inputs() output_lora = pipe(**inputs)[0] - self.assertTrue(output_lora.shape == self.output_shape) + assert output_lora.shape == self.output_shape @require_peft_version_greater("0.13.1") @require_transformers_version_greater("4.45.2") - def test_low_cpu_mem_usage_with_loading(self): + def test_low_cpu_mem_usage_with_loading(self, tmpdirname, pipe): """Tests if we can load LoRA state dict with low_cpu_mem_usage.""" - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts - ) - - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) - pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False) + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts) - for module_name, module in modules_to_save.items(): - self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results.", - ) + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False) + for module_name, module in modules_to_save.items(): + assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" - # Now, check for `low_cpu_mem_usage.` - pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True) + images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), ( + "Loading from saved checkpoints should give same results." + ) - for module_name, module in modules_to_save.items(): - self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True) + for module_name, module in modules_to_save.items(): + assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" - images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - np.allclose(images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results.", - ) + images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert np.allclose(images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), ( + "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results." + ) - def test_simple_inference_with_text_lora_and_scale(self): + def test_simple_inference_with_text_lora_and_scale(self, base_pipe_output, pipe): """ Tests a simple inference with lora attached on the text encoder + scale argument and makes sure it works as expected """ attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) - components, text_lora_config, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + _, text_lora_config, _ = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = self.get_base_pipe_output() - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" - ) + assert not np.allclose(output_lora, base_pipe_output, atol=1e-3, rtol=1e-3), "Lora should change the output" attention_kwargs = {attention_kwargs_name: {"scale": 0.5}} output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - - self.assertTrue( - not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), - "Lora + scale should change the output", + assert not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), ( + "Lora + scale should change the output" ) attention_kwargs = {attention_kwargs_name: {"scale": 0.0}} output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - - self.assertTrue( - np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3), - "Lora + 0 scale should lead to same result as no LoRA", + assert np.allclose(base_pipe_output, output_lora_0_scale, atol=1e-3, rtol=1e-3), ( + "Lora + 0 scale should lead to same result as no LoRA" ) - def test_simple_inference_with_text_lora_fused(self): + def test_simple_inference_with_text_lora_fused(self, base_pipe_output, pipe): """ Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model and makes sure it works as expected """ - components, text_lora_config, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + _, text_lora_config, _ = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = self.get_base_pipe_output() - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) pipe.fuse_lora() - # Fusing should still keep the LoRA layers - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" ouput_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertFalse( - np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" + assert not np.allclose(ouput_fused, base_pipe_output, atol=1e-3, rtol=1e-3), ( + "Fused lora should change the output" ) - def test_simple_inference_with_text_lora_unloaded(self): + def test_simple_inference_with_text_lora_unloaded(self, base_pipe_output, pipe): """ Tests a simple inference with lora attached to text encoder, then unloads the lora weights and makes sure it works as expected """ - components, text_lora_config, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + _, text_lora_config, _ = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = self.get_base_pipe_output() - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) pipe.unload_lora_weights() - # unloading should remove the LoRA layers - self.assertFalse(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder") + assert not check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder" if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - self.assertFalse( - check_if_lora_correctly_set(pipe.text_encoder_2), - "Lora not correctly unloaded in text encoder 2", + assert not check_if_lora_correctly_set(pipe.text_encoder_2), ( + "Lora not correctly unloaded in text encoder 2" ) ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3), - "Fused lora should change the output", + assert np.allclose(ouput_unloaded, base_pipe_output, atol=1e-3, rtol=1e-3), ( + "Unloading lora should match the base pipe output" ) - def test_simple_inference_with_text_lora_save_load(self): + def test_simple_inference_with_text_lora_save_load(self, tmpdirname, pipe): """ Tests a simple usecase where users could use saving utilities for LoRA. """ - components, text_lora_config, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + _, text_lora_config, _ = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) - images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + modules_to_save = self._get_modules_to_save(pipe) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts - ) - - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) - pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) for module_name, module in modules_to_save.items(): - self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") + assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results.", + assert np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), ( + "Loading from saved checkpoints should give same results." ) - def test_simple_inference_with_partial_text_lora(self): + def test_simple_inference_with_partial_text_lora(self, base_pipe_output, pipe): """ Tests a simple inference with lora attached on the text encoder with different ranks and some adapters removed and makes sure it works as expected """ - components, _, _ = self.get_dummy_components() - # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324). text_lora_config = LoraConfig( r=4, rank_pattern={self.text_encoder_target_modules[i]: i + 1 for i in range(3)}, @@ -602,274 +523,197 @@ def test_simple_inference_with_partial_text_lora(self): init_lora_weights=False, use_dora=False, ) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = self.get_base_pipe_output() + _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) - state_dict = {} if "text_encoder" in self.pipeline_class._lora_loadable_modules: - # Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder` - # supports missing layers (PR#8324). state_dict = { f"text_encoder.{module_name}": param - for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items() + for (module_name, param) in get_peft_model_state_dict(pipe.text_encoder).items() if "text_model.encoder.layers.4" not in module_name } - if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: state_dict.update( { f"text_encoder_2.{module_name}": param - for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items() + for (module_name, param) in get_peft_model_state_dict(pipe.text_encoder_2).items() if "text_model.encoder.layers.4" not in module_name } ) - output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" - ) + assert not np.allclose(output_lora, base_pipe_output, atol=1e-3, rtol=1e-3), "Lora should change the output" - # Unload lora and load it back using the pipe.load_lora_weights machinery pipe.unload_lora_weights() pipe.load_lora_weights(state_dict) - output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3), - "Removing adapters should change the output", + assert not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3), ( + "Removing adapters should change the output" ) - def test_simple_inference_save_pretrained_with_text_lora(self): + def test_simple_inference_save_pretrained_with_text_lora(self, tmpdirname, pipe): """ Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained """ - components, text_lora_config, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + _, text_lora_config, _ = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: - pipe.save_pretrained(tmpdirname) - - pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname) - pipe_from_pretrained.to(torch_device) + pipe.save_pretrained(tmpdirname) + pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname) + pipe_from_pretrained.to(torch_device) if "text_encoder" in self.pipeline_class._lora_loadable_modules: - self.assertTrue( - check_if_lora_correctly_set(pipe_from_pretrained.text_encoder), - "Lora not correctly set in text encoder", + assert check_if_lora_correctly_set(pipe_from_pretrained.text_encoder), ( + "Lora not correctly set in text encoder" ) - if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - self.assertTrue( - check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2), - "Lora not correctly set in text encoder 2", + assert check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2), ( + "Lora not correctly set in text encoder 2" ) images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results.", + assert np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3), ( + "Loading from saved checkpoints should give same results." ) - def test_simple_inference_with_text_denoiser_lora_save_load(self): + def test_simple_inference_with_text_denoiser_lora_save_load(self, tmpdirname, pipe): """ Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts - ) - - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) - pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) for module_name, module in modules_to_save.items(): - self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") + assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results.", + assert np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), ( + "Loading from saved checkpoints should give same results." ) - def test_simple_inference_with_text_denoiser_lora_and_scale(self): + def test_simple_inference_with_text_denoiser_lora_and_scale(self, base_pipe_output, pipe): """ Tests a simple inference with lora attached on the text encoder + Unet + scale argument and makes sure it works as expected """ attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = self.get_base_pipe_output() pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" - ) + assert not np.allclose(output_lora, base_pipe_output, atol=1e-3, rtol=1e-3), "Lora should change the output" attention_kwargs = {attention_kwargs_name: {"scale": 0.5}} output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - - self.assertTrue( - not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), - "Lora + scale should change the output", + assert not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), ( + "Lora + scale should change the output" ) attention_kwargs = {attention_kwargs_name: {"scale": 0.0}} output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - - self.assertTrue( - np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3), - "Lora + 0 scale should lead to same result as no LoRA", + assert np.allclose(base_pipe_output, output_lora_0_scale, atol=1e-3, rtol=1e-3), ( + "Lora + 0 scale should lead to same result as no LoRA" ) - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - self.assertTrue( - pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0, - "The scaling parameter has not been correctly restored!", + assert pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0, ( + "The scaling parameter has not been correctly restored!" ) - def test_simple_inference_with_text_lora_denoiser_fused(self): + def test_simple_inference_with_text_lora_denoiser_fused(self, base_pipe_output, pipe): """ Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model and makes sure it works as expected - with unet """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = self.get_base_pipe_output() - pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) - # Fusing should still keep the LoRA layers if "text_encoder" in self.pipeline_class._lora_loadable_modules: - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser") - + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser" if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" output_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertFalse( - np.allclose(output_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" + assert not np.allclose(output_fused, base_pipe_output, atol=1e-3, rtol=1e-3), ( + "Fused lora should change the output" ) - def test_simple_inference_with_text_denoiser_lora_unloaded(self): + def test_simple_inference_with_text_denoiser_lora_unloaded(self, base_pipe_output, pipe): """ Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights and makes sure it works as expected """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = self.get_base_pipe_output() - pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe.unload_lora_weights() - # unloading should remove the LoRA layers - self.assertFalse(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder") - self.assertFalse(check_if_lora_correctly_set(denoiser), "Lora not correctly unloaded in denoiser") + assert not check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder" + assert not check_if_lora_correctly_set(denoiser), "Lora not correctly unloaded in denoiser" if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - self.assertFalse( - check_if_lora_correctly_set(pipe.text_encoder_2), - "Lora not correctly unloaded in text encoder 2", + assert not check_if_lora_correctly_set(pipe.text_encoder_2), ( + "Lora not correctly unloaded in text encoder 2" ) output_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - np.allclose(output_unloaded, output_no_lora, atol=1e-3, rtol=1e-3), - "Fused lora should change the output", + assert np.allclose(output_unloaded, base_pipe_output, atol=1e-3, rtol=1e-3), ( + "Fused lora should change the output" ) def test_simple_inference_with_text_denoiser_lora_unfused( - self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3 + self, pipe, expected_atol: float = 1e-3, expected_rtol: float = 1e-3 ): """ Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights and makes sure it works as expected """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) - self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") - output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert pipe.num_fused_loras == 1, ( + f"pipe.num_fused_loras={pipe.num_fused_loras!r}, pipe.fused_loras={pipe.fused_loras!r}" + ) + output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) - self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") - output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert pipe.num_fused_loras == 0, ( + f"pipe.num_fused_loras={pipe.num_fused_loras!r}, pipe.fused_loras={pipe.fused_loras!r}" + ) - # unloading should remove the LoRA layers + output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] if "text_encoder" in self.pipeline_class._lora_loadable_modules: - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers") - - self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers") - + assert check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers" + assert check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers" if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers" - ) - - # Fuse and unfuse should lead to the same results - self.assertTrue( - np.allclose(output_fused_lora, output_unfused_lora, atol=expected_atol, rtol=expected_rtol), - "Fused lora should not change the output", + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers" + assert np.allclose(output_fused_lora, output_unfused_lora, atol=expected_atol, rtol=expected_rtol), ( + "Fused lora should not change the output" ) - def test_simple_inference_with_text_denoiser_multi_adapter(self): + def test_simple_inference_with_text_denoiser_multi_adapter(self, base_pipe_output): """ Tests a simple inference with lora attached to text encoder and unet, attaches multiple adapters and set them @@ -877,104 +721,83 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self): components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.get_base_pipe_output() - if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" pipe.set_adapters("adapter-1") output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertFalse( - np.allclose(output_no_lora, output_adapter_1, atol=1e-3, rtol=1e-3), - "Adapter outputs should be different.", + assert not np.allclose(base_pipe_output, output_adapter_1, atol=1e-3, rtol=1e-3), ( + "Adapter outputs should be different." ) pipe.set_adapters("adapter-2") output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertFalse( - np.allclose(output_no_lora, output_adapter_2, atol=1e-3, rtol=1e-3), - "Adapter outputs should be different.", + assert not np.allclose(base_pipe_output, output_adapter_2, atol=1e-3, rtol=1e-3), ( + "Adapter outputs should be different." ) pipe.set_adapters(["adapter-1", "adapter-2"]) output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertFalse( - np.allclose(output_no_lora, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter outputs should be different.", + assert not np.allclose(base_pipe_output, output_adapter_mixed, atol=1e-3, rtol=1e-3), ( + "Adapter outputs should be different." ) - # Fuse and unfuse should lead to the same results - self.assertFalse( - np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), - "Adapter 1 and 2 should give different results", + assert not np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), ( + "Adapter 1 and 2 should give different results" ) - - self.assertFalse( - np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 1 and mixed adapters should give different results", + assert not np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), ( + "Adapter 1 and mixed adapters should give different results" ) - - self.assertFalse( - np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 2 and mixed adapters should give different results", + assert not np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), ( + "Adapter 2 and mixed adapters should give different results" ) pipe.disable_lora() output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3), - "output with no lora and output with lora disabled should give same results", + assert np.allclose(base_pipe_output, output_disabled, atol=1e-3, rtol=1e-3), ( + "output with no lora and output with lora disabled should give same results" ) def test_wrong_adapter_name_raises_error(self): adapter_name = "adapter-1" - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name ) - - with self.assertRaises(ValueError) as err_context: + with pytest.raises(ValueError) as err_context: pipe.set_adapters("test") + assert "not in the list of present adapters" in str(err_context.value) - self.assertTrue("not in the list of present adapters" in str(err_context.exception)) - - # test this works. pipe.set_adapters(adapter_name) _ = pipe(**inputs, generator=torch.manual_seed(0))[0] - def test_multiple_wrong_adapter_name_raises_error(self): + def test_multiple_wrong_adapter_name_raises_error(self, pipe): adapter_name = "adapter-1" - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name ) @@ -984,105 +807,81 @@ def test_multiple_wrong_adapter_name_raises_error(self): logger.setLevel(30) with CaptureLogger(logger) as cap_logger: pipe.set_adapters(adapter_name, adapter_weights=scale_with_wrong_components) - wrong_components = sorted(set(scale_with_wrong_components.keys())) msg = f"The following components in `adapter_weights` are not part of the pipeline: {wrong_components}. " - self.assertTrue(msg in str(cap_logger.out)) - - # test this works. + assert msg in str(cap_logger.out) pipe.set_adapters(adapter_name) _ = pipe(**inputs, generator=torch.manual_seed(0))[0] - def test_simple_inference_with_text_denoiser_block_scale(self): + def test_simple_inference_with_text_denoiser_block_scale(self, base_pipe_output, pipe): """ Tests a simple inference with lora attached to text encoder and unet, attaches one adapter and set different weights for different blocks (i.e. block lora) """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.get_base_pipe_output() - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" weights_1 = {"text_encoder": 2, "unet": {"down": 5}} pipe.set_adapters("adapter-1", weights_1) output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - weights_2 = {"unet": {"up": 5}} pipe.set_adapters("adapter-1", weights_2) output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertFalse( - np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3), - "LoRA weights 1 and 2 should give different results", + assert not np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3), ( + "LoRA weights 1 and 2 should give different results" ) - self.assertFalse( - np.allclose(output_no_lora, output_weights_1, atol=1e-3, rtol=1e-3), - "No adapter and LoRA weights 1 should give different results", + assert not np.allclose(base_pipe_output, output_weights_1, atol=1e-3, rtol=1e-3), ( + "No adapter and LoRA weights 1 should give different results" ) - self.assertFalse( - np.allclose(output_no_lora, output_weights_2, atol=1e-3, rtol=1e-3), - "No adapter and LoRA weights 2 should give different results", + assert not np.allclose(base_pipe_output, output_weights_2, atol=1e-3, rtol=1e-3), ( + "No adapter and LoRA weights 2 should give different results" ) pipe.disable_lora() output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3), - "output with no lora and output with lora disabled should give same results", + assert np.allclose(base_pipe_output, output_disabled, atol=1e-3, rtol=1e-3), ( + "output with no lora and output with lora disabled should give same results" ) - def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): + def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self, base_pipe_output, pipe): """ Tests a simple inference with lora attached to text encoder and unet, attaches multiple adapters and set different weights for different blocks (i.e. block lora) """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.get_base_pipe_output() - if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" scales_1 = {"text_encoder": 2, "unet": {"down": 5}} scales_2 = {"unet": {"down": 5, "mid": 5}} - pipe.set_adapters("adapter-1", scales_1) output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -1092,35 +891,25 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2]) output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] - # Fuse and unfuse should lead to the same results - self.assertFalse( - np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), - "Adapter 1 and 2 should give different results", + assert not np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), ( + "Adapter 1 and 2 should give different results" ) - - self.assertFalse( - np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 1 and mixed adapters should give different results", + assert not np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), ( + "Adapter 1 and mixed adapters should give different results" ) - - self.assertFalse( - np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 2 and mixed adapters should give different results", + assert not np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), ( + "Adapter 2 and mixed adapters should give different results" ) pipe.disable_lora() output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3), - "output with no lora and output with lora disabled should give same results", + assert np.allclose(base_pipe_output, output_disabled, atol=1e-3, rtol=1e-3), ( + "output with no lora and output with lora disabled should give same results" ) - - # a mismatching number of adapter_names and adapter_weights should raise an error - with self.assertRaises(ValueError): + with pytest.raises(ValueError): pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1]) - def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self, pipe): """Tests that any valid combination of lora block scales can be used in pipe.set_adapter""" def updown_options(blocks_with_tf, layers_per_block, value): @@ -1130,13 +919,11 @@ def updown_options(blocks_with_tf, layers_per_block, value): """ num_val = value list_val = [value] * layers_per_block - node_opts = [None, num_val, list_val] node_opts_foreach_block = [node_opts] * len(blocks_with_tf) - updown_opts = [num_val] for nodes in product(*node_opts_foreach_block): - if all(n is None for n in nodes): + if all((n is None for n in nodes)): continue opt = {} for b, n in zip(blocks_with_tf, nodes): @@ -1150,30 +937,24 @@ def all_possible_dict_opts(unet, value): Generate every possible combination for how a lora weight dict can be. E.g. 2, {"unet: {"down": 2}}, {"unet: {"down": [2,2,2]}}, {"unet: {"mid": 2, "up": [2,2,2]}}, ... """ - - down_blocks_with_tf = [i for i, d in enumerate(unet.down_blocks) if hasattr(d, "attentions")] - up_blocks_with_tf = [i for i, u in enumerate(unet.up_blocks) if hasattr(u, "attentions")] - + down_blocks_with_tf = [i for (i, d) in enumerate(unet.down_blocks) if hasattr(d, "attentions")] + up_blocks_with_tf = [i for (i, u) in enumerate(unet.up_blocks) if hasattr(u, "attentions")] layers_per_block = unet.config.layers_per_block - text_encoder_opts = [None, value] text_encoder_2_opts = [None, value] mid_opts = [None, value] down_opts = [None] + updown_options(down_blocks_with_tf, layers_per_block, value) up_opts = [None] + updown_options(up_blocks_with_tf, layers_per_block + 1, value) - opts = [] - for t1, t2, d, m, u in product(text_encoder_opts, text_encoder_2_opts, down_opts, mid_opts, up_opts): - if all(o is None for o in (t1, t2, d, m, u)): + if all((o is None for o in (t1, t2, d, m, u))): continue opt = {} if t1 is not None: opt["text_encoder"] = t1 if t2 is not None: opt["text_encoder_2"] = t2 - if all(o is None for o in (d, m, u)): - # no unet scaling + if all((o is None for o in (d, m, u))): continue opt["unet"] = {} if d is not None: @@ -1183,14 +964,9 @@ def all_possible_dict_opts(unet, value): if u is not None: opt["unet"]["up"] = u opts.append(opt) - return opts - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) + _, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_cls) pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") @@ -1206,40 +982,32 @@ def all_possible_dict_opts(unet, value): # test if lora block scales can be set with this scale_dict if not self.has_two_text_encoders and "text_encoder_2" in scale_dict: del scale_dict["text_encoder_2"] - pipe.set_adapters("adapter-1", scale_dict) # test will fail if this line throws an error - def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self): + def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self, base_pipe_output, pipe): """ Tests a simple inference with lora attached to text encoder and unet, attaches multiple adapters and set/delete them """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.get_base_pipe_output() - if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." if self.has_two_text_encoders or self.has_three_text_encoders: lora_loadable_components = self.pipeline_class._lora_loadable_modules if "text_encoder_2" in lora_loadable_components: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" pipe.set_adapters("adapter-1") output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -1250,35 +1018,26 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self): pipe.set_adapters(["adapter-1", "adapter-2"]) output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertFalse( - np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), - "Adapter 1 and 2 should give different results", + assert not np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), ( + "Adapter 1 and 2 should give different results" ) - - self.assertFalse( - np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 1 and mixed adapters should give different results", + assert not np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), ( + "Adapter 1 and mixed adapters should give different results" ) - - self.assertFalse( - np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 2 and mixed adapters should give different results", + assert not np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), ( + "Adapter 2 and mixed adapters should give different results" ) pipe.delete_adapters("adapter-1") output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), - "Adapter 1 and 2 should give different results", + assert np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), ( + "Adapter 1 and 2 should give different results" ) pipe.delete_adapters("adapter-2") output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3), - "output with no lora and output with lora disabled should give same results", + assert np.allclose(base_pipe_output, output_deleted_adapters, atol=1e-3, rtol=1e-3), ( + "output with no lora and output with lora disabled should give same results" ) if "text_encoder" in self.pipeline_class._lora_loadable_modules: @@ -1288,49 +1047,39 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self): denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." pipe.set_adapters(["adapter-1", "adapter-2"]) pipe.delete_adapters(["adapter-1", "adapter-2"]) - output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3), - "output with no lora and output with lora disabled should give same results", + assert np.allclose(base_pipe_output, output_deleted_adapters, atol=1e-3, rtol=1e-3), ( + "output with no lora and output with lora disabled should give same results" ) - def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): + def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self, base_pipe_output, pipe): """ Tests a simple inference with lora attached to text encoder and unet, attaches multiple adapters and set them """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.get_base_pipe_output() - if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." if self.has_two_text_encoders or self.has_three_text_encoders: lora_loadable_components = self.pipeline_class._lora_loadable_modules if "text_encoder_2" in lora_loadable_components: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" pipe.set_adapters("adapter-1") output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -1340,37 +1089,26 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): pipe.set_adapters(["adapter-1", "adapter-2"]) output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] - - # Fuse and unfuse should lead to the same results - self.assertFalse( - np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), - "Adapter 1 and 2 should give different results", + assert not np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), ( + "Adapter 1 and 2 should give different results" ) - - self.assertFalse( - np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 1 and mixed adapters should give different results", + assert not np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), ( + "Adapter 1 and mixed adapters should give different results" ) - - self.assertFalse( - np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 2 and mixed adapters should give different results", + assert not np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), ( + "Adapter 2 and mixed adapters should give different results" ) pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6]) output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertFalse( - np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Weighted adapter and mixed adapter should give different results", + assert not np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3), ( + "Weighted adapter and mixed adapter should give different results" ) pipe.disable_lora() output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3), - "output with no lora and output with lora disabled should give same results", + assert np.allclose(base_pipe_output, output_disabled, atol=1e-3, rtol=1e-3), ( + "output with no lora and output with lora disabled should give same results" ) @skip_mps @@ -1379,29 +1117,25 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.", strict=False, ) - def test_lora_fuse_nan(self): - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + def test_lora_fuse_nan(self, pipe): + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." - # corrupt one LoRA weight with `inf` values with torch.no_grad(): if self.unet_kwargs: pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float( "inf" ) else: - named_modules = [name for name, _ in pipe.transformer.named_modules()] + named_modules = [name for (name, _) in pipe.transformer.named_modules()] possible_tower_names = [ "transformer_blocks", "blocks", @@ -1416,66 +1150,53 @@ def test_lora_fuse_nan(self): raise ValueError(reason) for tower_name in filtered_tower_names: transformer_tower = getattr(pipe.transformer, tower_name) - has_attn1 = any("attn1" in name for name in named_modules) + has_attn1 = any(("attn1" in name for name in named_modules)) if has_attn1: transformer_tower[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf") else: transformer_tower[0].attn.to_q.lora_A["adapter-1"].weight += float("inf") - - # with `safe_fusing=True` we should see an Error - with self.assertRaises(ValueError): + with pytest.raises(ValueError): pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) - # without we should not see an error, but every image will be black pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) out = pipe(**inputs)[0] + assert np.isnan(out).all() - self.assertTrue(np.isnan(out).all()) - - def test_get_adapters(self): + def test_get_adapters(self, pipe): """ Tests a simple usecase where we attach multiple adapters and check if the results are the expected results """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") adapter_names = pipe.get_active_adapters() - self.assertListEqual(adapter_names, ["adapter-1"]) + assert adapter_names == ["adapter-1"] pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") denoiser.add_adapter(denoiser_lora_config, "adapter-2") adapter_names = pipe.get_active_adapters() - self.assertListEqual(adapter_names, ["adapter-2"]) + assert adapter_names == ["adapter-2"] pipe.set_adapters(["adapter-1", "adapter-2"]) - self.assertListEqual(pipe.get_active_adapters(), ["adapter-1", "adapter-2"]) + assert sorted(pipe.get_active_adapters()) == ["adapter-1", "adapter-2"] - def test_get_list_adapters(self): + def test_get_list_adapters(self, pipe): """ Tests a simple usecase where we attach multiple adapters and check if the results are the expected results """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() # 1. dicts_to_be_checked = {} if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") dicts_to_be_checked = {"text_encoder": ["adapter-1"]} - if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") dicts_to_be_checked.update({"unet": ["adapter-1"]}) @@ -1483,84 +1204,72 @@ def test_get_list_adapters(self): pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") dicts_to_be_checked.update({"transformer": ["adapter-1"]}) - self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked) + assert pipe.get_list_adapters() == dicts_to_be_checked # 2. dicts_to_be_checked = {} if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]} - if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-2") dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]}) else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]}) - - self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked) + assert pipe.get_list_adapters() == dicts_to_be_checked # 3. pipe.set_adapters(["adapter-1", "adapter-2"]) - dicts_to_be_checked = {} if "text_encoder" in self.pipeline_class._lora_loadable_modules: dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]} - if self.unet_kwargs is not None: dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]}) else: dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]}) - - self.assertDictEqual( - pipe.get_list_adapters(), - dicts_to_be_checked, - ) + assert pipe.get_list_adapters() == dicts_to_be_checked # 4. dicts_to_be_checked = {} if "text_encoder" in self.pipeline_class._lora_loadable_modules: dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]} - if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-3") dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2", "adapter-3"]}) else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-3") dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2", "adapter-3"]}) - - self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked) + assert pipe.get_list_adapters() == dicts_to_be_checked def test_simple_inference_with_text_lora_denoiser_fused_multi( - self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3 + self, + pipe, + expected_atol: float = 1e-3, + expected_rtol: float = 1e-3, ): """ Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model and makes sure it works as expected - with unet and multi-adapter case """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." denoiser.add_adapter(denoiser_lora_config, "adapter-2") if self.has_two_text_encoders or self.has_three_text_encoders: lora_loadable_components = self.pipeline_class._lora_loadable_modules if "text_encoder_2" in lora_loadable_components: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") # set them to multi-adapter inference mode @@ -1571,45 +1280,46 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi( outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"]) - self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") + assert pipe.num_fused_loras == 1, ( + f"pipe.num_fused_loras={pipe.num_fused_loras!r}, pipe.fused_loras={pipe.fused_loras!r}" + ) # Fusing should still keep the LoRA layers so output should remain the same outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol), - "Fused lora should not change the output", + assert np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol), ( + "Fused lora should not change the output" ) pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) - self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") + assert pipe.num_fused_loras == 0, ( + f"pipe.num_fused_loras={pipe.num_fused_loras!r}, pipe.fused_loras={pipe.fused_loras!r}" + ) if "text_encoder" in self.pipeline_class._lora_loadable_modules: - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers") - - self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers") - + assert check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers" + assert check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers" if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers" - ) - + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers" pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"]) - self.assertTrue(pipe.num_fused_loras == 2, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") + assert pipe.num_fused_loras == 2, ( + f"pipe.num_fused_loras={pipe.num_fused_loras!r}, pipe.fused_loras={pipe.fused_loras!r}" + ) - # Fusing should still keep the LoRA layers output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol), - "Fused lora should not change the output", + assert np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol), ( + "Fused lora should not change the output" ) + pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) - self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") + assert pipe.num_fused_loras == 0, ( + f"pipe.num_fused_loras={pipe.num_fused_loras!r}, pipe.fused_loras={pipe.fused_loras!r}" + ) - def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3): + def test_lora_scale_kwargs_match_fusion( + self, base_pipe_output, expected_atol: float = 1e-3, expected_rtol: float = 1e-3 + ): attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) - for lora_scale in [1.0, 0.8]: components, text_lora_config, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -1617,26 +1327,19 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.get_base_pipe_output() - if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - ) + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." if self.has_two_text_encoders or self.has_three_text_encoders: lora_loadable_components = self.pipeline_class._lora_loadable_modules if "text_encoder_2" in lora_loadable_components: pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), - "Lora not correctly set in text encoder 2", - ) + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" pipe.set_adapters(["adapter-1"]) attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}} @@ -1647,166 +1350,119 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec adapter_names=["adapter-1"], lora_scale=lora_scale, ) - self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") + assert pipe.num_fused_loras == 1, ( + f"pipe.num_fused_loras={pipe.num_fused_loras!r}, pipe.fused_loras={pipe.fused_loras!r}" + ) outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol), - "Fused lora should not change the output", + assert np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol), ( + "Fused lora should not change the output" ) - self.assertFalse( - np.allclose(output_no_lora, outputs_lora_1, atol=expected_atol, rtol=expected_rtol), - "LoRA should change the output", + assert not np.allclose(base_pipe_output, outputs_lora_1, atol=expected_atol, rtol=expected_rtol), ( + "LoRA should change the output" ) - def test_simple_inference_with_dora(self): - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(use_dora=True) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + def test_simple_inference_with_dora(self, pipe): + _, text_lora_config, denoiser_lora_config = self.get_dummy_components(use_dora=True) _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(output_no_dora_lora.shape == self.output_shape) - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + assert output_no_dora_lora.shape == self.output_shape + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertFalse( - np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3), - "DoRA lora should change the output", + assert not np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3), ( + "DoRA lora should change the output" ) - def test_missing_keys_warning(self): - # Skip text encoder check for now as that is handled with `transformers`. - components, _, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - + def test_missing_keys_warning(self, tmpdirname, pipe): + _, _, denoiser_lora_config = self.get_dummy_components() denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts - ) - pipe.unload_lora_weights() - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) - state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True) + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts) + pipe.unload_lora_weights() + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True) - # To make things dynamic since we cannot settle with a single key for all the models where we - # offer PEFT support. missing_key = [k for k in state_dict if "lora_A" in k][0] del state_dict[missing_key] - logger = logging.get_logger("diffusers.utils.peft_utils") logger.setLevel(30) with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(state_dict) - # Since the missing key won't contain the adapter name ("default_0"). - # Also strip out the component prefix (such as "unet." from `missing_key`). component = list({k.split(".")[0] for k in state_dict})[0] - self.assertTrue(missing_key.replace(f"{component}.", "") in cap_logger.out.replace("default_0.", "")) + assert missing_key.replace(f"{component}.", "") in cap_logger.out.replace("default_0.", "") - def test_unexpected_keys_warning(self): - # Skip text encoder check for now as that is handled with `transformers`. - components, _, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + def test_unexpected_keys_warning(self, tmpdirname, pipe): + _, _, denoiser_lora_config = self.get_dummy_components() denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts - ) - pipe.unload_lora_weights() - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) - state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True) + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts) + pipe.unload_lora_weights() + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True) unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat" state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device) - logger = logging.get_logger("diffusers.utils.peft_utils") logger.setLevel(30) with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(state_dict) + assert ".diffusers_cat" in cap_logger.out - self.assertTrue(".diffusers_cat" in cap_logger.out) - - @unittest.skip("This is failing for now - need to investigate") - def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self): + @pytest.mark.skip("This is failing for now - need to investigate") + def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self, pipe): """ Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights and makes sure it works as expected """ - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True) if self.has_two_text_encoders or self.has_three_text_encoders: pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True) - - # Just makes sure it works. _ = pipe(**inputs, generator=torch.manual_seed(0))[0] - def test_modify_padding_mode(self): + def test_modify_padding_mode(self, pipe): def set_pad_mode(network, mode="circular"): for _, module in network.named_modules(): if isinstance(module, torch.nn.Conv2d): module.padding_mode = mode - components, _, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) _pad_mode = "circular" set_pad_mode(pipe.vae, _pad_mode) set_pad_mode(pipe.unet, _pad_mode) - _, _, inputs = self.get_dummy_inputs() _ = pipe(**inputs)[0] - def test_logs_info_when_no_lora_keys_found(self): - # Skip text encoder check for now as that is handled with `transformers`. - components, _, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - + def test_logs_info_when_no_lora_keys_found(self, base_pipe_output, pipe): _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.get_base_pipe_output() no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)} logger = logging.get_logger("diffusers.loaders.peft") logger.setLevel(logging.WARNING) - with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(no_op_state_dict) - out_after_lora_attempt = pipe(**inputs, generator=torch.manual_seed(0))[0] + out_after_lora_attempt = pipe(**inputs, generator=torch.manual_seed(0))[0] denoiser = getattr(pipe, "unet") if self.unet_kwargs is not None else getattr(pipe, "transformer") - self.assertTrue(cap_logger.out.startswith(f"No LoRA keys associated to {denoiser.__class__.__name__}")) - self.assertTrue(np.allclose(output_no_lora, out_after_lora_attempt, atol=1e-5, rtol=1e-5)) + assert cap_logger.out.startswith(f"No LoRA keys associated to {denoiser.__class__.__name__}") + assert np.allclose(base_pipe_output, out_after_lora_attempt, atol=1e-05, rtol=1e-05) - # test only for text encoder for lora_module in self.pipeline_class._lora_loadable_modules: if "text_encoder" in lora_module: text_encoder = getattr(pipe, lora_module) @@ -1814,109 +1470,80 @@ def test_logs_info_when_no_lora_keys_found(self): prefix = "text_encoder" elif lora_module == "text_encoder_2": prefix = "text_encoder_2" - logger = logging.get_logger("diffusers.loaders.lora_base") logger.setLevel(logging.WARNING) - with CaptureLogger(logger) as cap_logger: self.pipeline_class.load_lora_into_text_encoder( no_op_state_dict, network_alphas=None, text_encoder=text_encoder, prefix=prefix ) + assert cap_logger.out.startswith(f"No LoRA keys associated to {text_encoder.__class__.__name__}") - self.assertTrue( - cap_logger.out.startswith(f"No LoRA keys associated to {text_encoder.__class__.__name__}") - ) - - def test_set_adapters_match_attention_kwargs(self): + def test_set_adapters_match_attention_kwargs(self, base_pipe_output, tmpdirname, pipe): """Test to check if outputs after `set_adapters()` and attention kwargs match.""" attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = self.get_base_pipe_output() pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) lora_scale = 0.5 attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}} output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - self.assertFalse( - np.allclose(output_no_lora, output_lora_scale, atol=1e-3, rtol=1e-3), - "Lora + scale should change the output", + assert not np.allclose(base_pipe_output, output_lora_scale, atol=1e-3, rtol=1e-3), ( + "Lora + scale should change the output" ) pipe.set_adapters("default", lora_scale) output_lora_scale_wo_kwargs = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - not np.allclose(output_no_lora, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3), - "Lora + scale should change the output", + assert not np.allclose(base_pipe_output, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3), ( + "Lora + scale should change the output" ) - self.assertTrue( - np.allclose(output_lora_scale, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3), - "Lora + scale should match the output of `set_adapters()`.", + assert np.allclose(output_lora_scale, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3), ( + "Lora + scale should match the output of `set_adapters()`." ) - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts - ) - - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - - for module_name, module in modules_to_save.items(): - self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") - - output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - self.assertTrue( - not np.allclose(output_no_lora, output_lora_from_pretrained, atol=1e-3, rtol=1e-3), - "Lora + scale should change the output", - ) - self.assertTrue( - np.allclose(output_lora_scale, output_lora_from_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results as attention_kwargs.", - ) - self.assertTrue( - np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results as set_adapters().", - ) + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - @require_peft_version_greater("0.13.2") - def test_lora_B_bias(self): - # Currently, this test is only relevant for Flux Control LoRA as we are not - # aware of any other LoRA checkpoint that has its `lora_B` biases trained. - components, _, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + for module_name, module in modules_to_save.items(): + assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" - # keep track of the bias values of the base layers to perform checks later. + output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] + assert not np.allclose(base_pipe_output, output_lora_from_pretrained, atol=1e-3, rtol=1e-3), ( + "Lora + scale should change the output" + ) + assert np.allclose(output_lora_scale, output_lora_from_pretrained, atol=1e-3, rtol=1e-3), ( + "Loading from saved checkpoints should give same results as attention_kwargs." + ) + assert np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3), ( + "Loading from saved checkpoints should give same results as set_adapters()." + ) + + @require_peft_version_greater("0.13.2") + def test_lora_B_bias(self, base_pipe_output, pipe): + _, _, denoiser_lora_config = self.get_dummy_components() bias_values = {} denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer for name, module in denoiser.named_modules(): - if any(k in name for k in self.denoiser_target_modules): + if any((k in name for k in self.denoiser_target_modules)): if module.bias is not None: bias_values[name] = module.bias.data.clone() _, _, inputs = self.get_dummy_inputs(with_generator=False) - original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - denoiser_lora_config.lora_bias = False if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") lora_bias_false_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - pipe.delete_adapters("adapter-1") + pipe.delete_adapters("adapter-1") denoiser_lora_config.lora_bias = True if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") @@ -1924,38 +1551,29 @@ def test_lora_B_bias(self): pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") lora_bias_true_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertFalse(np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3)) - self.assertFalse(np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3)) - self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3)) + assert not np.allclose(base_pipe_output, lora_bias_false_output, atol=1e-3, rtol=1e-3) + assert not np.allclose(base_pipe_output, lora_bias_true_output, atol=1e-3, rtol=1e-3) + assert not np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3) - def test_correct_lora_configs_with_different_ranks(self): - components, _, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + def test_correct_lora_configs_with_different_ranks(self, base_pipe_output, pipe): + _, _, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) - original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] - if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0] if self.unet_kwargs is not None: pipe.unet.delete_adapters("adapter-1") else: pipe.transformer.delete_adapters("adapter-1") - denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer for name, _ in denoiser.named_modules(): - if "to_k" in name and "attn" in name and "lora" not in name: + if "to_k" in name and "attn" in name and ("lora" not in name): module_name_to_rank_update = name.replace(".base_layer.", ".") break - - # change the rank_pattern updated_rank = denoiser_lora_config.r * 2 denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank} @@ -1965,35 +1583,31 @@ def test_correct_lora_configs_with_different_ranks(self): else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern - - self.assertTrue(updated_rank_pattern == {module_name_to_rank_update: updated_rank}) + assert updated_rank_pattern == {module_name_to_rank_update: updated_rank} lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3)) - self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3)) + assert not np.allclose(base_pipe_output, lora_output_same_rank, atol=1e-3, rtol=1e-3) + assert not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3) if self.unet_kwargs is not None: pipe.unet.delete_adapters("adapter-1") else: pipe.transformer.delete_adapters("adapter-1") - # similarly change the alpha_pattern updated_alpha = denoiser_lora_config.lora_alpha * 2 denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha} + if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") - self.assertTrue( - pipe.unet.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha} - ) + assert pipe.unet.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha} else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") - self.assertTrue( - pipe.transformer.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha} - ) - + assert pipe.transformer.peft_config["adapter-1"].alpha_pattern == { + module_name_to_rank_update: updated_alpha + } lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3)) - self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3)) + assert not np.allclose(base_pipe_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3) + assert not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3) def test_layerwise_casting_inference_denoiser(self): from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS @@ -2007,12 +1621,12 @@ def check_linear_dtype(module, storage_dtype, compute_dtype): if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS): continue dtype_to_check = storage_dtype - if "lora" in name or any(re.search(pattern, name) for pattern in patterns_to_check): + if "lora" in name or any((re.search(pattern, name) for pattern in patterns_to_check)): dtype_to_check = compute_dtype if getattr(submodule, "weight", None) is not None: - self.assertEqual(submodule.weight.dtype, dtype_to_check) + assert submodule.weight.dtype == dtype_to_check if getattr(submodule, "bias", None) is not None: - self.assertEqual(submodule.bias.dtype, dtype_to_check) + assert submodule.bias.dtype == dtype_to_check def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32): components, text_lora_config, denoiser_lora_config = self.get_dummy_components() @@ -2025,23 +1639,20 @@ def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32): if storage_dtype is not None: denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) check_linear_dtype(denoiser, storage_dtype, compute_dtype) - return pipe _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe_fp32 = initialize_pipeline(storage_dtype=None) pipe_fp32(**inputs, generator=torch.manual_seed(0))[0] - pipe_float8_e4m3_fp32 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32) pipe_float8_e4m3_fp32(**inputs, generator=torch.manual_seed(0))[0] - pipe_float8_e4m3_bf16 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0] @require_peft_version_greater("0.14.0") - def test_layerwise_casting_peft_input_autocast_denoiser(self): - r""" + def test_layerwise_casting_peft_input_autocast_denoiser(self, tmpdirname): + """ A test that checks if layerwise casting works correctly with PEFT layers and forward pass does not fail. This is different from `test_layerwise_casting_inference_denoiser` as that disables the application of layerwise cast hooks on the PEFT layers (relevant logic in `models.modeling_utils.ModelMixin.enable_layerwise_casting`). @@ -2054,7 +1665,6 @@ def test_layerwise_casting_peft_input_autocast_denoiser(self): See the docstring of [`hooks.layerwise_casting.PeftInputAutocastDisableHook`] for more details. """ - from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS from diffusers.hooks.layerwise_casting import ( _PEFT_AUTOCAST_DISABLE_HOOK, @@ -2066,35 +1676,32 @@ def test_layerwise_casting_peft_input_autocast_denoiser(self): compute_dtype = torch.float32 def check_module(denoiser): - # This will also check if the peft layers are in torch.float8_e4m3fn dtype (unlike test_layerwise_casting_inference_denoiser) for name, module in denoiser.named_modules(): if not isinstance(module, _GO_LC_SUPPORTED_PYTORCH_LAYERS): continue dtype_to_check = storage_dtype - if any(re.search(pattern, name) for pattern in patterns_to_check): + if any((re.search(pattern, name) for pattern in patterns_to_check)): dtype_to_check = compute_dtype if getattr(module, "weight", None) is not None: - self.assertEqual(module.weight.dtype, dtype_to_check) + assert module.weight.dtype == dtype_to_check if getattr(module, "bias", None) is not None: - self.assertEqual(module.bias.dtype, dtype_to_check) + assert module.bias.dtype == dtype_to_check if isinstance(module, BaseTunerLayer): - self.assertTrue(getattr(module, "_diffusers_hook", None) is not None) - self.assertTrue(module._diffusers_hook.get_hook(_PEFT_AUTOCAST_DISABLE_HOOK) is not None) + assert getattr(module, "_diffusers_hook", None is not None) + assert module._diffusers_hook.get_hook(_PEFT_AUTOCAST_DISABLE_HOOK) is not None - # 1. Test forward with add_adapter components, _, denoiser_lora_config = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device, dtype=compute_dtype) - pipe.set_progress_bar_config(disable=None) + pipe.set_progress_bar_config(disable=None) denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN if getattr(denoiser, "_skip_layerwise_casting_patterns", None) is not None: patterns_to_check += tuple(denoiser._skip_layerwise_casting_patterns) - apply_layerwise_casting( denoiser, storage_dtype=storage_dtype, compute_dtype=compute_dtype, skip_modules_pattern=patterns_to_check ) @@ -2103,83 +1710,73 @@ def check_module(denoiser): _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe(**inputs, generator=torch.manual_seed(0))[0] - # 2. Test forward with load_lora_weights - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts - ) + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - components, _, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device, dtype=compute_dtype) - pipe.set_progress_bar_config(disable=None) - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - apply_layerwise_casting( - denoiser, - storage_dtype=storage_dtype, - compute_dtype=compute_dtype, - skip_modules_pattern=patterns_to_check, - ) - check_module(denoiser) - - _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe(**inputs, generator=torch.manual_seed(0))[0] - - @parameterized.expand([4, 8, 16]) - def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha): - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha) + components, _, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device, dtype=compute_dtype) + pipe.set_progress_bar_config(disable=None) + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + apply_layerwise_casting( + denoiser, + storage_dtype=storage_dtype, + compute_dtype=compute_dtype, + skip_modules_pattern=patterns_to_check, + ) + check_module(denoiser) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe(**inputs, generator=torch.manual_seed(0))[0] + @pytest.mark.parametrize("lora_alpha", [4, 8, 16]) + def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha, tmpdirname, pipe): + _, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha) pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config ) - with tempfile.TemporaryDirectory() as tmpdir: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) - self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas) - pipe.unload_lora_weights() + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts, **lora_metadatas) + pipe.unload_lora_weights() + out = pipe.lora_state_dict(tmpdirname, return_lora_metadata=True) + if len(out) == 3: + (_, _, parsed_metadata) = out + elif len(out) == 2: + (_, parsed_metadata) = out + denoiser_key = ( + f"{self.pipeline_class.transformer_name}" + if self.transformer_kwargs is not None + else f"{self.pipeline_class.unet_name}" + ) + assert any((k.startswith(f"{denoiser_key}.") for k in parsed_metadata)) - out = pipe.lora_state_dict(tmpdir, return_lora_metadata=True) - if len(out) == 3: - _, _, parsed_metadata = out - elif len(out) == 2: - _, parsed_metadata = out + check_module_lora_metadata( + parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=denoiser_key + ) - denoiser_key = ( - f"{self.pipeline_class.transformer_name}" - if self.transformer_kwargs is not None - else f"{self.pipeline_class.unet_name}" - ) - self.assertTrue(any(k.startswith(f"{denoiser_key}.") for k in parsed_metadata)) + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + text_encoder_key = self.pipeline_class.text_encoder_name + assert any((k.startswith(f"{text_encoder_key}.") for k in parsed_metadata)) check_module_lora_metadata( - parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=denoiser_key + parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_key ) - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - text_encoder_key = self.pipeline_class.text_encoder_name - self.assertTrue(any(k.startswith(f"{text_encoder_key}.") for k in parsed_metadata)) - check_module_lora_metadata( - parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_key - ) - - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - text_encoder_2_key = "text_encoder_2" - self.assertTrue(any(k.startswith(f"{text_encoder_2_key}.") for k in parsed_metadata)) - check_module_lora_metadata( - parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_2_key - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + text_encoder_2_key = "text_encoder_2" + assert any((k.startswith(f"{text_encoder_2_key}.") for k in parsed_metadata)) + check_module_lora_metadata( + parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_2_key + ) - @parameterized.expand([4, 8, 16]) - def test_lora_adapter_metadata_save_load_inference(self, lora_alpha): - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha) - pipe = self.pipeline_class(**components).to(torch_device) + @pytest.mark.parametrize("lora_alpha", [4, 8, 16]) + def test_lora_adapter_metadata_save_load_inference(self, lora_alpha, tmpdirname, pipe): + _, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha) _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe, _ = self.add_adapters_to_pipeline( @@ -2187,183 +1784,151 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha): ) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdir: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) - self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas) - pipe.unload_lora_weights() - pipe.load_lora_weights(tmpdir) - - output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) - self.assertTrue( - np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match." - ) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts, **lora_metadatas) + pipe.unload_lora_weights() + pipe.load_lora_weights(tmpdirname) + output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match." - def test_lora_unload_add_adapter(self): + def test_lora_unload_add_adapter(self, pipe): """Tests if `unload_lora_weights()` -> `add_adapter()` works.""" - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components).to(torch_device) + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config ) _ = pipe(**inputs, generator=torch.manual_seed(0))[0] - - # unload and then add. pipe.unload_lora_weights() + pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config ) _ = pipe(**inputs, generator=torch.manual_seed(0))[0] - def test_inference_load_delete_load_adapters(self): - "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works." - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + def test_inference_load_delete_load_adapters(self, base_pipe_output, tmpdirname, pipe): + """Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works.""" + _, text_lora_config, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.get_base_pipe_output() - if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." if self.has_two_text_encoders or self.has_three_text_encoders: lora_loadable_components = self.pipeline_class._lora_loadable_modules if "text_encoder_2" in lora_loadable_components: pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - # First, delete adapter and compare. - pipe.delete_adapters(pipe.get_active_adapters()[0]) - output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertFalse(np.allclose(output_adapter_1, output_no_adapter, atol=1e-3, rtol=1e-3)) - self.assertTrue(np.allclose(output_no_lora, output_no_adapter, atol=1e-3, rtol=1e-3)) + pipe.delete_adapters(pipe.get_active_adapters()[0]) + output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert not np.allclose(output_adapter_1, output_no_adapter, atol=1e-3, rtol=1e-3) + assert np.allclose(base_pipe_output, output_no_adapter, atol=1e-3, rtol=1e-3) - # Then load adapter and compare. - pipe.load_lora_weights(tmpdirname) - output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3)) + pipe.load_lora_weights(tmpdirname) + output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3) - def _test_group_offloading_inference_denoiser(self, offload_type, use_stream): + def _test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname, pipe): from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook onload_device = torch_device offload_device = torch.device("cpu") + _, _, denoiser_lora_config = self.get_dummy_components() - components, text_lora_config, denoiser_lora_config = self.get_dummy_components() + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." + + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts) + assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + components, _, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + check_if_lora_correctly_set(denoiser) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + denoiser.enable_group_offload( + onload_device=onload_device, + offload_device=offload_device, + offload_type=offload_type, + num_blocks_per_group=1, + use_stream=use_stream, + ) + for _, component in pipe.components.items(): + if isinstance(component, torch.nn.Module): + component.to(torch_device) - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts - ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser) + assert group_offload_hook_1 is not None - components, _, _ = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe.set_progress_bar_config(disable=None) - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] + pipe.unload_lora_weights() + group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser) + assert group_offload_hook_2 is not None - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - check_if_lora_correctly_set(denoiser) - _, _, inputs = self.get_dummy_inputs(with_generator=False) + output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841 + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + check_if_lora_correctly_set(denoiser) + group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser) + assert group_offload_hook_3 is not None - # Test group offloading with load_lora_weights - denoiser.enable_group_offload( - onload_device=onload_device, - offload_device=offload_device, - offload_type=offload_type, - num_blocks_per_group=1, - use_stream=use_stream, - ) - # Place other model-level components on `torch_device`. - for _, component in pipe.components.items(): - if isinstance(component, torch.nn.Module): - component.to(torch_device) - group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser) - self.assertTrue(group_offload_hook_1 is not None) - output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - # Test group offloading after removing the lora - pipe.unload_lora_weights() - group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser) - self.assertTrue(group_offload_hook_2 is not None) - output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841 - - # Add the lora again and check if group offloading works - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - check_if_lora_correctly_set(denoiser) - group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser) - self.assertTrue(group_offload_hook_3 is not None) - output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue(np.allclose(output_1, output_3, atol=1e-3, rtol=1e-3)) - - @parameterized.expand([("block_level", True), ("leaf_level", False), ("leaf_level", True)]) + output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0] + assert np.allclose(output_1, output_3, atol=1e-3, rtol=1e-3) + + @pytest.mark.parametrize( + "offload_type, use_stream", + [("block_level", True), ("leaf_level", False), ("leaf_level", True)], + ) @require_torch_accelerator - def test_group_offloading_inference_denoiser(self, offload_type, use_stream): + def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname, pipe): for cls in inspect.getmro(self.__class__): if "test_group_offloading_inference_denoiser" in cls.__dict__ and cls is not PeftLoraLoaderMixinTests: - # Skip this test if it is overwritten by child class. We need to do this because parameterized - # materializes the test methods on invocation which cannot be overridden. return - self._test_group_offloading_inference_denoiser(offload_type, use_stream) + self._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname, pipe) - @require_torch_accelerator - def test_lora_loading_model_cpu_offload(self): - components, _, denoiser_lora_config = self.get_dummy_components() + @pytest.mark.skipif(torch_device == "cpu", reason="test requires accelerator+PyTorch") + def test_lora_loading_model_cpu_offload(self, tmpdirname, pipe): + _, _, denoiser_lora_config = self.get_dummy_components() _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts - ) - # reinitialize the pipeline to mimic the inference workflow. - components, _, denoiser_lora_config = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe.enable_model_cpu_offload(device=torch_device) - pipe.load_lora_weights(tmpdirname) - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts) + + components, _, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.enable_model_cpu_offload(device=torch_device) + pipe.load_lora_weights(tmpdirname) + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser." output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(np.allclose(output_lora, output_lora_loaded, atol=1e-3, rtol=1e-3)) + assert np.allclose(output_lora, output_lora_loaded, atol=1e-3, rtol=1e-3) diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 7f849219c16f..988834acf546 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -24,6 +24,7 @@ import numpy as np import PIL.Image import PIL.ImageOps +import pytest import requests from numpy.linalg import norm from packaging import version @@ -265,7 +266,7 @@ def slow(test_case): Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. """ - return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) + return pytest.mark.skipif(not _run_slow_tests, reason="test is slow")(test_case) def nightly(test_case): @@ -275,7 +276,7 @@ def nightly(test_case): Slow tests are skipped by default. Set the RUN_NIGHTLY environment variable to a truthy value to run them. """ - return unittest.skipUnless(_run_nightly_tests, "test is nightly")(test_case) + return pytest.mark.skipif(not _run_nightly_tests, reason="test is nightly")(test_case) def is_torch_compile(test_case): @@ -350,9 +351,9 @@ def decorator(test_case): # These decorators are for accelerator-specific behaviours that are not GPU-specific def require_torch_accelerator(test_case): """Decorator marking a test that requires an accelerator backend and PyTorch.""" - return unittest.skipUnless(is_torch_available() and torch_device != "cpu", "test requires accelerator+PyTorch")( - test_case - ) + return pytest.mark.skipif( + not (is_torch_available() and torch_device != "cpu"), reason="test requires accelerator+PyTorch" + )(test_case) def require_torch_multi_gpu(test_case): @@ -441,9 +442,9 @@ def require_big_accelerator(test_case): device_properties = torch.cuda.get_device_properties(0) total_memory = device_properties.total_memory / (1024**3) - return unittest.skipUnless( - total_memory >= BIG_GPU_MEMORY, - f"test requires a hardware accelerator with at least {BIG_GPU_MEMORY} GB memory", + return pytest.mark.skipif( + not total_memory >= BIG_GPU_MEMORY, + reason=f"test requires a hardware accelerator with at least {BIG_GPU_MEMORY} GB memory", )(test_case) @@ -509,7 +510,7 @@ def require_peft_backend(test_case): Decorator marking a test that requires PEFT backend, this would require some specific versions of PEFT and transformers. """ - return unittest.skipUnless(USE_PEFT_BACKEND, "test requires PEFT backend")(test_case) + return pytest.mark.skipif(not USE_PEFT_BACKEND, reason="test requires PEFT backend")(test_case) def require_timm(test_case): @@ -550,8 +551,8 @@ def decorator(test_case): correct_peft_version = is_peft_available() and version.parse( version.parse(importlib.metadata.version("peft")).base_version ) > version.parse(peft_version) - return unittest.skipUnless( - correct_peft_version, f"test requires PEFT backend with the version greater than {peft_version}" + return pytest.mark.skipif( + not correct_peft_version, reason=f"test requires PEFT backend with the version greater than {peft_version}" )(test_case) return decorator @@ -567,9 +568,9 @@ def decorator(test_case): correct_transformers_version = is_transformers_available() and version.parse( version.parse(importlib.metadata.version("transformers")).base_version ) > version.parse(transformers_version) - return unittest.skipUnless( - correct_transformers_version, - f"test requires transformers with the version greater than {transformers_version}", + return pytest.mark.skipif( + not correct_transformers_version, + reason=f"test requires transformers with the version greater than {transformers_version}", )(test_case) return decorator