-
Notifications
You must be signed in to change notification settings - Fork 6k
[Examples] Add support for Min-SNR weighting strategy for better convergence #2899
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a009f1d
ecf008f
01b4d70
24cab1d
c73fdba
c8a2856
ca0c158
76e9446
052bc88
c481147
835b5ee
7c842f2
3f078bc
1d9f3bc
d2ce5e6
667d23d
a154335
ad3fb92
084a341
f91f6bd
565566f
bf837f5
7434dcd
077c957
db8bbbd
96e7254
245b558
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -155,6 +155,28 @@ python train_text_to_image_flax.py \ | |
</jax> | ||
</frameworkcontent> | ||
|
||
## Training with Min-SNR weighting | ||
|
||
We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence | ||
by rebalancing the loss. In order to use it, one needs to set the `--snr_gamma` argument. The recommended | ||
value when using it is 5.0. | ||
|
||
You can find [this project on Weights and Biases](https://wandb.ai/sayakpaul/text2image-finetune-minsnr) that compares the loss surfaces of the following setups: | ||
|
||
* Training without the Min-SNR weighting strategy | ||
* Training with the Min-SNR weighting strategy (`snr_gamma` set to 5.0) | ||
* Training with the Min-SNR weighting strategy (`snr_gamma` set to 1.0) | ||
|
||
For our small Pokemons dataset, the effects of Min-SNR weighting strategy might not appear to be pronounced, but for larger datasets, we believe the effects will be more pronounced. | ||
|
||
Also, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds. | ||
|
||
<Tip warning={true}> | ||
|
||
Training with Min-SNR weighting strategy is only supported in PyTorch. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for future PR: Could be cool to add this in jax as well, will be useful for the jax event. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @yiyixuxu could you take a look? |
||
|
||
</Tip> | ||
|
||
## LoRA | ||
|
||
You can also use Low-Rank Adaptation of Large Language Models (LoRA), a fine-tuning technique for accelerating training large models, for fine-tuning text-to-image models. For more details, take a look at the [LoRA training](lora#text-to-image) guide. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,15 +41,74 @@ | |
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel | ||
from diffusers.optimization import get_scheduler | ||
from diffusers.training_utils import EMAModel | ||
from diffusers.utils import check_min_version, deprecate | ||
from diffusers.utils import check_min_version, deprecate, is_wandb_available | ||
from diffusers.utils.import_utils import is_xformers_available | ||
|
||
|
||
if is_wandb_available(): | ||
import wandb | ||
|
||
|
||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. | ||
check_min_version("0.15.0.dev0") | ||
|
||
logger = get_logger(__name__, log_level="INFO") | ||
|
||
DATASET_NAME_MAPPING = { | ||
"lambdalabs/pokemon-blip-captions": ("image", "text"), | ||
} | ||
|
||
|
||
def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added this method as a part of the PR as well. Handles EMA offload and unload properly to ensure inference is being done with the EMA'd checkpoints. |
||
logger.info("Running validation... ") | ||
|
||
pipeline = StableDiffusionPipeline.from_pretrained( | ||
args.pretrained_model_name_or_path, | ||
vae=vae, | ||
text_encoder=text_encoder, | ||
tokenizer=tokenizer, | ||
unet=accelerator.unwrap_model(unet), | ||
safety_checker=None, | ||
revision=args.revision, | ||
torch_dtype=weight_dtype, | ||
) | ||
pipeline = pipeline.to(accelerator.device) | ||
pipeline.set_progress_bar_config(disable=True) | ||
|
||
if args.enable_xformers_memory_efficient_attention: | ||
pipeline.enable_xformers_memory_efficient_attention() | ||
|
||
if args.seed is None: | ||
generator = None | ||
else: | ||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) | ||
|
||
images = [] | ||
for i in range(len(args.validation_prompts)): | ||
with torch.autocast("cuda"): | ||
image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0] | ||
|
||
images.append(image) | ||
|
||
for tracker in accelerator.trackers: | ||
if tracker.name == "tensorboard": | ||
np_images = np.stack([np.asarray(img) for img in images]) | ||
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") | ||
elif tracker.name == "wandb": | ||
tracker.log( | ||
{ | ||
"validation": [ | ||
wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}") | ||
for i, image in enumerate(images) | ||
] | ||
} | ||
) | ||
else: | ||
logger.warn(f"image logging not implemented for {tracker.name}") | ||
|
||
del pipeline | ||
torch.cuda.empty_cache() | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description="Simple example of a training script.") | ||
|
@@ -111,6 +170,13 @@ def parse_args(): | |
"value if set." | ||
), | ||
) | ||
parser.add_argument( | ||
"--validation_prompts", | ||
type=str, | ||
default=None, | ||
nargs="+", | ||
help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), | ||
) | ||
parser.add_argument( | ||
"--output_dir", | ||
type=str, | ||
|
@@ -192,6 +258,13 @@ def parse_args(): | |
parser.add_argument( | ||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." | ||
) | ||
parser.add_argument( | ||
"--snr_gamma", | ||
type=float, | ||
default=None, | ||
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " | ||
"More details here: https://arxiv.org/abs/2303.09556.", | ||
) | ||
parser.add_argument( | ||
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." | ||
) | ||
|
@@ -297,6 +370,21 @@ def parse_args(): | |
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." | ||
) | ||
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") | ||
parser.add_argument( | ||
"--validation_epochs", | ||
type=int, | ||
default=5, | ||
help="Run validation every X epochs.", | ||
) | ||
parser.add_argument( | ||
"--tracker_project_name", | ||
type=str, | ||
default="text2image-fine-tune", | ||
help=( | ||
"The `project_name` argument passed to Accelerator.init_trackers for" | ||
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" | ||
), | ||
) | ||
|
||
args = parser.parse_args() | ||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) | ||
|
@@ -314,11 +402,6 @@ def parse_args(): | |
return args | ||
|
||
|
||
dataset_name_mapping = { | ||
"lambdalabs/pokemon-blip-captions": ("image", "text"), | ||
} | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
|
||
|
@@ -410,6 +493,30 @@ def main(): | |
else: | ||
raise ValueError("xformers is not available. Make sure it is installed correctly") | ||
|
||
def compute_snr(timesteps): | ||
""" | ||
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 | ||
""" | ||
alphas_cumprod = noise_scheduler.alphas_cumprod | ||
sqrt_alphas_cumprod = alphas_cumprod**0.5 | ||
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 | ||
|
||
# Expand the tensors. | ||
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 | ||
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() | ||
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): | ||
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] | ||
alpha = sqrt_alphas_cumprod.expand(timesteps.shape) | ||
|
||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() | ||
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): | ||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] | ||
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) | ||
|
||
# Compute SNR. | ||
snr = (alpha / sigma) ** 2 | ||
return snr | ||
|
||
# `accelerate` 0.16.0 will have better support for customized saving | ||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): | ||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format | ||
|
@@ -507,7 +614,7 @@ def load_model_hook(models, input_dir): | |
column_names = dataset["train"].column_names | ||
|
||
# 6. Get the column names for input/target. | ||
dataset_columns = dataset_name_mapping.get(args.dataset_name, None) | ||
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) | ||
if args.image_column is None: | ||
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] | ||
else: | ||
|
@@ -626,7 +733,9 @@ def collate_fn(examples): | |
# We need to initialize the trackers we use, and also store our configuration. | ||
# The trackers initializes automatically on the main process. | ||
if accelerator.is_main_process: | ||
accelerator.init_trackers("text2image-fine-tune", config=vars(args)) | ||
tracker_config = dict(vars(args)) | ||
tracker_config.pop("validation_prompts") | ||
accelerator.init_trackers(args.tracker_project_name, tracker_config) | ||
|
||
# Train! | ||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | ||
|
@@ -715,7 +824,23 @@ def collate_fn(examples): | |
|
||
# Predict the noise residual and compute loss | ||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
|
||
if args.snr_gamma is None: | ||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
else: | ||
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. | ||
# Since we predict the noise instead of x_0, the original formulation is slightly changed. | ||
Comment on lines
+831
to
+832
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. some models (sd2.1 and above) use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
# This is discussed in Section 4.2 of the same paper. | ||
snr = compute_snr(timesteps) | ||
mse_loss_weights = ( | ||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr | ||
) | ||
# We first calculate the original loss. Then we mean over the non-batch dimensions and | ||
# rebalance the sample-wise losses with their respective loss weights. | ||
# Finally, we take the mean of the rebalanced loss. | ||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") | ||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights | ||
loss = loss.mean() | ||
Comment on lines
+835
to
+843
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! |
||
|
||
# Gather the losses across all processes for logging (if we use distributed training). | ||
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() | ||
|
@@ -750,6 +875,26 @@ def collate_fn(examples): | |
if global_step >= args.max_train_steps: | ||
break | ||
|
||
if accelerator.is_main_process: | ||
if args.validation_prompts is not None and epoch % args.validation_epochs == 0: | ||
if args.use_ema: | ||
# Store the UNet parameters temporarily and load the EMA parameters to perform inference. | ||
ema_unet.store(unet.parameters()) | ||
ema_unet.copy_to(unet.parameters()) | ||
log_validation( | ||
vae, | ||
text_encoder, | ||
tokenizer, | ||
unet, | ||
args, | ||
accelerator, | ||
weight_dtype, | ||
global_step, | ||
) | ||
if args.use_ema: | ||
# Switch back to the original UNet parameters. | ||
ema_unet.restore(unet.parameters()) | ||
|
||
# Create the pipeline using the trained modules and save it. | ||
accelerator.wait_for_everyone() | ||
if accelerator.is_main_process: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, where is this value proposed? Is there a rule of thumb when choosing a value for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's reported in the paper. A gamma of 5.0 always leads to better results in the experiments presented by the authors in the paper.