Skip to content

Commit 4bb129d

Browse files
wfng92Jimmy
authored and
Jimmy
committed
Add min snr to text2img lora training script (huggingface#3459)
add min snr to text2img lora training script
1 parent ec2670d commit 4bb129d

File tree

1 file changed

+48
-1
lines changed

1 file changed

+48
-1
lines changed

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,13 @@ def parse_args():
239239
parser.add_argument(
240240
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
241241
)
242+
parser.add_argument(
243+
"--snr_gamma",
244+
type=float,
245+
default=None,
246+
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
247+
"More details here: https://arxiv.org/abs/2303.09556.",
248+
)
242249
parser.add_argument(
243250
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
244251
)
@@ -472,6 +479,30 @@ def main():
472479
else:
473480
raise ValueError("xformers is not available. Make sure it is installed correctly")
474481

482+
def compute_snr(timesteps):
483+
"""
484+
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
485+
"""
486+
alphas_cumprod = noise_scheduler.alphas_cumprod
487+
sqrt_alphas_cumprod = alphas_cumprod**0.5
488+
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
489+
490+
# Expand the tensors.
491+
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
492+
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
493+
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
494+
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
495+
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
496+
497+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
498+
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
499+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
500+
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
501+
502+
# Compute SNR.
503+
snr = (alpha / sigma) ** 2
504+
return snr
505+
475506
lora_layers = AttnProcsLayers(unet.attn_processors)
476507

477508
# Enable TF32 for faster training on Ampere GPUs,
@@ -727,7 +758,23 @@ def collate_fn(examples):
727758

728759
# Predict the noise residual and compute loss
729760
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
730-
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
761+
762+
if args.snr_gamma is None:
763+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
764+
else:
765+
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
766+
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
767+
# This is discussed in Section 4.2 of the same paper.
768+
snr = compute_snr(timesteps)
769+
mse_loss_weights = (
770+
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
771+
)
772+
# We first calculate the original loss. Then we mean over the non-batch dimensions and
773+
# rebalance the sample-wise losses with their respective loss weights.
774+
# Finally, we take the mean of the rebalanced loss.
775+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
776+
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
777+
loss = loss.mean()
731778

732779
# Gather the losses across all processes for logging (if we use distributed training).
733780
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()

0 commit comments

Comments
 (0)