Skip to content

Commit 81331f3

Browse files
bghirabghirasayakpaul
authored
Add x-prediction / prediction_type=sample support for SDXL fine-tuning (#5095)
Co-authored-by: bghira <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 2997075 commit 81331f3

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

examples/text_to_image/train_text_to_image_sdxl.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -998,6 +998,11 @@ def compute_time_ids(original_size, crops_coords_top_left):
998998
target = noise
999999
elif noise_scheduler.config.prediction_type == "v_prediction":
10001000
target = noise_scheduler.get_velocity(model_input, noise, timesteps)
1001+
elif noise_scheduler.config.prediction_type == "sample":
1002+
# We set the target to latents here, but the model_pred will return the noise sample prediction.
1003+
target = model_input
1004+
# We will have to subtract the noise residual from the prediction to get the target sample.
1005+
model_pred = model_pred - noise
10011006
else:
10021007
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
10031008

0 commit comments

Comments
 (0)