@@ -927,17 +927,22 @@ def load_model_hook(models, input_dir):
927
927
)
928
928
929
929
# Scheduler and math around the number of training steps.
930
- overrode_max_train_steps = False
931
- num_update_steps_per_epoch = math . ceil ( len ( train_dataloader ) / args . gradient_accumulation_steps )
930
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
931
+ num_warmup_steps_for_scheduler = args . lr_warmup_steps * accelerator . num_processes
932
932
if args .max_train_steps is None :
933
- args .max_train_steps = args .num_train_epochs * num_update_steps_per_epoch
934
- overrode_max_train_steps = True
933
+ len_train_dataloader_after_sharding = math .ceil (len (train_dataloader ) / accelerator .num_processes )
934
+ num_update_steps_per_epoch = math .ceil (len_train_dataloader_after_sharding / args .gradient_accumulation_steps )
935
+ num_training_steps_for_scheduler = (
936
+ args .num_train_epochs * num_update_steps_per_epoch * accelerator .num_processes
937
+ )
938
+ else :
939
+ num_training_steps_for_scheduler = args .max_train_steps * accelerator .num_processes
935
940
936
941
lr_scheduler = get_scheduler (
937
942
args .lr_scheduler ,
938
943
optimizer = optimizer ,
939
- num_warmup_steps = args . lr_warmup_steps * accelerator . num_processes ,
940
- num_training_steps = args . max_train_steps * accelerator . num_processes ,
944
+ num_warmup_steps = num_warmup_steps_for_scheduler ,
945
+ num_training_steps = num_training_steps_for_scheduler ,
941
946
num_cycles = args .lr_num_cycles ,
942
947
power = args .lr_power ,
943
948
)
@@ -962,8 +967,14 @@ def load_model_hook(models, input_dir):
962
967
963
968
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
964
969
num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
965
- if overrode_max_train_steps :
970
+ if args . max_train_steps is None :
966
971
args .max_train_steps = args .num_train_epochs * num_update_steps_per_epoch
972
+ if num_training_steps_for_scheduler != args .max_train_steps * accelerator .num_processes :
973
+ logger .warning (
974
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({ len (train_dataloader )} ) does not match "
975
+ f"the expected length ({ len_train_dataloader_after_sharding } ) when the learning rate scheduler was created. "
976
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
977
+ )
967
978
# Afterwards we recalculate our number of training epochs
968
979
args .num_train_epochs = math .ceil (args .max_train_steps / num_update_steps_per_epoch )
969
980
0 commit comments