Skip to content

DPMSolverMultistep add rescale_betas_zero_snr #7097

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,42 @@ def alpha_bar_fn(t):
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)

# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)


Args:
betas (`torch.FloatTensor`):
the betas that the scheduler is being initialized with.

Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_bar_sqrt = alphas_cumprod.sqrt()

# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()

# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T

# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas

return betas


class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
Expand Down Expand Up @@ -144,6 +180,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
An offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
Diffusion.
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
Expand Down Expand Up @@ -173,6 +213,7 @@ def __init__(
variance_type: Optional[str] = None,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
):
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
Expand All @@ -191,8 +232,17 @@ def __init__(
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)

self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

if rescale_betas_zero_snr:
# Close to 0 without being 0 so first sigma is not inf
# FP16 smallest positive subnormal works well here
self.alphas_cumprod[-1] = 2**-24

# Currently we only support VP-type noise schedule
self.alpha_t = torch.sqrt(self.alphas_cumprod)
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
Expand Down Expand Up @@ -880,6 +930,11 @@ def step(
if self.step_index is None:
self._init_step_index(timestep)

# store old dtype because model_output isn't always the same it seems
return_dtype = sample.dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean model_output.dtype isn't always the same as sample.dtype before the upcast?

Copy link
Contributor Author

@Beinsezii Beinsezii Feb 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turns out that was because model_output = self.convert_model_output(model_output, sample=sample) ends up creating a shadowed tensor cast to the sample's dtype. I moved the sample upcast after this call so return_type is no longer needed. Outputs are the same.

# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)

# Improve numerical stability for small number of steps
lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
self.config.euler_at_final
Expand All @@ -897,7 +952,7 @@ def step(

if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
noise = randn_tensor(
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
)
else:
noise = None
Expand All @@ -912,6 +967,9 @@ def step(
if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1

# Cast sample back to expected dtype
prev_sample = prev_sample.to(return_dtype)

# upon completion increase step index by one
self._step_index += 1

Expand Down
4 changes: 4 additions & 0 deletions tests/schedulers/test_scheduler_dpm_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,10 @@ def test_inference_steps(self):
for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:
self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)

def test_rescale_betas_zero_snr(self):
for rescale_betas_zero_snr in [True, False]:
self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr)

def test_full_loop_no_noise(self):
sample = self.full_loop()
result_mean = torch.mean(torch.abs(sample))
Expand Down