|
36 | 36 | from ...utils import (
|
37 | 37 | PIL_INTERPOLATION,
|
38 | 38 | BaseOutput,
|
39 |
| - deprecate, |
40 | 39 | is_accelerate_available,
|
41 | 40 | is_accelerate_version,
|
42 | 41 | logging,
|
@@ -722,31 +721,23 @@ def prepare_image_latents(self, image, batch_size, dtype, device, generator=None
|
722 | 721 | )
|
723 | 722 |
|
724 | 723 | if isinstance(generator, list):
|
725 |
| - latents = [self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)] |
726 |
| - latents = torch.cat(latents, dim=0) |
| 724 | + init_latents = [ |
| 725 | + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) |
| 726 | + ] |
| 727 | + init_latents = torch.cat(init_latents, dim=0) |
727 | 728 | else:
|
728 |
| - latents = self.vae.encode(image).latent_dist.sample(generator) |
729 |
| - |
730 |
| - latents = self.vae.config.scaling_factor * latents |
731 |
| - |
732 |
| - if batch_size != latents.shape[0]: |
733 |
| - if batch_size % latents.shape[0] == 0: |
734 |
| - # expand image_latents for batch_size |
735 |
| - deprecation_message = ( |
736 |
| - f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial" |
737 |
| - " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" |
738 |
| - " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" |
739 |
| - " your script to pass as many initial images as text prompts to suppress this warning." |
740 |
| - ) |
741 |
| - deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) |
742 |
| - additional_latents_per_image = batch_size // latents.shape[0] |
743 |
| - latents = torch.cat([latents] * additional_latents_per_image, dim=0) |
744 |
| - else: |
745 |
| - raise ValueError( |
746 |
| - f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts." |
747 |
| - ) |
| 729 | + init_latents = self.vae.encode(image).latent_dist.sample(generator) |
| 730 | + |
| 731 | + init_latents = self.vae.config.scaling_factor * init_latents |
| 732 | + |
| 733 | + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: |
| 734 | + raise ValueError( |
| 735 | + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." |
| 736 | + ) |
748 | 737 | else:
|
749 |
| - latents = torch.cat([latents], dim=0) |
| 738 | + init_latents = torch.cat([init_latents], dim=0) |
| 739 | + |
| 740 | + latents = init_latents |
750 | 741 |
|
751 | 742 | return latents
|
752 | 743 |
|
|
0 commit comments