Skip to content

Commit abe8d63

Browse files
committed
add inferring_controlnet_cond_batch
1 parent 364d59d commit abe8d63

File tree

1 file changed

+22
-16
lines changed

1 file changed

+22
-16
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -658,8 +658,7 @@ def prepare_image(
658658
num_images_per_prompt,
659659
device,
660660
dtype,
661-
do_classifier_free_guidance=False,
662-
guess_mode=False,
661+
inferring_controlnet_cond_batch=False,
663662
):
664663
if not isinstance(image, torch.Tensor):
665664
if isinstance(image, PIL.Image.Image):
@@ -696,7 +695,7 @@ def prepare_image(
696695

697696
image = image.to(device=device, dtype=dtype)
698697

699-
if do_classifier_free_guidance and not guess_mode:
698+
if not inferring_controlnet_cond_batch:
700699
image = torch.cat([image] * 2)
701700

702701
return image
@@ -898,7 +897,16 @@ def __call__(
898897
if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
899898
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)
900899

901-
# 3. Encode input prompt
900+
# 3. Determination of whether to infer ControlNet using only for the conditional batch.
901+
global_pool_conditions = False
902+
if isinstance(self.controlnet, ControlNetModel):
903+
global_pool_conditions = self.controlnet.config.global_pool_conditions
904+
else:
905+
... # TODO: Implement for MultiControlNetModel
906+
907+
inferring_controlnet_cond_batch = (guess_mode or global_pool_conditions) and do_classifier_free_guidance
908+
909+
# 4. Encode input prompt
902910
prompt_embeds = self._encode_prompt(
903911
prompt,
904912
device,
@@ -909,7 +917,7 @@ def __call__(
909917
negative_prompt_embeds=negative_prompt_embeds,
910918
)
911919

912-
# 4. Prepare image
920+
# 5. Prepare image
913921
if isinstance(self.controlnet, ControlNetModel):
914922
image = self.prepare_image(
915923
image=image,
@@ -919,8 +927,7 @@ def __call__(
919927
num_images_per_prompt=num_images_per_prompt,
920928
device=device,
921929
dtype=self.controlnet.dtype,
922-
do_classifier_free_guidance=do_classifier_free_guidance,
923-
guess_mode=guess_mode,
930+
inferring_controlnet_cond_batch=inferring_controlnet_cond_batch,
924931
)
925932
elif isinstance(self.controlnet, MultiControlNetModel):
926933
images = []
@@ -934,8 +941,7 @@ def __call__(
934941
num_images_per_prompt=num_images_per_prompt,
935942
device=device,
936943
dtype=self.controlnet.dtype,
937-
do_classifier_free_guidance=do_classifier_free_guidance,
938-
guess_mode=guess_mode,
944+
inferring_controlnet_cond_batch=inferring_controlnet_cond_batch,
939945
)
940946

941947
images.append(image_)
@@ -944,11 +950,11 @@ def __call__(
944950
else:
945951
assert False
946952

947-
# 5. Prepare timesteps
953+
# 6. Prepare timesteps
948954
self.scheduler.set_timesteps(num_inference_steps, device=device)
949955
timesteps = self.scheduler.timesteps
950956

951-
# 6. Prepare latent variables
957+
# 7. Prepare latent variables
952958
num_channels_latents = self.unet.config.in_channels
953959
latents = self.prepare_latents(
954960
batch_size * num_images_per_prompt,
@@ -961,10 +967,10 @@ def __call__(
961967
latents,
962968
)
963969

964-
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
970+
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
965971
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
966972

967-
# 8. Denoising loop
973+
# 9. Denoising loop
968974
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
969975
with self.progress_bar(total=num_inference_steps) as progress_bar:
970976
for i, t in enumerate(timesteps):
@@ -973,8 +979,8 @@ def __call__(
973979
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
974980

975981
# controlnet(s) inference
976-
if guess_mode and do_classifier_free_guidance:
977-
# Infer ControlNet only for the conditional batch.
982+
if inferring_controlnet_cond_batch:
983+
# Inferring ControlNet only for the conditional batch.
978984
controlnet_latent_model_input = latents
979985
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
980986
else:
@@ -991,7 +997,7 @@ def __call__(
991997
return_dict=False,
992998
)
993999

994-
if guess_mode and do_classifier_free_guidance:
1000+
if inferring_controlnet_cond_batch:
9951001
# Infered ControlNet only for the conditional batch.
9961002
# To apply the output of ControlNet to both the unconditional and conditional batches,
9971003
# add 0 to the unconditional batch to keep it unchanged.

0 commit comments

Comments
 (0)