Skip to content

[WIP] VaeImageProcessorImage Postprocessing refactor #2943

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 33 commits into from

Conversation

yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Mar 31, 2023

This PR refactors the post-processing for all the relevant pipelines.

Can we wrap the post-processing code into a method, so that we can just do #copy from ? it will make it easier for people to contribute new pipelines and for us to maintain

what do you guys think? @patrickvonplaten @williamberman @sayakpaul @pcuenca

        if output_type not in ["latent", "pt", "np", "pil"]:
            deprecation_message = (
                f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: "
                "`pil`, `np`, `pt`, `latent`"
            )
            deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
            output_type = "np"

        if output_type == "latent":
            image = latents
            has_nsfw_concept = None

        else:
            image = self.decode_latents(latents)

            if self.safety_checker is not None:
                image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
            else:
                has_nsfw_concept = False

        image = self.image_processor.postprocess(image, output_type=output_type)

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 31, 2023

The documentation is not available anymore as the PR was closed or merged.

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Apr 1, 2023

should fix #2871

@sayakpaul
Copy link
Member

Can we wrap the post-processing code into a method, so that we can just do #copy from ? it will make it easier for people to contribute new pipelines and for us to maintain

I am all in for the idea. Let's do that.

@@ -691,24 +689,27 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

if output_type not in ["latent", "pt", "np", "pil"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept

def decode_latents(self, latents):
Copy link
Contributor

@patrickvonplaten patrickvonplaten Apr 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should deprecate this function and instead advertise people to do the following:

image = self.vae.decode(latents / self.vae.config.scaling_factor).sample
image = self.image_processor.postprocess(image, output_type="pt")

instead.

This means we also should put:

 image = (image / 2 + 0.5).clamp(0, 1)

in the post-processor, which is good as it's clearly post-processing.

@patrickvonplaten
Copy link
Contributor

Great effort overall @yiyixuxu ! I think this is very close to being merged :-)

1.)
Can we make sure to also add the centering & post-processing also to the post-processor: #2943 (comment) ?
Note:
Here we should also correctly deprecate the method decode_latents as explained here: #2943 (comment)
People currently already use decode_latents on their own in their pipelines so we need to be careful to not change the function and gracefully deprecate it

2.) RE:

Can we wrap the post-processing code into a method, so that we can just do #copy from ? it will make it easier for people to contribute new pipelines and for us to maintain

Hmm not sure do we really have to?

The deprecation warning:

        if output_type not in ["latent", "pt", "np", "pil"]:
            deprecation_message = (
                f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: "
                "`pil`, `np`, `pt`, `latent`"
            )
            deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
            output_type = "np"

can live directly in self.image_processor.postprocess which is cleaner anyways IMO. This way the VAEImageProcessor can also clearly define the accepted output_types =["latent", "pt", "np", "pil"]

Then we would only have:

        if not output_type == "latent":
            image = self.vae.decode(latents / self.vae.config.scaling_factor).sample

       image = self.image_processor.postprocess(image, output_type=output_type)

for pipelines without safety checker, e.g.: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
Note I think the image processor should be able to process the output_type="latent" and just not do anything. This clearly differentiates between vae decoding and postp-processing.

Also many future pipelines won't need the safety checker or have a different one.

The other pipelines with safety checker then would have something like

        if not output_type == "latent":
            image = self.vae.decode(latents / self.vae.config.scaling_factor).sample

        if self.safety_checker is None:
            has_nsfw_concept = False
        elif self.safety_checker is not None and output_type == "latent":
            has_nsfw_concept = None
        else:
            image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)

        image = self.image_processor.postprocess(image, output_type=output_type)

where IMO it could be nice to refactor self.run_safety_checker(...) to have the API:

        if not output_type == "latent":
            image = self.vae.decode(latents / self.vae.config.scaling_factor).sample

        image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype, output_type=output_type)

        image = self.image_processor.postprocess(image, output_type=output_type)

=> We can do this by keeping full backwards compatibility for the self.run_safety_checker function.

I think this API is cleaner than creating one big function that uses one big "Copied from ..." because:

  • It's nicer to read code directly instead of having to jump into functions
  • Better seperation of responsibility (output_types are defined in the post processor; safety checker is disentangled from the output processor; pipelines with no safety checker are differentiated from one that have a safety checker)

Note this means we have to also call:

self.image_processor.postprocess(image, output_type="pt")

in the safety checker. To keep backwards compatibility we we should check whether the image is already post-processed (if outputs are already only between 0 and 1, the image was already post processed)

3.)
We need to add lots of tests here:

  • Make sure decode_latents is correctly deprecated and backwards compatible
  • Make sure run_safety_checker is backwards compatible
  • Make sure that every pipeline can correctly output different output types and accept different input types

Let me know if this make sense :-)

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Apr 6, 2023

@patrickvonplaten what's the difference between has_nsfw_concept=None and has_nsfw_concept=False?

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Apr 6, 2023

@patrickvonplaten what's the difference between has_nsfw_concept=None and has_nsfw_concept=False?
Not sure who added the has_nsfw_concept=None part though

@yiyixuxu yiyixuxu closed this Apr 9, 2023
@yiyixuxu yiyixuxu reopened this Apr 9, 2023
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Apr 9, 2023

@patrickvonplaten

How to test backward compatibility for run_safety_checker?

Comment on lines 178 to 179
if image.min() < 0:
image = (image / 2 + 0.5).clamp(0, 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm the condition looks a bit brittle in my opinion. Would it make sense to use an arg (and maybe default it to True)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh so this is just to make sure it won't denormalize anything twice (like the question you asked below) - it will always denormalize unless it's already within [0, 1]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch here @pcuenca - actually I think we should indeed have this default to True always since we currently always doing this normalization as a post-processing:

@yiyixuxu We should do this check for the input, not the output

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would have been pretty dangerous as it would have led to some silent hard to debug bugs (imagine the vae decodes to an output that has a min value of >=0 -> then we would have not applied the correct post-processing)

Comment on lines 427 to 428
image = self.image_processor.postprocess(image, output_type="pt")
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this denormalize twice?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the if image.min() < 0: condition make sure it only denormalize once

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I don't fully understand this here - why do we run self.image_processor.post_process twice?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here is a misunderstanding. The postprocess function should always apply the normalization x / 2 + 0.5 - We should not have to run this function twice here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten
I'm really confused 😣😣😣

Should we refactor it differently and not move the denormalization to postprocessing then?
see comment here #2943 (comment)

the original workflow is
image -> vae.decode -> denormalize ((image / 2 + 0.5).clamp(0, 1)) -> run_safety_checker -> numpy_to_pil

now we have decided to move the denormalization part to postprocessing
image -> vae.decode -> run_safety_checker -> image_processor.postprocessing ( e.g. denormalize + numpy_to_pil...)

This means:
(1) we need to call self.image_processor.postprocess(image, output_type="pt") from run_safety_checker since safety_checker expect input to be in [0,1]
(2) the image we sent to image_processor.postprocessing might have already been denormalized (in the case that we run safety_checker); it's also possible that it has not (in the case that there is not a safety_checker) - we need to address them differently

Did I understand something wrong?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's just that this:
"is there a better way to check if the image is already denormalized"
that needs to be addressed, we cannot check here depending on the values of the tensor whether it has already been denormalized.

Imagine for some reason stable diffusion generates a tensor that is always larger 0 - in this case we would never denormalize which is bad.

A better check is to add a flag argument to the post-process function that allows the safety checker to not denormalize it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we also don't really need this here. I understand that we need both a post-processed image in PyTorch and in PIL format, but we should not put the post processed PyTorch image into the post-process function again.

Then we also don't have this double post process problem

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree with Patrick, I'd rather not chain calls to postprocess. I see two potential solutions:

  • Just convert the post-processed image to PIL (maybe we can add a helper to the postprocessor if we need to).
  • Maybe just swap the calls? Perhaps I'm confused but wouldn't this work?
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
image = self.image_processor.postprocess(image, output_type="pt")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pcuenca yes we can swap the calls here. I agree it's better to swap and it will avoid calling postprocess twice inside run_safety_checker.

I still need to call postprocess again with the output of run_safety_checker, which is already denormalized https://github.com/huggingface/diffusers/blob/postprocessing-refactor/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L712

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten I think the flag will solve my problem - I will go with that for now. We can discuss later and happy to refactor again

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great I think we did lots of progress here. Left a bunch of statements. Sorry some final changes (maybe you can first just do them for stable diffusion text to image and stable diffusion image to image:

  • 1.) Let's make sure that the postprocess only accepts PyTorch tensors as an input, nothing else. This makes the API very easy and testable
  • 2.) There is one edge case with the postprocessing when the safety checker detects a nfsw image in this case we should not normalize it
  • 3.) Let's make sure the safety checker is fully backwards compatible (in our pipeline now it will always just accept a "pt" tensor, but previously it accepted a numpy array so let's still allow for this case
  • 4.) Differently to what I said before let's only run the safety checker when the output type is not "latent" (sorry I said this incorrectly before)

=> I left lots of comments and suggestions for the points 1.) - 4.)

Maybe we can first only apply them to stable diffusion text 2 image and image 2 image and then I can do a final review and then we apply to all other pipelines so that you can save some time? :-)

Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Co-authored-by: Patrick von Platen <[email protected]>

update img2img
@yiyixuxu yiyixuxu force-pushed the postprocessing-refactor branch from 200d5e0 to acf0d60 Compare April 19, 2023 02:50
@yiyixuxu
Copy link
Collaborator Author

@patrickvonplaten

updated the stable diffusion text2img and img2img - ready for a (hopefully😅) final review

if isinstance(has_nsfw_concept, list)
else not has_nsfw_concept
)
image = self.image_processor.postprocess(image, output_type=output_type, do_normalize=do_normalize)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pcuenca could you also check here. Note we need this because of #2943 (review) IMO - the other idea is to detect whether an image is black in the postprocessor as discussed here: #2943 (comment) but IMO it's a bit too brittle

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented above. I don't like to introduce black image detection in the post-processor so I'd go with this or by having run_safety_checker always return the same type of images (normalized).

@patrickvonplaten
Copy link
Contributor

@pcuenca could you maybe do a final review here (only need to look at changes of this file: https://github.com/huggingface/diffusers/pull/2943/files#diff-82113feddf255a70df849256026f9aed834c341cd73c9c89513bd5ab9be8f13d ) and the comments here: #2943 (review) and here: #2943 (comment)

@patrickvonplaten
Copy link
Contributor

Great job think we're almost there @yiyixuxu !

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the patience iterating here @yiyixuxu, this is certainly a trickier PR than it looks 😅.

I mainly focused on the Stable Diffusion pipeline for this review. I think we broke output_type="latent", and I suggested another alternative to deal with the black images by having run_safety_checker always return normalized images no matter what.

In addition to that, I think there's an unrelated bug/inconsistency as we sometimes return a Boolean and sometimes a list to indicate whether nsfw concepts were found. According to the documentation, my understanding is we should always return lists, but maybe we need to deal with that in a different PR in case it breaks existing users' workflows. I'm pointing it out because I got slightly confused while reading this PR and had to double-check with main.

raise ValueError(
f"Input for postprocess is in incorrect format: {type(image)}. we only support pytorch tensor"
)
if output_type not in ["latent", "pt", "np", "pil"]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: the raise ValueError that happens later at the end of the function will no longer trigger.

else:
raise ValueError(f"Unsupported output_type {output_type}.")
We could remove it, or we could leave it there anyway so everything still works after we remove the deprecation warning.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh thanks I think we can just remove it ( can add it back if we ever remove the warning)


if output_type == "latent":
return image

Copy link
Member

@pcuenca pcuenca Apr 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like the black magic that checks whether the image is black. I think the chance for SD to produce a full black image is really remote, but it's strange that the post-processor has to be aware about safety checker images.

Another alternative would be for run_safety_checker to normalize black images so it returns -1s instead of 0s. This is actually consistent with the fact that run_safety_checker receives normalized images and also returns normalized images if they are ok, but it returns denormalized images (black represented as 0) if they are NSFW.

We could do something like this after this line:

            image = torch.stack([self.image_processor.normalize(image[i]) if has_nsfw_concept[i] else image[i] for i in range(image.shape[0])])

And then we could remove the do_normalize argument here.

Would that be preferrable?

output_type: str = "pil",
):
if isinstance(image, torch.Tensor) and output_type == "pt":
do_normalize: Optional[Union[List[bool], bool]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe rename to do_denormalize (if we keep it)


# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
has_nsfw_concept = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is a bug from previous versions. According to the documentation:

[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.

and

nsfw_content_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, or `None` if safety checking could not be performed.

we are supposed to always return a list. Therefore we'd need to do something like this, same thing in run_safety_checker and then simplify do_normalize accordingly:

Suggested change
has_nsfw_concept = False
has_nsfw_concept = [False] * latents.shape[0]

Comment on lines +723 to +724
do_normalize = [not has_nsfw for has_nsfw in has_nsfw_concept] if isinstance(has_nsfw_concept, list) else not has_nsfw_concept
image = self.image_processor.postprocess(image, output_type=output_type, do_normalize=do_normalize)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't work if output_type is "latent". image will not be defined in that case. I think we need to move these two lines to inside the if, and define image = latent in the else block.

safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
def run_safety_checker(self, image, device, dtype, output_type="pil"):
if self.safety_checker is None or output_type == "latent":
has_nsfw_concept = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should always be a list, see another comment on that. Not sure if we could break existing workflows though.

Comment on lines +747 to +749
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype, output_type=output_type)

image = self.image_processor.postprocess(image, output_type=output_type)
image = self.image_processor.postprocess(image, output_type=output_type)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose we focused on just a couple of pipelines for review, just to note this would be missing the same do_normalize logic we have in other pipelines (if we keep it).

if not isinstance(do_normalize, list):
do_normalize = image.shape[0] * [do_normalize or self.config.do_normalize]

image = torch.stack([self.denormalize(image[i]) if do_normalize[i] else image[i] for i in range(image.shape[0])])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I find these oneliners hard to read even though what they're doing is quite simple. I think golang got it right by only allowing a single looping construct. Could we turn this into a for loop?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll always agree that clarity is more important than the number of lines, but I think in this case a loop would involve creating a temporary list to hold the images and then doing the stack outside the loop, which for me could be harder to parse than the oneliner. With the oneliner I see that this is "denormalizing somehow" and don't bother to look at the details unless I need to. With the loop, I'm more or less forced to read all the lines and understand them. Also, I hate vars. Maybe Python should be more functional too ;)

Copy link
Contributor

@williamberman williamberman Apr 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, once you use the list comprehension syntax, you're going to allocate the temporary list regardless :)

>>> foo = [1,2,3]
>>> bar = [x for x in foo] <- not a generator
>>> bar.__class__
<class 'list'>

Plus, if the stack function were to take a generator, then it would have to instantiate it regardless as there'll be a fixed number of dimensions for the resulting tensor (maybe there could be some more efficient implementation by doing incremental instantiations but as really it's a list of pointers, I'm not sure.

IIRC, the go argument here for the single looping construct is that it's always best to just make memory allocation explicit and let the caller handle it?

Agreed that it makes you read the body of the loop to know what it's doing, but I think that's a good thing!

re should python be more functional, guido agrees but selectively so ;) https://www.artima.com/weblogs/viewpost.jsp?thread=98196

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha, my comment was not about efficiency or memory, I dislike vars because they create state and take space in your mind :) But it's always fun to have these conversations, keep it going :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

of course 😁

else:
# 8. Post-processing
image = self.decode_latents(latents)
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype, output_type=output_type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should avoid passing output_type to run_safety_checker. It looks like output_type is just being used to check for if the output type is latent and if so, not running the safety checker. I think that should be done outside of run_safety_checker.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@williamberman
yeah agreed
actually, I already updated that on img2img pipeline - maybe you could review the code change on that pipeline instead?

I think at some point I gave up on updating all 30ish pipelines and decided to only iterate on the text2img and img2img pipelines until the design is finalized. Sorry it's a little bit confusing 😂

safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
def run_safety_checker(self, image, device, dtype, output_type="pil"):
if self.safety_checker is None or output_type == "latent":
has_nsfw_concept = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, we would set the has_nsfw_concept to None, when we couldn't run the safety checker. False would only be returned if the safety_checker said definitively no. Did we decide to change that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I think I asked @patrickvonplaten that - we can double-check with him again I guess

Is there a use case None and False that would make a difference, in terms of how we use the pipeline? I understand they mean different things

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to leave at None here as it might be better for backwards comp

Comment on lines +194 to +195
if not isinstance(do_normalize, list):
do_normalize = image.shape[0] * [do_normalize or self.config.do_normalize]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume the goal is to default to self.config.do_normalize only when the argument is None, When the argument is False, this will default to self.config.do_normalize.

When we use optional boolean arguments in python, a good rule of thumb is to as soon as possible in the function set them to a default value by explicitly using is None and if there is other behavior based off of them being not set, set alternative flags for those.

i.e.

def foo(bar: Optional[bool]=None):
    if bar is None:
        bar = False # the default value
        bar_was_set = False
   else:
      bar_was_set = True

   # from here on, bar can always be treated as just a boolean and if anything
   # is needed based on if not being set, `bar_was_set` can be used


# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
image = self.image_processor.postprocess(image, output_type=output_type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what types are allowed to be passed from call to run_safety_checker to the safety checker itself. From reading the code, my understanding is that the safety checker can only take numpy arrays and torch tensors for its image argument but here it looks like we might be passing PIL images as well? The safety checker's code for censoring the images would throw an error in that case or overwrite the PIL images to be numpy arrays.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would perhaps make this PR a bit easier if we were to first refactor the safety checker and the run_safety_checker function so it only took one copy of the images, and then if we want to apply the censoring, we can call into a separate helper function. I think the call method, the run_safety_checker method, and the forward function of the safety checker are a bit too tied together right now

@patrickvonplaten
Copy link
Contributor

@pcuenca while I like the idea:
#2943 (comment)

I don't think we can do it as it would break backcomp quite a bit. E.g. many people in my opinion are running pipe.run_safety_checker(...) and then pipe.numpy_to_pil(...)

Before this PR this gave all black [0, 0] PIL images but now this would give:

(255 * np.array([-1, -1])).round().astype("uint8")

=>

array([1, 1], dtype=uint8)

[1, 1] PIL image which is different. I think we have to pass a denormalize list flag to post_process

@patrickvonplaten
Copy link
Contributor

Think this PR was super helpful to discuss, but to get a quick PR merged could we maybe open a new PR with just the changes for img2img cc @yiyixuxu (we can remove the copy-from for this PR)
Then we can merge this do a quick final review there, give ✔️ and merge it and then in a quick second PR apply it to all other pipelines?

Sorry for the difficult PR and process here

@pcuenca
Copy link
Member

pcuenca commented Apr 27, 2023

@pcuenca while I like the idea: #2943 (comment)

I don't think we can do it as it would break backcomp quite a bit. E.g. many people in my opinion are running pipe.run_safety_checker(...) and then pipe.numpy_to_pil(...)

Of course backwards compatibility is really important, and I agree it's safest to keep returning 0s for the black images even if it's slightly inconsistent. But just wondering if that use case is already broken with this PR anyway? (I might be mistaken, will look at it more carefully).

@patrickvonplaten
Copy link
Contributor

@pcuenca while I like the idea: #2943 (comment)
I don't think we can do it as it would break backcomp quite a bit. E.g. many people in my opinion are running pipe.run_safety_checker(...) and then pipe.numpy_to_pil(...)

Of course backwards compatibility is really important, and I agree it's safest to keep returning 0s for the black images even if it's slightly inconsistent. But just wondering if that use case is already broken with this PR anyway? (I might be mistaken, will look at it more carefully).

It shouldn't have. Both decode_latents(...) and run_safety_checker(...) should be backwards comp IMO (or at least as much as possible)

@williamberman
Copy link
Contributor

@yiyixuxu worth closing this now that you merged #3268 ?

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented May 2, 2023

@williamberman I merged in one - working on a PR that will refactor all 28 others. WIll close this one once that PR is merged

@yiyixuxu yiyixuxu deleted the postprocessing-refactor branch July 14, 2024 19:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants