Skip to content

[WIP] VaeImageProcessorImage Postprocessing refactor #2943

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 33 commits into from
Closed
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
3768ed9
refactor post processing with imageprocessor (update decode_latents/r…
Mar 16, 2023
71eda72
fix
Mar 31, 2023
83da056
make style
Mar 31, 2023
d13bc7f
Merge branch 'main' into postprocessing-refactor
yiyixuxu Mar 31, 2023
51cabe2
Merge branch 'main' into postprocessing-refactor
patrickvonplaten Apr 4, 2023
d91bcc9
refactor post-processing
Apr 6, 2023
6cbd1ac
add pipelinelatenttestermixin
Apr 7, 2023
c6d2405
update stablediffusionpipeline fast test using pipelinelatenttestermixim
Apr 8, 2023
fe8e13e
fix
Apr 9, 2023
4b09a20
alt
Apr 9, 2023
ce19bc9
refactor all pipelines
Apr 9, 2023
8065199
update alt-img2img
Apr 9, 2023
2c76ca3
make style
Apr 9, 2023
1a2c7a9
Merge branch 'main' into postprocessing-refactor
yiyixuxu Apr 9, 2023
389fdfe
fix model edit pipeline
Apr 9, 2023
db33f87
evaluation model for depth estimator in testing
yiyixuxu Apr 10, 2023
a491a38
Merge branch 'main' into postprocessing-refactor
patrickvonplaten Apr 11, 2023
5cde78c
fix
Apr 17, 2023
c5e69b9
Merge branch 'postprocessing-refactor' of https://github.com/huggingf…
Apr 17, 2023
ccf5b37
merge with main after fixing conflicts
Apr 17, 2023
928c35b
style + copy
Apr 18, 2023
809f1fe
fix tests
Apr 18, 2023
9ebd8f9
style
Apr 18, 2023
3dcb7b1
fix onnx
Apr 18, 2023
fabf88a
fix tests
Apr 18, 2023
04d27fb
style
Apr 18, 2023
68bb18b
fix alt img2img test
Apr 18, 2023
106c43a
Update src/diffusers/image_processor.py
yiyixuxu Apr 18, 2023
acf0d60
Update src/diffusers/image_processor.py
yiyixuxu Apr 18, 2023
7c1bb9a
Merge branch 'main' into postprocessing-refactor
yiyixuxu Apr 20, 2023
a09ecca
resolve merge
Apr 25, 2023
f03ff17
valueerror not needed
Apr 25, 2023
8c4a31b
fix output_type=latents
Apr 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions src/diffusers/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
# limitations under the License.

import warnings
from typing import Union
from typing import Union, Optional, List

import numpy as np
import PIL
import torch
from PIL import Image

from .configuration_utils import ConfigMixin, register_to_config
from .utils import CONFIG_NAME, PIL_INTERPOLATION
from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate


class VaeImageProcessor(ConfigMixin):
Expand Down Expand Up @@ -82,7 +82,7 @@ def numpy_to_pt(images):
@staticmethod
def pt_to_numpy(images):
"""
Convert a numpy image to a pytorch tensor
Convert a pytorch tensor to a numpy image
"""
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
return images
Expand All @@ -93,6 +93,13 @@ def normalize(images):
Normalize an image array to [-1,1]
"""
return 2.0 * images - 1.0

@staticmethod
def denormalize(images):
"""
Denormalize an image array to [0,1]
"""
return (images / 2 + 0.5).clamp(0, 1)

def resize(self, images: PIL.Image.Image) -> PIL.Image.Image:
"""
Expand Down Expand Up @@ -165,10 +172,31 @@ def preprocess(

def postprocess(
self,
image,
image: torch.FloatTensor,
output_type: str = "pil",
):
if isinstance(image, torch.Tensor) and output_type == "pt":
do_normalize: Optional[Union[List[bool], bool]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe rename to do_denormalize (if we keep it)

):
if not isinstance(image, torch.Tensor):
raise ValueError(
f"Input for postprocess is in incorrect format: {type(image)}. we only support pytorch tensor"
)
if output_type not in ["latent", "pt", "np", "pil"]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: the raise ValueError that happens later at the end of the function will no longer trigger.

else:
raise ValueError(f"Unsupported output_type {output_type}.")
We could remove it, or we could leave it there anyway so everything still works after we remove the deprecation warning.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh thanks I think we can just remove it ( can add it back if we ever remove the warning)

deprecation_message = (
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
"`pil`, `np`, `pt`, `latent`"
)
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
output_type = "np"

if output_type == "latent":
return image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not isinstance(do_normalize, list):
do_normalize = image.shape[0] * [do_normalize or self.config.do_normalize]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we will need this to support the blacked-out image

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten

can we just check from the postprocess to see if the image is black? And not denormalize if so?

I will do what you proposed here but I just want to understand and think it will help me to make better design decisions in the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do this as well, I also thought about this. I'm a bit worried though that it's a bit black magic. In the 1 in a million cases where SD produces an image that has exactly only 0s, we should actually normalize it to an image that has only 0.5. It's highly unlikely, but it could happen. So in this sense there is a difference between a "blacked-out" image due to NSFW and a "blacked-out" image due to SD

In a sense the postprocessor should not have to know anything about the safety checker - they should be as orthogonal / disentangled as possible. But I definitely understand your point here, it's quite some extra code.

Copy link
Member

@pcuenca pcuenca Apr 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like the black magic that checks whether the image is black. I think the chance for SD to produce a full black image is really remote, but it's strange that the post-processor has to be aware about safety checker images.

Another alternative would be for run_safety_checker to normalize black images so it returns -1s instead of 0s. This is actually consistent with the fact that run_safety_checker receives normalized images and also returns normalized images if they are ok, but it returns denormalized images (black represented as 0) if they are NSFW.

We could do something like this after this line:

            image = torch.stack([self.image_processor.normalize(image[i]) if has_nsfw_concept[i] else image[i] for i in range(image.shape[0])])

And then we could remove the do_normalize argument here.

Would that be preferrable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh I think it makes sense to return -1 for NSFW pixels then post_processor doesn't have to treat it differently
but we also need to make sure run_safety_checker is backward-compatible

maybe we can do this only when the image is a PyTorch tensor? I think that would be a little bit confusing no?

if torch.is_tensor(image):            
    image = torch.stack([self.image_processor.normalize(image[i]) if has_nsfw_concept[i] else image[i] for i in range(image.shape[0])])

if not isinstance(do_normalize, list):
do_normalize = image.shape[0] * [do_normalize or self.config.do_normalize]
Comment on lines +194 to +195
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume the goal is to default to self.config.do_normalize only when the argument is None, When the argument is False, this will default to self.config.do_normalize.

When we use optional boolean arguments in python, a good rule of thumb is to as soon as possible in the function set them to a default value by explicitly using is None and if there is other behavior based off of them being not set, set alternative flags for those.

i.e.

def foo(bar: Optional[bool]=None):
    if bar is None:
        bar = False # the default value
        bar_was_set = False
   else:
      bar_was_set = True

   # from here on, bar can always be treated as just a boolean and if anything
   # is needed based on if not being set, `bar_was_set` can be used


image = torch.stack([self.denormalize(image[i]) if do_normalize[i] else image[i] for i in range(image.shape[0])])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I find these oneliners hard to read even though what they're doing is quite simple. I think golang got it right by only allowing a single looping construct. Could we turn this into a for loop?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll always agree that clarity is more important than the number of lines, but I think in this case a loop would involve creating a temporary list to hold the images and then doing the stack outside the loop, which for me could be harder to parse than the oneliner. With the oneliner I see that this is "denormalizing somehow" and don't bother to look at the details unless I need to. With the loop, I'm more or less forced to read all the lines and understand them. Also, I hate vars. Maybe Python should be more functional too ;)

Copy link
Contributor

@williamberman williamberman Apr 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, once you use the list comprehension syntax, you're going to allocate the temporary list regardless :)

>>> foo = [1,2,3]
>>> bar = [x for x in foo] <- not a generator
>>> bar.__class__
<class 'list'>

Plus, if the stack function were to take a generator, then it would have to instantiate it regardless as there'll be a fixed number of dimensions for the resulting tensor (maybe there could be some more efficient implementation by doing incremental instantiations but as really it's a list of pointers, I'm not sure.

IIRC, the go argument here for the single looping construct is that it's always best to just make memory allocation explicit and let the caller handle it?

Agreed that it makes you read the body of the loop to know what it's doing, but I think that's a good thing!

re should python be more functional, guido agrees but selectively so ;) https://www.artima.com/weblogs/viewpost.jsp?thread=98196

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha, my comment was not about efficiency or memory, I dislike vars because they create state and take space in your mind :) But it's always fun to have these conversations, keep it going :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

of course 😁


if output_type == "pt":
return image

image = self.pt_to_numpy(image)
Expand Down
41 changes: 20 additions & 21 deletions src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import inspect
import warnings
from typing import Any, Callable, Dict, List, Optional, Union

import torch
Expand All @@ -22,6 +23,7 @@
from diffusers.utils import is_accelerate_available, is_accelerate_version

from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
Expand Down Expand Up @@ -174,6 +176,7 @@ def __init__(
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)

def enable_vae_slicing(self):
Expand Down Expand Up @@ -425,17 +428,25 @@ def _encode_prompt(

return prompt_embeds

def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
def run_safety_checker(self, image, device, dtype, output_type="pil"):
if self.safety_checker is None or output_type == "latent":
has_nsfw_concept = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should always be a list, see another comment on that. Not sure if we could break existing workflows though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, we would set the has_nsfw_concept to None, when we couldn't run the safety checker. False would only be returned if the safety_checker said definitively no. Did we decide to change that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I think I asked @patrickvonplaten that - we can double-check with him again I guess

Is there a use case None and False that would make a difference, in terms of how we use the pipeline? I understand they mean different things

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to leave at None here as it might be better for backwards comp

else:
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
else:
has_nsfw_concept = None
return image, has_nsfw_concept

def decode_latents(self, latents):
Copy link
Contributor

@patrickvonplaten patrickvonplaten Apr 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should deprecate this function and instead advertise people to do the following:

image = self.vae.decode(latents / self.vae.config.scaling_factor).sample
image = self.image_processor.postprocess(image, output_type="pt")

instead.

This means we also should put:

 image = (image / 2 + 0.5).clamp(0, 1)

in the post-processor, which is good as it's clearly post-processing.

warnings.warn(
(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead"
),
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
Expand Down Expand Up @@ -699,24 +710,12 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

if output_type == "latent":
image = latents
has_nsfw_concept = None
elif output_type == "pil":
# 8. Post-processing
image = self.decode_latents(latents)

# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample

# 10. Convert to PIL
image = self.numpy_to_pil(image)
else:
# 8. Post-processing
image = self.decode_latents(latents)
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype, output_type=output_type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should avoid passing output_type to run_safety_checker. It looks like output_type is just being used to check for if the output type is latent and if so, not running the safety checker. I think that should be done outside of run_safety_checker.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@williamberman
yeah agreed
actually, I already updated that on img2img pipeline - maybe you could review the code change on that pipeline instead?

I think at some point I gave up on updating all 30ish pipelines and decided to only iterate on the text2img and img2img pipelines until the design is finalized. Sorry it's a little bit confusing 😂


# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
image = self.image_processor.postprocess(image, output_type=output_type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what types are allowed to be passed from call to run_safety_checker to the safety checker itself. From reading the code, my understanding is that the safety checker can only take numpy arrays and torch tensors for its image argument but here it looks like we might be passing PIL images as well? The safety checker's code for censoring the images would throw an error in that case or overwrite the PIL images to be numpy arrays.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would perhaps make this PR a bit easier if we were to first refactor the safety checker and the run_safety_checker function so it only took one copy of the images, and then if we want to apply the censoring, we can call into a separate helper function. I think the call method, the run_safety_checker method, and the forward function of the safety checker are a bit too tied together right now


# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import inspect
import warnings
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
Expand Down Expand Up @@ -202,6 +203,7 @@ def __init__(
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)

self.register_modules(
vae=vae,
text_encoder=text_encoder,
Expand All @@ -212,11 +214,8 @@ def __init__(
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)

self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(
requires_safety_checker=requires_safety_checker,
)
self.register_to_config(requires_safety_checker=requires_safety_checker)

def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Expand Down Expand Up @@ -435,18 +434,30 @@ def _encode_prompt(

return prompt_embeds

def run_safety_checker(self, image, device, dtype):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
def run_safety_checker(self, image, device, dtype, output_type="pil"):
if self.safety_checker is None or output_type == "latent":
has_nsfw_concept = False
else:
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept

def decode_latents(self, latents):
warnings.warn(
(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead"
),
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image

def prepare_extra_step_kwargs(self, generator, eta):
Expand Down Expand Up @@ -730,27 +741,12 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

if output_type not in ["latent", "pt", "np", "pil"]:
deprecation_message = (
f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: "
"`pil`, `np`, `pt`, `latent`"
)
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
output_type = "np"

if output_type == "latent":
image = latents
has_nsfw_concept = None
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample

else:
image = self.decode_latents(latents)

if self.safety_checker is not None:
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
has_nsfw_concept = False
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype, output_type=output_type)

image = self.image_processor.postprocess(image, output_type=output_type)
image = self.image_processor.postprocess(image, output_type=output_type)
Comment on lines +747 to +749
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose we focused on just a couple of pipelines for review, just to note this would be missing the same do_normalize logic we have in other pipelines (if we keep it).


# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import inspect
import warnings
from typing import Callable, List, Optional, Union

import numpy as np
Expand All @@ -22,9 +23,10 @@

from diffusers.utils import is_accelerate_available

from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging, randn_tensor
from ...utils import deprecate, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
Expand Down Expand Up @@ -184,6 +186,7 @@ def __init__(
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)

def enable_sequential_cpu_offload(self, gpu_id=0):
Expand Down Expand Up @@ -225,14 +228,15 @@ def _execution_device(self):
return self.device

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
def run_safety_checker(self, image, device, dtype, output_type="pil"):
if self.safety_checker is None or output_type == "latent":
has_nsfw_concept = False
else:
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
else:
has_nsfw_concept = None
return image, has_nsfw_concept

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
Expand All @@ -255,6 +259,11 @@ def prepare_extra_step_kwargs(self, generator, eta):

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
Expand Down Expand Up @@ -560,15 +569,22 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

# 11. Post-processing
image = self.decode_latents(latents)
if output_type not in ["latent", "pt", "np", "pil"]:
deprecation_message = (
f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: "
"`pil`, `np`, `pt`, `latent`"
)
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
output_type = "np"

if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample

# 12. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
image, has_nsfw_concept = self.run_safety_checker(
image, device, image_embeddings.dtype, output_type=output_type
)

# 13. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
image = self.image_processor.postprocess(image, output_type=output_type)

if not return_dict:
return (image, has_nsfw_concept)
Expand Down
Loading