Skip to content

Commit 44586e2

Browse files
committed
Add option to not decode latents in the inversion process
1 parent cc4ed9a commit 44586e2

File tree

1 file changed

+21
-11
lines changed

1 file changed

+21
-11
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,10 @@ def get_inverse_timesteps(self, num_inference_steps, strength, device):
653653
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
654654

655655
t_start = max(num_inference_steps - init_timestep, 0)
656+
657+
# safety for t_start overflow to prevent empty timsteps slice
658+
if t_start == num_inference_steps:
659+
return self.inverse_scheduler.timesteps, num_inference_steps
656660
timesteps = self.inverse_scheduler.timesteps[:-t_start]
657661

658662
return timesteps, num_inference_steps - t_start
@@ -958,6 +962,7 @@ def invert(
958962
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
959963
prompt_embeds: Optional[torch.FloatTensor] = None,
960964
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
965+
decode_latents: bool = False,
961966
output_type: Optional[str] = "pil",
962967
return_dict: bool = True,
963968
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
@@ -975,13 +980,12 @@ def invert(
975980
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
976981
instead.
977982
image (`PIL.Image.Image`):
978-
`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
979-
be masked out with `mask_image` and repainted according to `prompt`.
983+
`Image`, or tensor representing an image batch to produce the inverted latents, guided by `prompt`.
980984
inpaint_strength (`float`, *optional*, defaults to 0.8):
981-
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
982-
is 1, the denoising process will be run on the masked area for the full number of iterations specified
983-
in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more noise to
984-
that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
985+
Conceptually, indicates how far into the noising process to run latent inversion. Must be between 0 and 1. When `strength`
986+
is 1, the inversion process will be run for the full number of iterations specified
987+
in `num_inference_steps`. `image` will be used as a reference for the inversion process, adding more noise
988+
the larger the `strength`. If `strength` is 0, no inpainting will occur.
985989
num_inference_steps (`int`, *optional*, defaults to 50):
986990
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
987991
expense of slower inference.
@@ -1007,11 +1011,14 @@ def invert(
10071011
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
10081012
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
10091013
argument.
1014+
decode_latents (`bool`, *optional*, defaults to `False`):
1015+
Whether or not to decode the inverted latents into a generated image. Setting this argument to `True`
1016+
will decode all inverted latents for each timestep into a list of generated images.
10101017
output_type (`str`, *optional*, defaults to `"pil"`):
10111018
The output format of the generate image. Choose between
10121019
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
10131020
return_dict (`bool`, *optional*, defaults to `True`):
1014-
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1021+
Whether or not to return a [`~pipelines.stable_diffusion.DiffEditInversionPipelineOutput`] instead of a
10151022
plain tuple.
10161023
callback (`Callable`, *optional*):
10171024
A function that will be called every `callback_steps` steps during inference. The function will be
@@ -1064,8 +1071,9 @@ def invert(
10641071
Returns:
10651072
[`~pipelines.stable_diffusion.pipeline_stable_diffusion_diffedit.DiffEditInversionPipelineOutput`] or
10661073
`tuple`: [`~pipelines.stable_diffusion.pipeline_stable_diffusion_diffedit.DiffEditInversionPipelineOutput`]
1067-
if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is the inverted
1068-
latents tensors ordered by increasing noise, and then second is the corresponding decoded images.
1074+
if `return_dict` is `True`, otherwise a `tuple`. When returning a tuple, the first element is the inverted
1075+
latents tensors ordered by increasing noise, and then second is the corresponding decoded images if
1076+
`decode_latents` is `True`, otherwise `None`.
10691077
"""
10701078

10711079
# 1. Check inputs
@@ -1184,10 +1192,12 @@ def invert(
11841192
latents = torch.cat(list(reversed(inverted_latents)))
11851193

11861194
# 8. Post-processing
1187-
image = self.decode_latents(latents.detach())
1195+
image = None
1196+
if decode_latents:
1197+
image = self.decode_latents(latents.detach())
11881198

11891199
# 9. Convert to PIL.
1190-
if output_type == "pil":
1200+
if decode_latents and output_type == "pil":
11911201
image = self.numpy_to_pil(image)
11921202

11931203
# Offload last model to CPU

0 commit comments

Comments
 (0)