Skip to content

[2064]: Add Karras to DPMSolverMultistepScheduler #3001

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/diffusers/schedulers/scheduling_deis_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ def __init__(
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0

# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_timesteps
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it's no longer the same (as we added use_karras_sigmas). So, in a separate PR, we can add support for Karras sigmans to this scheduler too.

Cc: @patrickvonplaten

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense!

def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
Expand Down
52 changes: 51 additions & 1 deletion src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
lower_order_final (`bool`, default `True`):
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.

use_karras_sigmas (`bool`, *optional*, defaults to `False`):
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
Expand All @@ -136,6 +139,7 @@ def __init__(
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
use_karras_sigmas: Optional[bool] = False,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
Expand Down Expand Up @@ -181,6 +185,7 @@ def __init__(
self.timesteps = torch.from_numpy(timesteps)
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
self.use_karras_sigmas = use_karras_sigmas

def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Expand All @@ -199,6 +204,13 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
.astype(np.int64)
)

if self.use_karras_sigmas:
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
timesteps = np.flip(timesteps).copy().astype(np.int64)

# when num_inference_steps == num_train_timesteps, we can end up with
# duplicates in timesteps.
_, unique_indices = np.unique(timesteps, return_index=True)
Expand Down Expand Up @@ -248,6 +260,44 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:

return sample

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
log_sigma = np.log(sigma)

# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]

# get sigmas range
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1

low = log_sigmas[low_idx]
high = log_sigmas[high_idx]

# interpolate sigmas
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)

# transform interpolation to time range
t = (1 - w) * low_idx + w * high_idx
t = t.reshape(sigma.shape)
return t

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""

sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()

rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas

def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
) -> torch.FloatTensor:
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/schedulers/scheduling_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
)

if self.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas)
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])

sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
Expand Down Expand Up @@ -241,14 +241,14 @@ def _sigma_to_t(self, sigma, log_sigmas):
return t

# Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor:
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""

sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()

rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, self.num_inference_steps)
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
Expand Down
6 changes: 6 additions & 0 deletions tests/schedulers/test_scheduler_dpm_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,12 @@ def test_full_loop_with_v_prediction(self):

assert abs(result_mean.item() - 0.2251) < 1e-3

def test_full_loop_with_karras_and_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True)
result_mean = torch.mean(torch.abs(sample))

assert abs(result_mean.item() - 0.2096) < 1e-3

def test_switch(self):
# make sure that iterating over schedulers with same config names gives same results
# for defaults
Expand Down