Skip to content

Commit 86bd991

Browse files
authored
Fixed noise_pred_text referenced before assignment. (#9537)
* Fixed local variable noise_pred_text referenced before assignment when using PAG with guidance scale and guidance rescale at the same time. * Fixed style. * Made returning text pred noise an argument.
1 parent 02eeb8e commit 86bd991

6 files changed

+18
-12
lines changed

src/diffusers/pipelines/pag/pag_utils.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,9 @@ def _get_pag_scale(self, t):
9898
else:
9999
return self.pag_scale
100100

101-
def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t):
101+
def _apply_perturbed_attention_guidance(
102+
self, noise_pred, do_classifier_free_guidance, guidance_scale, t, return_pred_text=False
103+
):
102104
r"""
103105
Apply perturbed attention guidance to the noise prediction.
104106
@@ -107,9 +109,11 @@ def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_gui
107109
do_classifier_free_guidance (bool): Whether to apply classifier-free guidance.
108110
guidance_scale (float): The scale factor for the guidance term.
109111
t (int): The current time step.
112+
return_pred_text (bool): Whether to return the text noise prediction.
110113
111114
Returns:
112-
torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance.
115+
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: The updated noise prediction tensor after applying
116+
perturbed attention guidance and the text noise prediction.
113117
"""
114118
pag_scale = self._get_pag_scale(t)
115119
if do_classifier_free_guidance:
@@ -122,6 +126,8 @@ def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_gui
122126
else:
123127
noise_pred_text, noise_pred_perturb = noise_pred.chunk(2)
124128
noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)
129+
if return_pred_text:
130+
return noise_pred, noise_pred_text
125131
return noise_pred
126132

127133
def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance):

src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -893,8 +893,8 @@ def __call__(
893893

894894
# perform guidance
895895
if self.do_perturbed_attention_guidance:
896-
noise_pred = self._apply_perturbed_attention_guidance(
897-
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
896+
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
897+
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
898898
)
899899
elif self.do_classifier_free_guidance:
900900
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

src/diffusers/pipelines/pag/pipeline_pag_sd.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -993,8 +993,8 @@ def __call__(
993993

994994
# perform guidance
995995
if self.do_perturbed_attention_guidance:
996-
noise_pred = self._apply_perturbed_attention_guidance(
997-
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
996+
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
997+
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
998998
)
999999

10001000
elif self.do_classifier_free_guidance:

src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1237,8 +1237,8 @@ def __call__(
12371237

12381238
# perform guidance
12391239
if self.do_perturbed_attention_guidance:
1240-
noise_pred = self._apply_perturbed_attention_guidance(
1241-
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
1240+
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
1241+
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
12421242
)
12431243

12441244
elif self.do_classifier_free_guidance:

src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1437,8 +1437,8 @@ def denoising_value_valid(dnv):
14371437

14381438
# perform guidance
14391439
if self.do_perturbed_attention_guidance:
1440-
noise_pred = self._apply_perturbed_attention_guidance(
1441-
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
1440+
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
1441+
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
14421442
)
14431443
elif self.do_classifier_free_guidance:
14441444
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1649,8 +1649,8 @@ def denoising_value_valid(dnv):
16491649

16501650
# perform guidance
16511651
if self.do_perturbed_attention_guidance:
1652-
noise_pred = self._apply_perturbed_attention_guidance(
1653-
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
1652+
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
1653+
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
16541654
)
16551655
elif self.do_classifier_free_guidance:
16561656
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

0 commit comments

Comments
 (0)