Skip to content

Commit 4520e12

Browse files
adapt PixArtAlphaPipeline for pixart-lcm model (#5974)
* adapt PixArtAlphaPipeline for pixart-lcm model * remove original_inference_steps from __call__ --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 6182604 commit 4520e12

File tree

1 file changed

+46
-2
lines changed

1 file changed

+46
-2
lines changed

src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,51 @@
134134
}
135135

136136

137+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
138+
def retrieve_timesteps(
139+
scheduler,
140+
num_inference_steps: Optional[int] = None,
141+
device: Optional[Union[str, torch.device]] = None,
142+
timesteps: Optional[List[int]] = None,
143+
**kwargs,
144+
):
145+
"""
146+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
147+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
148+
149+
Args:
150+
scheduler (`SchedulerMixin`):
151+
The scheduler to get timesteps from.
152+
num_inference_steps (`int`):
153+
The number of diffusion steps used when generating samples with a pre-trained model. If used,
154+
`timesteps` must be `None`.
155+
device (`str` or `torch.device`, *optional*):
156+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
157+
timesteps (`List[int]`, *optional*):
158+
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
159+
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
160+
must be `None`.
161+
162+
Returns:
163+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
164+
second element is the number of inference steps.
165+
"""
166+
if timesteps is not None:
167+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
168+
if not accepts_timesteps:
169+
raise ValueError(
170+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
171+
f" timestep schedules. Please check whether you are using the correct scheduler."
172+
)
173+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
174+
timesteps = scheduler.timesteps
175+
num_inference_steps = len(timesteps)
176+
else:
177+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
178+
timesteps = scheduler.timesteps
179+
return timesteps, num_inference_steps
180+
181+
137182
class PixArtAlphaPipeline(DiffusionPipeline):
138183
r"""
139184
Pipeline for text-to-image generation using PixArt-Alpha.
@@ -783,8 +828,7 @@ def __call__(
783828
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
784829

785830
# 4. Prepare timesteps
786-
self.scheduler.set_timesteps(num_inference_steps, device=device)
787-
timesteps = self.scheduler.timesteps
831+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
788832

789833
# 5. Prepare latents.
790834
latent_channels = self.transformer.config.in_channels

0 commit comments

Comments
 (0)