Skip to content

DDIM produces incorrect samples with SDXL (epsilon or v-prediction) #6068

Closed
@bghira

Description

@bghira

Describe the bug

When generating images with SDXL and DDIM, there is some residual noise in the outputs.

This leads to a "smudgy" look, and in cases where fewer steps are used, DDIM and Euler diverge a lot more than they should because of the cumulative impact of not having the timesteps aligned properly.

In some brief tests, it looks like simply adding an extra timestep with a zero sigma to the end of the schedule resolves the problem.

Reproduction

This script uses a modified Euler scheduler to create fully-denoised images:

import PIL
import requests
import torch
import numpy as np
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler

model_id = "ptx0/terminus-xl-gamma-training"
pipe = StableDiffusionXLPipeline.from_pretrained(model_id, add_watermarker=False, torch_dtype=torch.bfloat16).to("cuda")
generator = torch.Generator("cuda").manual_seed(420420420)

prompt = "the artful dodger, cool dog in sunglasses sitting on a recliner in the dark, with the white noise reflecting on his sunglasses"
num_inference_steps = 30
guidance_scale = 7.5
def rescale_zero_terminal_snr_sigmas(sigmas):
    sigmas = sigmas.flip(0)
    alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
    alphas_bar_sqrt = alphas_cumprod.sqrt()

    # Store old values.
    alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
    alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()

    # Shift so the last timestep is zero.
    alphas_bar_sqrt -= (alphas_bar_sqrt_T)

    # Scale so the first timestep is back to the old value.
    alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

    # Convert alphas_bar_sqrt to betas
    alphas_bar = alphas_bar_sqrt**2  # Revert sqrt
    alphas_bar[-1] = 4.8973451890853435e-08
    sigmas = ((1 - alphas_bar) / alphas_bar) ** 0.5
    return sigmas.flip(0)


zsnr = getattr(pipe.scheduler.config, 'rescale_betas_zero_snr', False)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
if zsnr:
    tsbase = pipe.scheduler.set_timesteps
    def tspatch(*args, **kwargs):
        tsbase(*args, **kwargs)
        pipe.scheduler.sigmas = rescale_zero_terminal_snr_sigmas(pipe.scheduler.sigmas)
    pipe.scheduler.set_timesteps = tspatch
    sigmas = pipe.scheduler.betas

edited_image = pipe(
   prompt=prompt, 
   num_inference_steps=num_inference_steps, 
   guidance_scale=guidance_scale,
   generator=generator,
    guidance_rescale=0.7
).images[0]
edited_image.save("edited_image.png")

It uses the Sigmas code ported by @Beinsezii in #6024
image

However, with vanilla DDIM, the results are far worse:

import PIL
import requests
import torch
import numpy as np
from diffusers import StableDiffusionXLPipeline

model_id = "ptx0/terminus-xl-gamma-training"
pipe = StableDiffusionXLPipeline.from_pretrained(model_id, add_watermarker=False, torch_dtype=torch.bfloat16).to("cuda")
generator = torch.Generator("cuda").manual_seed(420420420)

prompt = "the artful dodger, cool dog in sunglasses sitting on a recliner in the dark, with the white noise reflecting on his sunglasses"
num_inference_steps = 30
guidance_scale = 7.5
edited_image = pipe(
   prompt=prompt, 
   num_inference_steps=num_inference_steps, 
   guidance_scale=guidance_scale,
   generator=generator,
    guidance_rescale=0.7
).images[0]
edited_image.save("edited_image.png")

image

Logs

No response

System Info

  • diffusers version: 0.21.4
  • Platform: Linux-5.19.0-45-generic-x86_64-with-glibc2.31
  • Python version: 3.9.16
  • PyTorch version (GPU?): 2.1.0+cu118 (True)
  • Huggingface_hub version: 0.16.4
  • Transformers version: 4.30.2
  • Accelerate version: 0.18.0
  • xFormers version: 0.0.22.post4+cu118
  • Using GPU in script?: A100-80G PCIe
  • Using distributed or parallel set-up in script?: FALSE

Who can help?

@patrickvonplaten @yiyixuxu

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions