@@ -332,15 +332,6 @@ def parse_args(input_args=None):
332
332
help = "SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
333
333
"More details here: https://arxiv.org/abs/2303.09556." ,
334
334
)
335
- parser .add_argument (
336
- "--force_snr_gamma" ,
337
- action = "store_true" ,
338
- help = (
339
- "When using SNR gamma with rescaled betas for zero terminal SNR, a divide-by-zero error can cause NaN"
340
- " condition when computing the SNR with a sigma value of zero. This parameter overrides the check,"
341
- " allowing the use of SNR gamma with a terminal SNR model. Use with caution, and closely monitor results."
342
- ),
343
- )
344
335
parser .add_argument ("--use_ema" , action = "store_true" , help = "Whether to use EMA model." )
345
336
parser .add_argument (
346
337
"--allow_tf32" ,
@@ -554,18 +545,6 @@ def main(args):
554
545
# Load scheduler and models
555
546
noise_scheduler = DDPMScheduler .from_pretrained (args .pretrained_model_name_or_path , subfolder = "scheduler" )
556
547
# Check for terminal SNR in combination with SNR Gamma
557
- if (
558
- args .snr_gamma
559
- and not args .force_snr_gamma
560
- and (
561
- hasattr (noise_scheduler .config , "rescale_betas_zero_snr" ) and noise_scheduler .config .rescale_betas_zero_snr
562
- )
563
- ):
564
- raise ValueError (
565
- f"The selected noise scheduler for the model { args .pretrained_model_name_or_path } uses rescaled betas for zero SNR.\n "
566
- "When this configuration is present, the parameter --snr_gamma may not be used without parameter --force_snr_gamma.\n "
567
- "This is due to a mathematical incompatibility between our current SNR gamma implementation, and a sigma value of zero."
568
- )
569
548
text_encoder_one = text_encoder_cls_one .from_pretrained (
570
549
args .pretrained_model_name_or_path , subfolder = "text_encoder" , revision = args .revision
571
550
)
@@ -1013,9 +992,17 @@ def compute_time_ids(original_size, crops_coords_top_left):
1013
992
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
1014
993
# This is discussed in Section 4.2 of the same paper.
1015
994
snr = compute_snr (timesteps )
1016
- mse_loss_weights = (
995
+ base_weight = (
1017
996
torch .stack ([snr , args .snr_gamma * torch .ones_like (timesteps )], dim = 1 ).min (dim = 1 )[0 ] / snr
1018
997
)
998
+
999
+ if noise_scheduler .config .prediction_type == "v_prediction" :
1000
+ # Velocity objective needs to be floored to an SNR weight of one.
1001
+ mse_loss_weights = base_weight + 1
1002
+ else :
1003
+ # Epsilon and sample both use the same loss weights.
1004
+ mse_loss_weights = base_weight
1005
+
1019
1006
# We first calculate the original loss. Then we mean over the non-batch dimensions and
1020
1007
# rebalance the sample-wise losses with their respective loss weights.
1021
1008
# Finally, we take the mean of the rebalanced loss.
0 commit comments