Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/diffusers/loaders/lora_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,14 @@ def _load_lora_into_text_encoder(
# convert state dict
state_dict = convert_state_dict_to_peft(state_dict)

# transformers>=5 flattened CLIPTextModel: the `text_model.` wrapper module
# was removed, so the encoder submodules (and `named_modules()`) are no longer
# prefixed with `text_model.`. kohya-style LoRAs are converted with that prefix,
# so strip it when the text encoder is flattened to keep the keys aligned with
# the module names; otherwise `rank` stays empty and loading fails (#13984).
if not hasattr(text_encoder, "text_model"):
state_dict = {k.removeprefix("text_model."): v for k, v in state_dict.items()}

for name, _ in text_encoder.named_modules():
if name.endswith((".q_proj", ".k_proj", ".v_proj", ".out_proj", ".fc1", ".fc2")):
rank_key = f"{name}.lora_B.weight"
Expand Down
54 changes: 53 additions & 1 deletion tests/lora/test_lora_layers_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch
from parameterized import parameterized
from PIL import Image
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel

from diffusers import FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxPipeline, FluxTransformer2DModel
from diffusers.utils import load_image, logging
Expand Down Expand Up @@ -272,6 +272,58 @@ def test_modify_padding_mode(self):
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass

def test_kohya_clip_text_encoder_lora_loads_with_flattened_clip(self):
# Regression test for #13984: a kohya FLUX LoRA carrying CLIP text-encoder
# weights (lora_te1_*) must load even when transformers flattens
# CLIPTextModel (drops the `text_model.` wrapper). Previously the converted
# keys kept the `text_model.` prefix while the module names did not, so the
# rank dict came out empty and loading raised an IndexError.
from diffusers.loaders.lora_base import _load_lora_into_text_encoder
from diffusers.loaders.lora_conversion_utils import _convert_kohya_flux_lora_to_diffusers

rank, hidden, inter = 4, 32, 64
text_encoder = CLIPTextModel(
CLIPTextConfig(
hidden_size=hidden,
intermediate_size=inter,
num_hidden_layers=1,
num_attention_heads=2,
vocab_size=100,
max_position_embeddings=77,
)
)
self.assertFalse(hasattr(text_encoder, "text_model")) # flattened (transformers>=5)

shapes = {
"self_attn_q_proj": (hidden, hidden),
"self_attn_k_proj": (hidden, hidden),
"self_attn_v_proj": (hidden, hidden),
"self_attn_out_proj": (hidden, hidden),
"mlp_fc1": (inter, hidden),
"mlp_fc2": (hidden, inter),
}
kohya_state_dict = {}
for module, (out_features, in_features) in shapes.items():
base = f"lora_te1_text_model_encoder_layers_0_{module}"
kohya_state_dict[f"{base}.lora_down.weight"] = torch.randn(rank, in_features)
kohya_state_dict[f"{base}.lora_up.weight"] = torch.randn(out_features, rank)
kohya_state_dict[f"{base}.alpha"] = torch.tensor(float(rank))

converted = _convert_kohya_flux_lora_to_diffusers(kohya_state_dict)

_load_lora_into_text_encoder(
state_dict=converted,
network_alphas=None,
text_encoder=text_encoder,
prefix="text_encoder",
)

# The adapter must actually be injected onto the CLIP attention/MLP layers.
self.assertTrue(len(text_encoder.peft_config) > 0)
k_proj = text_encoder.encoder.layers[0].self_attn.k_proj
self.assertTrue(hasattr(k_proj, "lora_A"))
self.assertEqual(k_proj.lora_A["default_0"].weight.shape[0], rank)


class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxControlPipeline
Expand Down
Loading