Skip to content

Commit 6d71b76

Browse files
dxqbclaudesayakpaul
authored
Complete Kohya LoRA conversion for Qwen and Z-Image (#14080)
* Fix Kohya LoRA conversion for Z-Image modules whose names contain underscores _convert_non_diffusers_z_image_lora_to_diffusers reverses Kohya's `.`->`_` flattening with a blanket `_`->`.` split, guarded only by a small protected-n-gram list (attention to_q/k/v/out, feed_forward) plus post-hoc fixes for context_refiner/noise_refiner. Z-Image's other modules whose names contain underscores were over-split: all_final_layer, all_x_embedder, adaLN_modulation, cap_embedder and t_embedder came out as all.final.layer, adaLN.modulation, ... and failed to load with "unexpected keys". Extend the existing dot->underscore post-normalization to re-merge these names, so Kohya (lora_unet_) Z-Image LoRAs load. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> * Fix Kohya LoRA conversion for Qwen top-level (non-block) modules _convert_non_diffusers_qwen_lora_to_diffusers's convert_key hardcodes the transformer_blocks prefix and assumes every lora_unet_ key lives under a block: it strips a transformer_blocks_ prefix and re-prepends transformer_blocks., which collapses the top-level modules (img_in, txt_in, proj_out, norm_out.linear, time_text_embed.timestep_embedder.linear_1/2) onto each other. They end up as transformer_blocks..weight / ...a.down.weight and trip the 'state_dict should be empty' guard. Resolve these six modules via an explicit flattened->dotted map before the block logic runs, preserving the .lora_down/.lora_up/.alpha suffix, so Kohya (lora_unet_) Qwen LoRAs load. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent b549ca9 commit 6d71b76

1 file changed

Lines changed: 37 additions & 4 deletions

File tree

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2232,8 +2232,26 @@ def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
22322232
if has_lora_unet:
22332233
state_dict = {k.removeprefix("lora_unet_"): v for k, v in state_dict.items()}
22342234

2235+
# Top-level (non-block) modules: convert_key below assumes every key lives under
2236+
# transformer_blocks_ and blindly strips/re-prepends that prefix, which collapses
2237+
# these module names onto each other. Map them explicitly before that logic runs.
2238+
# The flattened name -> dotted diffusers name is fixed, and the .lora_down/.lora_up/
2239+
# .alpha suffix is preserved.
2240+
top_level_modules = {
2241+
"img_in": "img_in",
2242+
"txt_in": "txt_in",
2243+
"proj_out": "proj_out",
2244+
"norm_out_linear": "norm_out.linear",
2245+
"time_text_embed_timestep_embedder_linear_1": "time_text_embed.timestep_embedder.linear_1",
2246+
"time_text_embed_timestep_embedder_linear_2": "time_text_embed.timestep_embedder.linear_2",
2247+
}
2248+
22352249
def convert_key(key: str) -> str:
22362250
prefix = "transformer_blocks"
2251+
for flat, dotted in top_level_modules.items():
2252+
if key == flat or key.startswith(flat + "."):
2253+
return dotted + key[len(flat) :]
2254+
22372255
if "." in key:
22382256
base, suffix = key.rsplit(".", 1)
22392257
else:
@@ -2803,12 +2821,27 @@ def normalize_out_key(k: str) -> str:
28032821
state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()}
28042822

28052823
# Normalize ZImage-specific dot-separated module names to underscore form so they
2806-
# match the diffusers model parameter names (context_refiner, noise_refiner).
2807-
state_dict = {
2808-
k.replace("context.refiner.", "context_refiner.").replace("noise.refiner.", "noise_refiner."): v
2809-
for k, v in state_dict.items()
2824+
# match the diffusers model parameter names. convert_key blindly split every "_",
2825+
# so module names whose own names contain underscores (and aren't protected as the
2826+
# attention/feed_forward n-grams are) come out over-split here. This runs on the full
2827+
# key (before the weight/alpha handlers below) so it fixes .lora_A/B and .alpha alike.
2828+
zimage_module_name_fixups = {
2829+
"context.refiner.": "context_refiner.",
2830+
"noise.refiner.": "noise_refiner.",
2831+
"adaLN.modulation.": "adaLN_modulation.",
2832+
"all.final.layer.": "all_final_layer.",
2833+
"all.x.embedder.": "all_x_embedder.",
2834+
"cap.embedder.": "cap_embedder.",
2835+
"t.embedder.": "t_embedder.",
28102836
}
28112837

2838+
def fixup_module_names(k: str) -> str:
2839+
for dotted, underscored in zimage_module_name_fixups.items():
2840+
k = k.replace(dotted, underscored)
2841+
return k
2842+
2843+
state_dict = {fixup_module_names(k): v for k, v in state_dict.items()}
2844+
28122845
converted_state_dict = {}
28132846
all_keys = list(state_dict.keys())
28142847
down_key = ".lora_down.weight"

0 commit comments

Comments
 (0)