@@ -172,6 +172,7 @@ def __call__(
172
172
prompt_embeds : Optional [torch .FloatTensor ] = None ,
173
173
negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
174
174
ip_adapter_image : Optional [PipelineImageInput ] = None ,
175
+ ip_adapter_image_embeds : Optional [List [torch .FloatTensor ]] = None ,
175
176
output_type : Optional [str ] = "pil" ,
176
177
return_dict : bool = True ,
177
178
callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] = None ,
@@ -296,21 +297,15 @@ def __call__(
296
297
negative_prompt ,
297
298
prompt_embeds ,
298
299
negative_prompt_embeds ,
300
+ ip_adapter_image ,
301
+ ip_adapter_image_embeds ,
299
302
callback_on_step_end_tensor_inputs ,
300
303
)
301
304
self ._guidance_scale = guidance_scale
302
305
self ._image_guidance_scale = image_guidance_scale
303
306
304
307
device = self ._execution_device
305
308
306
- if ip_adapter_image is not None :
307
- output_hidden_state = False if isinstance (self .unet .encoder_hid_proj , ImageProjection ) else True
308
- image_embeds , negative_image_embeds = self .encode_image (
309
- ip_adapter_image , device , num_images_per_prompt , output_hidden_state
310
- )
311
- if self .do_classifier_free_guidance :
312
- image_embeds = torch .cat ([image_embeds , negative_image_embeds , negative_image_embeds ])
313
-
314
309
if image is None :
315
310
raise ValueError ("`image` input cannot be undefined." )
316
311
@@ -335,6 +330,14 @@ def __call__(
335
330
negative_prompt_embeds = negative_prompt_embeds ,
336
331
)
337
332
333
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None :
334
+ image_embeds = self .prepare_ip_adapter_image_embeds (
335
+ ip_adapter_image ,
336
+ ip_adapter_image_embeds ,
337
+ device ,
338
+ batch_size * num_images_per_prompt ,
339
+ self .do_classifier_free_guidance ,
340
+ )
338
341
# 3. Preprocess image
339
342
image = self .image_processor .preprocess (image )
340
343
@@ -635,6 +638,65 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state
635
638
636
639
return image_embeds , uncond_image_embeds
637
640
641
+ def prepare_ip_adapter_image_embeds (
642
+ self , ip_adapter_image , ip_adapter_image_embeds , device , num_images_per_prompt , do_classifier_free_guidance
643
+ ):
644
+ if ip_adapter_image_embeds is None :
645
+ if not isinstance (ip_adapter_image , list ):
646
+ ip_adapter_image = [ip_adapter_image ]
647
+
648
+ if len (ip_adapter_image ) != len (self .unet .encoder_hid_proj .image_projection_layers ):
649
+ raise ValueError (
650
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got { len (ip_adapter_image )} images and { len (self .unet .encoder_hid_proj .image_projection_layers )} IP Adapters."
651
+ )
652
+
653
+ image_embeds = []
654
+ for single_ip_adapter_image , image_proj_layer in zip (
655
+ ip_adapter_image , self .unet .encoder_hid_proj .image_projection_layers
656
+ ):
657
+ output_hidden_state = not isinstance (image_proj_layer , ImageProjection )
658
+ single_image_embeds , single_negative_image_embeds = self .encode_image (
659
+ single_ip_adapter_image , device , 1 , output_hidden_state
660
+ )
661
+ single_image_embeds = torch .stack ([single_image_embeds ] * num_images_per_prompt , dim = 0 )
662
+ single_negative_image_embeds = torch .stack (
663
+ [single_negative_image_embeds ] * num_images_per_prompt , dim = 0
664
+ )
665
+
666
+ if do_classifier_free_guidance :
667
+ single_image_embeds = torch .cat (
668
+ [single_image_embeds , single_negative_image_embeds , single_negative_image_embeds ]
669
+ )
670
+ single_image_embeds = single_image_embeds .to (device )
671
+
672
+ image_embeds .append (single_image_embeds )
673
+ else :
674
+ repeat_dims = [1 ]
675
+ image_embeds = []
676
+ for single_image_embeds in ip_adapter_image_embeds :
677
+ if do_classifier_free_guidance :
678
+ (
679
+ single_image_embeds ,
680
+ single_negative_image_embeds ,
681
+ single_negative_image_embeds ,
682
+ ) = single_image_embeds .chunk (3 )
683
+ single_image_embeds = single_image_embeds .repeat (
684
+ num_images_per_prompt , * (repeat_dims * len (single_image_embeds .shape [1 :]))
685
+ )
686
+ single_negative_image_embeds = single_negative_image_embeds .repeat (
687
+ num_images_per_prompt , * (repeat_dims * len (single_negative_image_embeds .shape [1 :]))
688
+ )
689
+ single_image_embeds = torch .cat (
690
+ [single_image_embeds , single_negative_image_embeds , single_negative_image_embeds ]
691
+ )
692
+ else :
693
+ single_image_embeds = single_image_embeds .repeat (
694
+ num_images_per_prompt , * (repeat_dims * len (single_image_embeds .shape [1 :]))
695
+ )
696
+ image_embeds .append (single_image_embeds )
697
+
698
+ return image_embeds
699
+
638
700
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
639
701
def run_safety_checker (self , image , device , dtype ):
640
702
if self .safety_checker is None :
@@ -687,6 +749,8 @@ def check_inputs(
687
749
negative_prompt = None ,
688
750
prompt_embeds = None ,
689
751
negative_prompt_embeds = None ,
752
+ ip_adapter_image = None ,
753
+ ip_adapter_image_embeds = None ,
690
754
callback_on_step_end_tensor_inputs = None ,
691
755
):
692
756
if callback_steps is not None and (not isinstance (callback_steps , int ) or callback_steps <= 0 ):
@@ -728,6 +792,21 @@ def check_inputs(
728
792
f" { negative_prompt_embeds .shape } ."
729
793
)
730
794
795
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None :
796
+ raise ValueError (
797
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
798
+ )
799
+
800
+ if ip_adapter_image_embeds is not None :
801
+ if not isinstance (ip_adapter_image_embeds , list ):
802
+ raise ValueError (
803
+ f"`ip_adapter_image_embeds` has to be of type `list` but is { type (ip_adapter_image_embeds )} "
804
+ )
805
+ elif ip_adapter_image_embeds [0 ].ndim not in [3 , 4 ]:
806
+ raise ValueError (
807
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is { ip_adapter_image_embeds [0 ].ndim } D"
808
+ )
809
+
731
810
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
732
811
def prepare_latents (self , batch_size , num_channels_latents , height , width , dtype , device , generator , latents = None ):
733
812
shape = (
0 commit comments