@@ -1119,17 +1119,22 @@ def compute_text_embeddings(prompt):
1119
1119
)
1120
1120
1121
1121
# Scheduler and math around the number of training steps.
1122
- overrode_max_train_steps = False
1123
- num_update_steps_per_epoch = math . ceil ( len ( train_dataloader ) / args . gradient_accumulation_steps )
1122
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1123
+ num_warmup_steps_for_scheduler = args . lr_warmup_steps * accelerator . num_processes
1124
1124
if args .max_train_steps is None :
1125
- args .max_train_steps = args .num_train_epochs * num_update_steps_per_epoch
1126
- overrode_max_train_steps = True
1125
+ len_train_dataloader_after_sharding = math .ceil (len (train_dataloader ) / accelerator .num_processes )
1126
+ num_update_steps_per_epoch = math .ceil (len_train_dataloader_after_sharding / args .gradient_accumulation_steps )
1127
+ num_training_steps_for_scheduler = (
1128
+ args .num_train_epochs * accelerator .num_processes * num_update_steps_per_epoch
1129
+ )
1130
+ else :
1131
+ num_training_steps_for_scheduler = args .max_train_steps * accelerator .num_processes
1127
1132
1128
1133
lr_scheduler = get_scheduler (
1129
1134
args .lr_scheduler ,
1130
1135
optimizer = optimizer ,
1131
- num_warmup_steps = args . lr_warmup_steps * accelerator . num_processes ,
1132
- num_training_steps = args . max_train_steps * accelerator . num_processes ,
1136
+ num_warmup_steps = num_warmup_steps_for_scheduler ,
1137
+ num_training_steps = num_training_steps_for_scheduler ,
1133
1138
num_cycles = args .lr_num_cycles ,
1134
1139
power = args .lr_power ,
1135
1140
)
@@ -1146,8 +1151,15 @@ def compute_text_embeddings(prompt):
1146
1151
1147
1152
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
1148
1153
num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
1149
- if overrode_max_train_steps :
1154
+ if args . max_train_steps is None :
1150
1155
args .max_train_steps = args .num_train_epochs * num_update_steps_per_epoch
1156
+ if num_training_steps_for_scheduler != args .max_train_steps :
1157
+ logger .warning (
1158
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({ len (train_dataloader )} ) does not match "
1159
+ f"the expected length ({ len_train_dataloader_after_sharding } ) when the learning rate scheduler was created. "
1160
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
1161
+ )
1162
+
1151
1163
# Afterwards we recalculate our number of training epochs
1152
1164
args .num_train_epochs = math .ceil (args .max_train_steps / num_update_steps_per_epoch )
1153
1165
0 commit comments