Skip to content

Postprocessing refactor img2img #3268

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions src/diffusers/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
# limitations under the License.

import warnings
from typing import Union
from typing import List, Optional, Union

import numpy as np
import PIL
import torch
from PIL import Image

from .configuration_utils import ConfigMixin, register_to_config
from .utils import CONFIG_NAME, PIL_INTERPOLATION
from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate


class VaeImageProcessor(ConfigMixin):
Expand Down Expand Up @@ -82,7 +82,7 @@ def numpy_to_pt(images):
@staticmethod
def pt_to_numpy(images):
"""
Convert a numpy image to a pytorch tensor
Convert a pytorch tensor to a numpy image
"""
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
return images
Expand All @@ -94,6 +94,13 @@ def normalize(images):
"""
return 2.0 * images - 1.0

@staticmethod
def denormalize(images):
"""
Denormalize an image array to [0,1]
"""
return (images / 2 + 0.5).clamp(0, 1)

def resize(self, images: PIL.Image.Image) -> PIL.Image.Image:
"""
Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor`
Expand Down Expand Up @@ -165,17 +172,39 @@ def preprocess(

def postprocess(
self,
image,
image: torch.FloatTensor,
output_type: str = "pil",
do_denormalize: Optional[List[bool]] = None,
):
if isinstance(image, torch.Tensor) and output_type == "pt":
if not isinstance(image, torch.Tensor):
raise ValueError(
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
)
if output_type not in ["latent", "pt", "np", "pil"]:
deprecation_message = (
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
"`pil`, `np`, `pt`, `latent`"
)
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
output_type = "np"

if output_type == "latent":
return image

if do_denormalize is None:
do_denormalize = [self.config.do_normalize] * image.shape[0]

image = torch.stack(
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
Comment on lines +197 to +198
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is neat!

But would there be a need to denormalize images individually? I would assume if do_denormalize is True then ALL the images will be denormalized. Why individually?

Nit: Should it be "unnormalize"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

normally we should denormalize all images if we have normalized them in preprocessing (i.e. if self.config.do_normalize is True)
however, there is this edge case when the image is black due to safety_checker - in that case, we will pass a do_denormalize argument to allow skip the denormalize process only for these nsfw images.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Superb! Thanks for explaining!

)

if output_type == "pt":
return image

image = self.pt_to_numpy(image)

if output_type == "np":
return image
elif output_type == "pil":

if output_type == "pil":
return self.numpy_to_pil(image)
else:
raise ValueError(f"Unsupported output_type {output_type}.")
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import inspect
import warnings
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
Expand Down Expand Up @@ -202,6 +203,7 @@ def __init__(
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)

self.register_modules(
vae=vae,
text_encoder=text_encoder,
Expand All @@ -212,11 +214,8 @@ def __init__(
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)

self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(
requires_safety_checker=requires_safety_checker,
)
self.register_to_config(requires_safety_checker=requires_safety_checker)

def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Expand Down Expand Up @@ -436,17 +435,32 @@ def _encode_prompt(
return prompt_embeds

def run_safety_checker(self, image, device, dtype):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept

def decode_latents(self, latents):
Copy link
Collaborator Author

@yiyixuxu yiyixuxu Apr 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decode_latents method is same as decode_latents in StableDiffusionPipeline only with the deprecation message added (the code change here is a little bit confusing)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect! It's essentially a revert of what we had before

warnings.warn(
(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead"
),
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image

def prepare_extra_step_kwargs(self, generator, eta):
Expand Down Expand Up @@ -730,27 +744,19 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

if output_type not in ["latent", "pt", "np", "pil"]:
deprecation_message = (
f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: "
"`pil`, `np`, `pt`, `latent`"
)
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
output_type = "np"

if output_type == "latent":
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None

if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
image = self.decode_latents(latents)

if self.safety_checker is not None:
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
has_nsfw_concept = False
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

image = self.image_processor.postprocess(image, output_type=output_type)
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import inspect
import warnings
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
Expand Down Expand Up @@ -205,6 +206,7 @@ def __init__(
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)

self.register_modules(
vae=vae,
text_encoder=text_encoder,
Expand All @@ -215,11 +217,8 @@ def __init__(
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)

self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(
requires_safety_checker=requires_safety_checker,
)
self.register_to_config(requires_safety_checker=requires_safety_checker)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self, gpu_id=0):
Expand Down Expand Up @@ -443,17 +442,30 @@ def _encode_prompt(
return prompt_embeds

def run_safety_checker(self, image, device, dtype):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept

def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also specify the method of VaeImagProcessor users should use? I think that is more helpful and specific.

FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
Expand Down Expand Up @@ -738,27 +750,19 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

if output_type not in ["latent", "pt", "np", "pil"]:
deprecation_message = (
f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: "
"`pil`, `np`, `pt`, `latent`"
)
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
output_type = "np"

if output_type == "latent":
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

has_nsfw_concept is now either a list of bool or None (when didn't run safety_checker, i.e. self.safety_checker is None or output_type=='latents')

I think this is consistent with the documentation as @pcuenca pointed out. And it is consistent with how it's used in current pipelines so I think it's ok

nsfw_content_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, or `None` if safety checking could not be performed.


if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
image = self.decode_latents(latents)

if self.safety_checker is not None:
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
has_nsfw_concept = False
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

image = self.image_processor.postprocess(image, output_type=output_type)
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
Expand Down
6 changes: 3 additions & 3 deletions tests/others/test_image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def to_np(self, image):
return image

def test_vae_image_processor_pt(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
image_processor = VaeImageProcessor(do_resize=False, do_normalize=True)

input_pt = self.dummy_sample
input_np = self.to_np(input_pt)
Expand All @@ -59,7 +59,7 @@ def test_vae_image_processor_pt(self):
), f"decoded output does not match input for output_type {output_type}"

def test_vae_image_processor_np(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
image_processor = VaeImageProcessor(do_resize=False, do_normalize=True)
input_np = self.dummy_sample.cpu().numpy().transpose(0, 2, 3, 1)

for output_type in ["pt", "np", "pil"]:
Expand All @@ -72,7 +72,7 @@ def test_vae_image_processor_np(self):
), f"decoded output does not match input for output_type {output_type}"

def test_vae_image_processor_pil(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
image_processor = VaeImageProcessor(do_resize=False, do_normalize=True)

input_np = self.dummy_sample.cpu().numpy().transpose(0, 2, 3, 1)
input_pil = image_processor.numpy_to_pil(input_np)
Expand Down
4 changes: 4 additions & 0 deletions tests/pipelines/pipeline_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@

TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])

TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])

IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])

IMAGE_VARIATION_PARAMS = frozenset(
[
"image",
Expand Down
Loading