Skip to content

Commit e4fe941

Browse files
authored
[examples] update loss computation (#1861)
update loss computation
1 parent ac37384 commit e4fe941

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -732,7 +732,7 @@ def main(args):
732732
target, target_prior = torch.chunk(target, 2, dim=0)
733733

734734
# Compute instance loss
735-
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
735+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
736736

737737
# Compute prior loss
738738
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")

examples/textual_inversion/textual_inversion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,8 @@ def main():
634634
else:
635635
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
636636

637-
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
637+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
638+
638639
accelerator.backward(loss)
639640

640641
optimizer.step()

0 commit comments

Comments
 (0)