diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 5b5579664b55..68fd5b9fd29c 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -376,6 +376,14 @@ def _load_lora_into_text_encoder( # convert state dict state_dict = convert_state_dict_to_peft(state_dict) + # Under transformers>=5, CLIPTextModel was flattened and no longer has the + # `text_model.` wrapper module, so named_modules() returns unprefixed names + # like "encoder.layers.0.self_attn.q_proj". Kohya-sourced LoRA state dict + # keys still carry the "text_model." prefix from the old module layout. + # Strip it so that rank-key matching works regardless of the transformers version. + if not hasattr(text_encoder, "text_model"): + state_dict = {k.removeprefix("text_model."): v for k, v in state_dict.items()} + for name, _ in text_encoder.named_modules(): if name.endswith((".q_proj", ".k_proj", ".v_proj", ".out_proj", ".fc1", ".fc2")): rank_key = f"{name}.lora_B.weight" diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index b840d7ac72ce..3072d5ca9551 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -256,6 +256,48 @@ def test_lora_expansion_works_for_extra_keys(self): "LoRA should lead to different results.", ) + def test_kohya_clip_text_encoder_flattened_compat(self): + """Regression test for #13984: kohya FLUX CLIP text-encoder LoRA under transformers>=5. + + Under transformers>=5, CLIPTextModel was flattened and no longer has the `text_model.` + wrapper module. Kohya-sourced state dict keys still carry the "text_model." prefix, + which caused an empty rank dict and IndexError. This test verifies the fix by passing + a synthetic kohya-style state dict with the stale prefix to a text encoder that + doesn't have the `text_model` submodule. + """ + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + # Simulate transformers>=5 flattened CLIPTextModel by removing the text_model + # wrapper module attribute. Under transformers>=5 named_modules() no longer + # includes this prefix. + if hasattr(pipe.text_encoder, "text_model"): + del pipe.text_encoder.text_model + + # Build a synthetic kohya-style state dict whose text-encoder keys carry the stale + # "text_model." prefix that real kohya conversion produces. + state_dict = {} + with torch.no_grad(): + for name, module in pipe.text_encoder.named_modules(): + if name.endswith((".q_proj", ".k_proj", ".v_proj", ".out_proj", ".fc1", ".fc2")): + # Kohya conversion emits keys with text_model. prefix + kohya_key_b = f"text_encoder.text_model.{name}.lora_B.weight" + kohya_key_a = f"text_encoder.text_model.{name}.lora_A.weight" + state_dict[kohya_key_a] = torch.randn(2, module.weight.shape[0]) + state_dict[kohya_key_b] = torch.randn(module.weight.shape[1], 2) + + # This should not raise IndexError (the pre-fix crash point) + pipe.load_lora_weights(state_dict) + + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), + "LoRA should be correctly set in text encoder under transformers>=5 layout", + ) + + pipe.unload_lora_weights() + @unittest.skip("Not supported in Flux.") def test_simple_inference_with_text_denoiser_block_scale(self): pass