-
Notifications
You must be signed in to change notification settings - Fork 6k
[Examples] Add support for Min-SNR weighting strategy for better convergence #2899
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 18 commits
a009f1d
ecf008f
01b4d70
24cab1d
c73fdba
c8a2856
ca0c158
76e9446
052bc88
c481147
835b5ee
7c842f2
3f078bc
1d9f3bc
d2ce5e6
667d23d
a154335
ad3fb92
084a341
f91f6bd
565566f
bf837f5
7434dcd
077c957
db8bbbd
96e7254
245b558
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,15 +42,23 @@ | |
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel | ||
from diffusers.optimization import get_scheduler | ||
from diffusers.training_utils import EMAModel | ||
from diffusers.utils import check_min_version, deprecate | ||
from diffusers.utils import check_min_version, deprecate, is_wandb_available | ||
from diffusers.utils.import_utils import is_xformers_available | ||
|
||
|
||
if is_wandb_available(): | ||
import wandb | ||
|
||
|
||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. | ||
check_min_version("0.15.0.dev0") | ||
|
||
logger = get_logger(__name__, log_level="INFO") | ||
|
||
DATASET_NAME_MAPPING = { | ||
"lambdalabs/pokemon-blip-captions": ("image", "text"), | ||
} | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description="Simple example of a training script.") | ||
|
@@ -112,6 +120,13 @@ def parse_args(): | |
"value if set." | ||
), | ||
) | ||
parser.add_argument( | ||
"--validation_prompts", | ||
type=str, | ||
default=None, | ||
nargs="+", | ||
help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), | ||
) | ||
parser.add_argument( | ||
"--output_dir", | ||
type=str, | ||
|
@@ -193,6 +208,13 @@ def parse_args(): | |
parser.add_argument( | ||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." | ||
) | ||
parser.add_argument( | ||
"--snr_gamma", | ||
type=float, | ||
default=None, | ||
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " | ||
"More details here: https://arxiv.org/abs/2303.09556.", | ||
) | ||
parser.add_argument( | ||
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." | ||
) | ||
|
@@ -298,6 +320,21 @@ def parse_args(): | |
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." | ||
) | ||
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") | ||
parser.add_argument( | ||
"--validation_epochs", | ||
type=int, | ||
default=5, | ||
help="Run validation every X epochs.", | ||
) | ||
parser.add_argument( | ||
"--tracker_project_name", | ||
type=str, | ||
default="text2image-fine-tune", | ||
help=( | ||
"The `project_name` argument passed to Accelerator.init_trackers for" | ||
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" | ||
), | ||
) | ||
|
||
args = parser.parse_args() | ||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) | ||
|
@@ -325,9 +362,83 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: | |
return f"{organization}/{model_id}" | ||
|
||
|
||
dataset_name_mapping = { | ||
"lambdalabs/pokemon-blip-captions": ("image", "text"), | ||
} | ||
def expand_tensor(arr, timesteps, broadcast_shape): | ||
""" | ||
Extract values from a 1-D numpy array for a batch of indices. | ||
Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 | ||
""" | ||
res = arr.to(device=timesteps.device)[timesteps].float() | ||
while len(res.shape) < len(broadcast_shape): | ||
res = res[..., None] | ||
return res.expand(broadcast_shape) | ||
|
||
|
||
def compute_snr(noise_scheduler): | ||
""" | ||
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 | ||
""" | ||
alphas_cumprod = noise_scheduler.alphas_cumprod | ||
sqrt_alphas_cumprod = alphas_cumprod**0.5 | ||
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 | ||
|
||
def fn(timesteps): | ||
alpha = expand_tensor(sqrt_alphas_cumprod, timesteps, timesteps.shape) | ||
sigma = expand_tensor(sqrt_one_minus_alphas_cumprod, timesteps, timesteps.shape) | ||
snr = (alpha / sigma) ** 2 | ||
return snr | ||
|
||
return fn | ||
|
||
|
||
def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this related to this PR title or does it just add logging? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The PR could have been without this utility but it was important to have it in the PR because otherwise, it was difficult to validate the effectiveness of the method. FWIW, I am not a fan of adding unrelated changes in a PR but this seemed important. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok for me! |
||
logger.info("Running validation... ") | ||
|
||
pipeline = StableDiffusionPipeline.from_pretrained( | ||
args.pretrained_model_name_or_path, | ||
vae=vae, | ||
text_encoder=text_encoder, | ||
tokenizer=tokenizer, | ||
unet=accelerator.unwrap_model(unet), | ||
safety_checker=None, | ||
revision=args.revision, | ||
torch_dtype=weight_dtype, | ||
) | ||
pipeline = pipeline.to(accelerator.device) | ||
pipeline.set_progress_bar_config(disable=True) | ||
|
||
if args.enable_xformers_memory_efficient_attention: | ||
pipeline.enable_xformers_memory_efficient_attention() | ||
|
||
if args.seed is None: | ||
generator = None | ||
else: | ||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) | ||
|
||
images = [] | ||
for i in range(len(args.validation_prompts)): | ||
with torch.autocast("cuda"): | ||
image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0] | ||
|
||
images.append(image) | ||
|
||
for tracker in accelerator.trackers: | ||
if tracker.name == "tensorboard": | ||
np_images = np.stack([np.asarray(img) for img in images]) | ||
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") | ||
elif tracker.name == "wandb": | ||
tracker.log( | ||
{ | ||
"validation": [ | ||
wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}") | ||
for i, image in enumerate(images) | ||
] | ||
} | ||
) | ||
else: | ||
logger.warn(f"image logging not implemented for {tracker.name}") | ||
|
||
del pipeline | ||
torch.cuda.empty_cache() | ||
|
||
|
||
def main(): | ||
|
@@ -389,6 +500,8 @@ def main(): | |
gitignore.write("step_*\n") | ||
if "epoch_*" not in gitignore: | ||
gitignore.write("epoch_*\n") | ||
if "checkpoint-*" not in gitignore: | ||
gitignore.write("checkpoint-*\n") | ||
elif args.output_dir is not None: | ||
os.makedirs(args.output_dir, exist_ok=True) | ||
|
||
|
@@ -476,6 +589,9 @@ def load_model_hook(models, input_dir): | |
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes | ||
) | ||
|
||
if args.snr_gamma is not None: | ||
snr_fn = compute_snr(noise_scheduler) | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Initialize the optimizer | ||
if args.use_8bit_adam: | ||
try: | ||
|
@@ -526,7 +642,7 @@ def load_model_hook(models, input_dir): | |
column_names = dataset["train"].column_names | ||
|
||
# 6. Get the column names for input/target. | ||
dataset_columns = dataset_name_mapping.get(args.dataset_name, None) | ||
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) | ||
if args.image_column is None: | ||
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] | ||
else: | ||
|
@@ -645,7 +761,9 @@ def collate_fn(examples): | |
# We need to initialize the trackers we use, and also store our configuration. | ||
# The trackers initializes automatically on the main process. | ||
if accelerator.is_main_process: | ||
accelerator.init_trackers("text2image-fine-tune", config=vars(args)) | ||
tracker_config = dict(vars(args)) | ||
tracker_config.pop("validation_prompts") | ||
accelerator.init_trackers(args.tracker_project_name, tracker_config) | ||
|
||
# Train! | ||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | ||
|
@@ -734,7 +852,23 @@ def collate_fn(examples): | |
|
||
# Predict the noise residual and compute loss | ||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
|
||
if args.snr_gamma is None: | ||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
else: | ||
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. | ||
# Since we predict the noise instead of x_0, the original formulation is slightly changed. | ||
Comment on lines
+831
to
+832
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. some models (sd2.1 and above) use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
# This is discussed in Section 4.2 of the same paper. | ||
snr = snr_fn(timesteps) | ||
mse_loss_weights = ( | ||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr | ||
) | ||
# We first calculate the original loss. Then we mean over the non-batch dimensions and | ||
# rebalance the sample-wise losses with their respective loss weights. | ||
# Finally, we take the mean of the rebalanced loss. | ||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") | ||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights | ||
loss = loss.mean() | ||
Comment on lines
+835
to
+843
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! |
||
|
||
# Gather the losses across all processes for logging (if we use distributed training). | ||
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() | ||
|
@@ -769,6 +903,26 @@ def collate_fn(examples): | |
if global_step >= args.max_train_steps: | ||
break | ||
|
||
if accelerator.is_main_process: | ||
if args.validation_prompts is not None and epoch % args.validation_epochs == 0: | ||
if args.use_ema: | ||
# Store the UNet parameters temporarily and load the EMA parameters to perform inference. | ||
ema_unet.store(unet.parameters()) | ||
ema_unet.copy_to(unet.parameters()) | ||
log_validation( | ||
vae, | ||
text_encoder, | ||
tokenizer, | ||
unet, | ||
args, | ||
accelerator, | ||
weight_dtype, | ||
global_step, | ||
) | ||
if args.use_ema: | ||
# Switch back to the original UNet parameters. | ||
ema_unet.restore(unet.parameters()) | ||
|
||
# Create the pipeline using the trained modules and save it. | ||
accelerator.wait_for_everyone() | ||
if accelerator.is_main_process: | ||
|
@@ -786,7 +940,7 @@ def collate_fn(examples): | |
pipeline.save_pretrained(args.output_dir) | ||
|
||
if args.push_to_hub: | ||
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) | ||
repo.push_to_hub(commit_message="End of training", blocking=True, auto_lfs_prune=True) | ||
|
||
accelerator.end_training() | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.