Skip to content

fix DPM Scheduler with use_karras_sigmas option #6477

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Jan 19, 2024
Merged

fix DPM Scheduler with use_karras_sigmas option #6477

merged 26 commits into from
Jan 19, 2024

Conversation

yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Jan 7, 2024

This PR:

  1. deprecate dpmsolver and sde-dpmsolver algorithms from DPM schedulers
  2. introduce new argument final_sigmas_types:
    • When final_sigmas_types = "sigma_min" we use the min value as the last sigma, and this is the current algorithm;
    • When final_sigmas_type ="zero," it denoise to sigma=0 in the last step - this is newly introduced in this PR and matches k-diffusion implementation. At num_inference_step=25, it achieves significantly better results for SDXL across various configurations we tested (see testing script below)
  3. update the default config to set final_sigmas_type = "zero"

fix #6295

testing script

import torch
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline
from diffusers import DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler
import os

from diffusers.utils import make_image_grid

config_min = {"final_sigmas_type":"sigma_min"}
config_min_euler = {"final_sigmas_type":"sigma_min", "euler_at_final": True }
config_zero = {"final_sigmas_type":"zero"}

schedulers = {
    "DPMPP_2M": {
        "min": (DPMSolverMultistepScheduler, config_min),
        "min_euler": (DPMSolverMultistepScheduler, config_min_euler),
        "zero": (DPMSolverMultistepScheduler, config_zero),
     },
     "DPMPP_2M_K": {
        "min": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, **config_min}),
        "min_euler": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, **config_min_euler}),
        "zero": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, **config_zero}),
     },
     "DPMPP_2M_SDE": {
        "min": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", **config_min}),
        "min_euler": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", **config_min_euler}),
        "zero": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", **config_zero}),
     },
     "DPMPP_2M_SDE_K": {
        "min": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", "use_karras_sigmas": True, **config_min}),
        "min_euler": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", "use_karras_sigmas": True, **config_min_euler}),
        "zero": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "algorithm_type": "sde-dpmsolver++", **config_zero}),
     },
     "DPMPP": {
        "min": (DPMSolverSinglestepScheduler, config_min),
        "min_euler": (DPMSolverSinglestepScheduler, config_min_euler),
        "zero": (DPMSolverSinglestepScheduler, config_zero),
     },
     "DPMPP_K": {
        "min": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True, **config_min}),
        "min_euler": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True, **config_min_euler}),
        "zero": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True, **config_zero}),
     },
}


## Test SD-XL

# model_id = "stabilityai/stable-diffusion-xl-base-1.0"
model_id = "frankjoshua/juggernautXL_version6Rundiffusion"
pipe = StableDiffusionXLPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    use_safetensors=True,
    variant="fp16",
    add_watermarker=False)
pipe = pipe.to('cuda')
prompt = "Adorable infant playing with a variety of colorful rattle toys."
save_dir = './test_juggernautxl_baby'

if not os.path.exists(save_dir):
    os.mkdir(save_dir)

steps = 25

params = {
    "prompt": [prompt],
    "num_inference_steps": steps,
    "guidance_scale": 7,
}
for scheduler_name in schedulers.keys():
    for seed in [12345, 123]:
        out_imgs = []
        scheduler_configs = schedulers[scheduler_name]
        for scheduler_config_name in scheduler_configs.keys():
            generator = torch.Generator(device='cuda').manual_seed(seed)
            scheduler = scheduler_configs[scheduler_config_name][0].from_pretrained(
                model_id,
                subfolder="scheduler",
                **scheduler_configs[scheduler_config_name][1],
            )
            pipe.scheduler = scheduler

            img = pipe(**params, generator=generator).images[0]
            out_imgs.append(img)
        out_img = make_image_grid(out_imgs, rows=1, cols=3)    
        out_img.save(os.path.join(save_dir, f"seed_{seed}_steps_{steps}_{scheduler_name}.png"))

outputs

from left to right:

  • left: current default (final_sigmas_type="sigma_min")
  • middle: final_sigmas_type="sigma_min" + euler_at_final=True
  • right: new default set in this PR (final_sigmas_type="zero")

DPM++2M

seed_123_steps_25_DPMPP_2M

DPM++2M Karras

seed_123_steps_25_DPMPP_2M_K

DPM++2M SDE

seed_123_steps_25_DPMPP_2M_SDE

DPM++2M Karras SDE

seed_123_steps_25_DPMPP_2M_SDE_K

DPM++

seed_123_steps_25_DPMPP

DPM++ Karras

seed_123_steps_25_DPMPP_K

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Jan 7, 2024

cc @LuChengTHU here for his insights :)

@@ -267,7 +267,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten

Was there a reason to set the sigmas this way? or is it just a mistake that we didn't catch in the PR?

By repeating the last sigma twice, we make the very last step a dummy step with zero step size

i.e. this code here is x_t= sample because we have sigma_t = sigma_s0 and h=0

x_t = (
                    (sigma_t / sigma_s0) * sample
                    - (alpha_t * (torch.exp(-h) - 1.0)) * D0
                    - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
                )

if self.config.solver_type == "midpoint":

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch! Was this the culprit?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes this was it!

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Jan 8, 2024

I used this test script to compare against k-diffusion. I had to hardcode the k-sampler in order to force identical noise construction. I made comments where I updated code.

the results look the same quality to me - are they? I might need a little bit of help from more trained eyes @ivanprado

k-diffusion
yiyi_test_5_out_k

diffusers
yiyi_test_5_out_d

import torch
from diffusers import StableDiffusionXLKDiffusionPipeline, AutoPipelineForText2Image
from diffusers.schedulers import DPMSolverMultistepScheduler
from diffusers.utils.torch_utils import randn_tensor

from tqdm.auto import trange, tqdm

def default_noise_sampler(x):
    return lambda generator: randn_tensor(x.shape, dtype=x.dtype, device=x.device, generator=generator)

@torch.no_grad()
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
    """DPM-Solver++(2M) SDE."""

    if solver_type not in {'heun', 'midpoint'}:
        raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
    sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
    # yiyi edit: use custom noiser_sampler to match diffusers 
    noise_sampler = default_noise_sampler(x)
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])

    old_denoised = None
    h_last = None

    for i in trange(len(sigmas) - 1, disable=disable):
        denoised = model(x, sigmas[i] * s_in, **extra_args)
        if callback is not None:
            callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
        if sigmas[i + 1] == 0:
            # Denoising step
            x = denoised
        else:
            # DPM-Solver++(2M) SDE
            t, s = -sigmas[i].log(), -sigmas[i + 1].log()
            h = s - t
            eta_h = eta * h

            x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised

            if old_denoised is not None:
                r = h_last / h
                if solver_type == 'heun':
                    x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
                elif solver_type == 'midpoint':
                    x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)

            if eta:
                # yiyi edit: use custom noise_sampler to match diffusers
                noise = noise_sampler(generator)
                x = x + noise * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise


        old_denoised = denoised
        h_last = h
    return x

seed = 42
model_id = "frankjoshua/juggernautXL_version6Rundiffusion"
prompt = "Adorable infant playing with a variety of colorful rattle toys."



# k-diffusion pipeline
pipe = StableDiffusionXLKDiffusionPipeline.from_pretrained(
    model_id,
    use_safetensors=True,
)
pipe.enable_model_cpu_offload()

pipe.sampler = sample_dpmpp_2m_sde

generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(
    prompt, 
    generator=generator, 
    guidance_scale=3.0,
    num_inference_steps=25, 
    use_karras_sigmas=True,
    height=768,
    width=1344,
).images[0]

image.save("yiyi_test_out_k.png")


del pipe
torch.cuda.empty_cache()

# diffusers pipeline

pipe = AutoPipelineForText2Image.from_pretrained(
    model_id, 
    add_watermarker = False,
    use_safetensors=True,
)
pipe.enable_model_cpu_offload()

pipe.scheduler = DPMSolverMultistepScheduler.from_config(
        pipe.scheduler.config, 
        algorithm_type="sde-dpmsolver++",
        use_karras_sigmas=True,
        )
generator = torch.Generator(device="cuda").manual_seed(seed)
results = pipe(
    prompt=prompt, 
    guidance_scale=3.,
    generator=generator,  
    num_inference_steps=25, 
    height=768, 
    width=1344,)
results.images[0].save("yiyi_test_out_d.png")

@ivanprado
Copy link
Contributor

ivanprado commented Jan 8, 2024

Thank you @yiyixuxu !! The results are now really good in my eyes. Thank you! But also the DPMSolverSinglestepScheduler is affected by the issue (see the image below). Could you also apply the fix to it?

Code to reproduce the issue with DPMSolverSinglestepScheduler

import torch
from diffusers import AutoPipelineForText2Image
from diffusers.schedulers import DPMSolverSinglestepScheduler

seed = 42
model_id = "frankjoshua/juggernautXL_version6Rundiffusion"
prompt = "Adorable infant playing with a variety of colorful rattle toys."

# diffusers pipeline
pipe = AutoPipelineForText2Image.from_pretrained(
    model_id,
    add_watermarker = False,
    use_safetensors=True,
)
pipe.enable_model_cpu_offload()

pipe.scheduler = DPMSolverSinglestepScheduler.from_config(
        pipe.scheduler.config,
        use_karras_sigmas=True,
        )
generator = torch.Generator(device="cuda").manual_seed(seed)
results = pipe(
    prompt=prompt,
    guidance_scale=3.,
    generator=generator,
    num_inference_steps=25,
    height=768,
    width=1344,)
results.images[0].save("yiyi_test_out_d_single.png")
Screenshot 2024-01-08 at 14 42 00

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job fixing the bug! Would be great to also apply the fix to scheduling_dpmsolver_singlestep.py

@yiyixuxu
Copy link
Collaborator Author

@patrickvonplaten

I came across this issue crowsonkb/k-diffusion#43 (comment) on k-diffusion and found out skipping the last denoising step is a trick can people use for low step counts - do you have any memory that if this is the reason we did np.concatenate([sigmas, sigmas[-1:]])? We should have documented some where if it is the case, no?

similarly, in Auto1111 they have an option for you to skip second last step. I'm going to run some testing on my end for sd1.5 and sdxl. if this is not a bug then maybe we should add a new config instead ....

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@patrickvonplaten
Copy link
Contributor

@patrickvonplaten

I came across this issue crowsonkb/k-diffusion#43 (comment) on k-diffusion and found out skipping the last denoising step is a trick can people use for low step counts - do you have any memory that if this is the reason we did np.concatenate([sigmas, sigmas[-1:]])? We should have documented some where if it is the case, no?

similarly, in Auto1111 they have an option for you to skip second last step. I'm going to run some testing on my end for sd1.5 and sdxl. if this is not a bug then maybe we should add a new config instead ....

Good catch - I'm not 100% sure actually why we added sigmas in this way and don't find the first PR anymore. Maybe worth tracking down who added the first use_karras_sigmas

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Jan 15, 2024

testing

I compare output for last_sigmas_type=default vslast_sigmas_type = denoise_to_zero for:

  • both SDXL and SD1.5
  • DPMSolverSinglestepScheduler and DPMSolverSinglestepScheduler
  • all algorithem_types that are compatible with last_sigmas_type = denoise_to_zero, i.e. dpmsolver ++, dpmsolver-sde ++
  • use_karras_sigmas = True and use_karras_sigmas=False

I can see a nice improvement across all configurations for SDXL, but not so much for SD1.5 (it is possible it's just not so obvious in the examples I used. I tested sd1.5 on the baby example, too, and didn't see much difference there either; the sd1.5 babies are really scary looking, so I don't want to post them here)

testing script

import torch
from diffusers import StableDiffusionXLPipeline
from diffusers import DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler
import os

common_config = {'beta_start': 0.00085, 'beta_end': 0.012, 'beta_schedule': 'scaled_linear'}
schedulers = {
    "DPMPP_2M_default": (DPMSolverMultistepScheduler, {"final_sigmas_type":"default", "euler_at_final": True }),
    "DPMPP_2M_K_default": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "final_sigmas_type":"default", "euler_at_final": True}),
    "DPMPP_2M_zero": (DPMSolverMultistepScheduler, {"final_sigmas_type":"denoise_to_zero"}),
    "DPMPP_2M_K_zero": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "final_sigmas_type":"denoise_to_zero"}),
    "DPMPP_2M_SDE_default": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", "final_sigmas_type":"default", "euler_at_final": True}),
    "DPMPP_2M_SDE_K_default": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", "use_karras_sigmas": True, "final_sigmas_type":"default", "euler_at_final": True}),
    "DPMPP_2M_SDE_zero": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", "final_sigmas_type":"denoise_to_zero"}),
    "DPMPP_2M_SDE_K_zero": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "algorithm_type": "sde-dpmsolver++", "final_sigmas_type":"denoise_to_zero"}),
    "DPMPP_default": (DPMSolverSinglestepScheduler, {"final_sigmas_type":"default", "lower_order_final": True }),
    "DPMPP_K_default": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True, "final_sigmas_type":"default", "lower_order_final": True}),
    "DPMPP_zero": (DPMSolverSinglestepScheduler, {"final_sigmas_type":"denoise_to_zero"}),
    "DPMPP_K_zero": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True, "final_sigmas_type":"denoise_to_zero"}),
}


## Test SDXL/juggernautXL

#model_id = "stabilityai/stable-diffusion-xl-base-1.0"
model_id = "frankjoshua/juggernautXL_version6Rundiffusion"
pipe = StableDiffusionXLPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    use_safetensors=True,
    variant="fp16",
    add_watermarker=False)
pipe = pipe.to('cuda')
prompt = "Adorable infant playing with a variety of colorful rattle toys."
save_dir = './yiyi_test_juggernautxl_baby'

# # test sd 1.5
# model_id = "runwayml/stable-diffusion-v1-5"
# pipe = StableDiffusionPipeline.from_pretrained(
#     model_id,
#     torch_dtype=torch.float16,
#     use_safetensors=True,
#    variant="fp16",
# )
# pipe = pipe.to('cuda')
# prompt = "an astronaut riding a horse on mars"
# save_dir = './yiyi_test_1.5_baby'





if not os.path.exists(save_dir):
    os.mkdir(save_dir)

steps = 25

params = {
    "prompt": [prompt],
    "num_inference_steps": steps,
    "guidance_scale": 7,
}
for scheduler_name in schedulers.keys():
    for seed in [12345, 123]:
        generator = torch.Generator(device='cuda').manual_seed(seed)

        scheduler = schedulers[scheduler_name][0].from_pretrained(
            model_id,
            subfolder="scheduler",
            **schedulers[scheduler_name][1],
        )
        pipe.scheduler = scheduler

        sdxl_img = pipe(**params, generator=generator).images[0]
        sdxl_img.save(os.path.join(save_dir, f"seed_{seed}_steps_{steps}_{scheduler_name}.png"))

results

SDXL/juggernautxl

DPM++2M

default last_sigmas_type + euler_at_final denoise_to_zero (This PR)
seed_123_steps_25_DPMPP_2M_default seed_123_steps_25_DPMPP_2M_zero

DPM++2M Karras

default last_sigmas_type + euler_at_final denoise_to_zero (This PR)
seed_123_steps_25_DPMPP_2M_K_default seed_123_steps_25_DPMPP_2M_K_zero

DPM++2M SDE

default last_sigmas_type + euler_at_final denoise_to_zero (This PR)
seed_123_steps_25_DPMPP_2M_SDE_default seed_123_steps_25_DPMPP_2M_SDE_zero

DPM++2M Karras SDE

default last_sigmas_type + euler_at_final denoise_to_zero (This PR)
seed_123_steps_25_DPMPP_2M_SDE_K_default seed_123_steps_25_DPMPP_2M_SDE_K_zero

DPM++

default last_sigmas_type + euler_at_final denoise_to_zero (This PR)
seed_123_steps_25_DPMPP_default seed_123_steps_25_DPMPP_zero

DPM++ Karras

default last_sigmas_type + euler_at_final denoise_to_zero (This PR)
seed_123_steps_25_DPMPP_K_default seed_123_steps_25_DPMPP_K_zero

SD1.5

DPM++2M

default last_sigmas_type + euler_at_final denoise_to_zero (This PR)
seed_123_steps_25_DPMPP_2M_default seed_123_steps_25_DPMPP_2M_zero

DPM++2M Karras

default last_sigmas_type + euler_at_final denoise_to_zero (This PR)
seed_123_steps_25_DPMPP_2M_K_default seed_123_steps_25_DPMPP_2M_K_zero

DPM++2M SDE

default last_sigmas_type + euler_at_final denoise_to_zero (This PR)
seed_123_steps_25_DPMPP_2M_SDE_default seed_123_steps_25_DPMPP_2M_SDE_zero

DPM++2M Karras SDE

default last_sigmas_type + euler_at_final denoise_to_zero (This PR)
seed_123_steps_25_DPMPP_2M_SDE_K_default seed_123_steps_25_DPMPP_2M_SDE_K_zero

DPM++

default last_sigmas_type + euler_at_final denoise_to_zero (This PR)
seed_123_steps_25_DPMPP_default seed_123_steps_25_DPMPP_zero

DPM++ Karras

default last_sigmas_type + euler_at_final denoise_to_zero (This PR)
seed_123_steps_25_DPMPP_K_default seed_123_steps_25_DPMPP_K_zero

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Jan 15, 2024

@patrickvonplaten

Good catch - I'm not 100% sure actually why we added sigmas in this way and don't find the first PR anymore. Maybe worth tracking down who added the first use_karras_sigmas

I tracked this all the way down to myself 😱😱😱😱😱😱 pretty sure it was a bug but didn't catch it because it didn't cause much difference for SD1.5
#4986

Never mind, it wasn't me!!!🙈 - even though I was the one who added sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32). I did that to match the previous behavior. Previously, all the step calculations are based on timesteps. On the last step, it will set prev_timestep to be 0, and that will get us sigmas[-1]

        if isinstance(timestep, torch.Tensor):
            timestep = timestep.to(self.timesteps.device)
        step_index = (self.timesteps == timestep).nonzero()
        if len(step_index) == 0:
            step_index = len(self.timesteps) - 1
        else:
            step_index = step_index.item()
        prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]

I'm still pretty sure it wasn't on purpose. However, I think it is ok to leave it as default mainly because:

  1. only the dpmsolver +++ and dpmsolver-sde +++ works with denoise_to_zero option, dpmsolver and dpmsolver-sde doesn't with it
  2. the official implementation from the author (when use_karras_sigmas = False) did not denoise to zero either (even though it did not skip the last step as we did with Karras sigmas). And I observe a pretty nice performance boost, too, when we apply denoise_to_zero with the default sigmas!
  3. some people use this (skip the last denoising step) as a trick for low steps Stabilize the sampling of DPM-Solver++2M by a stabilizing trick crowsonkb/k-diffusion#43 (comment) - even though I did not notice much improvement in my quick experiment

@@ -164,6 +164,7 @@ def __init__(
lower_order_final: bool = True,
euler_at_final: bool = False,
use_karras_sigmas: Optional[bool] = False,
final_sigmas_type: Optional[str] = "default", # "denoise_to_zero", "default"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are both options necessary? Each has an advantage in certain scenarios?

Copy link
Collaborator Author

@yiyixuxu yiyixuxu Jan 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • denoise_to_zero matches the last step in k-diffusion. We can apply this with or without the karras sigmas. see some of my comparison results here fix DPM Scheduler with use_karras_sigmas option  #6477 (comment)
  • default is what we currently have:
    • in my experiment, it was worse than the other options we just added. But obviously, I did not test for a lot of different settings, and I believe it will do better with higher step numbers
    • I want to point out that there is some inconsistency between the current behavior of use_karras_sigmas = True and use_karras_sigmas = False: with karras_sigmas, we are skipping the last denoising step. I'm not entirely sure why it is done this way, so I'm not comfortable just updating it. I'm hoping to address that in a new PR so we don't delay this one getting merged in

current code for use_karras_sigmas

            sigmas = np.flip(sigmas).copy()
            sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
            sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)

I think it should be

            sigmas = np.flip(sigmas).copy()
            sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=(num_inference_steps+1))
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we should change the default to "denoise_to_zero" here though if you think it always gives better results. Also cc @sayakpaul

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can do that! I will deprecate the dpmsolver and dpmsolver-sde since they won't work with the denoise_to_zero and I think they are not really used at all

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for me!

Co-authored-by: Patrick von Platen <[email protected]>
@@ -189,6 +190,11 @@ def __init__(
else:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")

if algorithm_type != "dpmsolver++" and final_sigmas_type == "denoise_to_zero":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to support sde-dpmsolver++ in DPMSolverSinglestepScheduler. Is one of the best combinations in k-diffusers, and is missing in diffusers

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree. +1 for sde-dpmsolver++

@@ -0,0 +1,842 @@
# Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it makes sense to create a whole new "legacy" scheduler class here. I'd advocate for just deprecating it in the original file

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok! will do that instead. I was just kinda eager to get rid of the code sooner 😛

# Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it makes sense to create a whole new "legacy" scheduler class here. I'd advocate for just deprecating it in the original file

if algorithm_type == "deis":
self.register_to_config(algorithm_type="dpmsolver++")
elif algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
raise ValueError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just throw a deprecation warning directly here instead of a ValueError?

Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
`dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
paper, and the `dpmsolver++` type implements the algorithms in the
Algorithm type for the solver; can be `dpmsolver++` or `sde-dpmsolver++`. The `dpmsolver++` type implements the algorithms in the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sde-dpmsolver++ is not supported whereas the doc states it is supported. I guess this documentation should be updated.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes will update the doc now. I Will try to add the support for sde-dpmsolver++ soon too. It's on my to-do for a long time now

@hipsterusername
Copy link

For whoever it may benefit, ran through and updated diffusers to run this PR in Invoke's latest build, and can confirm that this fixes a long-standing complaint from many end-users about SDXL schedulers.

Looking forward to seeing this in a release
🙌

image

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job!

@spezialspezial
Copy link
Contributor

Nailed it.

@yiyixuxu yiyixuxu mentioned this pull request Jan 22, 2024
@alexisrolland
Copy link
Contributor

Yay 🫶

AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* fix

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Patrick von Platen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Artifacts with DPM++ 2M SDE Karras, even when using use_lu_lambdas
8 participants