1
1
# Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
2
2
3
3
import inspect
4
- from typing import Any , Callable , Dict , List , Optional , Union
4
+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
5
5
6
6
import numpy as np
7
7
import PIL .Image
10
10
11
11
from diffusers import AutoencoderKL , ControlNetModel , DiffusionPipeline , UNet2DConditionModel , logging
12
12
from diffusers .pipelines .stable_diffusion import StableDiffusionPipelineOutput , StableDiffusionSafetyChecker
13
+ from diffusers .pipelines .stable_diffusion .pipeline_stable_diffusion_controlnet import MultiControlNetModel
13
14
from diffusers .schedulers import KarrasDiffusionSchedulers
14
15
from diffusers .utils import (
15
16
PIL_INTERPOLATION ,
@@ -86,7 +87,14 @@ def prepare_image(image):
86
87
87
88
88
89
def prepare_controlnet_conditioning_image (
89
- controlnet_conditioning_image , width , height , batch_size , num_images_per_prompt , device , dtype
90
+ controlnet_conditioning_image ,
91
+ width ,
92
+ height ,
93
+ batch_size ,
94
+ num_images_per_prompt ,
95
+ device ,
96
+ dtype ,
97
+ do_classifier_free_guidance ,
90
98
):
91
99
if not isinstance (controlnet_conditioning_image , torch .Tensor ):
92
100
if isinstance (controlnet_conditioning_image , PIL .Image .Image ):
@@ -116,6 +124,9 @@ def prepare_controlnet_conditioning_image(
116
124
117
125
controlnet_conditioning_image = controlnet_conditioning_image .to (device = device , dtype = dtype )
118
126
127
+ if do_classifier_free_guidance :
128
+ controlnet_conditioning_image = torch .cat ([controlnet_conditioning_image ] * 2 )
129
+
119
130
return controlnet_conditioning_image
120
131
121
132
@@ -132,7 +143,7 @@ def __init__(
132
143
text_encoder : CLIPTextModel ,
133
144
tokenizer : CLIPTokenizer ,
134
145
unet : UNet2DConditionModel ,
135
- controlnet : ControlNetModel ,
146
+ controlnet : Union [ ControlNetModel , List [ ControlNetModel ], Tuple [ ControlNetModel ], MultiControlNetModel ] ,
136
147
scheduler : KarrasDiffusionSchedulers ,
137
148
safety_checker : StableDiffusionSafetyChecker ,
138
149
feature_extractor : CLIPImageProcessor ,
@@ -156,6 +167,9 @@ def __init__(
156
167
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
157
168
)
158
169
170
+ if isinstance (controlnet , (list , tuple )):
171
+ controlnet = MultiControlNetModel (controlnet )
172
+
159
173
self .register_modules (
160
174
vae = vae ,
161
175
text_encoder = text_encoder ,
@@ -424,6 +438,42 @@ def prepare_extra_step_kwargs(self, generator, eta):
424
438
extra_step_kwargs ["generator" ] = generator
425
439
return extra_step_kwargs
426
440
441
+ def check_controlnet_conditioning_image (self , image , prompt , prompt_embeds ):
442
+ image_is_pil = isinstance (image , PIL .Image .Image )
443
+ image_is_tensor = isinstance (image , torch .Tensor )
444
+ image_is_pil_list = isinstance (image , list ) and isinstance (image [0 ], PIL .Image .Image )
445
+ image_is_tensor_list = isinstance (image , list ) and isinstance (image [0 ], torch .Tensor )
446
+
447
+ if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list :
448
+ raise TypeError (
449
+ "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
450
+ )
451
+
452
+ if image_is_pil :
453
+ image_batch_size = 1
454
+ elif image_is_tensor :
455
+ image_batch_size = image .shape [0 ]
456
+ elif image_is_pil_list :
457
+ image_batch_size = len (image )
458
+ elif image_is_tensor_list :
459
+ image_batch_size = len (image )
460
+ else :
461
+ raise ValueError ("controlnet condition image is not valid" )
462
+
463
+ if prompt is not None and isinstance (prompt , str ):
464
+ prompt_batch_size = 1
465
+ elif prompt is not None and isinstance (prompt , list ):
466
+ prompt_batch_size = len (prompt )
467
+ elif prompt_embeds is not None :
468
+ prompt_batch_size = prompt_embeds .shape [0 ]
469
+ else :
470
+ raise ValueError ("prompt or prompt_embeds are not valid" )
471
+
472
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size :
473
+ raise ValueError (
474
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: { image_batch_size } , prompt batch size: { prompt_batch_size } "
475
+ )
476
+
427
477
def check_inputs (
428
478
self ,
429
479
prompt ,
@@ -438,6 +488,7 @@ def check_inputs(
438
488
strength = None ,
439
489
controlnet_guidance_start = None ,
440
490
controlnet_guidance_end = None ,
491
+ controlnet_conditioning_scale = None ,
441
492
):
442
493
if height % 8 != 0 or width % 8 != 0 :
443
494
raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
@@ -476,58 +527,51 @@ def check_inputs(
476
527
f" { negative_prompt_embeds .shape } ."
477
528
)
478
529
479
- controlnet_cond_image_is_pil = isinstance (controlnet_conditioning_image , PIL .Image .Image )
480
- controlnet_cond_image_is_tensor = isinstance (controlnet_conditioning_image , torch .Tensor )
481
- controlnet_cond_image_is_pil_list = isinstance (controlnet_conditioning_image , list ) and isinstance (
482
- controlnet_conditioning_image [0 ], PIL .Image .Image
483
- )
484
- controlnet_cond_image_is_tensor_list = isinstance (controlnet_conditioning_image , list ) and isinstance (
485
- controlnet_conditioning_image [0 ], torch .Tensor
486
- )
530
+ # check controlnet condition image
487
531
488
- if (
489
- not controlnet_cond_image_is_pil
490
- and not controlnet_cond_image_is_tensor
491
- and not controlnet_cond_image_is_pil_list
492
- and not controlnet_cond_image_is_tensor_list
493
- ):
494
- raise TypeError (
495
- "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
496
- )
532
+ if isinstance (self .controlnet , ControlNetModel ):
533
+ self .check_controlnet_conditioning_image (controlnet_conditioning_image , prompt , prompt_embeds )
534
+ elif isinstance (self .controlnet , MultiControlNetModel ):
535
+ if not isinstance (controlnet_conditioning_image , list ):
536
+ raise TypeError ("For multiple controlnets: `image` must be type `list`" )
497
537
498
- if controlnet_cond_image_is_pil :
499
- controlnet_cond_image_batch_size = 1
500
- elif controlnet_cond_image_is_tensor :
501
- controlnet_cond_image_batch_size = controlnet_conditioning_image .shape [0 ]
502
- elif controlnet_cond_image_is_pil_list :
503
- controlnet_cond_image_batch_size = len (controlnet_conditioning_image )
504
- elif controlnet_cond_image_is_tensor_list :
505
- controlnet_cond_image_batch_size = len (controlnet_conditioning_image )
538
+ if len (controlnet_conditioning_image ) != len (self .controlnet .nets ):
539
+ raise ValueError (
540
+ "For multiple controlnets: `image` must have the same length as the number of controlnets."
541
+ )
506
542
507
- if prompt is not None and isinstance (prompt , str ):
508
- prompt_batch_size = 1
509
- elif prompt is not None and isinstance (prompt , list ):
510
- prompt_batch_size = len (prompt )
511
- elif prompt_embeds is not None :
512
- prompt_batch_size = prompt_embeds .shape [0 ]
543
+ for image_ in controlnet_conditioning_image :
544
+ self .check_controlnet_conditioning_image (image_ , prompt , prompt_embeds )
545
+ else :
546
+ assert False
513
547
514
- if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size :
515
- raise ValueError (
516
- f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: { controlnet_cond_image_batch_size } , prompt batch size: { prompt_batch_size } "
517
- )
548
+ # Check `controlnet_conditioning_scale`
549
+
550
+ if isinstance (self .controlnet , ControlNetModel ):
551
+ if not isinstance (controlnet_conditioning_scale , float ):
552
+ raise TypeError ("For single controlnet: `controlnet_conditioning_scale` must be type `float`." )
553
+ elif isinstance (self .controlnet , MultiControlNetModel ):
554
+ if isinstance (controlnet_conditioning_scale , list ) and len (controlnet_conditioning_scale ) != len (
555
+ self .controlnet .nets
556
+ ):
557
+ raise ValueError (
558
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
559
+ " the same length as the number of controlnets"
560
+ )
561
+ else :
562
+ assert False
518
563
519
564
if isinstance (image , torch .Tensor ):
520
565
if image .ndim != 3 and image .ndim != 4 :
521
566
raise ValueError ("`image` must have 3 or 4 dimensions" )
522
567
523
- # if mask_image.ndim != 2 and mask_image.ndim != 3 and mask_image.ndim != 4:
524
- # raise ValueError("`mask_image` must have 2, 3, or 4 dimensions")
525
-
526
568
if image .ndim == 3 :
527
569
image_batch_size = 1
528
570
image_channels , image_height , image_width = image .shape
529
571
elif image .ndim == 4 :
530
572
image_batch_size , image_channels , image_height , image_width = image .shape
573
+ else :
574
+ assert False
531
575
532
576
if image_channels != 3 :
533
577
raise ValueError ("`image` must have 3 channels" )
@@ -659,7 +703,7 @@ def __call__(
659
703
callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
660
704
callback_steps : int = 1 ,
661
705
cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
662
- controlnet_conditioning_scale : float = 1.0 ,
706
+ controlnet_conditioning_scale : Union [ float , List [ float ]] = 1.0 ,
663
707
controlnet_guidance_start : float = 0.0 ,
664
708
controlnet_guidance_end : float = 1.0 ,
665
709
):
@@ -759,7 +803,6 @@ def __call__(
759
803
self .check_inputs (
760
804
prompt ,
761
805
image ,
762
- # mask_image,
763
806
controlnet_conditioning_image ,
764
807
height ,
765
808
width ,
@@ -770,6 +813,7 @@ def __call__(
770
813
strength ,
771
814
controlnet_guidance_start ,
772
815
controlnet_guidance_end ,
816
+ controlnet_conditioning_scale ,
773
817
)
774
818
775
819
# 2. Define call parameters
@@ -786,6 +830,9 @@ def __call__(
786
830
# corresponds to doing no classifier free guidance.
787
831
do_classifier_free_guidance = guidance_scale > 1.0
788
832
833
+ if isinstance (self .controlnet , MultiControlNetModel ) and isinstance (controlnet_conditioning_scale , float ):
834
+ controlnet_conditioning_scale = [controlnet_conditioning_scale ] * len (self .controlnet .nets )
835
+
789
836
# 3. Encode input prompt
790
837
prompt_embeds = self ._encode_prompt (
791
838
prompt ,
@@ -797,22 +844,41 @@ def __call__(
797
844
negative_prompt_embeds = negative_prompt_embeds ,
798
845
)
799
846
800
- # 4. Prepare mask, image, and controlnet_conditioning_image
847
+ # 4. Prepare image, and controlnet_conditioning_image
801
848
image = prepare_image (image )
802
849
803
- # mask_image = prepare_mask_image(mask_image)
850
+ # condition image(s)
851
+ if isinstance (self .controlnet , ControlNetModel ):
852
+ controlnet_conditioning_image = prepare_controlnet_conditioning_image (
853
+ controlnet_conditioning_image = controlnet_conditioning_image ,
854
+ width = width ,
855
+ height = height ,
856
+ batch_size = batch_size * num_images_per_prompt ,
857
+ num_images_per_prompt = num_images_per_prompt ,
858
+ device = device ,
859
+ dtype = self .controlnet .dtype ,
860
+ do_classifier_free_guidance = do_classifier_free_guidance ,
861
+ )
862
+ elif isinstance (self .controlnet , MultiControlNetModel ):
863
+ controlnet_conditioning_images = []
864
+
865
+ for image_ in controlnet_conditioning_image :
866
+ image_ = prepare_controlnet_conditioning_image (
867
+ controlnet_conditioning_image = image_ ,
868
+ width = width ,
869
+ height = height ,
870
+ batch_size = batch_size * num_images_per_prompt ,
871
+ num_images_per_prompt = num_images_per_prompt ,
872
+ device = device ,
873
+ dtype = self .controlnet .dtype ,
874
+ do_classifier_free_guidance = do_classifier_free_guidance ,
875
+ )
804
876
805
- controlnet_conditioning_image = prepare_controlnet_conditioning_image (
806
- controlnet_conditioning_image ,
807
- width ,
808
- height ,
809
- batch_size * num_images_per_prompt ,
810
- num_images_per_prompt ,
811
- device ,
812
- self .controlnet .dtype ,
813
- )
877
+ controlnet_conditioning_images .append (image_ )
814
878
815
- # masked_image = image * (mask_image < 0.5)
879
+ controlnet_conditioning_image = controlnet_conditioning_images
880
+ else :
881
+ assert False
816
882
817
883
# 5. Prepare timesteps
818
884
self .scheduler .set_timesteps (num_inference_steps , device = device )
@@ -830,9 +896,6 @@ def __call__(
830
896
generator ,
831
897
)
832
898
833
- if do_classifier_free_guidance :
834
- controlnet_conditioning_image = torch .cat ([controlnet_conditioning_image ] * 2 )
835
-
836
899
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
837
900
extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
838
901
@@ -862,15 +925,10 @@ def __call__(
862
925
t ,
863
926
encoder_hidden_states = prompt_embeds ,
864
927
controlnet_cond = controlnet_conditioning_image ,
928
+ conditioning_scale = controlnet_conditioning_scale ,
865
929
return_dict = False ,
866
930
)
867
931
868
- down_block_res_samples = [
869
- down_block_res_sample * controlnet_conditioning_scale
870
- for down_block_res_sample in down_block_res_samples
871
- ]
872
- mid_block_res_sample *= controlnet_conditioning_scale
873
-
874
932
# predict the noise residual
875
933
noise_pred = self .unet (
876
934
latent_model_input ,
0 commit comments