@@ -782,13 +782,40 @@ def prepare_control_image(
782
782
return image
783
783
784
784
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
785
- def get_timesteps (self , num_inference_steps , strength , device ):
785
+ def get_timesteps (self , num_inference_steps , strength , device , denoising_start = None ):
786
786
# get the original timestep using init_timestep
787
- init_timestep = min (int (num_inference_steps * strength ), num_inference_steps )
787
+ if denoising_start is None :
788
+ init_timestep = min (int (num_inference_steps * strength ), num_inference_steps )
789
+ t_start = max (num_inference_steps - init_timestep , 0 )
790
+ else :
791
+ t_start = 0
788
792
789
- t_start = max (num_inference_steps - init_timestep , 0 )
790
793
timesteps = self .scheduler .timesteps [t_start * self .scheduler .order :]
791
794
795
+ # Strength is irrelevant if we directly request a timestep to start at;
796
+ # that is, strength is determined by the denoising_start instead.
797
+ if denoising_start is not None :
798
+ discrete_timestep_cutoff = int (
799
+ round (
800
+ self .scheduler .config .num_train_timesteps
801
+ - (denoising_start * self .scheduler .config .num_train_timesteps )
802
+ )
803
+ )
804
+
805
+ num_inference_steps = (timesteps < discrete_timestep_cutoff ).sum ().item ()
806
+ if self .scheduler .order == 2 and num_inference_steps % 2 == 0 :
807
+ # if the scheduler is a 2nd order scheduler we might have to do +1
808
+ # because `num_inference_steps` might be even given that every timestep
809
+ # (except the highest one) is duplicated. If `num_inference_steps` is even it would
810
+ # mean that we cut the timesteps in the middle of the denoising step
811
+ # (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
812
+ # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
813
+ num_inference_steps = num_inference_steps + 1
814
+
815
+ # because t_n+1 >= t_n, we slice the timesteps starting from the end
816
+ timesteps = timesteps [- num_inference_steps :]
817
+ return timesteps , num_inference_steps
818
+
792
819
return timesteps , num_inference_steps - t_start
793
820
794
821
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents
@@ -979,6 +1006,14 @@ def do_classifier_free_guidance(self):
979
1006
def cross_attention_kwargs (self ):
980
1007
return self ._cross_attention_kwargs
981
1008
1009
+ @property
1010
+ def denoising_end (self ):
1011
+ return self ._denoising_end
1012
+
1013
+ @property
1014
+ def denoising_start (self ):
1015
+ return self ._denoising_start
1016
+
982
1017
@property
983
1018
def num_timesteps (self ):
984
1019
return self ._num_timesteps
@@ -995,6 +1030,8 @@ def __call__(
995
1030
width : Optional [int ] = None ,
996
1031
strength : float = 0.8 ,
997
1032
num_inference_steps : int = 50 ,
1033
+ denoising_start : Optional [float ] = None ,
1034
+ denoising_end : Optional [float ] = None ,
998
1035
guidance_scale : float = 5.0 ,
999
1036
negative_prompt : Optional [Union [str , List [str ]]] = None ,
1000
1037
negative_prompt_2 : Optional [Union [str , List [str ]]] = None ,
@@ -1236,6 +1273,8 @@ def __call__(
1236
1273
self ._guidance_scale = guidance_scale
1237
1274
self ._clip_skip = clip_skip
1238
1275
self ._cross_attention_kwargs = cross_attention_kwargs
1276
+ self ._denoising_end = denoising_end
1277
+ self ._denoising_start = denoising_start
1239
1278
1240
1279
# 2. Define call parameters
1241
1280
if prompt is not None and isinstance (prompt , str ):
@@ -1322,11 +1361,20 @@ def __call__(
1322
1361
assert False
1323
1362
1324
1363
# 5. Prepare timesteps
1364
+ def denoising_value_valid (dnv ):
1365
+ return isinstance (self .denoising_end , float ) and 0 < dnv < 1
1366
+
1325
1367
self .scheduler .set_timesteps (num_inference_steps , device = device )
1326
- timesteps , num_inference_steps = self .get_timesteps (num_inference_steps , strength , device )
1368
+ timesteps , num_inference_steps = self .get_timesteps (
1369
+ num_inference_steps ,
1370
+ strength ,
1371
+ device ,
1372
+ denoising_start = self .denoising_start if denoising_value_valid else None ,
1373
+ )
1327
1374
latent_timestep = timesteps [:1 ].repeat (batch_size * num_images_per_prompt )
1328
1375
self ._num_timesteps = len (timesteps )
1329
1376
1377
+ add_noise = True if self .denoising_start is None else False
1330
1378
# 6. Prepare latent variables
1331
1379
latents = self .prepare_latents (
1332
1380
image ,
@@ -1336,7 +1384,7 @@ def __call__(
1336
1384
prompt_embeds .dtype ,
1337
1385
device ,
1338
1386
generator ,
1339
- True ,
1387
+ add_noise ,
1340
1388
)
1341
1389
1342
1390
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
@@ -1395,6 +1443,28 @@ def __call__(
1395
1443
1396
1444
# 8. Denoising loop
1397
1445
num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
1446
+
1447
+ if (
1448
+ self .denoising_end is not None
1449
+ and self .denoising_start is not None
1450
+ and denoising_value_valid (self .denoising_end )
1451
+ and denoising_value_valid (self .denoising_start )
1452
+ and self .denoising_start >= self .denoising_end
1453
+ ):
1454
+ raise ValueError (
1455
+ f"`denoising_start`: { self .denoising_start } cannot be larger than or equal to `denoising_end`: "
1456
+ + f" { self .denoising_end } when using type float."
1457
+ )
1458
+ elif self .denoising_end is not None and denoising_value_valid (self .denoising_end ):
1459
+ discrete_timestep_cutoff = int (
1460
+ round (
1461
+ self .scheduler .config .num_train_timesteps
1462
+ - (self .denoising_end * self .scheduler .config .num_train_timesteps )
1463
+ )
1464
+ )
1465
+ num_inference_steps = len (list (filter (lambda ts : ts >= discrete_timestep_cutoff , timesteps )))
1466
+ timesteps = timesteps [:num_inference_steps ]
1467
+
1398
1468
with self .progress_bar (total = num_inference_steps ) as progress_bar :
1399
1469
for i , t in enumerate (timesteps ):
1400
1470
# expand the latents if we are doing classifier free guidance
0 commit comments