Skip to content

[WIP] VaeImageProcessorImage Postprocessing refactor #2943

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 33 commits into from
Closed
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
3768ed9
refactor post processing with imageprocessor (update decode_latents/r…
Mar 16, 2023
71eda72
fix
Mar 31, 2023
83da056
make style
Mar 31, 2023
d13bc7f
Merge branch 'main' into postprocessing-refactor
yiyixuxu Mar 31, 2023
51cabe2
Merge branch 'main' into postprocessing-refactor
patrickvonplaten Apr 4, 2023
d91bcc9
refactor post-processing
Apr 6, 2023
6cbd1ac
add pipelinelatenttestermixin
Apr 7, 2023
c6d2405
update stablediffusionpipeline fast test using pipelinelatenttestermixim
Apr 8, 2023
fe8e13e
fix
Apr 9, 2023
4b09a20
alt
Apr 9, 2023
ce19bc9
refactor all pipelines
Apr 9, 2023
8065199
update alt-img2img
Apr 9, 2023
2c76ca3
make style
Apr 9, 2023
1a2c7a9
Merge branch 'main' into postprocessing-refactor
yiyixuxu Apr 9, 2023
389fdfe
fix model edit pipeline
Apr 9, 2023
db33f87
evaluation model for depth estimator in testing
yiyixuxu Apr 10, 2023
a491a38
Merge branch 'main' into postprocessing-refactor
patrickvonplaten Apr 11, 2023
5cde78c
fix
Apr 17, 2023
c5e69b9
Merge branch 'postprocessing-refactor' of https://github.com/huggingf…
Apr 17, 2023
ccf5b37
merge with main after fixing conflicts
Apr 17, 2023
928c35b
style + copy
Apr 18, 2023
809f1fe
fix tests
Apr 18, 2023
9ebd8f9
style
Apr 18, 2023
3dcb7b1
fix onnx
Apr 18, 2023
fabf88a
fix tests
Apr 18, 2023
04d27fb
style
Apr 18, 2023
68bb18b
fix alt img2img test
Apr 18, 2023
106c43a
Update src/diffusers/image_processor.py
yiyixuxu Apr 18, 2023
acf0d60
Update src/diffusers/image_processor.py
yiyixuxu Apr 18, 2023
7c1bb9a
Merge branch 'main' into postprocessing-refactor
yiyixuxu Apr 20, 2023
a09ecca
resolve merge
Apr 25, 2023
f03ff17
valueerror not needed
Apr 25, 2023
8c4a31b
fix output_type=latents
Apr 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/diffusers/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from PIL import Image

from .configuration_utils import ConfigMixin, register_to_config
from .utils import CONFIG_NAME, PIL_INTERPOLATION
from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate


class VaeImageProcessor(ConfigMixin):
Expand Down Expand Up @@ -82,7 +82,7 @@ def numpy_to_pt(images):
@staticmethod
def pt_to_numpy(images):
"""
Convert a numpy image to a pytorch tensor
Convert a pytorch tensor to a numpy image
"""
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
return images
Expand Down Expand Up @@ -164,6 +164,20 @@ def postprocess(
image,
output_type: str = "pil",
):
if output_type not in ["latent", "pt", "np", "pil"]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: the raise ValueError that happens later at the end of the function will no longer trigger.

else:
raise ValueError(f"Unsupported output_type {output_type}.")
We could remove it, or we could leave it there anyway so everything still works after we remove the deprecation warning.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh thanks I think we can just remove it ( can add it back if we ever remove the warning)

deprecation_message = (
f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: "
"`pil`, `np`, `pt`, `latent`"
)
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
output_type = "np"

if output_type == "latent":
return image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not isinstance(do_normalize, list):
do_normalize = image.shape[0] * [do_normalize or self.config.do_normalize]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we will need this to support the blacked-out image

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten

can we just check from the postprocess to see if the image is black? And not denormalize if so?

I will do what you proposed here but I just want to understand and think it will help me to make better design decisions in the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do this as well, I also thought about this. I'm a bit worried though that it's a bit black magic. In the 1 in a million cases where SD produces an image that has exactly only 0s, we should actually normalize it to an image that has only 0.5. It's highly unlikely, but it could happen. So in this sense there is a difference between a "blacked-out" image due to NSFW and a "blacked-out" image due to SD

In a sense the postprocessor should not have to know anything about the safety checker - they should be as orthogonal / disentangled as possible. But I definitely understand your point here, it's quite some extra code.

Copy link
Member

@pcuenca pcuenca Apr 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like the black magic that checks whether the image is black. I think the chance for SD to produce a full black image is really remote, but it's strange that the post-processor has to be aware about safety checker images.

Another alternative would be for run_safety_checker to normalize black images so it returns -1s instead of 0s. This is actually consistent with the fact that run_safety_checker receives normalized images and also returns normalized images if they are ok, but it returns denormalized images (black represented as 0) if they are NSFW.

We could do something like this after this line:

            image = torch.stack([self.image_processor.normalize(image[i]) if has_nsfw_concept[i] else image[i] for i in range(image.shape[0])])

And then we could remove the do_normalize argument here.

Would that be preferrable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh I think it makes sense to return -1 for NSFW pixels then post_processor doesn't have to treat it differently
but we also need to make sure run_safety_checker is backward-compatible

maybe we can do this only when the image is a PyTorch tensor? I think that would be a little bit confusing no?

if torch.is_tensor(image):            
    image = torch.stack([self.image_processor.normalize(image[i]) if has_nsfw_concept[i] else image[i] for i in range(image.shape[0])])

if image.min() < 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if image.min() < 0:

image = (image / 2 + 0.5).clamp(0, 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm the condition looks a bit brittle in my opinion. Would it make sense to use an arg (and maybe default it to True)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh so this is just to make sure it won't denormalize anything twice (like the question you asked below) - it will always denormalize unless it's already within [0, 1]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch here @pcuenca - actually I think we should indeed have this default to True always since we currently always doing this normalization as a post-processing:

@yiyixuxu We should do this check for the input, not the output

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would have been pretty dangerous as it would have led to some silent hard to debug bugs (imagine the vae decodes to an output that has a min value of >=0 -> then we would have not applied the correct post-processing)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok nice this logic works


if isinstance(image, torch.Tensor) and output_type == "pt":
return image

Expand Down
42 changes: 21 additions & 21 deletions src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import inspect
import warnings
from typing import Any, Callable, Dict, List, Optional, Union

import torch
Expand All @@ -22,6 +23,7 @@
from diffusers.utils import is_accelerate_available, is_accelerate_version

from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
Expand Down Expand Up @@ -166,6 +168,7 @@ def __init__(
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)

def enable_vae_slicing(self):
Expand Down Expand Up @@ -417,17 +420,26 @@ def _encode_prompt(

return prompt_embeds

def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
def run_safety_checker(self, image, device, dtype, output_type="pil"):
if self.safety_checker is None or output_type == "latent":
has_nsfw_concept = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should always be a list, see another comment on that. Not sure if we could break existing workflows though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, we would set the has_nsfw_concept to None, when we couldn't run the safety checker. False would only be returned if the safety_checker said definitively no. Did we decide to change that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I think I asked @patrickvonplaten that - we can double-check with him again I guess

Is there a use case None and False that would make a difference, in terms of how we use the pipeline? I understand they mean different things

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to leave at None here as it might be better for backwards comp

else:
image = self.image_processor.postprocess(image, output_type="pt")
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this denormalize twice?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the if image.min() < 0: condition make sure it only denormalize once

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I don't fully understand this here - why do we run self.image_processor.post_process twice?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here is a misunderstanding. The postprocess function should always apply the normalization x / 2 + 0.5 - We should not have to run this function twice here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten
I'm really confused 😣😣😣

Should we refactor it differently and not move the denormalization to postprocessing then?
see comment here #2943 (comment)

the original workflow is
image -> vae.decode -> denormalize ((image / 2 + 0.5).clamp(0, 1)) -> run_safety_checker -> numpy_to_pil

now we have decided to move the denormalization part to postprocessing
image -> vae.decode -> run_safety_checker -> image_processor.postprocessing ( e.g. denormalize + numpy_to_pil...)

This means:
(1) we need to call self.image_processor.postprocess(image, output_type="pt") from run_safety_checker since safety_checker expect input to be in [0,1]
(2) the image we sent to image_processor.postprocessing might have already been denormalized (in the case that we run safety_checker); it's also possible that it has not (in the case that there is not a safety_checker) - we need to address them differently

Did I understand something wrong?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's just that this:
"is there a better way to check if the image is already denormalized"
that needs to be addressed, we cannot check here depending on the values of the tensor whether it has already been denormalized.

Imagine for some reason stable diffusion generates a tensor that is always larger 0 - in this case we would never denormalize which is bad.

A better check is to add a flag argument to the post-process function that allows the safety checker to not denormalize it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we also don't really need this here. I understand that we need both a post-processed image in PyTorch and in PIL format, but we should not put the post processed PyTorch image into the post-process function again.

Then we also don't have this double post process problem

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree with Patrick, I'd rather not chain calls to postprocess. I see two potential solutions:

  • Just convert the post-processed image to PIL (maybe we can add a helper to the postprocessor if we need to).
  • Maybe just swap the calls? Perhaps I'm confused but wouldn't this work?
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
image = self.image_processor.postprocess(image, output_type="pt")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pcuenca yes we can swap the calls here. I agree it's better to swap and it will avoid calling postprocess twice inside run_safety_checker.

I still need to call postprocess again with the output of run_safety_checker, which is already denormalized https://github.com/huggingface/diffusers/blob/postprocessing-refactor/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L712

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten I think the flag will solve my problem - I will go with that for now. We can discuss later and happy to refactor again

safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
else:
has_nsfw_concept = None
return image, has_nsfw_concept

def decode_latents(self, latents):
Copy link
Contributor

@patrickvonplaten patrickvonplaten Apr 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should deprecate this function and instead advertise people to do the following:

image = self.vae.decode(latents / self.vae.config.scaling_factor).sample
image = self.image_processor.postprocess(image, output_type="pt")

instead.

This means we also should put:

 image = (image / 2 + 0.5).clamp(0, 1)

in the post-processor, which is good as it's clearly post-processing.

warnings.warn(
(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead"
),
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
Expand Down Expand Up @@ -691,24 +703,12 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

if output_type == "latent":
image = latents
has_nsfw_concept = None
elif output_type == "pil":
# 8. Post-processing
image = self.decode_latents(latents)

# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample

# 10. Convert to PIL
image = self.numpy_to_pil(image)
else:
# 8. Post-processing
image = self.decode_latents(latents)
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype, output_type=output_type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should avoid passing output_type to run_safety_checker. It looks like output_type is just being used to check for if the output type is latent and if so, not running the safety checker. I think that should be done outside of run_safety_checker.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@williamberman
yeah agreed
actually, I already updated that on img2img pipeline - maybe you could review the code change on that pipeline instead?

I think at some point I gave up on updating all 30ish pipelines and decided to only iterate on the text2img and img2img pipelines until the design is finalized. Sorry it's a little bit confusing 😂


# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
image = self.image_processor.postprocess(image, output_type=output_type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what types are allowed to be passed from call to run_safety_checker to the safety checker itself. From reading the code, my understanding is that the safety checker can only take numpy arrays and torch tensors for its image argument but here it looks like we might be passing PIL images as well? The safety checker's code for censoring the images would throw an error in that case or overwrite the PIL images to be numpy arrays.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would perhaps make this PR a bit easier if we were to first refactor the safety checker and the run_safety_checker function so it only took one copy of the images, and then if we want to apply the censoring, we can call into a separate helper function. I think the call method, the run_safety_checker method, and the forward function of the safety checker are a bit too tied together right now


# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import inspect
import warnings
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
Expand Down Expand Up @@ -194,6 +195,7 @@ def __init__(
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)

self.register_modules(
vae=vae,
text_encoder=text_encoder,
Expand All @@ -204,11 +206,8 @@ def __init__(
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)

self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(
requires_safety_checker=requires_safety_checker,
)
self.register_to_config(requires_safety_checker=requires_safety_checker)

def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Expand Down Expand Up @@ -427,18 +426,31 @@ def _encode_prompt(

return prompt_embeds

def run_safety_checker(self, image, device, dtype):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
def run_safety_checker(self, image, device, dtype, output_type="pil"):
if self.safety_checker is None or output_type == "latent":
has_nsfw_concept = False
else:
image = self.image_processor.postprocess(image, output_type="pt")
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept

def decode_latents(self, latents):
warnings.warn(
(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead"
),
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image

def prepare_extra_step_kwargs(self, generator, eta):
Expand Down Expand Up @@ -722,27 +734,12 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

if output_type not in ["latent", "pt", "np", "pil"]:
deprecation_message = (
f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: "
"`pil`, `np`, `pt`, `latent`"
)
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
output_type = "np"

if output_type == "latent":
image = latents
has_nsfw_concept = None
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample

else:
image = self.decode_latents(latents)

if self.safety_checker is not None:
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
has_nsfw_concept = False
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype, output_type=output_type)

image = self.image_processor.postprocess(image, output_type=output_type)
image = self.image_processor.postprocess(image, output_type=output_type)
Comment on lines +747 to +749
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose we focused on just a couple of pipelines for review, just to note this would be missing the same do_normalize logic we have in other pipelines (if we keep it).


# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import inspect
import warnings
from typing import Callable, List, Optional, Union

import numpy as np
Expand All @@ -22,9 +23,10 @@

from diffusers.utils import is_accelerate_available

from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging, randn_tensor
from ...utils import deprecate, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
Expand Down Expand Up @@ -184,6 +186,7 @@ def __init__(
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)

def enable_sequential_cpu_offload(self, gpu_id=0):
Expand Down Expand Up @@ -225,14 +228,16 @@ def _execution_device(self):
return self.device

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
def run_safety_checker(self, image, device, dtype, output_type="pil"):
if self.safety_checker is None or output_type == "latent":
has_nsfw_concept = False
else:
image = self.image_processor.postprocess(image, output_type="pt")
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
else:
has_nsfw_concept = None
return image, has_nsfw_concept

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
Expand All @@ -255,6 +260,11 @@ def prepare_extra_step_kwargs(self, generator, eta):

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
Expand Down Expand Up @@ -560,15 +570,22 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

# 11. Post-processing
image = self.decode_latents(latents)
if output_type not in ["latent", "pt", "np", "pil"]:
deprecation_message = (
f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: "
"`pil`, `np`, `pt`, `latent`"
)
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
output_type = "np"

if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample

# 12. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
image, has_nsfw_concept = self.run_safety_checker(
image, device, image_embeddings.dtype, output_type=output_type
)

# 13. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
image = self.image_processor.postprocess(image, output_type=output_type)

if not return_dict:
return (image, has_nsfw_concept)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import inspect
import warnings
from itertools import repeat
from typing import Callable, List, Optional, Union

import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer

from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
Expand Down Expand Up @@ -129,10 +131,29 @@ def __init__(
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype, output_type="pil"):
if self.safety_checker is None or output_type == "latent":
has_nsfw_concept = False
else:
image = self.image_processor.postprocess(image, output_type="pt")
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
Expand Down Expand Up @@ -680,21 +701,14 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

# 8. Post-processing
image = self.decode_latents(latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample

if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
self.device
)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
)
else:
has_nsfw_concept = None
image, has_nsfw_concept = self.run_safety_checker(
image, self.device, text_embeddings.dtype, output_type=output_type
)

if output_type == "pil":
image = self.numpy_to_pil(image)
image = self.image_processor.postprocess(image, output_type=output_type)

if not return_dict:
return (image, has_nsfw_concept)
Expand Down
Loading