|
36 | 36 | from ...utils import (
|
37 | 37 | PIL_INTERPOLATION,
|
38 | 38 | BaseOutput,
|
| 39 | + deprecate, |
39 | 40 | is_accelerate_available,
|
40 | 41 | is_accelerate_version,
|
41 | 42 | logging,
|
@@ -721,23 +722,31 @@ def prepare_image_latents(self, image, batch_size, dtype, device, generator=None
|
721 | 722 | )
|
722 | 723 |
|
723 | 724 | if isinstance(generator, list):
|
724 |
| - init_latents = [ |
725 |
| - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) |
726 |
| - ] |
727 |
| - init_latents = torch.cat(init_latents, dim=0) |
| 725 | + latents = [self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)] |
| 726 | + latents = torch.cat(latents, dim=0) |
728 | 727 | else:
|
729 |
| - init_latents = self.vae.encode(image).latent_dist.sample(generator) |
730 |
| - |
731 |
| - init_latents = self.vae.config.scaling_factor * init_latents |
732 |
| - |
733 |
| - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: |
734 |
| - raise ValueError( |
735 |
| - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." |
736 |
| - ) |
| 728 | + latents = self.vae.encode(image).latent_dist.sample(generator) |
| 729 | + |
| 730 | + latents = self.vae.config.scaling_factor * latents |
| 731 | + |
| 732 | + if batch_size != latents.shape[0]: |
| 733 | + if batch_size % latents.shape[0] == 0: |
| 734 | + # expand image_latents for batch_size |
| 735 | + deprecation_message = ( |
| 736 | + f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial" |
| 737 | + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" |
| 738 | + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" |
| 739 | + " your script to pass as many initial images as text prompts to suppress this warning." |
| 740 | + ) |
| 741 | + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) |
| 742 | + additional_latents_per_image = batch_size // latents.shape[0] |
| 743 | + latents = torch.cat([latents] * additional_latents_per_image, dim=0) |
| 744 | + else: |
| 745 | + raise ValueError( |
| 746 | + f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts." |
| 747 | + ) |
737 | 748 | else:
|
738 |
| - init_latents = torch.cat([init_latents], dim=0) |
739 |
| - |
740 |
| - latents = init_latents |
| 749 | + latents = torch.cat([latents], dim=0) |
741 | 750 |
|
742 | 751 | return latents
|
743 | 752 |
|
|
0 commit comments