Skip to content

Commit c7e62c4

Browse files
committed
Guidance on denoised predictions
1 parent 8545966 commit c7e62c4

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1088,8 +1088,10 @@ def __call__(
10881088

10891089
# perform guidance
10901090
if self.do_classifier_free_guidance:
1091-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
10921091
if adaptive_projected_guidance:
1092+
sigma = self.scheduler.sigmas[self.scheduler.step_index]
1093+
noise_pred = latents - sigma * noise_pred
1094+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
10931095
noise_pred = normalized_guidance(
10941096
noise_pred_text,
10951097
noise_pred_uncond,
@@ -1098,7 +1100,9 @@ def __call__(
10981100
eta,
10991101
adaptive_projected_guidance_rescale_factor,
11001102
)
1103+
noise_pred = (latents - noise_pred) / sigma
11011104
else:
1105+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
11021106
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
11031107

11041108
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1281,8 +1281,10 @@ def __call__(
12811281

12821282
# perform guidance
12831283
if self.do_classifier_free_guidance:
1284-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
12851284
if adaptive_projected_guidance:
1285+
sigma = self.scheduler.sigmas[self.scheduler.step_index]
1286+
noise_pred = latents - sigma * noise_pred
1287+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
12861288
noise_pred = normalized_guidance(
12871289
noise_pred_text,
12881290
noise_pred_uncond,
@@ -1291,7 +1293,9 @@ def __call__(
12911293
eta,
12921294
adaptive_projected_guidance_rescale_factor,
12931295
)
1296+
noise_pred = (latents - noise_pred) / sigma
12941297
else:
1298+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
12951299
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
12961300

12971301
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:

0 commit comments

Comments
 (0)