Fix kohya FLUX CLIP text-encoder LoRA loading under transformers>=5 (#13984)#14029
Open
Jefsky wants to merge 1 commit into
Open
Fix kohya FLUX CLIP text-encoder LoRA loading under transformers>=5 (#13984)#14029Jefsky wants to merge 1 commit into
Jefsky wants to merge 1 commit into
Conversation
…uggingface#13984) Under transformers>=5, CLIPTextModel was flattened: the text_model. wrapper module was removed, so named_modules() returns unprefixed names like 'encoder.layers.0.self_attn.k_proj' instead of 'text_model.encoder.layers.0.self_attn.k_proj'. Kohya-sourced LoRA state dict keys still carry the stale 'text_model.' prefix after conversion, causing _load_lora_into_text_encoder to build an empty rank dict (nothing matches) and crash with IndexError in get_peft_kwargs. Fix: after the PEFT state dict conversion, strip 'text_model.' from state dict keys when the encoder doesn't have the text_model submodule (the transformers>=5 layout), so they align with named_modules() output. Added a regression test test_kohya_clip_text_encoder_flattened_compat that simulates the flattened CLIPTextModel layout and passes synthetic kohya-style keys.
BenjaminBossan
left a comment
Member
There was a problem hiding this comment.
Thanks for the PR. I agree that it should fix the specific issue that was mentioned and that should be safe to apply. My concern is that it is very specialized to this case and not a general solution. The PR covers this exact case:
But there could be other entries in the conversion mapping, now or added in the future, which are not covered by this patch. So I wonder if we should instead call Transformers get_model_conversion_mapping on the model (if it's a Transformers model) and apply the conversions from there.
I'll leave it up to the Diffusers maintainers to decide how to deal with this.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #13984
Root cause
transformers>=5flattenedCLIPTextModel: thetext_model.wrapper module was removed, sotext_encoder.named_modules()now yields names likeencoder.layers.0.self_attn.k_projinstead oftext_model.encoder.layers.0.self_attn.k_proj.The kohya→diffusers conversion still emits text-encoder keys prefixed with
text_model.(e.g.text_encoder.text_model.encoder.layers.0.self_attn.k_proj.lora_B.weight). In_load_lora_into_text_encoder, therankdict is built by matchingnamed_modules()against the converted state-dict keys — under transformers>=5 nothing matches,rankstays empty, andget_peft_kwargsdoeslist(rank_dict.values())[0]→IndexError.The PEFT-side fix (#3212) doesn't help here because the crash happens before any PEFT state-dict injection (confirmed by the issue reporter).
Fix
In
_load_lora_into_text_encoder, after theconvert_state_dict_to_peftcall, strip the staletext_model.prefix from the converted state-dict keys when the text encoder doesn't have thetext_modelsubmodule (i.e. the transformers>=5 layout). The check useshasattr(text_encoder, "text_model"), not a version number, so it's forward-compatible.Test
Added
FluxLoRATests::test_kohya_clip_text_encoder_flattened_compat— a fast CPU regression test that:text_modelattribute fromtext_encoder(simulating transformers>=5)text_model.prefixVerification
text_model.prefix causesIndexError: list index out of rangeinget_peft_kwargs