Skip to content

Commit 3b1e986

Browse files
At-sushiJimmy
authored and
Jimmy
committed
Fix TypeError when using prompt_embeds and negative_prompt (huggingface#2982)
* test: Added test case * fix: fixed type checking issue on _encode_prompt * fix: fixed copies consistency * fix: one copy was not sufficient
1 parent 562e937 commit 3b1e986

20 files changed

+58
-19
lines changed

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def _encode_prompt(
369369
uncond_tokens: List[str]
370370
if negative_prompt is None:
371371
uncond_tokens = [""] * batch_size
372-
elif type(prompt) is not type(negative_prompt):
372+
elif prompt is not None and type(prompt) is not type(negative_prompt):
373373
raise TypeError(
374374
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
375375
f" {type(prompt)}."

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ def _encode_prompt(
378378
uncond_tokens: List[str]
379379
if negative_prompt is None:
380380
uncond_tokens = [""] * batch_size
381-
elif type(prompt) is not type(negative_prompt):
381+
elif prompt is not None and type(prompt) is not type(negative_prompt):
382382
raise TypeError(
383383
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
384384
f" {type(prompt)}."

src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def _encode_prompt(
387387
uncond_tokens: List[str]
388388
if negative_prompt is None:
389389
uncond_tokens = [""] * batch_size
390-
elif type(prompt) is not type(negative_prompt):
390+
elif prompt is not None and type(prompt) is not type(negative_prompt):
391391
raise TypeError(
392392
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
393393
f" {type(prompt)}."

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ def _encode_prompt(
372372
uncond_tokens: List[str]
373373
if negative_prompt is None:
374374
uncond_tokens = [""] * batch_size
375-
elif type(prompt) is not type(negative_prompt):
375+
elif prompt is not None and type(prompt) is not type(negative_prompt):
376376
raise TypeError(
377377
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
378378
f" {type(prompt)}."

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def _encode_prompt(
384384
uncond_tokens: List[str]
385385
if negative_prompt is None:
386386
uncond_tokens = [""] * batch_size
387-
elif type(prompt) is not type(negative_prompt):
387+
elif prompt is not None and type(prompt) is not type(negative_prompt):
388388
raise TypeError(
389389
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
390390
f" {type(prompt)}."

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ def _encode_prompt(
427427
uncond_tokens: List[str]
428428
if negative_prompt is None:
429429
uncond_tokens = [""] * batch_size
430-
elif type(prompt) is not type(negative_prompt):
430+
elif prompt is not None and type(prompt) is not type(negative_prompt):
431431
raise TypeError(
432432
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
433433
f" {type(prompt)}."

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def _encode_prompt(
256256
uncond_tokens: List[str]
257257
if negative_prompt is None:
258258
uncond_tokens = [""] * batch_size
259-
elif type(prompt) is not type(negative_prompt):
259+
elif prompt is not None and type(prompt) is not type(negative_prompt):
260260
raise TypeError(
261261
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
262262
f" {type(prompt)}."

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def _encode_prompt(
385385
uncond_tokens: List[str]
386386
if negative_prompt is None:
387387
uncond_tokens = [""] * batch_size
388-
elif type(prompt) is not type(negative_prompt):
388+
elif prompt is not None and type(prompt) is not type(negative_prompt):
389389
raise TypeError(
390390
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
391391
f" {type(prompt)}."

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ def _encode_prompt(
437437
uncond_tokens: List[str]
438438
if negative_prompt is None:
439439
uncond_tokens = [""] * batch_size
440-
elif type(prompt) is not type(negative_prompt):
440+
elif prompt is not None and type(prompt) is not type(negative_prompt):
441441
raise TypeError(
442442
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
443443
f" {type(prompt)}."

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def _encode_prompt(
376376
uncond_tokens: List[str]
377377
if negative_prompt is None:
378378
uncond_tokens = [""] * batch_size
379-
elif type(prompt) is not type(negative_prompt):
379+
elif prompt is not None and type(prompt) is not type(negative_prompt):
380380
raise TypeError(
381381
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
382382
f" {type(prompt)}."

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def _encode_prompt(
288288
uncond_tokens: List[str]
289289
if negative_prompt is None:
290290
uncond_tokens = [""] * batch_size
291-
elif type(prompt) is not type(negative_prompt):
291+
elif prompt is not None and type(prompt) is not type(negative_prompt):
292292
raise TypeError(
293293
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
294294
f" {type(prompt)}."

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def _encode_prompt(
315315
uncond_tokens: List[str]
316316
if negative_prompt is None:
317317
uncond_tokens = [""] * batch_size
318-
elif type(prompt) is not type(negative_prompt):
318+
elif prompt is not None and type(prompt) is not type(negative_prompt):
319319
raise TypeError(
320320
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
321321
f" {type(prompt)}."

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def _encode_prompt(
279279
uncond_tokens: List[str]
280280
if negative_prompt is None:
281281
uncond_tokens = [""] * batch_size
282-
elif type(prompt) is not type(negative_prompt):
282+
elif prompt is not None and type(prompt) is not type(negative_prompt):
283283
raise TypeError(
284284
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
285285
f" {type(prompt)}."

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ def _encode_prompt(
520520
uncond_tokens: List[str]
521521
if negative_prompt is None:
522522
uncond_tokens = [""] * batch_size
523-
elif type(prompt) is not type(negative_prompt):
523+
elif prompt is not None and type(prompt) is not type(negative_prompt):
524524
raise TypeError(
525525
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
526526
f" {type(prompt)}."

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def _encode_prompt(
296296
uncond_tokens: List[str]
297297
if negative_prompt is None:
298298
uncond_tokens = [""] * batch_size
299-
elif type(prompt) is not type(negative_prompt):
299+
elif prompt is not None and type(prompt) is not type(negative_prompt):
300300
raise TypeError(
301301
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
302302
f" {type(prompt)}."

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def _encode_prompt(
296296
uncond_tokens: List[str]
297297
if negative_prompt is None:
298298
uncond_tokens = [""] * batch_size
299-
elif type(prompt) is not type(negative_prompt):
299+
elif prompt is not None and type(prompt) is not type(negative_prompt):
300300
raise TypeError(
301301
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
302302
f" {type(prompt)}."

src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def _encode_prompt(
416416
uncond_tokens: List[str]
417417
if negative_prompt is None:
418418
uncond_tokens = [""] * batch_size
419-
elif type(prompt) is not type(negative_prompt):
419+
elif prompt is not None and type(prompt) is not type(negative_prompt):
420420
raise TypeError(
421421
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
422422
f" {type(prompt)}."

src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def _encode_prompt(
316316
uncond_tokens: List[str]
317317
if negative_prompt is None:
318318
uncond_tokens = [""] * batch_size
319-
elif type(prompt) is not type(negative_prompt):
319+
elif prompt is not None and type(prompt) is not type(negative_prompt):
320320
raise TypeError(
321321
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
322322
f" {type(prompt)}."

src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def _encode_prompt(
305305
uncond_tokens: List[str]
306306
if negative_prompt is None:
307307
uncond_tokens = [""] * batch_size
308-
elif type(prompt) is not type(negative_prompt):
308+
elif prompt is not None and type(prompt) is not type(negative_prompt):
309309
raise TypeError(
310310
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
311311
f" {type(prompt)}."

tests/pipelines/stable_diffusion/test_stable_diffusion.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,45 @@ def test_stable_diffusion_negative_prompt_embeds(self):
251251

252252
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
253253

254+
def test_stable_diffusion_prompt_embeds_with_plain_negative_prompt_list(self):
255+
components = self.get_dummy_components()
256+
sd_pipe = StableDiffusionPipeline(**components)
257+
sd_pipe = sd_pipe.to(torch_device)
258+
sd_pipe = sd_pipe.to(torch_device)
259+
sd_pipe.set_progress_bar_config(disable=None)
260+
261+
inputs = self.get_dummy_inputs(torch_device)
262+
negative_prompt = 3 * ["this is a negative prompt"]
263+
inputs["negative_prompt"] = negative_prompt
264+
inputs["prompt"] = 3 * [inputs["prompt"]]
265+
266+
# forward
267+
output = sd_pipe(**inputs)
268+
image_slice_1 = output.images[0, -3:, -3:, -1]
269+
270+
inputs = self.get_dummy_inputs(torch_device)
271+
inputs["negative_prompt"] = negative_prompt
272+
prompt = 3 * [inputs.pop("prompt")]
273+
274+
text_inputs = sd_pipe.tokenizer(
275+
prompt,
276+
padding="max_length",
277+
max_length=sd_pipe.tokenizer.model_max_length,
278+
truncation=True,
279+
return_tensors="pt",
280+
)
281+
text_inputs = text_inputs["input_ids"].to(torch_device)
282+
283+
prompt_embeds = sd_pipe.text_encoder(text_inputs)[0]
284+
285+
inputs["prompt_embeds"] = prompt_embeds
286+
287+
# forward
288+
output = sd_pipe(**inputs)
289+
image_slice_2 = output.images[0, -3:, -3:, -1]
290+
291+
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
292+
254293
def test_stable_diffusion_ddim_factor_8(self):
255294
device = "cpu" # ensure determinism for the device-dependent torch.Generator
256295

0 commit comments

Comments
 (0)