Skip to content

Commit fa1f470

Browse files
[examples] misc fixes (#1886)
* misc fixes * more comments * Update examples/textual_inversion/textual_inversion.py Co-authored-by: Patrick von Platen <[email protected]> * set transformers verbosity to warning Co-authored-by: Patrick von Platen <[email protected]>
1 parent 423c3a4 commit fa1f470

File tree

3 files changed

+145
-70
lines changed

3 files changed

+145
-70
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 64 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22
import hashlib
33
import itertools
4+
import logging
45
import math
56
import os
67
import warnings
@@ -12,6 +13,9 @@
1213
import torch.utils.checkpoint
1314
from torch.utils.data import Dataset
1415

16+
import datasets
17+
import diffusers
18+
import transformers
1519
from accelerate import Accelerator
1620
from accelerate.logging import get_logger
1721
from accelerate.utils import set_seed
@@ -236,6 +240,24 @@ def parse_args(input_args=None):
236240
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
237241
),
238242
)
243+
parser.add_argument(
244+
"--allow_tf32",
245+
action="store_true",
246+
help=(
247+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
248+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
249+
),
250+
)
251+
parser.add_argument(
252+
"--report_to",
253+
type=str,
254+
default="tensorboard",
255+
help=(
256+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
257+
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
258+
"Only applicable when `--with_tracking` is passed."
259+
),
260+
)
239261
parser.add_argument(
240262
"--mixed_precision",
241263
type=str,
@@ -422,7 +444,7 @@ def main(args):
422444
accelerator = Accelerator(
423445
gradient_accumulation_steps=args.gradient_accumulation_steps,
424446
mixed_precision=args.mixed_precision,
425-
log_with="tensorboard",
447+
log_with=args.report_to,
426448
logging_dir=logging_dir,
427449
)
428450

@@ -435,9 +457,27 @@ def main(args):
435457
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
436458
)
437459

460+
# Make one log on every process with the configuration for debugging.
461+
logging.basicConfig(
462+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
463+
datefmt="%m/%d/%Y %H:%M:%S",
464+
level=logging.INFO,
465+
)
466+
logger.info(accelerator.state, main_process_only=False)
467+
if accelerator.is_local_main_process:
468+
datasets.utils.logging.set_verbosity_warning()
469+
transformers.utils.logging.set_verbosity_warning()
470+
diffusers.utils.logging.set_verbosity_info()
471+
else:
472+
datasets.utils.logging.set_verbosity_error()
473+
transformers.utils.logging.set_verbosity_error()
474+
diffusers.utils.logging.set_verbosity_error()
475+
476+
# If passed along, set the training seed now.
438477
if args.seed is not None:
439478
set_seed(args.seed)
440479

480+
# Generate class images if prior preservation is enabled.
441481
if args.with_prior_preservation:
442482
class_images_dir = Path(args.class_data_dir)
443483
if not class_images_dir.exists():
@@ -502,11 +542,7 @@ def main(args):
502542

503543
# Load the tokenizer
504544
if args.tokenizer_name:
505-
tokenizer = AutoTokenizer.from_pretrained(
506-
args.tokenizer_name,
507-
revision=args.revision,
508-
use_fast=False,
509-
)
545+
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
510546
elif args.pretrained_model_name_or_path:
511547
tokenizer = AutoTokenizer.from_pretrained(
512548
args.pretrained_model_name_or_path,
@@ -518,38 +554,36 @@ def main(args):
518554
# import correct text encoder class
519555
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
520556

521-
# Load models and create wrapper for stable diffusion
557+
# Load scheduler and models
558+
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
522559
text_encoder = text_encoder_cls.from_pretrained(
523-
args.pretrained_model_name_or_path,
524-
subfolder="text_encoder",
525-
revision=args.revision,
526-
)
527-
vae = AutoencoderKL.from_pretrained(
528-
args.pretrained_model_name_or_path,
529-
subfolder="vae",
530-
revision=args.revision,
560+
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
531561
)
562+
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
532563
unet = UNet2DConditionModel.from_pretrained(
533-
args.pretrained_model_name_or_path,
534-
subfolder="unet",
535-
revision=args.revision,
564+
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
536565
)
537566

567+
vae.requires_grad_(False)
568+
if not args.train_text_encoder:
569+
text_encoder.requires_grad_(False)
570+
538571
if args.enable_xformers_memory_efficient_attention:
539572
if is_xformers_available():
540573
unet.enable_xformers_memory_efficient_attention()
541574
else:
542575
raise ValueError("xformers is not available. Make sure it is installed correctly")
543576

544-
vae.requires_grad_(False)
545-
if not args.train_text_encoder:
546-
text_encoder.requires_grad_(False)
547-
548577
if args.gradient_checkpointing:
549578
unet.enable_gradient_checkpointing()
550579
if args.train_text_encoder:
551580
text_encoder.gradient_checkpointing_enable()
552581

582+
# Enable TF32 for faster training on Ampere GPUs,
583+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
584+
if args.allow_tf32:
585+
torch.backends.cuda.matmul.allow_tf32 = True
586+
553587
if args.scale_lr:
554588
args.learning_rate = (
555589
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
@@ -568,6 +602,7 @@ def main(args):
568602
else:
569603
optimizer_class = torch.optim.AdamW
570604

605+
# Optimizer creation
571606
params_to_optimize = (
572607
itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
573608
)
@@ -579,8 +614,7 @@ def main(args):
579614
eps=args.adam_epsilon,
580615
)
581616

582-
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
583-
617+
# Dataset and DataLoaders creation:
584618
train_dataset = DreamBoothDataset(
585619
instance_data_root=args.instance_data_dir,
586620
instance_prompt=args.instance_prompt,
@@ -615,6 +649,7 @@ def main(args):
615649
power=args.lr_power,
616650
)
617651

652+
# Prepare everything with our `accelerator`.
618653
if args.train_text_encoder:
619654
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
620655
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
@@ -623,17 +658,16 @@ def main(args):
623658
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
624659
unet, optimizer, train_dataloader, lr_scheduler
625660
)
626-
accelerator.register_for_checkpointing(lr_scheduler)
627661

662+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
663+
# as these models are only used for inference, keeping weights in full precision is not required.
628664
weight_dtype = torch.float32
629665
if accelerator.mixed_precision == "fp16":
630666
weight_dtype = torch.float16
631667
elif accelerator.mixed_precision == "bf16":
632668
weight_dtype = torch.bfloat16
633669

634-
# Move text_encode and vae to gpu.
635-
# For mixed precision training we cast the text_encoder and vae weights to half-precision
636-
# as these models are only used for inference, keeping weights in full precision is not required.
670+
# Move vae and text_encoder to device and cast to weight_dtype
637671
vae.to(accelerator.device, dtype=weight_dtype)
638672
if not args.train_text_encoder:
639673
text_encoder.to(accelerator.device, dtype=weight_dtype)
@@ -664,6 +698,7 @@ def main(args):
664698
global_step = 0
665699
first_epoch = 0
666700

701+
# Potentially load in the weights and states from a previous save
667702
if args.resume_from_checkpoint:
668703
if args.resume_from_checkpoint != "latest":
669704
path = os.path.basename(args.resume_from_checkpoint)
@@ -772,9 +807,8 @@ def main(args):
772807
if global_step >= args.max_train_steps:
773808
break
774809

775-
accelerator.wait_for_everyone()
776-
777810
# Create the pipeline using using the trained modules and save it.
811+
accelerator.wait_for_everyone()
778812
if accelerator.is_main_process:
779813
pipeline = DiffusionPipeline.from_pretrained(
780814
args.pretrained_model_name_or_path,

examples/text_to_image/train_text_to_image.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ def main():
411411
logging_dir=logging_dir,
412412
)
413413

414+
# Make one log on every process with the configuration for debugging.
414415
logging.basicConfig(
415416
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
416417
datefmt="%m/%d/%Y %H:%M:%S",
@@ -419,7 +420,7 @@ def main():
419420
logger.info(accelerator.state, main_process_only=False)
420421
if accelerator.is_local_main_process:
421422
datasets.utils.logging.set_verbosity_warning()
422-
transformers.utils.logging.set_verbosity_info()
423+
transformers.utils.logging.set_verbosity_warning()
423424
diffusers.utils.logging.set_verbosity_info()
424425
else:
425426
datasets.utils.logging.set_verbosity_error()
@@ -577,6 +578,7 @@ def tokenize_captions(examples, is_train=True):
577578
)
578579
return inputs.input_ids
579580

581+
# Preprocessing the datasets.
580582
train_transforms = transforms.Compose(
581583
[
582584
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
@@ -605,6 +607,7 @@ def collate_fn(examples):
605607
input_ids = torch.stack([example["input_ids"] for example in examples])
606608
return {"pixel_values": pixel_values, "input_ids": input_ids}
607609

610+
# DataLoaders creation:
608611
train_dataloader = torch.utils.data.DataLoader(
609612
train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size
610613
)
@@ -623,6 +626,7 @@ def collate_fn(examples):
623626
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
624627
)
625628

629+
# Prepare everything with our `accelerator`.
626630
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
627631
unet, optimizer, train_dataloader, lr_scheduler
628632
)
@@ -668,6 +672,7 @@ def collate_fn(examples):
668672
global_step = 0
669673
first_epoch = 0
670674

675+
# Potentially load in the weights and states from a previous save
671676
if args.resume_from_checkpoint:
672677
if args.resume_from_checkpoint != "latest":
673678
path = os.path.basename(args.resume_from_checkpoint)

0 commit comments

Comments
 (0)