Skip to content

Commit ca9b41a

Browse files
cmdr2w4ffl35
authored andcommitted
Update the K-Diffusion SD pipeline, to allow calling it with only prompt_embeds (instead of always requiring a prompt) (huggingface#2962)
1 parent 7ca1d46 commit ca9b41a

File tree

1 file changed

+47
-6
lines changed

1 file changed

+47
-6
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -364,10 +364,17 @@ def decode_latents(self, latents):
364364
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
365365
return image
366366

367-
def check_inputs(self, prompt, height, width, callback_steps):
368-
if not isinstance(prompt, str) and not isinstance(prompt, list):
369-
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
370-
367+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
368+
def check_inputs(
369+
self,
370+
prompt,
371+
height,
372+
width,
373+
callback_steps,
374+
negative_prompt=None,
375+
prompt_embeds=None,
376+
negative_prompt_embeds=None,
377+
):
371378
if height % 8 != 0 or width % 8 != 0:
372379
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
373380

@@ -379,6 +386,32 @@ def check_inputs(self, prompt, height, width, callback_steps):
379386
f" {type(callback_steps)}."
380387
)
381388

389+
if prompt is not None and prompt_embeds is not None:
390+
raise ValueError(
391+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
392+
" only forward one of the two."
393+
)
394+
elif prompt is None and prompt_embeds is None:
395+
raise ValueError(
396+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
397+
)
398+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
399+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
400+
401+
if negative_prompt is not None and negative_prompt_embeds is not None:
402+
raise ValueError(
403+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
404+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
405+
)
406+
407+
if prompt_embeds is not None and negative_prompt_embeds is not None:
408+
if prompt_embeds.shape != negative_prompt_embeds.shape:
409+
raise ValueError(
410+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
411+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
412+
f" {negative_prompt_embeds.shape}."
413+
)
414+
382415
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
383416
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
384417
if latents is None:
@@ -483,10 +516,18 @@ def __call__(
483516
width = width or self.unet.config.sample_size * self.vae_scale_factor
484517

485518
# 1. Check inputs. Raise error if not correct
486-
self.check_inputs(prompt, height, width, callback_steps)
519+
self.check_inputs(
520+
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
521+
)
487522

488523
# 2. Define call parameters
489-
batch_size = 1 if isinstance(prompt, str) else len(prompt)
524+
if prompt is not None and isinstance(prompt, str):
525+
batch_size = 1
526+
elif prompt is not None and isinstance(prompt, list):
527+
batch_size = len(prompt)
528+
else:
529+
batch_size = prompt_embeds.shape[0]
530+
490531
device = self._execution_device
491532
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
492533
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`

0 commit comments

Comments
 (0)