Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/diffusers/loaders/lora_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,14 @@ def _load_lora_into_text_encoder(
# convert state dict
state_dict = convert_state_dict_to_peft(state_dict)

# Under transformers>=5, CLIPTextModel was flattened and no longer has the
# `text_model.` wrapper module, so named_modules() returns unprefixed names
# like "encoder.layers.0.self_attn.q_proj". Kohya-sourced LoRA state dict
# keys still carry the "text_model." prefix from the old module layout.
# Strip it so that rank-key matching works regardless of the transformers version.
if not hasattr(text_encoder, "text_model"):
state_dict = {k.removeprefix("text_model."): v for k, v in state_dict.items()}

for name, _ in text_encoder.named_modules():
if name.endswith((".q_proj", ".k_proj", ".v_proj", ".out_proj", ".fc1", ".fc2")):
rank_key = f"{name}.lora_B.weight"
Expand Down
42 changes: 42 additions & 0 deletions tests/lora/test_lora_layers_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,48 @@ def test_lora_expansion_works_for_extra_keys(self):
"LoRA should lead to different results.",
)

def test_kohya_clip_text_encoder_flattened_compat(self):
"""Regression test for #13984: kohya FLUX CLIP text-encoder LoRA under transformers>=5.

Under transformers>=5, CLIPTextModel was flattened and no longer has the `text_model.`
wrapper module. Kohya-sourced state dict keys still carry the "text_model." prefix,
which caused an empty rank dict and IndexError. This test verifies the fix by passing
a synthetic kohya-style state dict with the stale prefix to a text encoder that
doesn't have the `text_model` submodule.
"""
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

# Simulate transformers>=5 flattened CLIPTextModel by removing the text_model
# wrapper module attribute. Under transformers>=5 named_modules() no longer
# includes this prefix.
if hasattr(pipe.text_encoder, "text_model"):
del pipe.text_encoder.text_model

# Build a synthetic kohya-style state dict whose text-encoder keys carry the stale
# "text_model." prefix that real kohya conversion produces.
state_dict = {}
with torch.no_grad():
for name, module in pipe.text_encoder.named_modules():
if name.endswith((".q_proj", ".k_proj", ".v_proj", ".out_proj", ".fc1", ".fc2")):
# Kohya conversion emits keys with text_model. prefix
kohya_key_b = f"text_encoder.text_model.{name}.lora_B.weight"
kohya_key_a = f"text_encoder.text_model.{name}.lora_A.weight"
state_dict[kohya_key_a] = torch.randn(2, module.weight.shape[0])
state_dict[kohya_key_b] = torch.randn(module.weight.shape[1], 2)

# This should not raise IndexError (the pre-fix crash point)
pipe.load_lora_weights(state_dict)

self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder),
"LoRA should be correctly set in text encoder under transformers>=5 layout",
)

pipe.unload_lora_weights()

@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
Expand Down
Loading