@@ -536,7 +536,7 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state
536
536
537
537
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
538
538
def prepare_ip_adapter_image_embeds (
539
- self , ip_adapter_image , ip_adapter_image_embeds , device , num_images_per_prompt , do_classifier_free_guidance
539
+ self , ip_adapter_image , ip_adapter_image_embeds , device , num_images_per_prompt , do_classifier_free_guidance , do_perturbed_attention_guidance
540
540
):
541
541
if ip_adapter_image_embeds is None :
542
542
if not isinstance (ip_adapter_image , list ):
@@ -560,6 +560,10 @@ def prepare_ip_adapter_image_embeds(
560
560
[single_negative_image_embeds ] * num_images_per_prompt , dim = 0
561
561
)
562
562
563
+ if do_perturbed_attention_guidance :
564
+ single_image_embeds = torch .cat ([single_image_embeds , single_image_embeds ], dim = 0 )
565
+ single_image_embeds = single_image_embeds .to (device )
566
+
563
567
if do_classifier_free_guidance :
564
568
single_image_embeds = torch .cat ([single_negative_image_embeds , single_image_embeds ])
565
569
single_image_embeds = single_image_embeds .to (device )
@@ -577,11 +581,16 @@ def prepare_ip_adapter_image_embeds(
577
581
single_negative_image_embeds = single_negative_image_embeds .repeat (
578
582
num_images_per_prompt , * (repeat_dims * len (single_negative_image_embeds .shape [1 :]))
579
583
)
580
- single_image_embeds = torch .cat ([single_negative_image_embeds , single_image_embeds ])
584
+ if do_perturbed_attention_guidance :
585
+ single_image_embeds = torch .cat ([single_negative_image_embeds , single_image_embeds , single_image_embeds ], dim = 0 )
586
+ else :
587
+ single_image_embeds = torch .cat ([single_negative_image_embeds , single_image_embeds ])
581
588
else :
582
589
single_image_embeds = single_image_embeds .repeat (
583
590
num_images_per_prompt , * (repeat_dims * len (single_image_embeds .shape [1 :]))
584
591
)
592
+ if do_perturbed_attention_guidance :
593
+ single_image_embeds = torch .cat ([single_image_embeds , single_image_embeds ], dim = 0 )
585
594
image_embeds .append (single_image_embeds )
586
595
587
596
return image_embeds
@@ -1170,6 +1179,7 @@ def __call__(
1170
1179
device ,
1171
1180
batch_size * num_images_per_prompt ,
1172
1181
self .do_classifier_free_guidance ,
1182
+ self .do_perturbed_attention_guidance ,
1173
1183
)
1174
1184
1175
1185
# 8. Denoising loop
@@ -1205,7 +1215,7 @@ def __call__(
1205
1215
if self .interrupt :
1206
1216
continue
1207
1217
1208
- # expand the latents if we are doing classifier free guidance
1218
+ # expand the latents if we are doing classifier free guidance, perturbed-attention guidance, or both
1209
1219
latent_model_input = torch .cat ([latents ] * (prompt_embeds .shape [0 ] // latents .shape [0 ]))
1210
1220
1211
1221
latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
0 commit comments