-
Notifications
You must be signed in to change notification settings - Fork 6k
[WIP] VaeImageProcessorImage Postprocessing refactor #2943
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…un_safety_checker) and all pipeliens copied these 2 methods
The documentation is not available anymore as the PR was closed or merged. |
should fix #2871 |
I am all in for the idea. Let's do that. |
@@ -691,24 +689,27 @@ def __call__( | |||
if callback is not None and i % callback_steps == 0: | |||
callback(i, t, latents) | |||
|
|||
if output_type not in ["latent", "pt", "np", "pil"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice!
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) | ||
image, has_nsfw_concept = self.safety_checker( | ||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) | ||
) | ||
return image, has_nsfw_concept | ||
|
||
def decode_latents(self, latents): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should deprecate this function and instead advertise people to do the following:
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample
image = self.image_processor.postprocess(image, output_type="pt")
instead.
This means we also should put:
image = (image / 2 + 0.5).clamp(0, 1)
in the post-processor, which is good as it's clearly post-processing.
Great effort overall @yiyixuxu ! I think this is very close to being merged :-) 1.) 2.) RE:
Hmm not sure do we really have to? The deprecation warning: if output_type not in ["latent", "pt", "np", "pil"]:
deprecation_message = (
f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: "
"`pil`, `np`, `pt`, `latent`"
)
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
output_type = "np" can live directly in Then we would only have: if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample
image = self.image_processor.postprocess(image, output_type=output_type) for pipelines without safety checker, e.g.: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py Also many future pipelines won't need the safety checker or have a different one. The other pipelines with safety checker then would have something like if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample
if self.safety_checker is None:
has_nsfw_concept = False
elif self.safety_checker is not None and output_type == "latent":
has_nsfw_concept = None
else:
image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
image = self.image_processor.postprocess(image, output_type=output_type) where IMO it could be nice to refactor if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample
image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype, output_type=output_type)
image = self.image_processor.postprocess(image, output_type=output_type) => We can do this by keeping full backwards compatibility for the I think this API is cleaner than creating one big function that uses one big "Copied from ..." because:
Note this means we have to also call: self.image_processor.postprocess(image, output_type="pt") in the safety checker. To keep backwards compatibility we we should check whether the image is already post-processed (if outputs are already only between 0 and 1, the image was already post processed) 3.)
Let me know if this make sense :-) |
@patrickvonplaten what's the difference between |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
How to test backward compatibility for |
src/diffusers/image_processor.py
Outdated
if image.min() < 0: | ||
image = (image / 2 + 0.5).clamp(0, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm the condition looks a bit brittle in my opinion. Would it make sense to use an arg (and maybe default it to True)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh so this is just to make sure it won't denormalize anything twice (like the question you asked below) - it will always denormalize unless it's already within [0, 1]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would have been pretty dangerous as it would have led to some silent hard to debug bugs (imagine the vae decodes to an output that has a min value of >=0 -> then we would have not applied the correct post-processing)
image = self.image_processor.postprocess(image, output_type="pt") | ||
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't this denormalize twice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the if image.min() < 0:
condition make sure it only denormalize once
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I don't fully understand this here - why do we run self.image_processor.post_process
twice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think here is a misunderstanding. The postprocess
function should always apply the normalization x / 2 + 0.5
- We should not have to run this function twice here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten
I'm really confused 😣😣😣
Should we refactor it differently and not move the denormalization to postprocessing then?
see comment here #2943 (comment)
the original workflow is
image -> vae.decode -> denormalize ((image / 2 + 0.5).clamp(0, 1)
) -> run_safety_checker -> numpy_to_pil
now we have decided to move the denormalization part to postprocessing
image -> vae.decode -> run_safety_checker -> image_processor.postprocessing ( e.g. denormalize + numpy_to_pil...)
This means:
(1) we need to call self.image_processor.postprocess(image, output_type="pt")
from run_safety_checker
since safety_checker
expect input to be in [0,1]
(2) the image we sent to image_processor.postprocessing
might have already been denormalized (in the case that we run safety_checker
); it's also possible that it has not (in the case that there is not a safety_checker
) - we need to address them differently
Did I understand something wrong?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's just that this:
"is there a better way to check if the image is already denormalized"
that needs to be addressed, we cannot check here depending on the values of the tensor whether it has already been denormalized.
Imagine for some reason stable diffusion generates a tensor that is always larger 0 - in this case we would never denormalize which is bad.
A better check is to add a flag argument to the post-process function that allows the safety checker to not denormalize it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But we also don't really need this here. I understand that we need both a post-processed image in PyTorch and in PIL format, but we should not put the post processed PyTorch image into the post-process function again.
Then we also don't have this double post process problem
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I agree with Patrick, I'd rather not chain calls to postprocess
. I see two potential solutions:
- Just convert the post-processed image to PIL (maybe we can add a helper to the postprocessor if we need to).
- Maybe just swap the calls? Perhaps I'm confused but wouldn't this work?
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
image = self.image_processor.postprocess(image, output_type="pt")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pcuenca yes we can swap the calls here. I agree it's better to swap and it will avoid calling postprocess
twice inside run_safety_checker
.
I still need to call postprocess
again with the output of run_safety_checker
, which is already denormalized https://github.com/huggingface/diffusers/blob/postprocessing-refactor/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L712
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten I think the flag will solve my problem - I will go with that for now. We can discuss later and happy to refactor again
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great I think we did lots of progress here. Left a bunch of statements. Sorry some final changes (maybe you can first just do them for stable diffusion text to image and stable diffusion image to image:
- 1.) Let's make sure that the postprocess only accepts PyTorch tensors as an input, nothing else. This makes the API very easy and testable
- 2.) There is one edge case with the postprocessing when the safety checker detects a nfsw image in this case we should not normalize it
- 3.) Let's make sure the safety checker is fully backwards compatible (in our pipeline now it will always just accept a "pt" tensor, but previously it accepted a numpy array so let's still allow for this case
- 4.) Differently to what I said before let's only run the safety checker when the output type is not "latent" (sorry I said this incorrectly before)
=> I left lots of comments and suggestions for the points 1.) - 4.)
Maybe we can first only apply them to stable diffusion text 2 image and image 2 image and then I can do a final review and then we apply to all other pipelines so that you can save some time? :-)
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py Co-authored-by: Patrick von Platen <[email protected]> update img2img
200d5e0
to
acf0d60
Compare
updated the stable diffusion text2img and img2img - ready for a (hopefully😅) final review |
if isinstance(has_nsfw_concept, list) | ||
else not has_nsfw_concept | ||
) | ||
image = self.image_processor.postprocess(image, output_type=output_type, do_normalize=do_normalize) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pcuenca could you also check here. Note we need this because of #2943 (review) IMO - the other idea is to detect whether an image is black in the postprocessor as discussed here: #2943 (comment) but IMO it's a bit too brittle
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Commented above. I don't like to introduce black image detection in the post-processor so I'd go with this or by having run_safety_checker
always return the same type of images (normalized).
@pcuenca could you maybe do a final review here (only need to look at changes of this file: https://github.com/huggingface/diffusers/pull/2943/files#diff-82113feddf255a70df849256026f9aed834c341cd73c9c89513bd5ab9be8f13d ) and the comments here: #2943 (review) and here: #2943 (comment) |
Great job think we're almost there @yiyixuxu ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the patience iterating here @yiyixuxu, this is certainly a trickier PR than it looks 😅.
I mainly focused on the Stable Diffusion pipeline for this review. I think we broke output_type="latent"
, and I suggested another alternative to deal with the black images by having run_safety_checker
always return normalized images no matter what.
In addition to that, I think there's an unrelated bug/inconsistency as we sometimes return a Boolean and sometimes a list to indicate whether nsfw concepts were found. According to the documentation, my understanding is we should always return lists, but maybe we need to deal with that in a different PR in case it breaks existing users' workflows. I'm pointing it out because I got slightly confused while reading this PR and had to double-check with main
.
raise ValueError( | ||
f"Input for postprocess is in incorrect format: {type(image)}. we only support pytorch tensor" | ||
) | ||
if output_type not in ["latent", "pt", "np", "pil"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: the raise ValueError
that happens later at the end of the function will no longer trigger.
diffusers/src/diffusers/image_processor.py
Lines 208 to 209 in 7c1bb9a
else: | |
raise ValueError(f"Unsupported output_type {output_type}.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh thanks I think we can just remove it ( can add it back if we ever remove the warning)
|
||
if output_type == "latent": | ||
return image | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like the black magic that checks whether the image is black. I think the chance for SD to produce a full black image is really remote, but it's strange that the post-processor has to be aware about safety checker images.
Another alternative would be for run_safety_checker
to normalize black images so it returns -1
s instead of 0
s. This is actually consistent with the fact that run_safety_checker
receives normalized images and also returns normalized images if they are ok, but it returns denormalized images (black represented as 0
) if they are NSFW.
We could do something like this after this line:
image = torch.stack([self.image_processor.normalize(image[i]) if has_nsfw_concept[i] else image[i] for i in range(image.shape[0])])
And then we could remove the do_normalize
argument here.
Would that be preferrable?
output_type: str = "pil", | ||
): | ||
if isinstance(image, torch.Tensor) and output_type == "pt": | ||
do_normalize: Optional[Union[List[bool], bool]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe rename to do_denormalize
(if we keep it)
|
||
# 9. Run safety checker | ||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) | ||
has_nsfw_concept = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is a bug from previous versions. According to the documentation:
diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
Lines 628 to 631 in 7c1bb9a
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. | |
When returning a tuple, the first element is a list with the generated images, and the second element is a | |
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" | |
(nsfw) content, according to the `safety_checker`. |
and
diffusers/src/diffusers/pipelines/stable_diffusion/__init__.py
Lines 30 to 32 in 7c1bb9a
nsfw_content_detected (`List[bool]`) | |
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" | |
(nsfw) content, or `None` if safety checking could not be performed. |
we are supposed to always return a list. Therefore we'd need to do something like this, same thing in run_safety_checker
and then simplify do_normalize
accordingly:
has_nsfw_concept = False | |
has_nsfw_concept = [False] * latents.shape[0] |
do_normalize = [not has_nsfw for has_nsfw in has_nsfw_concept] if isinstance(has_nsfw_concept, list) else not has_nsfw_concept | ||
image = self.image_processor.postprocess(image, output_type=output_type, do_normalize=do_normalize) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't work if output_type
is "latent"
. image
will not be defined in that case. I think we need to move these two lines to inside the if
, and define image = latent
in the else
block.
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) | ||
def run_safety_checker(self, image, device, dtype, output_type="pil"): | ||
if self.safety_checker is None or output_type == "latent": | ||
has_nsfw_concept = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should always be a list, see another comment on that. Not sure if we could break existing workflows though.
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype, output_type=output_type) | ||
|
||
image = self.image_processor.postprocess(image, output_type=output_type) | ||
image = self.image_processor.postprocess(image, output_type=output_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose we focused on just a couple of pipelines for review, just to note this would be missing the same do_normalize
logic we have in other pipelines (if we keep it).
if not isinstance(do_normalize, list): | ||
do_normalize = image.shape[0] * [do_normalize or self.config.do_normalize] | ||
|
||
image = torch.stack([self.denormalize(image[i]) if do_normalize[i] else image[i] for i in range(image.shape[0])]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I find these oneliners hard to read even though what they're doing is quite simple. I think golang got it right by only allowing a single looping construct. Could we turn this into a for loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll always agree that clarity is more important than the number of lines, but I think in this case a loop would involve creating a temporary list to hold the images and then doing the stack outside the loop, which for me could be harder to parse than the oneliner. With the oneliner I see that this is "denormalizing somehow" and don't bother to look at the details unless I need to. With the loop, I'm more or less forced to read all the lines and understand them. Also, I hate vars. Maybe Python should be more functional too ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, once you use the list comprehension syntax, you're going to allocate the temporary list regardless :)
>>> foo = [1,2,3]
>>> bar = [x for x in foo] <- not a generator
>>> bar.__class__
<class 'list'>
Plus, if the stack function were to take a generator, then it would have to instantiate it regardless as there'll be a fixed number of dimensions for the resulting tensor (maybe there could be some more efficient implementation by doing incremental instantiations but as really it's a list of pointers, I'm not sure.
IIRC, the go argument here for the single looping construct is that it's always best to just make memory allocation explicit and let the caller handle it?
Agreed that it makes you read the body of the loop to know what it's doing, but I think that's a good thing!
re should python be more functional, guido agrees but selectively so ;) https://www.artima.com/weblogs/viewpost.jsp?thread=98196
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
haha, my comment was not about efficiency or memory, I dislike vars because they create state and take space in your mind :) But it's always fun to have these conversations, keep it going :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
of course 😁
else: | ||
# 8. Post-processing | ||
image = self.decode_latents(latents) | ||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype, output_type=output_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should avoid passing output_type to run_safety_checker
. It looks like output_type is just being used to check for if the output type is latent and if so, not running the safety checker. I think that should be done outside of run_safety_checker
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@williamberman
yeah agreed
actually, I already updated that on img2img pipeline - maybe you could review the code change on that pipeline instead?
I think at some point I gave up on updating all 30ish pipelines and decided to only iterate on the text2img and img2img pipelines until the design is finalized. Sorry it's a little bit confusing 😂
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) | ||
def run_safety_checker(self, image, device, dtype, output_type="pil"): | ||
if self.safety_checker is None or output_type == "latent": | ||
has_nsfw_concept = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously, we would set the has_nsfw_concept
to None, when we couldn't run the safety checker. False would only be returned if the safety_checker said definitively no. Did we decide to change that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I think I asked @patrickvonplaten that - we can double-check with him again I guess
Is there a use case None
and False
that would make a difference, in terms of how we use the pipeline? I understand they mean different things
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy to leave at None
here as it might be better for backwards comp
if not isinstance(do_normalize, list): | ||
do_normalize = image.shape[0] * [do_normalize or self.config.do_normalize] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume the goal is to default to self.config.do_normalize only when the argument is None
, When the argument is False, this will default to self.config.do_normalize
.
When we use optional boolean arguments in python, a good rule of thumb is to as soon as possible in the function set them to a default value by explicitly using is None
and if there is other behavior based off of them being not set, set alternative flags for those.
i.e.
def foo(bar: Optional[bool]=None):
if bar is None:
bar = False # the default value
bar_was_set = False
else:
bar_was_set = True
# from here on, bar can always be treated as just a boolean and if anything
# is needed based on if not being set, `bar_was_set` can be used
|
||
# 9. Run safety checker | ||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) | ||
image = self.image_processor.postprocess(image, output_type=output_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what types are allowed to be passed from call to run_safety_checker to the safety checker itself. From reading the code, my understanding is that the safety checker can only take numpy arrays and torch tensors for its image
argument but here it looks like we might be passing PIL images as well? The safety checker's code for censoring the images would throw an error in that case or overwrite the PIL images to be numpy arrays.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would perhaps make this PR a bit easier if we were to first refactor the safety checker and the run_safety_checker function so it only took one copy of the images, and then if we want to apply the censoring, we can call into a separate helper function. I think the call method, the run_safety_checker method, and the forward function of the safety checker are a bit too tied together right now
@pcuenca while I like the idea: I don't think we can do it as it would break backcomp quite a bit. E.g. many people in my opinion are running Before this PR this gave all black [0, 0] PIL images but now this would give: (255 * np.array([-1, -1])).round().astype("uint8") => array([1, 1], dtype=uint8) [1, 1] PIL image which is different. I think we have to pass a denormalize list flag to |
Think this PR was super helpful to discuss, but to get a quick PR merged could we maybe open a new PR with just the changes for img2img cc @yiyixuxu (we can remove the copy-from for this PR) Sorry for the difficult PR and process here |
Of course backwards compatibility is really important, and I agree it's safest to keep returning 0s for the black images even if it's slightly inconsistent. But just wondering if that use case is already broken with this PR anyway? (I might be mistaken, will look at it more carefully). |
It shouldn't have. Both |
@williamberman I merged in one - working on a PR that will refactor all 28 others. WIll close this one once that PR is merged |
This PR refactors the post-processing for all the relevant pipelines.
Can we wrap the post-processing code into a method, so that we can just do
#copy from
? it will make it easier for people to contribute new pipelines and for us to maintainwhat do you guys think? @patrickvonplaten @williamberman @sayakpaul @pcuenca