Skip to content

Commit 4cc0b8b

Browse files
authored
fix compatability issue between PAG and IP-adapter (#8379)
* fix compatability issue between PAG and IP-adapter * fix compatibility issue between PAG and IP-adapter plus
1 parent 8950e80 commit 4cc0b8b

File tree

3 files changed

+17
-8
lines changed

3 files changed

+17
-8
lines changed

src/diffusers/loaders/unet.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -928,9 +928,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
928928
hidden_size = self.config.block_out_channels[block_id]
929929

930930
if cross_attention_dim is None or "motion_modules" in name:
931-
attn_processor_class = (
932-
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
933-
)
931+
attn_processor_class = self.attn_processors[name].__class__
934932
attn_procs[name] = attn_processor_class()
935933

936934
else:

src/diffusers/pipelines/pag_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def enable_pag(
4545
self._pag_applied_layers = pag_applied_layers
4646
self._pag_applied_layers_index = pag_applied_layers_index
4747
self._pag_cfg = pag_cfg
48-
48+
self._is_pag_enabled = True
4949
self._set_pag_attn_processor()
5050

5151
def _get_self_attn_layers(self):
@@ -180,6 +180,7 @@ def disable_pag(self):
180180
self._pag_applied_layers = None
181181
self._pag_applied_layers_index = None
182182
self._pag_cfg = None
183+
self._is_pag_enabled = False
183184

184185
@property
185186
def pag_adaptive_scaling(self):
@@ -191,4 +192,4 @@ def do_pag_adaptive_scaling(self):
191192

192193
@property
193194
def do_perturbed_attention_guidance(self):
194-
return hasattr(self, "_pag_scale") and self._pag_scale is not None and self._pag_scale > 0
195+
return self._is_pag_enabled

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state
536536

537537
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
538538
def prepare_ip_adapter_image_embeds(
539-
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
539+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance, do_perturbed_attention_guidance
540540
):
541541
if ip_adapter_image_embeds is None:
542542
if not isinstance(ip_adapter_image, list):
@@ -560,6 +560,10 @@ def prepare_ip_adapter_image_embeds(
560560
[single_negative_image_embeds] * num_images_per_prompt, dim=0
561561
)
562562

563+
if do_perturbed_attention_guidance:
564+
single_image_embeds = torch.cat([single_image_embeds, single_image_embeds], dim=0)
565+
single_image_embeds = single_image_embeds.to(device)
566+
563567
if do_classifier_free_guidance:
564568
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
565569
single_image_embeds = single_image_embeds.to(device)
@@ -577,11 +581,16 @@ def prepare_ip_adapter_image_embeds(
577581
single_negative_image_embeds = single_negative_image_embeds.repeat(
578582
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
579583
)
580-
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
584+
if do_perturbed_attention_guidance:
585+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds, single_image_embeds], dim=0)
586+
else:
587+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
581588
else:
582589
single_image_embeds = single_image_embeds.repeat(
583590
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
584591
)
592+
if do_perturbed_attention_guidance:
593+
single_image_embeds = torch.cat([single_image_embeds, single_image_embeds], dim=0)
585594
image_embeds.append(single_image_embeds)
586595

587596
return image_embeds
@@ -1170,6 +1179,7 @@ def __call__(
11701179
device,
11711180
batch_size * num_images_per_prompt,
11721181
self.do_classifier_free_guidance,
1182+
self.do_perturbed_attention_guidance,
11731183
)
11741184

11751185
# 8. Denoising loop
@@ -1205,7 +1215,7 @@ def __call__(
12051215
if self.interrupt:
12061216
continue
12071217

1208-
# expand the latents if we are doing classifier free guidance
1218+
# expand the latents if we are doing classifier free guidance, perturbed-attention guidance, or both
12091219
latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0]))
12101220

12111221
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

0 commit comments

Comments
 (0)