|
134 | 134 | }
|
135 | 135 |
|
136 | 136 |
|
| 137 | +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps |
| 138 | +def retrieve_timesteps( |
| 139 | + scheduler, |
| 140 | + num_inference_steps: Optional[int] = None, |
| 141 | + device: Optional[Union[str, torch.device]] = None, |
| 142 | + timesteps: Optional[List[int]] = None, |
| 143 | + **kwargs, |
| 144 | +): |
| 145 | + """ |
| 146 | + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles |
| 147 | + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. |
| 148 | +
|
| 149 | + Args: |
| 150 | + scheduler (`SchedulerMixin`): |
| 151 | + The scheduler to get timesteps from. |
| 152 | + num_inference_steps (`int`): |
| 153 | + The number of diffusion steps used when generating samples with a pre-trained model. If used, |
| 154 | + `timesteps` must be `None`. |
| 155 | + device (`str` or `torch.device`, *optional*): |
| 156 | + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. |
| 157 | + timesteps (`List[int]`, *optional*): |
| 158 | + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default |
| 159 | + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` |
| 160 | + must be `None`. |
| 161 | +
|
| 162 | + Returns: |
| 163 | + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the |
| 164 | + second element is the number of inference steps. |
| 165 | + """ |
| 166 | + if timesteps is not None: |
| 167 | + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
| 168 | + if not accepts_timesteps: |
| 169 | + raise ValueError( |
| 170 | + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
| 171 | + f" timestep schedules. Please check whether you are using the correct scheduler." |
| 172 | + ) |
| 173 | + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) |
| 174 | + timesteps = scheduler.timesteps |
| 175 | + num_inference_steps = len(timesteps) |
| 176 | + else: |
| 177 | + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
| 178 | + timesteps = scheduler.timesteps |
| 179 | + return timesteps, num_inference_steps |
| 180 | + |
| 181 | + |
137 | 182 | class PixArtAlphaPipeline(DiffusionPipeline):
|
138 | 183 | r"""
|
139 | 184 | Pipeline for text-to-image generation using PixArt-Alpha.
|
@@ -783,8 +828,7 @@ def __call__(
|
783 | 828 | prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
784 | 829 |
|
785 | 830 | # 4. Prepare timesteps
|
786 |
| - self.scheduler.set_timesteps(num_inference_steps, device=device) |
787 |
| - timesteps = self.scheduler.timesteps |
| 831 | + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) |
788 | 832 |
|
789 | 833 | # 5. Prepare latents.
|
790 | 834 | latent_channels = self.transformer.config.in_channels
|
|
0 commit comments