Skip to content

Commit 3e2fc0b

Browse files
committed
Clarify purpose and mark as deprecated
Fix inversion prompt broadcasting
1 parent 69dcac2 commit 3e2fc0b

File tree

1 file changed

+24
-15
lines changed

1 file changed

+24
-15
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from ...utils import (
3737
PIL_INTERPOLATION,
3838
BaseOutput,
39+
deprecate,
3940
is_accelerate_available,
4041
is_accelerate_version,
4142
logging,
@@ -721,23 +722,31 @@ def prepare_image_latents(self, image, batch_size, dtype, device, generator=None
721722
)
722723

723724
if isinstance(generator, list):
724-
init_latents = [
725-
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
726-
]
727-
init_latents = torch.cat(init_latents, dim=0)
725+
latents = [self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)]
726+
latents = torch.cat(latents, dim=0)
728727
else:
729-
init_latents = self.vae.encode(image).latent_dist.sample(generator)
730-
731-
init_latents = self.vae.config.scaling_factor * init_latents
732-
733-
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
734-
raise ValueError(
735-
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
736-
)
728+
latents = self.vae.encode(image).latent_dist.sample(generator)
729+
730+
latents = self.vae.config.scaling_factor * latents
731+
732+
if batch_size != latents.shape[0]:
733+
if batch_size % latents.shape[0] == 0:
734+
# expand image_latents for batch_size
735+
deprecation_message = (
736+
f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial"
737+
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
738+
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
739+
" your script to pass as many initial images as text prompts to suppress this warning."
740+
)
741+
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
742+
additional_latents_per_image = batch_size // latents.shape[0]
743+
latents = torch.cat([latents] * additional_latents_per_image, dim=0)
744+
else:
745+
raise ValueError(
746+
f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts."
747+
)
737748
else:
738-
init_latents = torch.cat([init_latents], dim=0)
739-
740-
latents = init_latents
749+
latents = torch.cat([latents], dim=0)
741750

742751
return latents
743752

0 commit comments

Comments
 (0)