6868 is_wandb_available ,
6969)
7070from diffusers .utils .import_utils import is_xformers_available
71+ from diffusers .utils .torch_utils import is_compiled_module
7172
7273
7374# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -1293,6 +1294,11 @@ def main(args):
12931294 else :
12941295 param .requires_grad = False
12951296
1297+ def unwrap_model (model ):
1298+ model = accelerator .unwrap_model (model )
1299+ model = model ._orig_mod if is_compiled_module (model ) else model
1300+ return model
1301+
12961302 # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
12971303 def save_model_hook (models , weights , output_dir ):
12981304 if accelerator .is_main_process :
@@ -1303,14 +1309,14 @@ def save_model_hook(models, weights, output_dir):
13031309 text_encoder_two_lora_layers_to_save = None
13041310
13051311 for model in models :
1306- if isinstance (model , type (accelerator . unwrap_model (unet ))):
1312+ if isinstance (model , type (unwrap_model (unet ))):
13071313 unet_lora_layers_to_save = convert_state_dict_to_diffusers (get_peft_model_state_dict (model ))
1308- elif isinstance (model , type (accelerator . unwrap_model (text_encoder_one ))):
1314+ elif isinstance (model , type (unwrap_model (text_encoder_one ))):
13091315 if args .train_text_encoder :
13101316 text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers (
13111317 get_peft_model_state_dict (model )
13121318 )
1313- elif isinstance (model , type (accelerator . unwrap_model (text_encoder_two ))):
1319+ elif isinstance (model , type (unwrap_model (text_encoder_two ))):
13141320 if args .train_text_encoder :
13151321 text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers (
13161322 get_peft_model_state_dict (model )
@@ -1338,11 +1344,11 @@ def load_model_hook(models, input_dir):
13381344 while len (models ) > 0 :
13391345 model = models .pop ()
13401346
1341- if isinstance (model , type (accelerator . unwrap_model (unet ))):
1347+ if isinstance (model , type (unwrap_model (unet ))):
13421348 unet_ = model
1343- elif isinstance (model , type (accelerator . unwrap_model (text_encoder_one ))):
1349+ elif isinstance (model , type (unwrap_model (text_encoder_one ))):
13441350 text_encoder_one_ = model
1345- elif isinstance (model , type (accelerator . unwrap_model (text_encoder_two ))):
1351+ elif isinstance (model , type (unwrap_model (text_encoder_two ))):
13461352 text_encoder_two_ = model
13471353 else :
13481354 raise ValueError (f"unexpected save model: { model .__class__ } " )
0 commit comments