Skip to content

Commit 5c07de9

Browse files
committed
fix mixed precision issue as proposed in huggingface#9565
1 parent 4ca5101 commit 5c07de9

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

examples/dreambooth/train_dreambooth_lora_sd3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def log_validation(
186186
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
187187
f" {args.validation_prompt}."
188188
)
189-
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
189+
pipeline = pipeline.to(accelerator.device)
190190
pipeline.set_progress_bar_config(disable=True)
191191

192192
# run inference
@@ -1805,6 +1805,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18051805
text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
18061806
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
18071807
)
1808+
text_encoder_one.to(weight_dtype)
1809+
text_encoder_two.to(weight_dtype)
18081810
pipeline = StableDiffusion3Pipeline.from_pretrained(
18091811
args.pretrained_model_name_or_path,
18101812
vae=vae,

examples/dreambooth/train_dreambooth_sd3.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def log_validation(
164164
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
165165
f" {args.validation_prompt}."
166166
)
167-
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
167+
pipeline = pipeline.to(accelerator.device)
168168
pipeline.set_progress_bar_config(disable=True)
169169

170170
# run inference
@@ -1704,6 +1704,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17041704
text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
17051705
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
17061706
)
1707+
text_encoder_one.to(weight_dtype)
1708+
text_encoder_two.to(weight_dtype)
1709+
text_encoder_three.to(weight_dtype)
17071710
pipeline = StableDiffusion3Pipeline.from_pretrained(
17081711
args.pretrained_model_name_or_path,
17091712
vae=vae,

0 commit comments

Comments
 (0)