Skip to content

Stable Diffusion img2img: img2img pipeline does not appear to work at below 0.1 strength with any scheduler #1867

Closed
@cadaeix

Description

@cadaeix

Describe the bug

With StableDiffusionImg2ImgPipeline, all the schedulers I have tested in the course of testing #1866 error out when the img2img strength is 0.0 or 0.05 during the NSFW safety checker

While investigating this issue, a local copy of a community pipeline (StableDiffusionLongPromptWeightingPipeline) seemed to pass the above error, but every output generated with those strength values triggered the NSFW checker.

Removing the NSFW checker here shows that the output for strength 0.0 and 0.05 appears to be random noise with one step, while 0.1 produces the expected result of the output being very similar to the initial image with one step

It's unlikely that 0.0 and 0.05 would be used as strength values, but for consistency's sake I think these values should behave similarly to 0.1

Tested schedulers: EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, DPMSolverSinglestepScheduler, DPMSolverMultistepScheduler

Excerpt from stack trace of NSFW safety checker error:

diffusers\pipelines\stable_diffusion\pipeline_stable_diffusion_img2img.py:595, in StableDiffusionImg2ImgPipeline.__call__(self, prompt, image, strength, num_inference_steps, guidance_scale, negative_prompt, num_images_per_prompt, eta, generator, output_type, return_dict, callback, callback_steps, **kwargs)
    592 image = self.decode_latents(latents)
    594 # 10. Run safety checker
--> 595 image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
    597 # 11. Convert to PIL
    598 if output_type == "pil":

diffusers\pipelines\stable_diffusion\pipeline_stable_diffusion_img2img.py:343, in StableDiffusionImg2ImgPipeline.run_safety_checker(self, image, device, dtype)
    341 def run_safety_checker(self, image, device, dtype):
    342     if self.safety_checker is not None:
--> 343         safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
    344         image, has_nsfw_concept = self.safety_checker(
    345             images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
    346         )
    347     else:

transformers\image_processing_utils.py:437, in BaseImageProcessor.__call__(self, images, **kwargs)
    435 def __call__(self, images, **kwargs) -> BatchFeature:
    436     """Preprocess an image or a batch of images."""
--> 437     return self.preprocess(images, **kwargs)

transformers\models\clip\image_processing_clip.py:302, in CLIPImageProcessor.preprocess(self, images, do_resize, size, resample, do_center_crop, crop_size, do_rescale, rescale_factor, do_normalize, image_mean, image_std, do_convert_rgb, return_tensors, data_format, **kwargs)
    299 image_std = image_std if image_std is not None else self.image_std
    300 do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
--> 302 if not is_batched(images):
    303     images = [images]
    305 if not valid_images(images):

transformers\image_utils.py:88, in is_batched(img)
     86 def is_batched(img):
     87     if isinstance(img, (list, tuple)):
---> 88         return is_valid_image(img[0])
     89     return False

Reproduction of noisy outputs:

I had a copy of StableDiffusionLongPromptWeightingPipeline that I was experimenting with in a Jupyter notebook (edit: I forgot that I changed something in the timesteps, will update this later), I commented out has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) and other lines that referenced has_nsfw_concept and passed False to nsfw_content_detected in the return

Result of the above with strength 0.05

0test

Same settings with strength 0.1

0test

Settings for above outputs:
image: "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
model: runwayml/stable-diffusion-v1-5
width: 512
height: 512
steps: 14
seed: random
scale: 10
sampler: EulerAncestralDiscreteScheduler

Reproduction

import requests
from PIL import Image
from io import BytesIO

from diffusers import StableDiffusionImg2ImgPipeline, EulerDiscreteScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler, DPMSolverSinglestepScheduler, DPMSolverMultistepScheduler

device = "cuda"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
).to(device)


url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((768, 512))

for scheduler in [EulerDiscreteScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler, DPMSolverSinglestepScheduler, DPMSolverMultistepScheduler]:
    pipe.scheduler = scheduler()
    print(scheduler.__name__)
    for strength in [0.0, 0.05]:
        try:
            images = pipe(
                prompt="a fantasy landscape",
                negative_prompt=None,
                image=init_image,
                strength=strength,
                num_inference_steps=15,
                guidance_scale=10,
                num_images_per_prompt=1
            ).images
            image = images[0].save(f"fantasy_landscape_{strength}.png")
            print(
                f"fantasy_landscape_{strength}.png saved at img2img strength {strength}")
        except Exception as e:
            print(
                f"fantasy_landscape_{strength}.png at img2img strength {strength} failed with error:\n {e} ")

Logs

Fetching 15 files: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:00<00:00, 4995.20it/s]
EulerDiscreteScheduler
0it [00:00, ?it/s]
fantasy_landscape_0.0.png at img2img strength 0.0 failed with error:
 list index out of range 
0it [00:00, ?it/s]
fantasy_landscape_0.05.png at img2img strength 0.05 failed with error:
 list index out of range 
LMSDiscreteScheduler
0it [00:00, ?it/s]
fantasy_landscape_0.0.png at img2img strength 0.0 failed with error:
 list index out of range
0it [00:00, ?it/s]
fantasy_landscape_0.05.png at img2img strength 0.05 failed with error:
 list index out of range
EulerAncestralDiscreteScheduler
0it [00:00, ?it/s]
fantasy_landscape_0.0.png at img2img strength 0.0 failed with error:
 list index out of range
0it [00:00, ?it/s]
fantasy_landscape_0.05.png at img2img strength 0.05 failed with error:
 list index out of range
DPMSolverSinglestepScheduler
0it [00:00, ?it/s]
fantasy_landscape_0.0.png at img2img strength 0.0 failed with error:
 list index out of range
0it [00:00, ?it/s]
fantasy_landscape_0.05.png at img2img strength 0.05 failed with error:
 list index out of range
DPMSolverMultistepScheduler
0it [00:00, ?it/s]
fantasy_landscape_0.0.png at img2img strength 0.0 failed with error:
 list index out of range
0it [00:00, ?it/s]
fantasy_landscape_0.05.png at img2img strength 0.05 failed with error:
 list index out of range

System Info

  • diffusers version: 0.11.1
  • Platform: Windows-10-10.0.19045-SP0
  • Python version: 3.10.8
  • PyTorch version (GPU?): 1.13.1+cu116 (True)
  • Huggingface_hub version: 0.11.1
  • Transformers version: 4.25.1
  • Using GPU in script?: Yes, RTX3090
  • Using distributed or parallel set-up in script?: No

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions