-
Notifications
You must be signed in to change notification settings - Fork 6k
Postprocessing refactor img2img #3268
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0240918
161db12
43068ef
4cd9ecf
91092f0
350d4ae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
# limitations under the License. | ||
|
||
import inspect | ||
import warnings | ||
from typing import Any, Callable, Dict, List, Optional, Union | ||
|
||
import numpy as np | ||
|
@@ -202,6 +203,7 @@ def __init__( | |
new_config = dict(unet.config) | ||
new_config["sample_size"] = 64 | ||
unet._internal_dict = FrozenDict(new_config) | ||
|
||
self.register_modules( | ||
vae=vae, | ||
text_encoder=text_encoder, | ||
|
@@ -212,11 +214,8 @@ def __init__( | |
feature_extractor=feature_extractor, | ||
) | ||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) | ||
|
||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) | ||
self.register_to_config( | ||
requires_safety_checker=requires_safety_checker, | ||
) | ||
self.register_to_config(requires_safety_checker=requires_safety_checker) | ||
|
||
def enable_sequential_cpu_offload(self, gpu_id=0): | ||
r""" | ||
|
@@ -436,17 +435,32 @@ def _encode_prompt( | |
return prompt_embeds | ||
|
||
def run_safety_checker(self, image, device, dtype): | ||
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") | ||
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) | ||
image, has_nsfw_concept = self.safety_checker( | ||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) | ||
) | ||
if self.safety_checker is None: | ||
has_nsfw_concept = None | ||
else: | ||
if torch.is_tensor(image): | ||
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") | ||
else: | ||
feature_extractor_input = self.image_processor.numpy_to_pil(image) | ||
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) | ||
image, has_nsfw_concept = self.safety_checker( | ||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) | ||
) | ||
return image, has_nsfw_concept | ||
|
||
def decode_latents(self, latents): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perfect! It's essentially a revert of what we had before |
||
warnings.warn( | ||
( | ||
"The decode_latents method is deprecated and will be removed in a future version. Please" | ||
" use VaeImageProcessor instead" | ||
), | ||
FutureWarning, | ||
) | ||
latents = 1 / self.vae.config.scaling_factor * latents | ||
image = self.vae.decode(latents).sample | ||
image = (image / 2 + 0.5).clamp(0, 1) | ||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 | ||
image = image.cpu().permute(0, 2, 3, 1).float().numpy() | ||
return image | ||
|
||
def prepare_extra_step_kwargs(self, generator, eta): | ||
|
@@ -730,27 +744,19 @@ def __call__( | |
if callback is not None and i % callback_steps == 0: | ||
callback(i, t, latents) | ||
|
||
if output_type not in ["latent", "pt", "np", "pil"]: | ||
deprecation_message = ( | ||
f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: " | ||
"`pil`, `np`, `pt`, `latent`" | ||
) | ||
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) | ||
output_type = "np" | ||
|
||
if output_type == "latent": | ||
if not output_type == "latent": | ||
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample | ||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) | ||
else: | ||
image = latents | ||
has_nsfw_concept = None | ||
|
||
if has_nsfw_concept is None: | ||
do_denormalize = [True] * image.shape[0] | ||
else: | ||
image = self.decode_latents(latents) | ||
|
||
if self.safety_checker is not None: | ||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) | ||
else: | ||
has_nsfw_concept = False | ||
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] | ||
|
||
image = self.image_processor.postprocess(image, output_type=output_type) | ||
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) | ||
|
||
# Offload last model to CPU | ||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: | ||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -13,6 +13,7 @@ | |||||||
# limitations under the License. | ||||||||
|
||||||||
import inspect | ||||||||
import warnings | ||||||||
from typing import Any, Callable, Dict, List, Optional, Union | ||||||||
|
||||||||
import numpy as np | ||||||||
|
@@ -205,6 +206,7 @@ def __init__( | |||||||
new_config = dict(unet.config) | ||||||||
new_config["sample_size"] = 64 | ||||||||
unet._internal_dict = FrozenDict(new_config) | ||||||||
|
||||||||
self.register_modules( | ||||||||
vae=vae, | ||||||||
text_encoder=text_encoder, | ||||||||
|
@@ -215,11 +217,8 @@ def __init__( | |||||||
feature_extractor=feature_extractor, | ||||||||
) | ||||||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) | ||||||||
|
||||||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) | ||||||||
self.register_to_config( | ||||||||
requires_safety_checker=requires_safety_checker, | ||||||||
) | ||||||||
self.register_to_config(requires_safety_checker=requires_safety_checker) | ||||||||
|
||||||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload | ||||||||
def enable_sequential_cpu_offload(self, gpu_id=0): | ||||||||
|
@@ -443,17 +442,30 @@ def _encode_prompt( | |||||||
return prompt_embeds | ||||||||
|
||||||||
def run_safety_checker(self, image, device, dtype): | ||||||||
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") | ||||||||
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) | ||||||||
image, has_nsfw_concept = self.safety_checker( | ||||||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) | ||||||||
) | ||||||||
if self.safety_checker is None: | ||||||||
has_nsfw_concept = None | ||||||||
else: | ||||||||
if torch.is_tensor(image): | ||||||||
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") | ||||||||
else: | ||||||||
feature_extractor_input = self.image_processor.numpy_to_pil(image) | ||||||||
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) | ||||||||
image, has_nsfw_concept = self.safety_checker( | ||||||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) | ||||||||
) | ||||||||
return image, has_nsfw_concept | ||||||||
|
||||||||
def decode_latents(self, latents): | ||||||||
warnings.warn( | ||||||||
"The decode_latents method is deprecated and will be removed in a future version. Please" | ||||||||
" use VaeImageProcessor instead", | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we also specify the method of |
||||||||
FutureWarning, | ||||||||
) | ||||||||
latents = 1 / self.vae.config.scaling_factor * latents | ||||||||
image = self.vae.decode(latents).sample | ||||||||
image = (image / 2 + 0.5).clamp(0, 1) | ||||||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 | ||||||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy() | ||||||||
return image | ||||||||
|
||||||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs | ||||||||
|
@@ -738,27 +750,19 @@ def __call__( | |||||||
if callback is not None and i % callback_steps == 0: | ||||||||
callback(i, t, latents) | ||||||||
|
||||||||
if output_type not in ["latent", "pt", "np", "pil"]: | ||||||||
deprecation_message = ( | ||||||||
f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: " | ||||||||
"`pil`, `np`, `pt`, `latent`" | ||||||||
) | ||||||||
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) | ||||||||
output_type = "np" | ||||||||
|
||||||||
if output_type == "latent": | ||||||||
if not output_type == "latent": | ||||||||
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample | ||||||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) | ||||||||
else: | ||||||||
image = latents | ||||||||
has_nsfw_concept = None | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think this is consistent with the documentation as @pcuenca pointed out. And it is consistent with how it's used in current pipelines so I think it's ok diffusers/src/diffusers/pipelines/stable_diffusion/__init__.py Lines 30 to 32 in 7c1bb9a
|
||||||||
|
||||||||
if has_nsfw_concept is None: | ||||||||
do_denormalize = [True] * image.shape[0] | ||||||||
else: | ||||||||
image = self.decode_latents(latents) | ||||||||
|
||||||||
if self.safety_checker is not None: | ||||||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) | ||||||||
else: | ||||||||
has_nsfw_concept = False | ||||||||
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] | ||||||||
|
||||||||
image = self.image_processor.postprocess(image, output_type=output_type) | ||||||||
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) | ||||||||
|
||||||||
# Offload last model to CPU | ||||||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is neat!
But would there be a need to denormalize images individually? I would assume if
do_denormalize
is True then ALL the images will be denormalized. Why individually?Nit: Should it be "unnormalize"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
normally we should denormalize all images if we have normalized them in preprocessing (i.e.
if self.config.do_normalize is True
)however, there is this edge case when the image is black due to
safety_checker
- in that case, we will pass ado_denormalize
argument to allow skip the denormalize process only for these nsfw images.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Superb! Thanks for explaining!