Skip to content

Commit 863bb75

Browse files
yiyixuxuyiyixuxusayakpaul
authored andcommitted
Postprocessing refactor img2img (huggingface#3268)
* refactor img2img VaeImageProcessor.postprocess * remove copy from for init, run_safety_checker, decode_latents Co-authored-by: Sayak Paul <[email protected]> --------- Co-authored-by: yiyixuxu <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 6a84a74 commit 863bb75

File tree

7 files changed

+198
-125
lines changed

7 files changed

+198
-125
lines changed

src/diffusers/image_processor.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313
# limitations under the License.
1414

1515
import warnings
16-
from typing import Union
16+
from typing import List, Optional, Union
1717

1818
import numpy as np
1919
import PIL
2020
import torch
2121
from PIL import Image
2222

2323
from .configuration_utils import ConfigMixin, register_to_config
24-
from .utils import CONFIG_NAME, PIL_INTERPOLATION
24+
from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
2525

2626

2727
class VaeImageProcessor(ConfigMixin):
@@ -82,7 +82,7 @@ def numpy_to_pt(images):
8282
@staticmethod
8383
def pt_to_numpy(images):
8484
"""
85-
Convert a numpy image to a pytorch tensor
85+
Convert a pytorch tensor to a numpy image
8686
"""
8787
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
8888
return images
@@ -94,6 +94,13 @@ def normalize(images):
9494
"""
9595
return 2.0 * images - 1.0
9696

97+
@staticmethod
98+
def denormalize(images):
99+
"""
100+
Denormalize an image array to [0,1]
101+
"""
102+
return (images / 2 + 0.5).clamp(0, 1)
103+
97104
def resize(self, images: PIL.Image.Image) -> PIL.Image.Image:
98105
"""
99106
Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor`
@@ -165,17 +172,39 @@ def preprocess(
165172

166173
def postprocess(
167174
self,
168-
image,
175+
image: torch.FloatTensor,
169176
output_type: str = "pil",
177+
do_denormalize: Optional[List[bool]] = None,
170178
):
171-
if isinstance(image, torch.Tensor) and output_type == "pt":
179+
if not isinstance(image, torch.Tensor):
180+
raise ValueError(
181+
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
182+
)
183+
if output_type not in ["latent", "pt", "np", "pil"]:
184+
deprecation_message = (
185+
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
186+
"`pil`, `np`, `pt`, `latent`"
187+
)
188+
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
189+
output_type = "np"
190+
191+
if output_type == "latent":
192+
return image
193+
194+
if do_denormalize is None:
195+
do_denormalize = [self.config.do_normalize] * image.shape[0]
196+
197+
image = torch.stack(
198+
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
199+
)
200+
201+
if output_type == "pt":
172202
return image
173203

174204
image = self.pt_to_numpy(image)
175205

176206
if output_type == "np":
177207
return image
178-
elif output_type == "pil":
208+
209+
if output_type == "pil":
179210
return self.numpy_to_pil(image)
180-
else:
181-
raise ValueError(f"Unsupported output_type {output_type}.")

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import inspect
16+
import warnings
1617
from typing import Any, Callable, Dict, List, Optional, Union
1718

1819
import numpy as np
@@ -202,6 +203,7 @@ def __init__(
202203
new_config = dict(unet.config)
203204
new_config["sample_size"] = 64
204205
unet._internal_dict = FrozenDict(new_config)
206+
205207
self.register_modules(
206208
vae=vae,
207209
text_encoder=text_encoder,
@@ -212,11 +214,8 @@ def __init__(
212214
feature_extractor=feature_extractor,
213215
)
214216
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
215-
216217
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
217-
self.register_to_config(
218-
requires_safety_checker=requires_safety_checker,
219-
)
218+
self.register_to_config(requires_safety_checker=requires_safety_checker)
220219

221220
def enable_sequential_cpu_offload(self, gpu_id=0):
222221
r"""
@@ -436,17 +435,32 @@ def _encode_prompt(
436435
return prompt_embeds
437436

438437
def run_safety_checker(self, image, device, dtype):
439-
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
440-
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
441-
image, has_nsfw_concept = self.safety_checker(
442-
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
443-
)
438+
if self.safety_checker is None:
439+
has_nsfw_concept = None
440+
else:
441+
if torch.is_tensor(image):
442+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
443+
else:
444+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
445+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
446+
image, has_nsfw_concept = self.safety_checker(
447+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
448+
)
444449
return image, has_nsfw_concept
445450

446451
def decode_latents(self, latents):
452+
warnings.warn(
453+
(
454+
"The decode_latents method is deprecated and will be removed in a future version. Please"
455+
" use VaeImageProcessor instead"
456+
),
457+
FutureWarning,
458+
)
447459
latents = 1 / self.vae.config.scaling_factor * latents
448460
image = self.vae.decode(latents).sample
449461
image = (image / 2 + 0.5).clamp(0, 1)
462+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
463+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
450464
return image
451465

452466
def prepare_extra_step_kwargs(self, generator, eta):
@@ -730,27 +744,19 @@ def __call__(
730744
if callback is not None and i % callback_steps == 0:
731745
callback(i, t, latents)
732746

733-
if output_type not in ["latent", "pt", "np", "pil"]:
734-
deprecation_message = (
735-
f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: "
736-
"`pil`, `np`, `pt`, `latent`"
737-
)
738-
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
739-
output_type = "np"
740-
741-
if output_type == "latent":
747+
if not output_type == "latent":
748+
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample
749+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
750+
else:
742751
image = latents
743752
has_nsfw_concept = None
744753

754+
if has_nsfw_concept is None:
755+
do_denormalize = [True] * image.shape[0]
745756
else:
746-
image = self.decode_latents(latents)
747-
748-
if self.safety_checker is not None:
749-
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
750-
else:
751-
has_nsfw_concept = False
757+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
752758

753-
image = self.image_processor.postprocess(image, output_type=output_type)
759+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
754760

755761
# Offload last model to CPU
756762
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import inspect
16+
import warnings
1617
from typing import Any, Callable, Dict, List, Optional, Union
1718

1819
import numpy as np
@@ -205,6 +206,7 @@ def __init__(
205206
new_config = dict(unet.config)
206207
new_config["sample_size"] = 64
207208
unet._internal_dict = FrozenDict(new_config)
209+
208210
self.register_modules(
209211
vae=vae,
210212
text_encoder=text_encoder,
@@ -215,11 +217,8 @@ def __init__(
215217
feature_extractor=feature_extractor,
216218
)
217219
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
218-
219220
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
220-
self.register_to_config(
221-
requires_safety_checker=requires_safety_checker,
222-
)
221+
self.register_to_config(requires_safety_checker=requires_safety_checker)
223222

224223
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
225224
def enable_sequential_cpu_offload(self, gpu_id=0):
@@ -443,17 +442,30 @@ def _encode_prompt(
443442
return prompt_embeds
444443

445444
def run_safety_checker(self, image, device, dtype):
446-
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
447-
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
448-
image, has_nsfw_concept = self.safety_checker(
449-
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
450-
)
445+
if self.safety_checker is None:
446+
has_nsfw_concept = None
447+
else:
448+
if torch.is_tensor(image):
449+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
450+
else:
451+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
452+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
453+
image, has_nsfw_concept = self.safety_checker(
454+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
455+
)
451456
return image, has_nsfw_concept
452457

453458
def decode_latents(self, latents):
459+
warnings.warn(
460+
"The decode_latents method is deprecated and will be removed in a future version. Please"
461+
" use VaeImageProcessor instead",
462+
FutureWarning,
463+
)
454464
latents = 1 / self.vae.config.scaling_factor * latents
455465
image = self.vae.decode(latents).sample
456466
image = (image / 2 + 0.5).clamp(0, 1)
467+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
468+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
457469
return image
458470

459471
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
@@ -738,27 +750,19 @@ def __call__(
738750
if callback is not None and i % callback_steps == 0:
739751
callback(i, t, latents)
740752

741-
if output_type not in ["latent", "pt", "np", "pil"]:
742-
deprecation_message = (
743-
f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: "
744-
"`pil`, `np`, `pt`, `latent`"
745-
)
746-
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
747-
output_type = "np"
748-
749-
if output_type == "latent":
753+
if not output_type == "latent":
754+
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample
755+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
756+
else:
750757
image = latents
751758
has_nsfw_concept = None
752759

760+
if has_nsfw_concept is None:
761+
do_denormalize = [True] * image.shape[0]
753762
else:
754-
image = self.decode_latents(latents)
755-
756-
if self.safety_checker is not None:
757-
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
758-
else:
759-
has_nsfw_concept = False
763+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
760764

761-
image = self.image_processor.postprocess(image, output_type=output_type)
765+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
762766

763767
# Offload last model to CPU
764768
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:

tests/others/test_image_processor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def to_np(self, image):
4242
return image
4343

4444
def test_vae_image_processor_pt(self):
45-
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
45+
image_processor = VaeImageProcessor(do_resize=False, do_normalize=True)
4646

4747
input_pt = self.dummy_sample
4848
input_np = self.to_np(input_pt)
@@ -59,7 +59,7 @@ def test_vae_image_processor_pt(self):
5959
), f"decoded output does not match input for output_type {output_type}"
6060

6161
def test_vae_image_processor_np(self):
62-
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
62+
image_processor = VaeImageProcessor(do_resize=False, do_normalize=True)
6363
input_np = self.dummy_sample.cpu().numpy().transpose(0, 2, 3, 1)
6464

6565
for output_type in ["pt", "np", "pil"]:
@@ -72,7 +72,7 @@ def test_vae_image_processor_np(self):
7272
), f"decoded output does not match input for output_type {output_type}"
7373

7474
def test_vae_image_processor_pil(self):
75-
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
75+
image_processor = VaeImageProcessor(do_resize=False, do_normalize=True)
7676

7777
input_np = self.dummy_sample.cpu().numpy().transpose(0, 2, 3, 1)
7878
input_pil = image_processor.numpy_to_pil(input_np)

tests/pipelines/pipeline_params.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222

2323
TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
2424

25+
TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])
26+
27+
IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])
28+
2529
IMAGE_VARIATION_PARAMS = frozenset(
2630
[
2731
"image",

0 commit comments

Comments
 (0)