diff --git a/scripts/stable/library/model_util.py b/scripts/stable/library/model_util.py index be410a02..991fafa8 100644 --- a/scripts/stable/library/model_util.py +++ b/scripts/stable/library/model_util.py @@ -10,7 +10,7 @@ import diffusers from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging -from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel +from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, StableUnCLIPImg2ImgPipeline # , UNet2DConditionModel from safetensors.torch import load_file, save_file from library.original_unet import UNet2DConditionModel from library.utils import setup_logging @@ -656,6 +656,77 @@ def convert_key(key): return new_sd +def convert_ldm_clip_checkpoint_v2_fix(checkpoint, max_length): + # 嫌になるくらい違うぞ! + def convert_key(key): + if not key.startswith("cond_stage_model"): + return None + + # common conversion + key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.") + key = key.replace("cond_stage_model.model.", "text_model.") + + if "resblocks" in key: + # resblocks conversion + key = key.replace(".resblocks.", ".layers.") + if ".ln_" in key: + key = key.replace(".ln_", ".layer_norm") + elif ".mlp." in key: + key = key.replace(".c_fc.", ".fc1.") + key = key.replace(".c_proj.", ".fc2.") + elif ".attn.out_proj" in key: + key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") + elif ".attn.in_proj" in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f"unexpected key in SD: {key}") + elif ".positional_embedding" in key: + key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") + elif ".text_projection" in key: + key = None # 使われない??? + elif ".logit_scale" in key: + key = None # 使われない??? + elif ".token_embedding" in key: + key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") + elif ".ln_final" in key: + key = key.replace(".ln_final", ".final_layer_norm") + return key + + keys = list(checkpoint.keys()) + new_sd = {} + for key in keys: + # remove resblocks 23 + if ".resblocks.23." in key: + continue + if 'embedder.model' in key: + continue + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] + + # attnの変換 + for key in keys: + if ".resblocks.23." in key: + continue + if 'embedder.model' in key: + continue + if ".resblocks" in key and ".attn.in_proj_" in key: + # 三つに分割 + values = torch.chunk(checkpoint[key], 3) + + key_suffix = ".weight" if "weight" in key else ".bias" + key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.") + key_pfx = key_pfx.replace("_weight", "") + key_pfx = key_pfx.replace("_bias", "") + key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") + new_sd[key_pfx + "q_proj" + key_suffix] = values[0] + new_sd[key_pfx + "k_proj" + key_suffix] = values[1] + new_sd[key_pfx + "v_proj" + key_suffix] = values[2] + + return new_sd + + # endregion @@ -1015,33 +1086,58 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt vae = AutoencoderKL(**vae_config).to(device) info = vae.load_state_dict(converted_vae_checkpoint) logger.info(f"loading vae: {info}") - - # convert text_model + if v2: - converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77) - cfg = CLIPTextConfig( - vocab_size=49408, - hidden_size=1024, - intermediate_size=4096, - num_hidden_layers=23, - num_attention_heads=16, - max_position_embeddings=77, - hidden_act="gelu", - layer_norm_eps=1e-05, - dropout=0.0, - attention_dropout=0.0, - initializer_range=0.02, - initializer_factor=1.0, - pad_token_id=1, - bos_token_id=0, - eos_token_id=2, - model_type="clip_text_model", - projection_dim=512, - torch_dtype="float32", - transformers_version="4.25.0.dev0", - ) - text_model = CLIPTextModel._from_config(cfg) - info = text_model.load_state_dict(converted_text_encoder_checkpoint) + try: + converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2_fix(state_dict, 77) + cfg = CLIPTextConfig( + attention_dropout = 0.0, + bos_token_id = 0, + dropout = 0.0, + eos_token_id = 2, + hidden_act = "gelu", + hidden_size = 1024, + initializer_factor = 1.0, + initializer_range = 0.02, + intermediate_size = 4096, + layer_norm_eps = 1e-05, + max_position_embeddings = 77, + model_type = "clip_text_model", + num_attention_heads = 16, + num_hidden_layers = 23, + pad_token_id = 1, + projection_dim = 512, + torch_dtype = "float16", + transformers_version = "4.28.0.dev0", + vocab_size = 49408 + ) + text_model = CLIPTextModel._from_config(cfg) + info = text_model.load_state_dict(converted_text_encoder_checkpoint) + except Exception as e: + converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77) + cfg = CLIPTextConfig( + vocab_size=49408, + hidden_size=1024, + intermediate_size=4096, + num_hidden_layers=23, + num_attention_heads=16, + max_position_embeddings=77, + hidden_act="gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=512, + torch_dtype="float32", + transformers_version="4.25.0.dev0", + ) + text_model = CLIPTextModel._from_config(cfg) + info = text_model.load_state_dict(converted_text_encoder_checkpoint) else: converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)