Skip to content

Commit b934215

Browse files
yiyixuxusayakpaulBenjaminBossanstevhliu
authored
[scheduler] support custom timesteps and sigmas (#7817)
* support custom sigmas and timesteps, dpm euler --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Benjamin Bossan <[email protected]> Co-authored-by: Steven Liu <[email protected]>
1 parent 5ed3abd commit b934215

33 files changed

+1116
-205
lines changed

docs/source/en/using-diffusers/schedulers.md

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,62 @@ images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).
212212
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
213213
```
214214

215+
## Custom Timestep Schedules
216+
217+
With all our schedulers, you can choose one of the popular timestep schedules using configurations such as `timestep_spacing`, `interpolation_type`, and `use_karras_sigmas`. Some schedulers also provide the flexibility to use a custom timestep schedule. You can use any list of arbitrary timesteps, we will use the AYS timestep schedule here as example. It is a set of 10-step optimized timestep schedules released by researchers from Nvidia that can achieve significantly better quality compared to the preset timestep schedules. You can read more about their research [here](https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/).
218+
219+
```python
220+
from diffusers.schedulers import AysSchedules
221+
sampling_schedule = AysSchedules["StableDiffusionXLTimesteps"]
222+
print(sampling_schedule)
223+
```
224+
```
225+
[999, 845, 730, 587, 443, 310, 193, 116, 53, 13]
226+
```
227+
228+
You can then create a pipeline and pass this custom timestep schedule to it as `timesteps`.
229+
230+
```python
231+
pipe = StableDiffusionXLPipeline.from_pretrained(
232+
"SG161222/RealVisXL_V4.0",
233+
torch_dtype=torch.float16,
234+
variant="fp16",
235+
).to("cuda")
236+
237+
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, algorithm_type="sde-dpmsolver++")
238+
239+
prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
240+
241+
generator = torch.Generator(device="cpu").manual_seed(2487854446)
242+
243+
image = pipe(
244+
prompt=prompt,
245+
negative_prompt="",
246+
generator=generator,
247+
timesteps=sampling_schedule,
248+
).images[0]
249+
```
250+
The generated image has better quality than the default linear timestep schedule for the same number of steps, and it is similar to the default timestep scheduler when running for 25 steps.
251+
252+
<div class="flex gap-4">
253+
<div>
254+
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ays.png"/>
255+
<figcaption class="mt-2 text-center text-sm text-gray-500">AYS timestep schedule 10 steps</figcaption>
256+
</div>
257+
<div>
258+
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/10.png"/>
259+
<figcaption class="mt-2 text-center text-sm text-gray-500">Linearly-spaced timestep schedule 10 steps</figcaption>
260+
</div>
261+
<div>
262+
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/25.png"/>
263+
<figcaption class="mt-2 text-center text-sm text-gray-500">Linearly-spaced timestep schedule 25 steps</figcaption>
264+
</div>
265+
</div>
266+
267+
> [!TIP]
268+
> 🤗 Diffusers currently only supports `timesteps` and `sigmas` for a selected list of schedulers and pipelines, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you want to extend feature to a scheduler and pipeline that does not currently support it!
269+
270+
215271
## Models
216272

217273
Models are loaded from the [`ModelMixin.from_pretrained`] method, which downloads and caches the latest version of the model weights and configurations. If the latest files are available in the local cache, [`~ModelMixin.from_pretrained`] reuses files in the cache instead of re-downloading them.

src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def retrieve_timesteps(
156156
num_inference_steps: Optional[int] = None,
157157
device: Optional[Union[str, torch.device]] = None,
158158
timesteps: Optional[List[int]] = None,
159+
sigmas: Optional[List[float]] = None,
159160
**kwargs,
160161
):
161162
"""
@@ -171,14 +172,18 @@ def retrieve_timesteps(
171172
device (`str` or `torch.device`, *optional*):
172173
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
173174
timesteps (`List[int]`, *optional*):
174-
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
175-
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
176-
must be `None`.
175+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
176+
`num_inference_steps` and `sigmas` must be `None`.
177+
sigmas (`List[float]`, *optional*):
178+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
179+
`num_inference_steps` and `timesteps` must be `None`.
177180
178181
Returns:
179182
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
180183
second element is the number of inference steps.
181184
"""
185+
if timesteps is not None and sigmas is not None:
186+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
182187
if timesteps is not None:
183188
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
184189
if not accepts_timesteps:
@@ -189,6 +194,16 @@ def retrieve_timesteps(
189194
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
190195
timesteps = scheduler.timesteps
191196
num_inference_steps = len(timesteps)
197+
elif sigmas is not None:
198+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
199+
if not accept_sigmas:
200+
raise ValueError(
201+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
202+
f" sigmas schedules. Please check whether you are using the correct scheduler."
203+
)
204+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
205+
timesteps = scheduler.timesteps
206+
num_inference_steps = len(timesteps)
192207
else:
193208
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
194209
timesteps = scheduler.timesteps
@@ -865,6 +880,7 @@ def __call__(
865880
width: Optional[int] = None,
866881
num_inference_steps: int = 50,
867882
timesteps: List[int] = None,
883+
sigmas: List[float] = None,
868884
denoising_end: Optional[float] = None,
869885
guidance_scale: float = 5.0,
870886
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -923,6 +939,10 @@ def __call__(
923939
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
924940
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
925941
passed will be used. Must be in descending order.
942+
sigmas (`List[float]`, *optional*):
943+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
944+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
945+
will be used.
926946
denoising_end (`float`, *optional*):
927947
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
928948
completed before it is intentionally prematurely terminated. As a result, the returned sample will
@@ -1104,7 +1124,9 @@ def __call__(
11041124
)
11051125

11061126
# 4. Prepare timesteps
1107-
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1127+
timesteps, num_inference_steps = retrieve_timesteps(
1128+
self.scheduler, num_inference_steps, device, timesteps, sigmas
1129+
)
11081130

11091131
# 5. Prepare latent variables
11101132
num_channels_latents = self.unet.config.in_channels

src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def retrieve_timesteps(
137137
num_inference_steps: Optional[int] = None,
138138
device: Optional[Union[str, torch.device]] = None,
139139
timesteps: Optional[List[int]] = None,
140+
sigmas: Optional[List[float]] = None,
140141
**kwargs,
141142
):
142143
"""
@@ -152,14 +153,18 @@ def retrieve_timesteps(
152153
device (`str` or `torch.device`, *optional*):
153154
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
154155
timesteps (`List[int]`, *optional*):
155-
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
156-
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
157-
must be `None`.
156+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
157+
`num_inference_steps` and `sigmas` must be `None`.
158+
sigmas (`List[float]`, *optional*):
159+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
160+
`num_inference_steps` and `timesteps` must be `None`.
158161
159162
Returns:
160163
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
161164
second element is the number of inference steps.
162165
"""
166+
if timesteps is not None and sigmas is not None:
167+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
163168
if timesteps is not None:
164169
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
165170
if not accepts_timesteps:
@@ -170,6 +175,16 @@ def retrieve_timesteps(
170175
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
171176
timesteps = scheduler.timesteps
172177
num_inference_steps = len(timesteps)
178+
elif sigmas is not None:
179+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
180+
if not accept_sigmas:
181+
raise ValueError(
182+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
183+
f" sigmas schedules. Please check whether you are using the correct scheduler."
184+
)
185+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
186+
timesteps = scheduler.timesteps
187+
num_inference_steps = len(timesteps)
173188
else:
174189
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
175190
timesteps = scheduler.timesteps
@@ -750,6 +765,7 @@ def __call__(
750765
width: Optional[int] = None,
751766
num_inference_steps: int = 50,
752767
timesteps: Optional[List[int]] = None,
768+
sigmas: Optional[List[float]] = None,
753769
guidance_scale: float = 7.5,
754770
strength: float = 0.8,
755771
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -783,6 +799,14 @@ def __call__(
783799
num_inference_steps (`int`, *optional*, defaults to 50):
784800
The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
785801
expense of slower inference.
802+
timesteps (`List[int]`, *optional*):
803+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
804+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
805+
passed will be used. Must be in descending order.
806+
sigmas (`List[float]`, *optional*):
807+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
808+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
809+
will be used.
786810
strength (`float`, *optional*, defaults to 0.8):
787811
Higher strength leads to more differences between original video and generated video.
788812
guidance_scale (`float`, *optional*, defaults to 7.5):
@@ -912,7 +936,9 @@ def __call__(
912936
)
913937

914938
# 4. Prepare timesteps
915-
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
939+
timesteps, num_inference_steps = retrieve_timesteps(
940+
self.scheduler, num_inference_steps, device, timesteps, sigmas
941+
)
916942
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
917943
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
918944

src/diffusers/pipelines/controlnet/pipeline_controlnet.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def retrieve_timesteps(
9797
num_inference_steps: Optional[int] = None,
9898
device: Optional[Union[str, torch.device]] = None,
9999
timesteps: Optional[List[int]] = None,
100+
sigmas: Optional[List[float]] = None,
100101
**kwargs,
101102
):
102103
"""
@@ -112,14 +113,18 @@ def retrieve_timesteps(
112113
device (`str` or `torch.device`, *optional*):
113114
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
114115
timesteps (`List[int]`, *optional*):
115-
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
116-
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
117-
must be `None`.
116+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
117+
`num_inference_steps` and `sigmas` must be `None`.
118+
sigmas (`List[float]`, *optional*):
119+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
120+
`num_inference_steps` and `timesteps` must be `None`.
118121
119122
Returns:
120123
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
121124
second element is the number of inference steps.
122125
"""
126+
if timesteps is not None and sigmas is not None:
127+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
123128
if timesteps is not None:
124129
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
125130
if not accepts_timesteps:
@@ -130,6 +135,16 @@ def retrieve_timesteps(
130135
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
131136
timesteps = scheduler.timesteps
132137
num_inference_steps = len(timesteps)
138+
elif sigmas is not None:
139+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
140+
if not accept_sigmas:
141+
raise ValueError(
142+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
143+
f" sigmas schedules. Please check whether you are using the correct scheduler."
144+
)
145+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
146+
timesteps = scheduler.timesteps
147+
num_inference_steps = len(timesteps)
133148
else:
134149
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
135150
timesteps = scheduler.timesteps
@@ -892,6 +907,7 @@ def __call__(
892907
width: Optional[int] = None,
893908
num_inference_steps: int = 50,
894909
timesteps: List[int] = None,
910+
sigmas: List[float] = None,
895911
guidance_scale: float = 7.5,
896912
negative_prompt: Optional[Union[str, List[str]]] = None,
897913
num_images_per_prompt: Optional[int] = 1,
@@ -941,6 +957,10 @@ def __call__(
941957
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
942958
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
943959
passed will be used. Must be in descending order.
960+
sigmas (`List[float]`, *optional*):
961+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
962+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
963+
will be used.
944964
guidance_scale (`float`, *optional*, defaults to 7.5):
945965
A higher guidance scale value encourages the model to generate images closely linked to the text
946966
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
@@ -1162,7 +1182,9 @@ def __call__(
11621182
assert False
11631183

11641184
# 5. Prepare timesteps
1165-
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1185+
timesteps, num_inference_steps = retrieve_timesteps(
1186+
self.scheduler, num_inference_steps, device, timesteps, sigmas
1187+
)
11661188
self._num_timesteps = len(timesteps)
11671189

11681190
# 6. Prepare latent variables

src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def retrieve_timesteps(
7979
num_inference_steps: Optional[int] = None,
8080
device: Optional[Union[str, torch.device]] = None,
8181
timesteps: Optional[List[int]] = None,
82+
sigmas: Optional[List[float]] = None,
8283
**kwargs,
8384
):
8485
"""
@@ -94,14 +95,18 @@ def retrieve_timesteps(
9495
device (`str` or `torch.device`, *optional*):
9596
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
9697
timesteps (`List[int]`, *optional*):
97-
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
98-
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
99-
must be `None`.
98+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
99+
`num_inference_steps` and `sigmas` must be `None`.
100+
sigmas (`List[float]`, *optional*):
101+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
102+
`num_inference_steps` and `timesteps` must be `None`.
100103
101104
Returns:
102105
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
103106
second element is the number of inference steps.
104107
"""
108+
if timesteps is not None and sigmas is not None:
109+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
105110
if timesteps is not None:
106111
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
107112
if not accepts_timesteps:
@@ -112,6 +117,16 @@ def retrieve_timesteps(
112117
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
113118
timesteps = scheduler.timesteps
114119
num_inference_steps = len(timesteps)
120+
elif sigmas is not None:
121+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
122+
if not accept_sigmas:
123+
raise ValueError(
124+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
125+
f" sigmas schedules. Please check whether you are using the correct scheduler."
126+
)
127+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
128+
timesteps = scheduler.timesteps
129+
num_inference_steps = len(timesteps)
115130
else:
116131
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
117132
timesteps = scheduler.timesteps
@@ -673,6 +688,7 @@ def __call__(
673688
width: Optional[int] = None,
674689
num_inference_steps: int = 50,
675690
timesteps: List[int] = None,
691+
sigmas: List[float] = None,
676692
guidance_scale: float = 7.5,
677693
negative_prompt: Optional[Union[str, List[str]]] = None,
678694
num_images_per_prompt: Optional[int] = 1,
@@ -848,7 +864,9 @@ def __call__(
848864
image_embeds = torch.cat([negative_image_embeds, image_embeds])
849865

850866
# 4. Prepare timesteps
851-
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
867+
timesteps, num_inference_steps = retrieve_timesteps(
868+
self.scheduler, num_inference_steps, device, timesteps, sigmas
869+
)
852870

853871
# 5. Prepare latent variables
854872
num_channels_latents = self.unet.config.in_channels

0 commit comments

Comments
 (0)