Description
Describe the bug
With StableDiffusionImg2ImgPipeline
, all the schedulers I have tested in the course of testing #1866 error out when the img2img strength is 0.0 or 0.05 during the NSFW safety checker
While investigating this issue, a local copy of a community pipeline (StableDiffusionLongPromptWeightingPipeline
) seemed to pass the above error, but every output generated with those strength values triggered the NSFW checker.
Removing the NSFW checker here shows that the output for strength 0.0 and 0.05 appears to be random noise with one step, while 0.1 produces the expected result of the output being very similar to the initial image with one step
It's unlikely that 0.0 and 0.05 would be used as strength values, but for consistency's sake I think these values should behave similarly to 0.1
Tested schedulers: EulerAncestralDiscreteScheduler
, EulerDiscreteScheduler
, DPMSolverSinglestepScheduler
, DPMSolverMultistepScheduler
Excerpt from stack trace of NSFW safety checker error:
diffusers\pipelines\stable_diffusion\pipeline_stable_diffusion_img2img.py:595, in StableDiffusionImg2ImgPipeline.__call__(self, prompt, image, strength, num_inference_steps, guidance_scale, negative_prompt, num_images_per_prompt, eta, generator, output_type, return_dict, callback, callback_steps, **kwargs)
592 image = self.decode_latents(latents)
594 # 10. Run safety checker
--> 595 image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
597 # 11. Convert to PIL
598 if output_type == "pil":
diffusers\pipelines\stable_diffusion\pipeline_stable_diffusion_img2img.py:343, in StableDiffusionImg2ImgPipeline.run_safety_checker(self, image, device, dtype)
341 def run_safety_checker(self, image, device, dtype):
342 if self.safety_checker is not None:
--> 343 safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
344 image, has_nsfw_concept = self.safety_checker(
345 images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
346 )
347 else:
transformers\image_processing_utils.py:437, in BaseImageProcessor.__call__(self, images, **kwargs)
435 def __call__(self, images, **kwargs) -> BatchFeature:
436 """Preprocess an image or a batch of images."""
--> 437 return self.preprocess(images, **kwargs)
transformers\models\clip\image_processing_clip.py:302, in CLIPImageProcessor.preprocess(self, images, do_resize, size, resample, do_center_crop, crop_size, do_rescale, rescale_factor, do_normalize, image_mean, image_std, do_convert_rgb, return_tensors, data_format, **kwargs)
299 image_std = image_std if image_std is not None else self.image_std
300 do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
--> 302 if not is_batched(images):
303 images = [images]
305 if not valid_images(images):
transformers\image_utils.py:88, in is_batched(img)
86 def is_batched(img):
87 if isinstance(img, (list, tuple)):
---> 88 return is_valid_image(img[0])
89 return False
Reproduction of noisy outputs:
I had a copy of StableDiffusionLongPromptWeightingPipeline
that I was experimenting with in a Jupyter notebook (edit: I forgot that I changed something in the timesteps, will update this later), I commented out has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
and other lines that referenced has_nsfw_concept
and passed False
to nsfw_content_detected
in the return
Result of the above with strength 0.05
Same settings with strength 0.1
Settings for above outputs:
image: "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
model: runwayml/stable-diffusion-v1-5
width: 512
height: 512
steps: 14
seed: random
scale: 10
sampler: EulerAncestralDiscreteScheduler
Reproduction
import requests
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionImg2ImgPipeline, EulerDiscreteScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler, DPMSolverSinglestepScheduler, DPMSolverMultistepScheduler
device = "cuda"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
).to(device)
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((768, 512))
for scheduler in [EulerDiscreteScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler, DPMSolverSinglestepScheduler, DPMSolverMultistepScheduler]:
pipe.scheduler = scheduler()
print(scheduler.__name__)
for strength in [0.0, 0.05]:
try:
images = pipe(
prompt="a fantasy landscape",
negative_prompt=None,
image=init_image,
strength=strength,
num_inference_steps=15,
guidance_scale=10,
num_images_per_prompt=1
).images
image = images[0].save(f"fantasy_landscape_{strength}.png")
print(
f"fantasy_landscape_{strength}.png saved at img2img strength {strength}")
except Exception as e:
print(
f"fantasy_landscape_{strength}.png at img2img strength {strength} failed with error:\n {e} ")
Logs
Fetching 15 files: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:00<00:00, 4995.20it/s]
EulerDiscreteScheduler
0it [00:00, ?it/s]
fantasy_landscape_0.0.png at img2img strength 0.0 failed with error:
list index out of range
0it [00:00, ?it/s]
fantasy_landscape_0.05.png at img2img strength 0.05 failed with error:
list index out of range
LMSDiscreteScheduler
0it [00:00, ?it/s]
fantasy_landscape_0.0.png at img2img strength 0.0 failed with error:
list index out of range
0it [00:00, ?it/s]
fantasy_landscape_0.05.png at img2img strength 0.05 failed with error:
list index out of range
EulerAncestralDiscreteScheduler
0it [00:00, ?it/s]
fantasy_landscape_0.0.png at img2img strength 0.0 failed with error:
list index out of range
0it [00:00, ?it/s]
fantasy_landscape_0.05.png at img2img strength 0.05 failed with error:
list index out of range
DPMSolverSinglestepScheduler
0it [00:00, ?it/s]
fantasy_landscape_0.0.png at img2img strength 0.0 failed with error:
list index out of range
0it [00:00, ?it/s]
fantasy_landscape_0.05.png at img2img strength 0.05 failed with error:
list index out of range
DPMSolverMultistepScheduler
0it [00:00, ?it/s]
fantasy_landscape_0.0.png at img2img strength 0.0 failed with error:
list index out of range
0it [00:00, ?it/s]
fantasy_landscape_0.05.png at img2img strength 0.05 failed with error:
list index out of range
System Info
diffusers
version: 0.11.1- Platform: Windows-10-10.0.19045-SP0
- Python version: 3.10.8
- PyTorch version (GPU?): 1.13.1+cu116 (True)
- Huggingface_hub version: 0.11.1
- Transformers version: 4.25.1
- Using GPU in script?: Yes, RTX3090
- Using distributed or parallel set-up in script?: No