Skip to content

Commit e14a17c

Browse files
committed
Fix checkpoint overwrite bug in train_sana_sprint_diffusers
1 parent 1066de8 commit e14a17c

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

examples/research_projects/sana/train_sana_sprint_diffusers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1145,6 +1145,11 @@ def save_model_hook(models, weights, output_dir):
11451145
elif isinstance(unwrapped_model, type(unwrap_model(disc))):
11461146
# Save only the heads
11471147
torch.save(unwrapped_model.heads.state_dict(), os.path.join(output_dir, "disc_heads.pt"))
1148+
1149+
# Skip frozen pretrained_model
1150+
elif isinstance(unwrapped_model, type(unwrap_model(transformer))):
1151+
pass
1152+
11481153
else:
11491154
raise ValueError(f"unexpected save model: {unwrapped_model.__class__}")
11501155

@@ -1161,7 +1166,7 @@ def load_model_hook(models, input_dir):
11611166
model = models.pop()
11621167
unwrapped_model = unwrap_model(model)
11631168

1164-
if isinstance(unwrapped_model, type(unwrap_model(transformer))):
1169+
if isinstance(unwrapped_model, type(unwrap_model(transformer))) and getattr(unwrapped_model, 'guidance', False):
11651170
transformer_ = model # noqa: F841
11661171
elif isinstance(unwrapped_model, type(unwrap_model(disc))):
11671172
# Load only the heads

0 commit comments

Comments
 (0)