Skip to content

Commit 619e3ab

Browse files
authored
[bug fix] advanced dreambooth lora sdxl - fixes bugs described in #6486 (#6599)
* fixes bugs: 1. redundant retraction 2. param clone 3. stopping optimization of text encoder params * param upscaling * style
1 parent 9e2804f commit 619e3ab

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,7 +1279,7 @@ def main(args):
12791279
for name, param in text_encoder_one.named_parameters():
12801280
if "token_embedding" in name:
12811281
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
1282-
param = param.to(dtype=torch.float32)
1282+
param.data = param.to(dtype=torch.float32)
12831283
param.requires_grad = True
12841284
text_lora_parameters_one.append(param)
12851285
else:
@@ -1288,7 +1288,7 @@ def main(args):
12881288
for name, param in text_encoder_two.named_parameters():
12891289
if "token_embedding" in name:
12901290
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
1291-
param = param.to(dtype=torch.float32)
1291+
param.data = param.to(dtype=torch.float32)
12921292
param.requires_grad = True
12931293
text_lora_parameters_two.append(param)
12941294
else:
@@ -1725,19 +1725,19 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
17251725
num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs)
17261726
elif args.train_text_encoder_ti: # args.train_text_encoder_ti
17271727
num_train_epochs_text_encoder = int(args.train_text_encoder_ti_frac * args.num_train_epochs)
1728-
1728+
# flag used for textual inversion
1729+
pivoted = False
17291730
for epoch in range(first_epoch, args.num_train_epochs):
17301731
# if performing any kind of optimization of text_encoder params
17311732
if args.train_text_encoder or args.train_text_encoder_ti:
17321733
if epoch == num_train_epochs_text_encoder:
17331734
print("PIVOT HALFWAY", epoch)
17341735
# stopping optimization of text_encoder params
1735-
# re setting the optimizer to optimize only on unet params
1736-
optimizer.param_groups[1]["lr"] = 0.0
1737-
optimizer.param_groups[2]["lr"] = 0.0
1736+
# this flag is used to reset the optimizer to optimize only on unet params
1737+
pivoted = True
17381738

17391739
else:
1740-
# still optimizng the text encoder
1740+
# still optimizing the text encoder
17411741
text_encoder_one.train()
17421742
text_encoder_two.train()
17431743
# set top parameter requires_grad = True for gradient checkpointing works
@@ -1747,6 +1747,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
17471747

17481748
unet.train()
17491749
for step, batch in enumerate(train_dataloader):
1750+
if pivoted:
1751+
# stopping optimization of text_encoder params
1752+
# re setting the optimizer to optimize only on unet params
1753+
optimizer.param_groups[1]["lr"] = 0.0
1754+
optimizer.param_groups[2]["lr"] = 0.0
1755+
17501756
with accelerator.accumulate(unet):
17511757
prompts = batch["prompts"]
17521758
# encode batch prompts when custom prompts are provided for each image -
@@ -1885,8 +1891,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
18851891

18861892
# every step, we reset the embeddings to the original embeddings.
18871893
if args.train_text_encoder_ti:
1888-
for idx, text_encoder in enumerate(text_encoders):
1889-
embedding_handler.retract_embeddings()
1894+
embedding_handler.retract_embeddings()
18901895

18911896
# Checks if the accelerator has performed an optimization step behind the scenes
18921897
if accelerator.sync_gradients:

0 commit comments

Comments
 (0)