Skip to content

Commit 6f0d3ea

Browse files
Support ControlNet v1.1 shuffle properly (huggingface#3340)
* add inferring_controlnet_cond_batch * Revert "add inferring_controlnet_cond_batch" This reverts commit abe8d63. * set guess_mode to True whenever global_pool_conditions is True Co-authored-by: Patrick von Platen <[email protected]> * nit * add integration test --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent f9f041d commit 6f0d3ea

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

models/controlnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ def forward(
558558
mid_block_res_sample = self.controlnet_mid_block(sample)
559559

560560
# 6. scaling
561-
if guess_mode:
561+
if guess_mode and not self.config.global_pool_conditions:
562562
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
563563

564564
scales = scales * conditioning_scale

pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,13 @@ def __call__(
930930
if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
931931
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)
932932

933+
global_pool_conditions = (
934+
self.controlnet.config.global_pool_conditions
935+
if isinstance(self.controlnet, ControlNetModel)
936+
else self.controlnet.nets[0].config.global_pool_conditions
937+
)
938+
guess_mode = guess_mode or global_pool_conditions
939+
933940
# 3. Encode input prompt
934941
prompt_embeds = self._encode_prompt(
935942
prompt,

0 commit comments

Comments
 (0)