Skip to content

Commit 5915c29

Browse files
authored
[ip-adapter] fix ip-adapter for StableDiffusionInstructPix2PixPipeline (#7820)
update prepare_ip_adapter_ for pix2pix
1 parent 21a7ff1 commit 5915c29

File tree

2 files changed

+87
-9
lines changed

2 files changed

+87
-9
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py

Lines changed: 87 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def __call__(
172172
prompt_embeds: Optional[torch.FloatTensor] = None,
173173
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
174174
ip_adapter_image: Optional[PipelineImageInput] = None,
175+
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
175176
output_type: Optional[str] = "pil",
176177
return_dict: bool = True,
177178
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
@@ -296,21 +297,15 @@ def __call__(
296297
negative_prompt,
297298
prompt_embeds,
298299
negative_prompt_embeds,
300+
ip_adapter_image,
301+
ip_adapter_image_embeds,
299302
callback_on_step_end_tensor_inputs,
300303
)
301304
self._guidance_scale = guidance_scale
302305
self._image_guidance_scale = image_guidance_scale
303306

304307
device = self._execution_device
305308

306-
if ip_adapter_image is not None:
307-
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
308-
image_embeds, negative_image_embeds = self.encode_image(
309-
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
310-
)
311-
if self.do_classifier_free_guidance:
312-
image_embeds = torch.cat([image_embeds, negative_image_embeds, negative_image_embeds])
313-
314309
if image is None:
315310
raise ValueError("`image` input cannot be undefined.")
316311

@@ -335,6 +330,14 @@ def __call__(
335330
negative_prompt_embeds=negative_prompt_embeds,
336331
)
337332

333+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
334+
image_embeds = self.prepare_ip_adapter_image_embeds(
335+
ip_adapter_image,
336+
ip_adapter_image_embeds,
337+
device,
338+
batch_size * num_images_per_prompt,
339+
self.do_classifier_free_guidance,
340+
)
338341
# 3. Preprocess image
339342
image = self.image_processor.preprocess(image)
340343

@@ -635,6 +638,65 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state
635638

636639
return image_embeds, uncond_image_embeds
637640

641+
def prepare_ip_adapter_image_embeds(
642+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
643+
):
644+
if ip_adapter_image_embeds is None:
645+
if not isinstance(ip_adapter_image, list):
646+
ip_adapter_image = [ip_adapter_image]
647+
648+
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
649+
raise ValueError(
650+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
651+
)
652+
653+
image_embeds = []
654+
for single_ip_adapter_image, image_proj_layer in zip(
655+
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
656+
):
657+
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
658+
single_image_embeds, single_negative_image_embeds = self.encode_image(
659+
single_ip_adapter_image, device, 1, output_hidden_state
660+
)
661+
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
662+
single_negative_image_embeds = torch.stack(
663+
[single_negative_image_embeds] * num_images_per_prompt, dim=0
664+
)
665+
666+
if do_classifier_free_guidance:
667+
single_image_embeds = torch.cat(
668+
[single_image_embeds, single_negative_image_embeds, single_negative_image_embeds]
669+
)
670+
single_image_embeds = single_image_embeds.to(device)
671+
672+
image_embeds.append(single_image_embeds)
673+
else:
674+
repeat_dims = [1]
675+
image_embeds = []
676+
for single_image_embeds in ip_adapter_image_embeds:
677+
if do_classifier_free_guidance:
678+
(
679+
single_image_embeds,
680+
single_negative_image_embeds,
681+
single_negative_image_embeds,
682+
) = single_image_embeds.chunk(3)
683+
single_image_embeds = single_image_embeds.repeat(
684+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
685+
)
686+
single_negative_image_embeds = single_negative_image_embeds.repeat(
687+
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
688+
)
689+
single_image_embeds = torch.cat(
690+
[single_image_embeds, single_negative_image_embeds, single_negative_image_embeds]
691+
)
692+
else:
693+
single_image_embeds = single_image_embeds.repeat(
694+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
695+
)
696+
image_embeds.append(single_image_embeds)
697+
698+
return image_embeds
699+
638700
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
639701
def run_safety_checker(self, image, device, dtype):
640702
if self.safety_checker is None:
@@ -687,6 +749,8 @@ def check_inputs(
687749
negative_prompt=None,
688750
prompt_embeds=None,
689751
negative_prompt_embeds=None,
752+
ip_adapter_image=None,
753+
ip_adapter_image_embeds=None,
690754
callback_on_step_end_tensor_inputs=None,
691755
):
692756
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
@@ -728,6 +792,21 @@ def check_inputs(
728792
f" {negative_prompt_embeds.shape}."
729793
)
730794

795+
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
796+
raise ValueError(
797+
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
798+
)
799+
800+
if ip_adapter_image_embeds is not None:
801+
if not isinstance(ip_adapter_image_embeds, list):
802+
raise ValueError(
803+
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
804+
)
805+
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
806+
raise ValueError(
807+
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
808+
)
809+
731810
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
732811
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
733812
shape = (

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,6 @@ def prepare_extra_step_kwargs(self, generator, eta):
436436
extra_step_kwargs["generator"] = generator
437437
return extra_step_kwargs
438438

439-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.check_inputs
440439
def check_inputs(
441440
self,
442441
prompt,

0 commit comments

Comments
 (0)