@@ -364,10 +364,17 @@ def decode_latents(self, latents):
364
364
image = image .cpu ().permute (0 , 2 , 3 , 1 ).float ().numpy ()
365
365
return image
366
366
367
- def check_inputs (self , prompt , height , width , callback_steps ):
368
- if not isinstance (prompt , str ) and not isinstance (prompt , list ):
369
- raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
370
-
367
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
368
+ def check_inputs (
369
+ self ,
370
+ prompt ,
371
+ height ,
372
+ width ,
373
+ callback_steps ,
374
+ negative_prompt = None ,
375
+ prompt_embeds = None ,
376
+ negative_prompt_embeds = None ,
377
+ ):
371
378
if height % 8 != 0 or width % 8 != 0 :
372
379
raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
373
380
@@ -379,6 +386,32 @@ def check_inputs(self, prompt, height, width, callback_steps):
379
386
f" { type (callback_steps )} ."
380
387
)
381
388
389
+ if prompt is not None and prompt_embeds is not None :
390
+ raise ValueError (
391
+ f"Cannot forward both `prompt`: { prompt } and `prompt_embeds`: { prompt_embeds } . Please make sure to"
392
+ " only forward one of the two."
393
+ )
394
+ elif prompt is None and prompt_embeds is None :
395
+ raise ValueError (
396
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
397
+ )
398
+ elif prompt is not None and (not isinstance (prompt , str ) and not isinstance (prompt , list )):
399
+ raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
400
+
401
+ if negative_prompt is not None and negative_prompt_embeds is not None :
402
+ raise ValueError (
403
+ f"Cannot forward both `negative_prompt`: { negative_prompt } and `negative_prompt_embeds`:"
404
+ f" { negative_prompt_embeds } . Please make sure to only forward one of the two."
405
+ )
406
+
407
+ if prompt_embeds is not None and negative_prompt_embeds is not None :
408
+ if prompt_embeds .shape != negative_prompt_embeds .shape :
409
+ raise ValueError (
410
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
411
+ f" got: `prompt_embeds` { prompt_embeds .shape } != `negative_prompt_embeds`"
412
+ f" { negative_prompt_embeds .shape } ."
413
+ )
414
+
382
415
def prepare_latents (self , batch_size , num_channels_latents , height , width , dtype , device , generator , latents = None ):
383
416
shape = (batch_size , num_channels_latents , height // self .vae_scale_factor , width // self .vae_scale_factor )
384
417
if latents is None :
@@ -483,10 +516,18 @@ def __call__(
483
516
width = width or self .unet .config .sample_size * self .vae_scale_factor
484
517
485
518
# 1. Check inputs. Raise error if not correct
486
- self .check_inputs (prompt , height , width , callback_steps )
519
+ self .check_inputs (
520
+ prompt , height , width , callback_steps , negative_prompt , prompt_embeds , negative_prompt_embeds
521
+ )
487
522
488
523
# 2. Define call parameters
489
- batch_size = 1 if isinstance (prompt , str ) else len (prompt )
524
+ if prompt is not None and isinstance (prompt , str ):
525
+ batch_size = 1
526
+ elif prompt is not None and isinstance (prompt , list ):
527
+ batch_size = len (prompt )
528
+ else :
529
+ batch_size = prompt_embeds .shape [0 ]
530
+
490
531
device = self ._execution_device
491
532
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
492
533
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
0 commit comments