Skip to content

Impl denoising_start/end for SDXL ControlNet pipeline #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 22, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -782,13 +782,40 @@ def prepare_control_image(
return image

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
if denoising_start is None:
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
else:
t_start = 0

t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]

# Strength is irrelevant if we directly request a timestep to start at;
# that is, strength is determined by the denoising_start instead.
if denoising_start is not None:
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
- (denoising_start * self.scheduler.config.num_train_timesteps)
)
)

num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
# if the scheduler is a 2nd order scheduler we might have to do +1
# because `num_inference_steps` might be even given that every timestep
# (except the highest one) is duplicated. If `num_inference_steps` is even it would
# mean that we cut the timesteps in the middle of the denoising step
# (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
# we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
num_inference_steps = num_inference_steps + 1

# because t_n+1 >= t_n, we slice the timesteps starting from the end
timesteps = timesteps[-num_inference_steps:]
return timesteps, num_inference_steps

return timesteps, num_inference_steps - t_start

# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents
Expand Down Expand Up @@ -979,6 +1006,14 @@ def do_classifier_free_guidance(self):
def cross_attention_kwargs(self):
return self._cross_attention_kwargs

@property
def denoising_end(self):
return self._denoising_end

@property
def denoising_start(self):
return self._denoising_start

@property
def num_timesteps(self):
return self._num_timesteps
Expand All @@ -995,6 +1030,8 @@ def __call__(
width: Optional[int] = None,
strength: float = 0.8,
num_inference_steps: int = 50,
denoising_start: Optional[float] = None,
denoising_end: Optional[float] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
Expand Down Expand Up @@ -1236,6 +1273,8 @@ def __call__(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._denoising_end = denoising_end
self._denoising_start = denoising_start

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
Expand Down Expand Up @@ -1322,11 +1361,20 @@ def __call__(
assert False

# 5. Prepare timesteps
def denoising_value_valid(dnv):
return isinstance(self.denoising_end, float) and 0 < dnv < 1

self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps,
strength,
device,
denoising_start=self.denoising_start if denoising_value_valid else None,
)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
self._num_timesteps = len(timesteps)

add_noise = True if self.denoising_start is None else False
# 6. Prepare latent variables
latents = self.prepare_latents(
image,
Expand All @@ -1336,7 +1384,7 @@ def __call__(
prompt_embeds.dtype,
device,
generator,
True,
add_noise,
)

# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
Expand Down Expand Up @@ -1395,6 +1443,28 @@ def __call__(

# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order

if (
self.denoising_end is not None
and self.denoising_start is not None
and denoising_value_valid(self.denoising_end)
and denoising_value_valid(self.denoising_start)
and self.denoising_start >= self.denoising_end
):
raise ValueError(
f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
+ f" {self.denoising_end} when using type float."
)
elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
)
)
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps]

with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
Expand Down