@@ -1279,7 +1279,7 @@ def main(args):
1279
1279
for name , param in text_encoder_one .named_parameters ():
1280
1280
if "token_embedding" in name :
1281
1281
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
1282
- param = param .to (dtype = torch .float32 )
1282
+ param . data = param .to (dtype = torch .float32 )
1283
1283
param .requires_grad = True
1284
1284
text_lora_parameters_one .append (param )
1285
1285
else :
@@ -1288,7 +1288,7 @@ def main(args):
1288
1288
for name , param in text_encoder_two .named_parameters ():
1289
1289
if "token_embedding" in name :
1290
1290
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
1291
- param = param .to (dtype = torch .float32 )
1291
+ param . data = param .to (dtype = torch .float32 )
1292
1292
param .requires_grad = True
1293
1293
text_lora_parameters_two .append (param )
1294
1294
else :
@@ -1725,19 +1725,19 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
1725
1725
num_train_epochs_text_encoder = int (args .train_text_encoder_frac * args .num_train_epochs )
1726
1726
elif args .train_text_encoder_ti : # args.train_text_encoder_ti
1727
1727
num_train_epochs_text_encoder = int (args .train_text_encoder_ti_frac * args .num_train_epochs )
1728
-
1728
+ # flag used for textual inversion
1729
+ pivoted = False
1729
1730
for epoch in range (first_epoch , args .num_train_epochs ):
1730
1731
# if performing any kind of optimization of text_encoder params
1731
1732
if args .train_text_encoder or args .train_text_encoder_ti :
1732
1733
if epoch == num_train_epochs_text_encoder :
1733
1734
print ("PIVOT HALFWAY" , epoch )
1734
1735
# stopping optimization of text_encoder params
1735
- # re setting the optimizer to optimize only on unet params
1736
- optimizer .param_groups [1 ]["lr" ] = 0.0
1737
- optimizer .param_groups [2 ]["lr" ] = 0.0
1736
+ # this flag is used to reset the optimizer to optimize only on unet params
1737
+ pivoted = True
1738
1738
1739
1739
else :
1740
- # still optimizng the text encoder
1740
+ # still optimizing the text encoder
1741
1741
text_encoder_one .train ()
1742
1742
text_encoder_two .train ()
1743
1743
# set top parameter requires_grad = True for gradient checkpointing works
@@ -1747,6 +1747,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
1747
1747
1748
1748
unet .train ()
1749
1749
for step , batch in enumerate (train_dataloader ):
1750
+ if pivoted :
1751
+ # stopping optimization of text_encoder params
1752
+ # re setting the optimizer to optimize only on unet params
1753
+ optimizer .param_groups [1 ]["lr" ] = 0.0
1754
+ optimizer .param_groups [2 ]["lr" ] = 0.0
1755
+
1750
1756
with accelerator .accumulate (unet ):
1751
1757
prompts = batch ["prompts" ]
1752
1758
# encode batch prompts when custom prompts are provided for each image -
@@ -1885,8 +1891,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
1885
1891
1886
1892
# every step, we reset the embeddings to the original embeddings.
1887
1893
if args .train_text_encoder_ti :
1888
- for idx , text_encoder in enumerate (text_encoders ):
1889
- embedding_handler .retract_embeddings ()
1894
+ embedding_handler .retract_embeddings ()
1890
1895
1891
1896
# Checks if the accelerator has performed an optimization step behind the scenes
1892
1897
if accelerator .sync_gradients :
0 commit comments