Skip to content

Commit 37b8edf

Browse files
authored
[train_dreambooth_lora.py] Fix the LR Schedulers when num_train_epochs is passed in a distributed training env (#10973)
* updated train_dreambooth_lora to fix the LR schedulers for `num_train_epochs` in distributed training env * fixed formatting * remove trailing newlines * fixed style error
1 parent fbf6b85 commit 37b8edf

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,17 +1119,22 @@ def compute_text_embeddings(prompt):
11191119
)
11201120

11211121
# Scheduler and math around the number of training steps.
1122-
overrode_max_train_steps = False
1123-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1122+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1123+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
11241124
if args.max_train_steps is None:
1125-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1126-
overrode_max_train_steps = True
1125+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
1126+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
1127+
num_training_steps_for_scheduler = (
1128+
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
1129+
)
1130+
else:
1131+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
11271132

11281133
lr_scheduler = get_scheduler(
11291134
args.lr_scheduler,
11301135
optimizer=optimizer,
1131-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1132-
num_training_steps=args.max_train_steps * accelerator.num_processes,
1136+
num_warmup_steps=num_warmup_steps_for_scheduler,
1137+
num_training_steps=num_training_steps_for_scheduler,
11331138
num_cycles=args.lr_num_cycles,
11341139
power=args.lr_power,
11351140
)
@@ -1146,8 +1151,15 @@ def compute_text_embeddings(prompt):
11461151

11471152
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
11481153
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1149-
if overrode_max_train_steps:
1154+
if args.max_train_steps is None:
11501155
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1156+
if num_training_steps_for_scheduler != args.max_train_steps:
1157+
logger.warning(
1158+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
1159+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
1160+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
1161+
)
1162+
11511163
# Afterwards we recalculate our number of training epochs
11521164
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
11531165

0 commit comments

Comments
 (0)