Skip to content

fix the add_noise function for dpm-multi et al #5158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 23, 2023
Merged

Conversation

yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Sep 23, 2023

My PR to refactor schedulers #4986 broke img2img pipelines :(
I updated the add_noise incorrectly and never tested it with img2img and it was not caught by our current tests

this PR updateadd_noise to use sigmas

testing

I compared outputs from this branch vs the output from a branch where I reverted the commit that introduced this bug #4986

I used a single strength value 0.8, and I tested both use_karras_sigma=True and use_karras_sigma=False config. The results are identical.

testing use_karras_sigma=True

import requests
import numpy as np
from PIL import Image
from io import BytesIO
import torch

from diffusers import (
    StableDiffusionImg2ImgPipeline,
    DPMSolverSinglestepScheduler,
    DPMSolverMultistepScheduler,
    UniPCMultistepScheduler,
    DEISMultistepScheduler,
)

url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((768, 512))

device = "cuda"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
).to(device)

scheduler_config = dict(pipe.scheduler.config)

seed = 33
strength = 0.8
#branch = "current"
branch = "revert-dpm-refactor"

# DPMSolverMultistepScheduler
sched_name = "dpm_miulti"
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
    scheduler_config, use_karras_sigmas=True, algorithm_type="dpmsolver++")

generator = torch.Generator(device="cuda").manual_seed(seed)
images = pipe(
    prompt="a fantasy landscape",
    negative_prompt=None,
    image=init_image,
    strength=strength,
    num_inference_steps=15,
    generator=generator,
    guidance_scale=10,
    num_images_per_prompt=1
).images
image = images[0].save(f"test_5_out/[{sched_name}]_{strength:.2f}_[{branch}].png")

# DPMSolverMultistepScheduler(sde) 
sched_name = "dpm_miulti_sde"
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
    scheduler_config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++")

generator = torch.Generator(device="cuda").manual_seed(seed)
images = pipe(
    prompt="a fantasy landscape",
    negative_prompt=None,
    image=init_image,
    strength=strength,
    num_inference_steps=15,
    guidance_scale=10,
    generator=generator,
    num_images_per_prompt=1
).images
image = images[0].save(f"test_5_out/[{sched_name}]_{strength:.2f}_[{branch}].png")

# DPMSolverSingleScheduler
sched_name = "dpm_single"
pipe.scheduler = DPMSolverSinglestepScheduler.from_config(
    scheduler_config, use_karras_sigmas=True)

generator = torch.Generator(device="cuda").manual_seed(seed)
images = pipe(
    prompt="a fantasy landscape",
    negative_prompt=None,
    image=init_image,
    strength=strength,
    num_inference_steps=15,
    guidance_scale=10,
    generator=generator,
    num_images_per_prompt=1
).images
image = images[0].save(f"test_5_out/[{sched_name}]_{strength:.2f}_[{branch}].png")

# UniPCMultistepScheduler
sched_name = "unipc"
pipe.scheduler = UniPCMultistepScheduler.from_config(
    scheduler_config, use_karras_sigmas=True)

generator = torch.Generator(device="cuda").manual_seed(seed)
images = pipe(
    prompt="a fantasy landscape",
    negative_prompt=None,
    image=init_image,
    strength=strength,
    num_inference_steps=15,
    guidance_scale=10,
    generator=generator,
    num_images_per_prompt=1
).images
image = images[0].save(f"test_5_out/[{sched_name}]_{strength:.2f}_[{branch}].png")

# DEISMultistepScheduler
sched_name = "deis"
pipe.scheduler = DEISMultistepScheduler.from_config(
    scheduler_config, use_karras_sigmas=True)

generator = torch.Generator(device="cuda").manual_seed(seed)
images = pipe(
    prompt="a fantasy landscape",
    negative_prompt=None,
    image=init_image,
    strength=strength,
    num_inference_steps=15,
    generator=generator,
    guidance_scale=10,
    num_images_per_prompt=1
).images
image = images[0].save(f"test_5_out/[{sched_name}]_{strength:.2f}_[{branch}].png")

DEISMultistepScheduler

this PR revert dpm-refactor PR
deis 0 80 current deis 0 80 revert-dpm-refactor

DPMSolverMultistepScheduler(sde)

this PR revert dpm-refactor PR
dpm_miulti_sde 0 80 current dpm_miulti_sde 0 80 revert-dpm-refactor

DPMSolverMultistepScheduler

this PR revert dpm-refactor PR
dpm_miulti 0 80 current dpm_miulti 0 80 revert-dpm-refactor

DPMSolverSingleScheduler

this PR revert dpm-refactor PR
dpm_single 0 80 current dpm_single 0 80 revert-dpm-refactor

UniPCMultistepScheduler

this PR revert dpm-refactor PR
unipc 0 80 current unipc 0 80 revert-dpm-refactor

testing use_karras_sigma=False

DEISMultistepScheduler

this PR revert dpm-refactor PR
deis 0 80 current deis 0 80 revert-dpm-refactor

DPMSolverMultistepScheduler(sde)

this PR revert dpm-refactor PR
dpm_miulti_sde 0 80 current dpm_miulti_sde 0 80 revert-dpm-refactor

DPMSolverMultistepScheduler

this PR revert dpm-refactor PR
dpm_miulti 0 80 current dpm_miulti 0 80 revert-dpm-refactor

DPMSolverSingleScheduler

this PR revert dpm-refactor PR
dpm_single 0 80 current dpm_single 0 80 revert-dpm-refactor

UniPCMultistepScheduler

this PR revert dpm-refactor PR
unipc 0 80 current unipc 0 80 revert-dpm-refactor

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Sep 23, 2023

More testing here

I also test various strength levels (from 0.1 ~ 0.95 with 0.05 interval) for each of the schedulers we updated, for both use_karras_sigmas=True and use_karras_sigmas=False,

I uploaded outputs here https://huggingface.co/datasets/YiYiXu/pr5158/tree/main

import requests
import numpy as np
from PIL import Image
from io import BytesIO
import torch

from diffusers import (
    StableDiffusionImg2ImgPipeline,
    DPMSolverSinglestepScheduler,
    DPMSolverMultistepScheduler,
    UniPCMultistepScheduler,
    DEISMultistepScheduler,
)

url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((768, 512))

device = "cuda"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
).to(device)

scheduler_config = dict(pipe.scheduler.config)


# DPMSolverMultistepScheduler
sched_name = "dpm_miulti"
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
    scheduler_config, use_karras_sigmas=True, algorithm_type="dpmsolver++")


for strength in torch.linspace(0.1,0.95,18):
    images = pipe(
        prompt="a fantasy landscape",
        negative_prompt=None,
        image=init_image,
        strength=strength,
        num_inference_steps=15,
        guidance_scale=10,
        num_images_per_prompt=1
    ).images
    image = images[0].save(f"test_2_out/[{sched_name}]_{strength:.2f}.png")

# DPMSolverMultistepScheduler(sde) 
sched_name = "dpm_miulti_sde"
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
    scheduler_config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++")


for strength in torch.linspace(0.1,0.95,18):
    images = pipe(
        prompt="a fantasy landscape",
        negative_prompt=None,
        image=init_image,
        strength=strength,
        num_inference_steps=15,
        guidance_scale=10,
        num_images_per_prompt=1
    ).images
    image = images[0].save(f"test_2_out/[{sched_name}]_{strength:.2f}.png")

# DPMSolverSingleScheduler
sched_name = "dpm_single"
pipe.scheduler = DPMSolverSinglestepScheduler.from_config(
    scheduler_config, use_karras_sigmas=True)


for strength in torch.linspace(0.1,0.95,18):
    images = pipe(
        prompt="a fantasy landscape",
        negative_prompt=None,
        image=init_image,
        strength=strength,
        num_inference_steps=15,
        guidance_scale=10,
        num_images_per_prompt=1
    ).images
    image = images[0].save(f"test_2_out/[{sched_name}]_{strength:.2f}.png")

# UniPCMultistepScheduler
sched_name = "unipc"
pipe.scheduler = UniPCMultistepScheduler.from_config(
    scheduler_config, use_karras_sigmas=True)


for strength in torch.linspace(0.1,0.95,18):
    images = pipe(
        prompt="a fantasy landscape",
        negative_prompt=None,
        image=init_image,
        strength=strength,
        num_inference_steps=15,
        guidance_scale=10,
        num_images_per_prompt=1
    ).images
    image = images[0].save(f"test_2_out/[{sched_name}]_{strength:.2f}.png")

# DEISMultistepScheduler
sched_name = "deis"
pipe.scheduler = DEISMultistepScheduler.from_config(
    scheduler_config, use_karras_sigmas=True)

for strength in torch.linspace(0.1,0.95,18):
    images = pipe(
        prompt="a fantasy landscape",
        negative_prompt=None,
        image=init_image,
        strength=strength,
        num_inference_steps=15,
        guidance_scale=10,
        num_images_per_prompt=1
    ).images
    image = images[0].save(f"test_2_out/[{sched_name}]_{strength:.2f}.png")

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works for me (tested DPMSolverMultistepScheduler). The SDXL refiner step was broken, wasn't it? Did we receive any reports?

I'd suggest we merge this soon but then create tests (or fix them) as a followup, to ensure we catch similar problems in the future.

@pcuenca pcuenca requested a review from sayakpaul September 23, 2023 18:46
@yiyixuxu yiyixuxu merged commit 5b11c5d into main Sep 23, 2023
@rvorias
Copy link

rvorias commented Sep 24, 2023

Thanks for fixing this so fast!

@patrickvonplaten
Copy link
Contributor

Very clean fix - thanks a lot @yiyixuxu

@kashif kashif deleted the dpm-mstep-add-noise branch September 29, 2023 11:40
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* remove to _device() for sigmas

* update add_noise to use simgas

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
@yiyixuxu yiyixuxu mentioned this pull request Jan 8, 2024
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* remove to _device() for sigmas

* update add_noise to use simgas

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants