@@ -658,8 +658,7 @@ def prepare_image(
658
658
num_images_per_prompt ,
659
659
device ,
660
660
dtype ,
661
- do_classifier_free_guidance = False ,
662
- guess_mode = False ,
661
+ inferring_controlnet_cond_batch = False ,
663
662
):
664
663
if not isinstance (image , torch .Tensor ):
665
664
if isinstance (image , PIL .Image .Image ):
@@ -696,7 +695,7 @@ def prepare_image(
696
695
697
696
image = image .to (device = device , dtype = dtype )
698
697
699
- if do_classifier_free_guidance and not guess_mode :
698
+ if not inferring_controlnet_cond_batch :
700
699
image = torch .cat ([image ] * 2 )
701
700
702
701
return image
@@ -898,7 +897,16 @@ def __call__(
898
897
if isinstance (self .controlnet , MultiControlNetModel ) and isinstance (controlnet_conditioning_scale , float ):
899
898
controlnet_conditioning_scale = [controlnet_conditioning_scale ] * len (self .controlnet .nets )
900
899
901
- # 3. Encode input prompt
900
+ # 3. Determination of whether to infer ControlNet using only for the conditional batch.
901
+ global_pool_conditions = False
902
+ if isinstance (self .controlnet , ControlNetModel ):
903
+ global_pool_conditions = self .controlnet .config .global_pool_conditions
904
+ else :
905
+ ... # TODO: Implement for MultiControlNetModel
906
+
907
+ inferring_controlnet_cond_batch = (guess_mode or global_pool_conditions ) and do_classifier_free_guidance
908
+
909
+ # 4. Encode input prompt
902
910
prompt_embeds = self ._encode_prompt (
903
911
prompt ,
904
912
device ,
@@ -909,7 +917,7 @@ def __call__(
909
917
negative_prompt_embeds = negative_prompt_embeds ,
910
918
)
911
919
912
- # 4 . Prepare image
920
+ # 5 . Prepare image
913
921
if isinstance (self .controlnet , ControlNetModel ):
914
922
image = self .prepare_image (
915
923
image = image ,
@@ -919,8 +927,7 @@ def __call__(
919
927
num_images_per_prompt = num_images_per_prompt ,
920
928
device = device ,
921
929
dtype = self .controlnet .dtype ,
922
- do_classifier_free_guidance = do_classifier_free_guidance ,
923
- guess_mode = guess_mode ,
930
+ inferring_controlnet_cond_batch = inferring_controlnet_cond_batch ,
924
931
)
925
932
elif isinstance (self .controlnet , MultiControlNetModel ):
926
933
images = []
@@ -934,8 +941,7 @@ def __call__(
934
941
num_images_per_prompt = num_images_per_prompt ,
935
942
device = device ,
936
943
dtype = self .controlnet .dtype ,
937
- do_classifier_free_guidance = do_classifier_free_guidance ,
938
- guess_mode = guess_mode ,
944
+ inferring_controlnet_cond_batch = inferring_controlnet_cond_batch ,
939
945
)
940
946
941
947
images .append (image_ )
@@ -944,11 +950,11 @@ def __call__(
944
950
else :
945
951
assert False
946
952
947
- # 5 . Prepare timesteps
953
+ # 6 . Prepare timesteps
948
954
self .scheduler .set_timesteps (num_inference_steps , device = device )
949
955
timesteps = self .scheduler .timesteps
950
956
951
- # 6 . Prepare latent variables
957
+ # 7 . Prepare latent variables
952
958
num_channels_latents = self .unet .config .in_channels
953
959
latents = self .prepare_latents (
954
960
batch_size * num_images_per_prompt ,
@@ -961,10 +967,10 @@ def __call__(
961
967
latents ,
962
968
)
963
969
964
- # 7 . Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
970
+ # 8 . Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
965
971
extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
966
972
967
- # 8 . Denoising loop
973
+ # 9 . Denoising loop
968
974
num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
969
975
with self .progress_bar (total = num_inference_steps ) as progress_bar :
970
976
for i , t in enumerate (timesteps ):
@@ -973,8 +979,8 @@ def __call__(
973
979
latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
974
980
975
981
# controlnet(s) inference
976
- if guess_mode and do_classifier_free_guidance :
977
- # Infer ControlNet only for the conditional batch.
982
+ if inferring_controlnet_cond_batch :
983
+ # Inferring ControlNet only for the conditional batch.
978
984
controlnet_latent_model_input = latents
979
985
controlnet_prompt_embeds = prompt_embeds .chunk (2 )[1 ]
980
986
else :
@@ -991,7 +997,7 @@ def __call__(
991
997
return_dict = False ,
992
998
)
993
999
994
- if guess_mode and do_classifier_free_guidance :
1000
+ if inferring_controlnet_cond_batch :
995
1001
# Infered ControlNet only for the conditional batch.
996
1002
# To apply the output of ControlNet to both the unconditional and conditional batches,
997
1003
# add 0 to the unconditional batch to keep it unchanged.
0 commit comments