Skip to content

Commit 1ad630c

Browse files
sayakpaulpatil-suraj
authored and
Jimmy
committed
[Examples] Add support for Min-SNR weighting strategy for better convergence (huggingface#2899)
* improve stable unclip doc. * feat: support for applying min-snr weighting for faster convergence. * add: support for validation logging with wandb * make not a required arg. * fix: arg name. * fix: cli args. * fix: tracker config. * fix: loss calculation. * fix: validation logging. * fix: unwrap call. * fix: validation logging. * fix: internval. * fix: checkpointing push to hub. * fix: https://github.com/huggingface/diffusers/commit/c8a2856c6d5e45577bf4c24dee06b1a4a2f5c050\#commitcomment-106913193 * fix: norm group test for UNet3D. * address PR comments. * remove unneeded code. * add: entry in the readme and docs. * Apply suggestions from code review Co-authored-by: Suraj Patil <[email protected]> --------- Co-authored-by: Suraj Patil <[email protected]>
1 parent af5887a commit 1ad630c

File tree

3 files changed

+192
-9
lines changed

3 files changed

+192
-9
lines changed

docs/source/en/training/text2image.mdx

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,28 @@ python train_text_to_image_flax.py \
155155
</jax>
156156
</frameworkcontent>
157157

158+
## Training with Min-SNR weighting
159+
160+
We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence
161+
by rebalancing the loss. In order to use it, one needs to set the `--snr_gamma` argument. The recommended
162+
value when using it is 5.0.
163+
164+
You can find [this project on Weights and Biases](https://wandb.ai/sayakpaul/text2image-finetune-minsnr) that compares the loss surfaces of the following setups:
165+
166+
* Training without the Min-SNR weighting strategy
167+
* Training with the Min-SNR weighting strategy (`snr_gamma` set to 5.0)
168+
* Training with the Min-SNR weighting strategy (`snr_gamma` set to 1.0)
169+
170+
For our small Pokemons dataset, the effects of Min-SNR weighting strategy might not appear to be pronounced, but for larger datasets, we believe the effects will be more pronounced.
171+
172+
Also, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds.
173+
174+
<Tip warning={true}>
175+
176+
Training with Min-SNR weighting strategy is only supported in PyTorch.
177+
178+
</Tip>
179+
158180
## LoRA
159181

160182
You can also use Low-Rank Adaptation of Large Language Models (LoRA), a fine-tuning technique for accelerating training large models, for fine-tuning text-to-image models. For more details, take a look at the [LoRA training](lora#text-to-image) guide.

examples/text_to_image/README.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,22 @@ image = pipe(prompt="yoda").images[0]
111111
image.save("yoda-pokemon.png")
112112
```
113113

114+
#### Training with Min-SNR weighting
115+
116+
We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence
117+
by rebalancing the loss. In order to use it, one needs to set the `--snr_gamma` argument. The recommended
118+
value when using it is 5.0.
119+
120+
You can find [this project on Weights and Biases](https://wandb.ai/sayakpaul/text2image-finetune-minsnr) that compares the loss surfaces of the following setups:
121+
122+
* Training without the Min-SNR weighting strategy
123+
* Training with the Min-SNR weighting strategy (`snr_gamma` set to 5.0)
124+
* Training with the Min-SNR weighting strategy (`snr_gamma` set to 1.0)
125+
126+
For our small Pokemons dataset, the effects of Min-SNR weighting strategy might not appear to be pronounced, but for larger datasets, we believe the effects will be more pronounced.
127+
128+
Also, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds.
129+
114130
## Training with LoRA
115131

116132
Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.

examples/text_to_image/train_text_to_image.py

Lines changed: 154 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,74 @@
4141
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
4242
from diffusers.optimization import get_scheduler
4343
from diffusers.training_utils import EMAModel
44-
from diffusers.utils import check_min_version, deprecate
44+
from diffusers.utils import check_min_version, deprecate, is_wandb_available
4545
from diffusers.utils.import_utils import is_xformers_available
4646

4747

48+
if is_wandb_available():
49+
import wandb
50+
51+
4852
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
4953
check_min_version("0.15.0.dev0")
5054

5155
logger = get_logger(__name__, log_level="INFO")
5256

57+
DATASET_NAME_MAPPING = {
58+
"lambdalabs/pokemon-blip-captions": ("image", "text"),
59+
}
60+
61+
62+
def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):
63+
logger.info("Running validation... ")
64+
65+
pipeline = StableDiffusionPipeline.from_pretrained(
66+
args.pretrained_model_name_or_path,
67+
vae=vae,
68+
text_encoder=text_encoder,
69+
tokenizer=tokenizer,
70+
unet=accelerator.unwrap_model(unet),
71+
safety_checker=None,
72+
revision=args.revision,
73+
torch_dtype=weight_dtype,
74+
)
75+
pipeline = pipeline.to(accelerator.device)
76+
pipeline.set_progress_bar_config(disable=True)
77+
78+
if args.enable_xformers_memory_efficient_attention:
79+
pipeline.enable_xformers_memory_efficient_attention()
80+
81+
if args.seed is None:
82+
generator = None
83+
else:
84+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
85+
86+
images = []
87+
for i in range(len(args.validation_prompts)):
88+
with torch.autocast("cuda"):
89+
image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
90+
91+
images.append(image)
92+
93+
for tracker in accelerator.trackers:
94+
if tracker.name == "tensorboard":
95+
np_images = np.stack([np.asarray(img) for img in images])
96+
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
97+
elif tracker.name == "wandb":
98+
tracker.log(
99+
{
100+
"validation": [
101+
wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
102+
for i, image in enumerate(images)
103+
]
104+
}
105+
)
106+
else:
107+
logger.warn(f"image logging not implemented for {tracker.name}")
108+
109+
del pipeline
110+
torch.cuda.empty_cache()
111+
53112

54113
def parse_args():
55114
parser = argparse.ArgumentParser(description="Simple example of a training script.")
@@ -111,6 +170,13 @@ def parse_args():
111170
"value if set."
112171
),
113172
)
173+
parser.add_argument(
174+
"--validation_prompts",
175+
type=str,
176+
default=None,
177+
nargs="+",
178+
help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
179+
)
114180
parser.add_argument(
115181
"--output_dir",
116182
type=str,
@@ -192,6 +258,13 @@ def parse_args():
192258
parser.add_argument(
193259
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
194260
)
261+
parser.add_argument(
262+
"--snr_gamma",
263+
type=float,
264+
default=None,
265+
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
266+
"More details here: https://arxiv.org/abs/2303.09556.",
267+
)
195268
parser.add_argument(
196269
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
197270
)
@@ -297,6 +370,21 @@ def parse_args():
297370
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
298371
)
299372
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
373+
parser.add_argument(
374+
"--validation_epochs",
375+
type=int,
376+
default=5,
377+
help="Run validation every X epochs.",
378+
)
379+
parser.add_argument(
380+
"--tracker_project_name",
381+
type=str,
382+
default="text2image-fine-tune",
383+
help=(
384+
"The `project_name` argument passed to Accelerator.init_trackers for"
385+
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
386+
),
387+
)
300388

301389
args = parser.parse_args()
302390
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@@ -314,11 +402,6 @@ def parse_args():
314402
return args
315403

316404

317-
dataset_name_mapping = {
318-
"lambdalabs/pokemon-blip-captions": ("image", "text"),
319-
}
320-
321-
322405
def main():
323406
args = parse_args()
324407

@@ -410,6 +493,30 @@ def main():
410493
else:
411494
raise ValueError("xformers is not available. Make sure it is installed correctly")
412495

496+
def compute_snr(timesteps):
497+
"""
498+
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
499+
"""
500+
alphas_cumprod = noise_scheduler.alphas_cumprod
501+
sqrt_alphas_cumprod = alphas_cumprod**0.5
502+
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
503+
504+
# Expand the tensors.
505+
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
506+
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
507+
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
508+
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
509+
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
510+
511+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
512+
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
513+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
514+
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
515+
516+
# Compute SNR.
517+
snr = (alpha / sigma) ** 2
518+
return snr
519+
413520
# `accelerate` 0.16.0 will have better support for customized saving
414521
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
415522
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -507,7 +614,7 @@ def load_model_hook(models, input_dir):
507614
column_names = dataset["train"].column_names
508615

509616
# 6. Get the column names for input/target.
510-
dataset_columns = dataset_name_mapping.get(args.dataset_name, None)
617+
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
511618
if args.image_column is None:
512619
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
513620
else:
@@ -626,7 +733,9 @@ def collate_fn(examples):
626733
# We need to initialize the trackers we use, and also store our configuration.
627734
# The trackers initializes automatically on the main process.
628735
if accelerator.is_main_process:
629-
accelerator.init_trackers("text2image-fine-tune", config=vars(args))
736+
tracker_config = dict(vars(args))
737+
tracker_config.pop("validation_prompts")
738+
accelerator.init_trackers(args.tracker_project_name, tracker_config)
630739

631740
# Train!
632741
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@@ -715,7 +824,23 @@ def collate_fn(examples):
715824

716825
# Predict the noise residual and compute loss
717826
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
718-
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
827+
828+
if args.snr_gamma is None:
829+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
830+
else:
831+
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
832+
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
833+
# This is discussed in Section 4.2 of the same paper.
834+
snr = compute_snr(timesteps)
835+
mse_loss_weights = (
836+
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
837+
)
838+
# We first calculate the original loss. Then we mean over the non-batch dimensions and
839+
# rebalance the sample-wise losses with their respective loss weights.
840+
# Finally, we take the mean of the rebalanced loss.
841+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
842+
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
843+
loss = loss.mean()
719844

720845
# Gather the losses across all processes for logging (if we use distributed training).
721846
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
@@ -750,6 +875,26 @@ def collate_fn(examples):
750875
if global_step >= args.max_train_steps:
751876
break
752877

878+
if accelerator.is_main_process:
879+
if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
880+
if args.use_ema:
881+
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
882+
ema_unet.store(unet.parameters())
883+
ema_unet.copy_to(unet.parameters())
884+
log_validation(
885+
vae,
886+
text_encoder,
887+
tokenizer,
888+
unet,
889+
args,
890+
accelerator,
891+
weight_dtype,
892+
global_step,
893+
)
894+
if args.use_ema:
895+
# Switch back to the original UNet parameters.
896+
ema_unet.restore(unet.parameters())
897+
753898
# Create the pipeline using the trained modules and save it.
754899
accelerator.wait_for_everyone()
755900
if accelerator.is_main_process:

0 commit comments

Comments
 (0)