Skip to content

Commit 3745995

Browse files
authored
Merge pull request #3 from playgroundai/pai/add-denoising-start-end-cn-pipeline
Impl denoising_start/end for SDXL ControlNet pipeline
2 parents b07c1fe + a29d651 commit 3745995

File tree

1 file changed

+75
-5
lines changed

1 file changed

+75
-5
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -782,13 +782,40 @@ def prepare_control_image(
782782
return image
783783

784784
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
785-
def get_timesteps(self, num_inference_steps, strength, device):
785+
def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
786786
# get the original timestep using init_timestep
787-
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
787+
if denoising_start is None:
788+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
789+
t_start = max(num_inference_steps - init_timestep, 0)
790+
else:
791+
t_start = 0
788792

789-
t_start = max(num_inference_steps - init_timestep, 0)
790793
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
791794

795+
# Strength is irrelevant if we directly request a timestep to start at;
796+
# that is, strength is determined by the denoising_start instead.
797+
if denoising_start is not None:
798+
discrete_timestep_cutoff = int(
799+
round(
800+
self.scheduler.config.num_train_timesteps
801+
- (denoising_start * self.scheduler.config.num_train_timesteps)
802+
)
803+
)
804+
805+
num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
806+
if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
807+
# if the scheduler is a 2nd order scheduler we might have to do +1
808+
# because `num_inference_steps` might be even given that every timestep
809+
# (except the highest one) is duplicated. If `num_inference_steps` is even it would
810+
# mean that we cut the timesteps in the middle of the denoising step
811+
# (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
812+
# we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
813+
num_inference_steps = num_inference_steps + 1
814+
815+
# because t_n+1 >= t_n, we slice the timesteps starting from the end
816+
timesteps = timesteps[-num_inference_steps:]
817+
return timesteps, num_inference_steps
818+
792819
return timesteps, num_inference_steps - t_start
793820

794821
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents
@@ -979,6 +1006,14 @@ def do_classifier_free_guidance(self):
9791006
def cross_attention_kwargs(self):
9801007
return self._cross_attention_kwargs
9811008

1009+
@property
1010+
def denoising_end(self):
1011+
return self._denoising_end
1012+
1013+
@property
1014+
def denoising_start(self):
1015+
return self._denoising_start
1016+
9821017
@property
9831018
def num_timesteps(self):
9841019
return self._num_timesteps
@@ -995,6 +1030,8 @@ def __call__(
9951030
width: Optional[int] = None,
9961031
strength: float = 0.8,
9971032
num_inference_steps: int = 50,
1033+
denoising_start: Optional[float] = None,
1034+
denoising_end: Optional[float] = None,
9981035
guidance_scale: float = 5.0,
9991036
negative_prompt: Optional[Union[str, List[str]]] = None,
10001037
negative_prompt_2: Optional[Union[str, List[str]]] = None,
@@ -1236,6 +1273,8 @@ def __call__(
12361273
self._guidance_scale = guidance_scale
12371274
self._clip_skip = clip_skip
12381275
self._cross_attention_kwargs = cross_attention_kwargs
1276+
self._denoising_end = denoising_end
1277+
self._denoising_start = denoising_start
12391278

12401279
# 2. Define call parameters
12411280
if prompt is not None and isinstance(prompt, str):
@@ -1322,11 +1361,20 @@ def __call__(
13221361
assert False
13231362

13241363
# 5. Prepare timesteps
1364+
def denoising_value_valid(dnv):
1365+
return isinstance(self.denoising_end, float) and 0 < dnv < 1
1366+
13251367
self.scheduler.set_timesteps(num_inference_steps, device=device)
1326-
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
1368+
timesteps, num_inference_steps = self.get_timesteps(
1369+
num_inference_steps,
1370+
strength,
1371+
device,
1372+
denoising_start=self.denoising_start if denoising_value_valid else None,
1373+
)
13271374
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
13281375
self._num_timesteps = len(timesteps)
13291376

1377+
add_noise = True if self.denoising_start is None else False
13301378
# 6. Prepare latent variables
13311379
latents = self.prepare_latents(
13321380
image,
@@ -1336,7 +1384,7 @@ def __call__(
13361384
prompt_embeds.dtype,
13371385
device,
13381386
generator,
1339-
True,
1387+
add_noise,
13401388
)
13411389

13421390
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
@@ -1395,6 +1443,28 @@ def __call__(
13951443

13961444
# 8. Denoising loop
13971445
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1446+
1447+
if (
1448+
self.denoising_end is not None
1449+
and self.denoising_start is not None
1450+
and denoising_value_valid(self.denoising_end)
1451+
and denoising_value_valid(self.denoising_start)
1452+
and self.denoising_start >= self.denoising_end
1453+
):
1454+
raise ValueError(
1455+
f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
1456+
+ f" {self.denoising_end} when using type float."
1457+
)
1458+
elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
1459+
discrete_timestep_cutoff = int(
1460+
round(
1461+
self.scheduler.config.num_train_timesteps
1462+
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
1463+
)
1464+
)
1465+
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1466+
timesteps = timesteps[:num_inference_steps]
1467+
13981468
with self.progress_bar(total=num_inference_steps) as progress_bar:
13991469
for i, t in enumerate(timesteps):
14001470
# expand the latents if we are doing classifier free guidance

0 commit comments

Comments
 (0)