-
Notifications
You must be signed in to change notification settings - Fork 6k
[WIP] VaeImageProcessorImage Postprocessing refactor #2943
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 16 commits
3768ed9
71eda72
83da056
d13bc7f
51cabe2
d91bcc9
6cbd1ac
c6d2405
fe8e13e
4b09a20
ce19bc9
8065199
2c76ca3
1a2c7a9
389fdfe
db33f87
a491a38
5cde78c
c5e69b9
ccf5b37
928c35b
809f1fe
9ebd8f9
3dcb7b1
fabf88a
04d27fb
68bb18b
106c43a
acf0d60
7c1bb9a
a09ecca
f03ff17
8c4a31b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -21,7 +21,7 @@ | |||||||||
from PIL import Image | ||||||||||
|
||||||||||
from .configuration_utils import ConfigMixin, register_to_config | ||||||||||
from .utils import CONFIG_NAME, PIL_INTERPOLATION | ||||||||||
from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate | ||||||||||
|
||||||||||
|
||||||||||
class VaeImageProcessor(ConfigMixin): | ||||||||||
|
@@ -82,7 +82,7 @@ def numpy_to_pt(images): | |||||||||
@staticmethod | ||||||||||
def pt_to_numpy(images): | ||||||||||
""" | ||||||||||
Convert a numpy image to a pytorch tensor | ||||||||||
Convert a pytorch tensor to a numpy image | ||||||||||
""" | ||||||||||
images = images.cpu().permute(0, 2, 3, 1).float().numpy() | ||||||||||
return images | ||||||||||
|
@@ -164,6 +164,20 @@ def postprocess( | |||||||||
image, | ||||||||||
output_type: str = "pil", | ||||||||||
): | ||||||||||
yiyixuxu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
if output_type not in ["latent", "pt", "np", "pil"]: | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: the diffusers/src/diffusers/image_processor.py Lines 208 to 209 in 7c1bb9a
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh thanks I think we can just remove it ( can add it back if we ever remove the warning) |
||||||||||
deprecation_message = ( | ||||||||||
f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: " | ||||||||||
yiyixuxu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
"`pil`, `np`, `pt`, `latent`" | ||||||||||
) | ||||||||||
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) | ||||||||||
output_type = "np" | ||||||||||
|
||||||||||
if output_type == "latent": | ||||||||||
return image | ||||||||||
|
||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we will need this to support the blacked-out image There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we just check from the I will do what you proposed here but I just want to understand and think it will help me to make better design decisions in the future. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could do this as well, I also thought about this. I'm a bit worried though that it's a bit black magic. In the 1 in a million cases where SD produces an image that has exactly only 0s, we should actually normalize it to an image that has only 0.5. It's highly unlikely, but it could happen. So in this sense there is a difference between a "blacked-out" image due to NSFW and a "blacked-out" image due to SD In a sense the postprocessor should not have to know anything about the safety checker - they should be as orthogonal / disentangled as possible. But I definitely understand your point here, it's quite some extra code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't like the black magic that checks whether the image is black. I think the chance for SD to produce a full black image is really remote, but it's strange that the post-processor has to be aware about safety checker images. Another alternative would be for We could do something like this after this line: image = torch.stack([self.image_processor.normalize(image[i]) if has_nsfw_concept[i] else image[i] for i in range(image.shape[0])]) And then we could remove the Would that be preferrable? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ohh I think it makes sense to return maybe we can do this only when the image is a PyTorch tensor? I think that would be a little bit confusing no? if torch.is_tensor(image):
image = torch.stack([self.image_processor.normalize(image[i]) if has_nsfw_concept[i] else image[i] for i in range(image.shape[0])]) |
||||||||||
if image.min() < 0: | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
image = (image / 2 + 0.5).clamp(0, 1) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm the condition looks a bit brittle in my opinion. Would it make sense to use an arg (and maybe default it to True)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh so this is just to make sure it won't denormalize anything twice (like the question you asked below) - it will always denormalize unless it's already within There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would have been pretty dangerous as it would have led to some silent hard to debug bugs (imagine the vae decodes to an output that has a min value of >=0 -> then we would have not applied the correct post-processing) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok nice this logic works |
||||||||||
|
||||||||||
if isinstance(image, torch.Tensor) and output_type == "pt": | ||||||||||
yiyixuxu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
return image | ||||||||||
|
||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
# limitations under the License. | ||
|
||
import inspect | ||
import warnings | ||
from typing import Any, Callable, Dict, List, Optional, Union | ||
|
||
import torch | ||
|
@@ -22,6 +23,7 @@ | |
from diffusers.utils import is_accelerate_available, is_accelerate_version | ||
|
||
from ...configuration_utils import FrozenDict | ||
from ...image_processor import VaeImageProcessor | ||
from ...loaders import TextualInversionLoaderMixin | ||
from ...models import AutoencoderKL, UNet2DConditionModel | ||
from ...schedulers import KarrasDiffusionSchedulers | ||
|
@@ -166,6 +168,7 @@ def __init__( | |
feature_extractor=feature_extractor, | ||
) | ||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) | ||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) | ||
self.register_to_config(requires_safety_checker=requires_safety_checker) | ||
|
||
def enable_vae_slicing(self): | ||
|
@@ -417,17 +420,26 @@ def _encode_prompt( | |
|
||
return prompt_embeds | ||
|
||
def run_safety_checker(self, image, device, dtype): | ||
if self.safety_checker is not None: | ||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) | ||
def run_safety_checker(self, image, device, dtype, output_type="pil"): | ||
if self.safety_checker is None or output_type == "latent": | ||
has_nsfw_concept = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this should always be a list, see another comment on that. Not sure if we could break existing workflows though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Previously, we would set the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, I think I asked @patrickvonplaten that - we can double-check with him again I guess Is there a use case There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Happy to leave at |
||
else: | ||
image = self.image_processor.postprocess(image, output_type="pt") | ||
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't this denormalize twice? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry I don't fully understand this here - why do we run There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think here is a misunderstanding. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @patrickvonplaten Should we refactor it differently and not move the denormalization to postprocessing then? the original workflow is now we have decided to move the denormalization part to postprocessing This means: Did I understand something wrong? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's just that this: Imagine for some reason stable diffusion generates a tensor that is always larger 0 - in this case we would never denormalize which is bad. A better check is to add a flag argument to the post-process function that allows the safety checker to not denormalize it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But we also don't really need this here. I understand that we need both a post-processed image in PyTorch and in PIL format, but we should not put the post processed PyTorch image into the post-process function again. Then we also don't have this double post process problem There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I agree with Patrick, I'd rather not chain calls to
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
image = self.image_processor.postprocess(image, output_type="pt") There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @pcuenca yes we can swap the calls here. I agree it's better to swap and it will avoid calling I still need to call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @patrickvonplaten I think the flag will solve my problem - I will go with that for now. We can discuss later and happy to refactor again |
||
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) | ||
image, has_nsfw_concept = self.safety_checker( | ||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) | ||
) | ||
else: | ||
has_nsfw_concept = None | ||
return image, has_nsfw_concept | ||
|
||
def decode_latents(self, latents): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should deprecate this function and instead advertise people to do the following: image = self.vae.decode(latents / self.vae.config.scaling_factor).sample
image = self.image_processor.postprocess(image, output_type="pt") instead. This means we also should put: image = (image / 2 + 0.5).clamp(0, 1) in the post-processor, which is good as it's clearly post-processing. |
||
warnings.warn( | ||
( | ||
"The decode_latents method is deprecated and will be removed in a future version. Please" | ||
" use VaeImageProcessor instead" | ||
), | ||
FutureWarning, | ||
) | ||
latents = 1 / self.vae.config.scaling_factor * latents | ||
image = self.vae.decode(latents).sample | ||
image = (image / 2 + 0.5).clamp(0, 1) | ||
|
@@ -691,24 +703,12 @@ def __call__( | |
if callback is not None and i % callback_steps == 0: | ||
callback(i, t, latents) | ||
|
||
if output_type == "latent": | ||
image = latents | ||
has_nsfw_concept = None | ||
elif output_type == "pil": | ||
# 8. Post-processing | ||
image = self.decode_latents(latents) | ||
|
||
# 9. Run safety checker | ||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) | ||
if not output_type == "latent": | ||
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample | ||
|
||
# 10. Convert to PIL | ||
image = self.numpy_to_pil(image) | ||
else: | ||
# 8. Post-processing | ||
image = self.decode_latents(latents) | ||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype, output_type=output_type) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should avoid passing output_type to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @williamberman I think at some point I gave up on updating all 30ish pipelines and decided to only iterate on the text2img and img2img pipelines until the design is finalized. Sorry it's a little bit confusing 😂 |
||
|
||
# 9. Run safety checker | ||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) | ||
image = self.image_processor.postprocess(image, output_type=output_type) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure what types are allowed to be passed from call to run_safety_checker to the safety checker itself. From reading the code, my understanding is that the safety checker can only take numpy arrays and torch tensors for its There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it would perhaps make this PR a bit easier if we were to first refactor the safety checker and the run_safety_checker function so it only took one copy of the images, and then if we want to apply the censoring, we can call into a separate helper function. I think the call method, the run_safety_checker method, and the forward function of the safety checker are a bit too tied together right now |
||
|
||
# Offload last model to CPU | ||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
# limitations under the License. | ||
|
||
import inspect | ||
import warnings | ||
from typing import Any, Callable, Dict, List, Optional, Union | ||
|
||
import numpy as np | ||
|
@@ -194,6 +195,7 @@ def __init__( | |
new_config = dict(unet.config) | ||
new_config["sample_size"] = 64 | ||
unet._internal_dict = FrozenDict(new_config) | ||
|
||
self.register_modules( | ||
vae=vae, | ||
text_encoder=text_encoder, | ||
|
@@ -204,11 +206,8 @@ def __init__( | |
feature_extractor=feature_extractor, | ||
) | ||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) | ||
|
||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) | ||
self.register_to_config( | ||
requires_safety_checker=requires_safety_checker, | ||
) | ||
self.register_to_config(requires_safety_checker=requires_safety_checker) | ||
|
||
def enable_sequential_cpu_offload(self, gpu_id=0): | ||
r""" | ||
|
@@ -427,18 +426,31 @@ def _encode_prompt( | |
|
||
return prompt_embeds | ||
|
||
def run_safety_checker(self, image, device, dtype): | ||
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") | ||
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) | ||
image, has_nsfw_concept = self.safety_checker( | ||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) | ||
) | ||
def run_safety_checker(self, image, device, dtype, output_type="pil"): | ||
if self.safety_checker is None or output_type == "latent": | ||
has_nsfw_concept = False | ||
else: | ||
image = self.image_processor.postprocess(image, output_type="pt") | ||
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") | ||
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) | ||
image, has_nsfw_concept = self.safety_checker( | ||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) | ||
) | ||
return image, has_nsfw_concept | ||
|
||
def decode_latents(self, latents): | ||
warnings.warn( | ||
( | ||
"The decode_latents method is deprecated and will be removed in a future version. Please" | ||
" use VaeImageProcessor instead" | ||
), | ||
FutureWarning, | ||
) | ||
latents = 1 / self.vae.config.scaling_factor * latents | ||
image = self.vae.decode(latents).sample | ||
image = (image / 2 + 0.5).clamp(0, 1) | ||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 | ||
image = image.cpu().permute(0, 2, 3, 1).float().numpy() | ||
return image | ||
|
||
def prepare_extra_step_kwargs(self, generator, eta): | ||
|
@@ -722,27 +734,12 @@ def __call__( | |
if callback is not None and i % callback_steps == 0: | ||
callback(i, t, latents) | ||
|
||
if output_type not in ["latent", "pt", "np", "pil"]: | ||
deprecation_message = ( | ||
f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: " | ||
"`pil`, `np`, `pt`, `latent`" | ||
) | ||
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) | ||
output_type = "np" | ||
|
||
if output_type == "latent": | ||
image = latents | ||
has_nsfw_concept = None | ||
if not output_type == "latent": | ||
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample | ||
|
||
else: | ||
image = self.decode_latents(latents) | ||
|
||
if self.safety_checker is not None: | ||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) | ||
else: | ||
has_nsfw_concept = False | ||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype, output_type=output_type) | ||
|
||
image = self.image_processor.postprocess(image, output_type=output_type) | ||
image = self.image_processor.postprocess(image, output_type=output_type) | ||
Comment on lines
+747
to
+749
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suppose we focused on just a couple of pipelines for review, just to note this would be missing the same |
||
|
||
# Offload last model to CPU | ||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: | ||
|
Uh oh!
There was an error while loading. Please reload this page.