1
1
import argparse
2
2
import hashlib
3
3
import itertools
4
+ import logging
4
5
import math
5
6
import os
6
7
import warnings
12
13
import torch .utils .checkpoint
13
14
from torch .utils .data import Dataset
14
15
16
+ import datasets
17
+ import diffusers
18
+ import transformers
15
19
from accelerate import Accelerator
16
20
from accelerate .logging import get_logger
17
21
from accelerate .utils import set_seed
@@ -236,6 +240,24 @@ def parse_args(input_args=None):
236
240
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
237
241
),
238
242
)
243
+ parser .add_argument (
244
+ "--allow_tf32" ,
245
+ action = "store_true" ,
246
+ help = (
247
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
248
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
249
+ ),
250
+ )
251
+ parser .add_argument (
252
+ "--report_to" ,
253
+ type = str ,
254
+ default = "tensorboard" ,
255
+ help = (
256
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
257
+ ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
258
+ "Only applicable when `--with_tracking` is passed."
259
+ ),
260
+ )
239
261
parser .add_argument (
240
262
"--mixed_precision" ,
241
263
type = str ,
@@ -422,7 +444,7 @@ def main(args):
422
444
accelerator = Accelerator (
423
445
gradient_accumulation_steps = args .gradient_accumulation_steps ,
424
446
mixed_precision = args .mixed_precision ,
425
- log_with = "tensorboard" ,
447
+ log_with = args . report_to ,
426
448
logging_dir = logging_dir ,
427
449
)
428
450
@@ -435,9 +457,27 @@ def main(args):
435
457
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
436
458
)
437
459
460
+ # Make one log on every process with the configuration for debugging.
461
+ logging .basicConfig (
462
+ format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" ,
463
+ datefmt = "%m/%d/%Y %H:%M:%S" ,
464
+ level = logging .INFO ,
465
+ )
466
+ logger .info (accelerator .state , main_process_only = False )
467
+ if accelerator .is_local_main_process :
468
+ datasets .utils .logging .set_verbosity_warning ()
469
+ transformers .utils .logging .set_verbosity_warning ()
470
+ diffusers .utils .logging .set_verbosity_info ()
471
+ else :
472
+ datasets .utils .logging .set_verbosity_error ()
473
+ transformers .utils .logging .set_verbosity_error ()
474
+ diffusers .utils .logging .set_verbosity_error ()
475
+
476
+ # If passed along, set the training seed now.
438
477
if args .seed is not None :
439
478
set_seed (args .seed )
440
479
480
+ # Generate class images if prior preservation is enabled.
441
481
if args .with_prior_preservation :
442
482
class_images_dir = Path (args .class_data_dir )
443
483
if not class_images_dir .exists ():
@@ -502,11 +542,7 @@ def main(args):
502
542
503
543
# Load the tokenizer
504
544
if args .tokenizer_name :
505
- tokenizer = AutoTokenizer .from_pretrained (
506
- args .tokenizer_name ,
507
- revision = args .revision ,
508
- use_fast = False ,
509
- )
545
+ tokenizer = AutoTokenizer .from_pretrained (args .tokenizer_name , revision = args .revision , use_fast = False )
510
546
elif args .pretrained_model_name_or_path :
511
547
tokenizer = AutoTokenizer .from_pretrained (
512
548
args .pretrained_model_name_or_path ,
@@ -518,38 +554,36 @@ def main(args):
518
554
# import correct text encoder class
519
555
text_encoder_cls = import_model_class_from_model_name_or_path (args .pretrained_model_name_or_path , args .revision )
520
556
521
- # Load models and create wrapper for stable diffusion
557
+ # Load scheduler and models
558
+ noise_scheduler = DDPMScheduler .from_pretrained (args .pretrained_model_name_or_path , subfolder = "scheduler" )
522
559
text_encoder = text_encoder_cls .from_pretrained (
523
- args .pretrained_model_name_or_path ,
524
- subfolder = "text_encoder" ,
525
- revision = args .revision ,
526
- )
527
- vae = AutoencoderKL .from_pretrained (
528
- args .pretrained_model_name_or_path ,
529
- subfolder = "vae" ,
530
- revision = args .revision ,
560
+ args .pretrained_model_name_or_path , subfolder = "text_encoder" , revision = args .revision
531
561
)
562
+ vae = AutoencoderKL .from_pretrained (args .pretrained_model_name_or_path , subfolder = "vae" , revision = args .revision )
532
563
unet = UNet2DConditionModel .from_pretrained (
533
- args .pretrained_model_name_or_path ,
534
- subfolder = "unet" ,
535
- revision = args .revision ,
564
+ args .pretrained_model_name_or_path , subfolder = "unet" , revision = args .revision
536
565
)
537
566
567
+ vae .requires_grad_ (False )
568
+ if not args .train_text_encoder :
569
+ text_encoder .requires_grad_ (False )
570
+
538
571
if args .enable_xformers_memory_efficient_attention :
539
572
if is_xformers_available ():
540
573
unet .enable_xformers_memory_efficient_attention ()
541
574
else :
542
575
raise ValueError ("xformers is not available. Make sure it is installed correctly" )
543
576
544
- vae .requires_grad_ (False )
545
- if not args .train_text_encoder :
546
- text_encoder .requires_grad_ (False )
547
-
548
577
if args .gradient_checkpointing :
549
578
unet .enable_gradient_checkpointing ()
550
579
if args .train_text_encoder :
551
580
text_encoder .gradient_checkpointing_enable ()
552
581
582
+ # Enable TF32 for faster training on Ampere GPUs,
583
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
584
+ if args .allow_tf32 :
585
+ torch .backends .cuda .matmul .allow_tf32 = True
586
+
553
587
if args .scale_lr :
554
588
args .learning_rate = (
555
589
args .learning_rate * args .gradient_accumulation_steps * args .train_batch_size * accelerator .num_processes
@@ -568,6 +602,7 @@ def main(args):
568
602
else :
569
603
optimizer_class = torch .optim .AdamW
570
604
605
+ # Optimizer creation
571
606
params_to_optimize = (
572
607
itertools .chain (unet .parameters (), text_encoder .parameters ()) if args .train_text_encoder else unet .parameters ()
573
608
)
@@ -579,8 +614,7 @@ def main(args):
579
614
eps = args .adam_epsilon ,
580
615
)
581
616
582
- noise_scheduler = DDPMScheduler .from_pretrained (args .pretrained_model_name_or_path , subfolder = "scheduler" )
583
-
617
+ # Dataset and DataLoaders creation:
584
618
train_dataset = DreamBoothDataset (
585
619
instance_data_root = args .instance_data_dir ,
586
620
instance_prompt = args .instance_prompt ,
@@ -615,6 +649,7 @@ def main(args):
615
649
power = args .lr_power ,
616
650
)
617
651
652
+ # Prepare everything with our `accelerator`.
618
653
if args .train_text_encoder :
619
654
unet , text_encoder , optimizer , train_dataloader , lr_scheduler = accelerator .prepare (
620
655
unet , text_encoder , optimizer , train_dataloader , lr_scheduler
@@ -623,17 +658,16 @@ def main(args):
623
658
unet , optimizer , train_dataloader , lr_scheduler = accelerator .prepare (
624
659
unet , optimizer , train_dataloader , lr_scheduler
625
660
)
626
- accelerator .register_for_checkpointing (lr_scheduler )
627
661
662
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
663
+ # as these models are only used for inference, keeping weights in full precision is not required.
628
664
weight_dtype = torch .float32
629
665
if accelerator .mixed_precision == "fp16" :
630
666
weight_dtype = torch .float16
631
667
elif accelerator .mixed_precision == "bf16" :
632
668
weight_dtype = torch .bfloat16
633
669
634
- # Move text_encode and vae to gpu.
635
- # For mixed precision training we cast the text_encoder and vae weights to half-precision
636
- # as these models are only used for inference, keeping weights in full precision is not required.
670
+ # Move vae and text_encoder to device and cast to weight_dtype
637
671
vae .to (accelerator .device , dtype = weight_dtype )
638
672
if not args .train_text_encoder :
639
673
text_encoder .to (accelerator .device , dtype = weight_dtype )
@@ -664,6 +698,7 @@ def main(args):
664
698
global_step = 0
665
699
first_epoch = 0
666
700
701
+ # Potentially load in the weights and states from a previous save
667
702
if args .resume_from_checkpoint :
668
703
if args .resume_from_checkpoint != "latest" :
669
704
path = os .path .basename (args .resume_from_checkpoint )
@@ -772,9 +807,8 @@ def main(args):
772
807
if global_step >= args .max_train_steps :
773
808
break
774
809
775
- accelerator .wait_for_everyone ()
776
-
777
810
# Create the pipeline using using the trained modules and save it.
811
+ accelerator .wait_for_everyone ()
778
812
if accelerator .is_main_process :
779
813
pipeline = DiffusionPipeline .from_pretrained (
780
814
args .pretrained_model_name_or_path ,
0 commit comments