Skip to content

Commit c4b5d2f

Browse files
authored
[SD3 dreambooth lora] smol fix to checkpoint saving (#9993)
* smol change to fix checkpoint saving & resuming (as done in train_dreambooth_sd3.py) * style * modify comment to explain reasoning behind hidden size check
1 parent 7ac6e28 commit c4b5d2f

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

examples/dreambooth/train_dreambooth_lora_sd3.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,10 +1294,13 @@ def save_model_hook(models, weights, output_dir):
12941294
for model in models:
12951295
if isinstance(model, type(unwrap_model(transformer))):
12961296
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
1297-
elif isinstance(model, type(unwrap_model(text_encoder_one))):
1298-
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
1299-
elif isinstance(model, type(unwrap_model(text_encoder_two))):
1300-
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
1297+
elif isinstance(model, type(unwrap_model(text_encoder_one))): # or text_encoder_two
1298+
# both text encoders are of the same class, so we check hidden size to distinguish between the two
1299+
hidden_size = unwrap_model(model).config.hidden_size
1300+
if hidden_size == 768:
1301+
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
1302+
elif hidden_size == 1280:
1303+
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
13011304
else:
13021305
raise ValueError(f"unexpected save model: {model.__class__}")
13031306

0 commit comments

Comments
 (0)