Closed
Description
Describe the bug
When generating images with SDXL and DDIM, there is some residual noise in the outputs.
This leads to a "smudgy" look, and in cases where fewer steps are used, DDIM and Euler diverge a lot more than they should because of the cumulative impact of not having the timesteps aligned properly.
In some brief tests, it looks like simply adding an extra timestep with a zero sigma to the end of the schedule resolves the problem.
Reproduction
This script uses a modified Euler scheduler to create fully-denoised images:
import PIL
import requests
import torch
import numpy as np
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
model_id = "ptx0/terminus-xl-gamma-training"
pipe = StableDiffusionXLPipeline.from_pretrained(model_id, add_watermarker=False, torch_dtype=torch.bfloat16).to("cuda")
generator = torch.Generator("cuda").manual_seed(420420420)
prompt = "the artful dodger, cool dog in sunglasses sitting on a recliner in the dark, with the white noise reflecting on his sunglasses"
num_inference_steps = 30
guidance_scale = 7.5
def rescale_zero_terminal_snr_sigmas(sigmas):
sigmas = sigmas.flip(0)
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas_bar[-1] = 4.8973451890853435e-08
sigmas = ((1 - alphas_bar) / alphas_bar) ** 0.5
return sigmas.flip(0)
zsnr = getattr(pipe.scheduler.config, 'rescale_betas_zero_snr', False)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
if zsnr:
tsbase = pipe.scheduler.set_timesteps
def tspatch(*args, **kwargs):
tsbase(*args, **kwargs)
pipe.scheduler.sigmas = rescale_zero_terminal_snr_sigmas(pipe.scheduler.sigmas)
pipe.scheduler.set_timesteps = tspatch
sigmas = pipe.scheduler.betas
edited_image = pipe(
prompt=prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
guidance_rescale=0.7
).images[0]
edited_image.save("edited_image.png")
It uses the Sigmas code ported by @Beinsezii in #6024
However, with vanilla DDIM, the results are far worse:
import PIL
import requests
import torch
import numpy as np
from diffusers import StableDiffusionXLPipeline
model_id = "ptx0/terminus-xl-gamma-training"
pipe = StableDiffusionXLPipeline.from_pretrained(model_id, add_watermarker=False, torch_dtype=torch.bfloat16).to("cuda")
generator = torch.Generator("cuda").manual_seed(420420420)
prompt = "the artful dodger, cool dog in sunglasses sitting on a recliner in the dark, with the white noise reflecting on his sunglasses"
num_inference_steps = 30
guidance_scale = 7.5
edited_image = pipe(
prompt=prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
guidance_rescale=0.7
).images[0]
edited_image.save("edited_image.png")
Logs
No response
System Info
diffusers
version: 0.21.4- Platform: Linux-5.19.0-45-generic-x86_64-with-glibc2.31
- Python version: 3.9.16
- PyTorch version (GPU?): 2.1.0+cu118 (True)
- Huggingface_hub version: 0.16.4
- Transformers version: 4.30.2
- Accelerate version: 0.18.0
- xFormers version: 0.0.22.post4+cu118
- Using GPU in script?: A100-80G PCIe
- Using distributed or parallel set-up in script?: FALSE