Skip to content

Commit acf0d60

Browse files
yiyixuxuyiyixuxu
authored and
yiyixuxu
committed
Update src/diffusers/image_processor.py
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py Co-authored-by: Patrick von Platen <[email protected]> update img2img
1 parent 106c43a commit acf0d60

File tree

3 files changed

+47
-19
lines changed

3 files changed

+47
-19
lines changed

src/diffusers/image_processor.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import warnings
16-
from typing import Union
16+
from typing import Union, Optional, List
1717

1818
import numpy as np
1919
import PIL
@@ -93,6 +93,13 @@ def normalize(images):
9393
Normalize an image array to [-1,1]
9494
"""
9595
return 2.0 * images - 1.0
96+
97+
@staticmethod
98+
def denormalize(images):
99+
"""
100+
Denormalize an image array to [0,1]
101+
"""
102+
return (images / 2 + 0.5).clamp(0, 1)
96103

97104
def resize(self, images: PIL.Image.Image) -> PIL.Image.Image:
98105
"""
@@ -165,9 +172,14 @@ def preprocess(
165172

166173
def postprocess(
167174
self,
168-
image,
175+
image: torch.FloatTensor,
169176
output_type: str = "pil",
170-
):
177+
do_normalize: Optional[Union[List[bool], bool]] = None,
178+
):
179+
if not isinstance(image, torch.Tensor):
180+
raise ValueError(
181+
f"Input for postprocess is in incorrect format: {type(image)}. we only support pytorch tensor"
182+
)
171183
if output_type not in ["latent", "pt", "np", "pil"]:
172184
deprecation_message = (
173185
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
@@ -179,10 +191,12 @@ def postprocess(
179191
if output_type == "latent":
180192
return image
181193

182-
if self.config.do_normalize:
183-
image = (image / 2 + 0.5).clamp(0, 1)
194+
if not isinstance(do_normalize, list):
195+
do_normalize = image.shape[0] * [do_normalize or self.config.do_normalize]
196+
197+
image = torch.stack([self.denormalize(image[i]) if do_normalize[i] else image[i] for i in range(image.shape[0])])
184198

185-
if isinstance(image, torch.Tensor) and output_type == "pt":
199+
if output_type == "pt":
186200
return image
187201

188202
image = self.pt_to_numpy(image)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -423,11 +423,14 @@ def _encode_prompt(
423423

424424
return prompt_embeds
425425

426-
def run_safety_checker(self, image, device, dtype, output_type="pil"):
427-
if self.safety_checker is None or output_type == "latent":
426+
def run_safety_checker(self, image, device, dtype):
427+
if self.safety_checker is None:
428428
has_nsfw_concept = False
429429
else:
430-
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
430+
if torch.is_tensor(image):
431+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
432+
else:
433+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
431434
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
432435
image, has_nsfw_concept = self.safety_checker(
433436
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
@@ -705,10 +708,12 @@ def __call__(
705708

706709
if not output_type == "latent":
707710
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample
708-
709-
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype, output_type=output_type)
710-
711-
image = self.image_processor.postprocess(image, output_type=output_type)
711+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
712+
else:
713+
has_nsfw_concept = False
714+
715+
do_normalize = [not has_nsfw for has_nsfw in has_nsfw_concept] if isinstance(has_nsfw_concept, list) else not has_nsfw_concept
716+
image = self.image_processor.postprocess(image, output_type=output_type, do_normalize=do_normalize)
712717

713718
# Offload last model to CPU
714719
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -436,10 +436,13 @@ def _encode_prompt(
436436

437437
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
438438
def run_safety_checker(self, image, device, dtype, output_type="pil"):
439-
if self.safety_checker is None or output_type == "latent":
439+
if self.safety_checker is None:
440440
has_nsfw_concept = False
441441
else:
442-
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
442+
if torch.is_tensor(image):
443+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
444+
else:
445+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
443446
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
444447
image, has_nsfw_concept = self.safety_checker(
445448
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
@@ -744,10 +747,16 @@ def __call__(
744747

745748
if not output_type == "latent":
746749
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample
747-
748-
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype, output_type=output_type)
749-
750-
image = self.image_processor.postprocess(image, output_type=output_type)
750+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
751+
else:
752+
has_nsfw_concept = False
753+
754+
do_normalize = (
755+
[not has_nsfw for has_nsfw in has_nsfw_concept]
756+
if isinstance(has_nsfw_concept, list)
757+
else not has_nsfw_concept
758+
)
759+
image = self.image_processor.postprocess(image, output_type=output_type, do_normalize=do_normalize)
751760

752761
# Offload last model to CPU
753762
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:

0 commit comments

Comments
 (0)