Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 123 additions & 27 deletions scripts/stable/library/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import diffusers
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, StableUnCLIPImg2ImgPipeline # , UNet2DConditionModel
from safetensors.torch import load_file, save_file
from library.original_unet import UNet2DConditionModel
from library.utils import setup_logging
Expand Down Expand Up @@ -656,6 +656,77 @@ def convert_key(key):
return new_sd


def convert_ldm_clip_checkpoint_v2_fix(checkpoint, max_length):
# 嫌になるくらい違うぞ!
def convert_key(key):
if not key.startswith("cond_stage_model"):
return None

# common conversion
key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
key = key.replace("cond_stage_model.model.", "text_model.")

if "resblocks" in key:
# resblocks conversion
key = key.replace(".resblocks.", ".layers.")
if ".ln_" in key:
key = key.replace(".ln_", ".layer_norm")
elif ".mlp." in key:
key = key.replace(".c_fc.", ".fc1.")
key = key.replace(".c_proj.", ".fc2.")
elif ".attn.out_proj" in key:
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
elif ".attn.in_proj" in key:
key = None # 特殊なので後で処理する
else:
raise ValueError(f"unexpected key in SD: {key}")
elif ".positional_embedding" in key:
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
elif ".text_projection" in key:
key = None # 使われない???
elif ".logit_scale" in key:
key = None # 使われない???
elif ".token_embedding" in key:
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
elif ".ln_final" in key:
key = key.replace(".ln_final", ".final_layer_norm")
return key

keys = list(checkpoint.keys())
new_sd = {}
for key in keys:
# remove resblocks 23
if ".resblocks.23." in key:
continue
if 'embedder.model' in key:
continue
new_key = convert_key(key)
if new_key is None:
continue
new_sd[new_key] = checkpoint[key]

# attnの変換
for key in keys:
if ".resblocks.23." in key:
continue
if 'embedder.model' in key:
continue
if ".resblocks" in key and ".attn.in_proj_" in key:
# 三つに分割
values = torch.chunk(checkpoint[key], 3)

key_suffix = ".weight" if "weight" in key else ".bias"
key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
key_pfx = key_pfx.replace("_weight", "")
key_pfx = key_pfx.replace("_bias", "")
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]

return new_sd


# endregion


Expand Down Expand Up @@ -1015,33 +1086,58 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
vae = AutoencoderKL(**vae_config).to(device)
info = vae.load_state_dict(converted_vae_checkpoint)
logger.info(f"loading vae: {info}")

# convert text_model

if v2:
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
cfg = CLIPTextConfig(
vocab_size=49408,
hidden_size=1024,
intermediate_size=4096,
num_hidden_layers=23,
num_attention_heads=16,
max_position_embeddings=77,
hidden_act="gelu",
layer_norm_eps=1e-05,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
model_type="clip_text_model",
projection_dim=512,
torch_dtype="float32",
transformers_version="4.25.0.dev0",
)
text_model = CLIPTextModel._from_config(cfg)
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
try:
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2_fix(state_dict, 77)
cfg = CLIPTextConfig(
attention_dropout = 0.0,
bos_token_id = 0,
dropout = 0.0,
eos_token_id = 2,
hidden_act = "gelu",
hidden_size = 1024,
initializer_factor = 1.0,
initializer_range = 0.02,
intermediate_size = 4096,
layer_norm_eps = 1e-05,
max_position_embeddings = 77,
model_type = "clip_text_model",
num_attention_heads = 16,
num_hidden_layers = 23,
pad_token_id = 1,
projection_dim = 512,
torch_dtype = "float16",
transformers_version = "4.28.0.dev0",
vocab_size = 49408
)
text_model = CLIPTextModel._from_config(cfg)
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
except Exception as e:
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
cfg = CLIPTextConfig(
vocab_size=49408,
hidden_size=1024,
intermediate_size=4096,
num_hidden_layers=23,
num_attention_heads=16,
max_position_embeddings=77,
hidden_act="gelu",
layer_norm_eps=1e-05,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
model_type="clip_text_model",
projection_dim=512,
torch_dtype="float32",
transformers_version="4.25.0.dev0",
)
text_model = CLIPTextModel._from_config(cfg)
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
else:
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)

Expand Down