Skip to content

[SD3 dreambooth-lora training] small updates + bug fixes #9682

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/dreambooth/README_sd3.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ accelerate launch train_dreambooth_lora_sd3.py \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-5 \
--learning_rate=4e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
Expand Down
55 changes: 46 additions & 9 deletions examples/dreambooth/train_dreambooth_lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def save_model_card(
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="openrail++",
license="other",
base_model=base_model,
prompt=instance_prompt,
model_description=model_description,
Expand Down Expand Up @@ -186,7 +186,7 @@ def log_validation(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)

# run inference
Expand Down Expand Up @@ -608,6 +608,12 @@ def parse_args(input_args=None):
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--cache_latents",
action="store_true",
default=False,
help="Cache the VAE latents",
)
parser.add_argument(
"--report_to",
type=str,
Expand All @@ -628,6 +634,15 @@ def parse_args(input_args=None):
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--upcast_before_saving",
action="store_true",
default=False,
help=(
"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
"Defaults to precision dtype used for training to save memory"
),
)
parser.add_argument(
"--prior_generation_precision",
type=str,
Expand Down Expand Up @@ -1394,6 +1409,16 @@ def load_model_hook(models, input_dir):
logger.warning(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
if args.train_text_encoder and args.text_encoder_lr:
logger.warning(
f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:"
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
f"When using prodigy only learning_rate is used as the initial learning rate."
)
# changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
# --learning_rate
params_to_optimize[1]["lr"] = args.learning_rate
params_to_optimize[2]["lr"] = args.learning_rate

optimizer = optimizer_class(
params_to_optimize,
Expand Down Expand Up @@ -1440,6 +1465,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
return prompt_embeds, pooled_prompt_embeds

# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
# the redundant encoding.
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
args.instance_prompt, text_encoders, tokenizers
Expand Down Expand Up @@ -1500,7 +1528,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
power=args.lr_power,
)

# Prepare everything with our `accelerator`.
# Prepare everything with our `accelerator`.
if args.train_text_encoder:
(
Expand Down Expand Up @@ -1607,8 +1634,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):

for step, batch in enumerate(train_dataloader):
models_to_accumulate = [transformer]
if args.train_text_encoder:
models_to_accumulate.extend([text_encoder_one, text_encoder_two])
with accelerator.accumulate(models_to_accumulate):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
prompts = batch["prompts"]

# encode batch prompts when custom prompts are provided for each image -
Expand Down Expand Up @@ -1639,7 +1667,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
)

# Convert images to latent space
model_input = vae.encode(pixel_values).latent_dist.sample()
if args.cache_latents:
model_input = latents_cache[step].sample()
else:
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
model_input = model_input.to(dtype=weight_dtype)

Expand Down Expand Up @@ -1773,6 +1805,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
)
text_encoder_one.to(weight_dtype)
text_encoder_two.to(weight_dtype)
pipeline = StableDiffusion3Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
Expand All @@ -1793,15 +1827,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
epoch=epoch,
torch_dtype=weight_dtype,
)

del text_encoder_one, text_encoder_two, text_encoder_three
free_memory()
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two, text_encoder_three
free_memory()

# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
transformer = unwrap_model(transformer)
transformer = transformer.to(torch.float32)
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)

if args.train_text_encoder:
Expand Down
23 changes: 10 additions & 13 deletions examples/dreambooth/train_dreambooth_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import argparse
import copy
import gc
import itertools
import logging
import math
Expand Down Expand Up @@ -51,7 +50,7 @@
StableDiffusion3Pipeline,
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
from diffusers.utils import (
check_min_version,
is_wandb_available,
Expand Down Expand Up @@ -119,7 +118,7 @@ def save_model_card(
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="openrail++",
license="other",
base_model=base_model,
prompt=instance_prompt,
model_description=model_description,
Expand Down Expand Up @@ -164,7 +163,7 @@ def log_validation(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)

# run inference
Expand All @@ -190,8 +189,7 @@ def log_validation(
)

del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
free_memory()

return images

Expand Down Expand Up @@ -1065,8 +1063,7 @@ def main(args):
image.save(image_filename)

del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
free_memory()

# Handle the repository creation
if accelerator.is_main_process:
Expand Down Expand Up @@ -1386,9 +1383,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
del tokenizers, text_encoders
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
del text_encoder_one, text_encoder_two, text_encoder_three
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
free_memory()

# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
Expand Down Expand Up @@ -1708,6 +1703,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
)
text_encoder_one.to(weight_dtype)
text_encoder_two.to(weight_dtype)
text_encoder_three.to(weight_dtype)
pipeline = StableDiffusion3Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
Expand All @@ -1730,8 +1728,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
)
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two, text_encoder_three
torch.cuda.empty_cache()
gc.collect()
free_memory()

# Save the lora layers
accelerator.wait_for_everyone()
Expand Down
Loading