Skip to content

Commit 74e43a4

Browse files
bghirabghirasayakpaul
authored
Resolve v_prediction issue for min-SNR gamma weighted loss function (#5096)
* Resolve v_prediction issue for min-SNR gamma weighted loss function * Combine MSE loss calculation of epsilon and velocity, with a note about the application of the epsilon code to sample prediction * style --------- Co-authored-by: bghira <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 81331f3 commit 74e43a4

File tree

1 file changed

+9
-22
lines changed

1 file changed

+9
-22
lines changed

examples/text_to_image/train_text_to_image_sdxl.py

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -332,15 +332,6 @@ def parse_args(input_args=None):
332332
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
333333
"More details here: https://arxiv.org/abs/2303.09556.",
334334
)
335-
parser.add_argument(
336-
"--force_snr_gamma",
337-
action="store_true",
338-
help=(
339-
"When using SNR gamma with rescaled betas for zero terminal SNR, a divide-by-zero error can cause NaN"
340-
" condition when computing the SNR with a sigma value of zero. This parameter overrides the check,"
341-
" allowing the use of SNR gamma with a terminal SNR model. Use with caution, and closely monitor results."
342-
),
343-
)
344335
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
345336
parser.add_argument(
346337
"--allow_tf32",
@@ -554,18 +545,6 @@ def main(args):
554545
# Load scheduler and models
555546
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
556547
# Check for terminal SNR in combination with SNR Gamma
557-
if (
558-
args.snr_gamma
559-
and not args.force_snr_gamma
560-
and (
561-
hasattr(noise_scheduler.config, "rescale_betas_zero_snr") and noise_scheduler.config.rescale_betas_zero_snr
562-
)
563-
):
564-
raise ValueError(
565-
f"The selected noise scheduler for the model {args.pretrained_model_name_or_path} uses rescaled betas for zero SNR.\n"
566-
"When this configuration is present, the parameter --snr_gamma may not be used without parameter --force_snr_gamma.\n"
567-
"This is due to a mathematical incompatibility between our current SNR gamma implementation, and a sigma value of zero."
568-
)
569548
text_encoder_one = text_encoder_cls_one.from_pretrained(
570549
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
571550
)
@@ -1013,9 +992,17 @@ def compute_time_ids(original_size, crops_coords_top_left):
1013992
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
1014993
# This is discussed in Section 4.2 of the same paper.
1015994
snr = compute_snr(timesteps)
1016-
mse_loss_weights = (
995+
base_weight = (
1017996
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
1018997
)
998+
999+
if noise_scheduler.config.prediction_type == "v_prediction":
1000+
# Velocity objective needs to be floored to an SNR weight of one.
1001+
mse_loss_weights = base_weight + 1
1002+
else:
1003+
# Epsilon and sample both use the same loss weights.
1004+
mse_loss_weights = base_weight
1005+
10191006
# We first calculate the original loss. Then we mean over the non-batch dimensions and
10201007
# rebalance the sample-wise losses with their respective loss weights.
10211008
# Finally, we take the mean of the rebalanced loss.

0 commit comments

Comments
 (0)