Skip to content

Commit 7275de1

Browse files
takuma104patrickvonplaten
authored andcommitted
Support ControlNet v1.1 shuffle properly (huggingface#3340)
* add inferring_controlnet_cond_batch * Revert "add inferring_controlnet_cond_batch" This reverts commit abe8d63. * set guess_mode to True whenever global_pool_conditions is True Co-authored-by: Patrick von Platen <[email protected]> * nit * add integration test --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 1e49073 commit 7275de1

File tree

3 files changed

+39
-1
lines changed

3 files changed

+39
-1
lines changed

src/diffusers/models/controlnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ def forward(
558558
mid_block_res_sample = self.controlnet_mid_block(sample)
559559

560560
# 6. scaling
561-
if guess_mode:
561+
if guess_mode and not self.config.global_pool_conditions:
562562
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
563563

564564
scales = scales * conditioning_scale

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,13 @@ def __call__(
930930
if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
931931
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)
932932

933+
global_pool_conditions = (
934+
self.controlnet.config.global_pool_conditions
935+
if isinstance(self.controlnet, ControlNetModel)
936+
else self.controlnet.nets[0].config.global_pool_conditions
937+
)
938+
guess_mode = guess_mode or global_pool_conditions
939+
933940
# 3. Encode input prompt
934941
prompt_embeds = self._encode_prompt(
935942
prompt,

tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,37 @@ def test_stable_diffusion_compile(self):
623623

624624
assert np.abs(expected_image - image).max() < 1e-1
625625

626+
def test_v11_shuffle_global_pool_conditions(self):
627+
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11e_sd15_shuffle")
628+
629+
pipe = StableDiffusionControlNetPipeline.from_pretrained(
630+
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
631+
)
632+
pipe.enable_model_cpu_offload()
633+
pipe.set_progress_bar_config(disable=None)
634+
635+
generator = torch.Generator(device="cpu").manual_seed(0)
636+
prompt = "New York"
637+
image = load_image(
638+
"https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle/resolve/main/images/control.png"
639+
)
640+
641+
output = pipe(
642+
prompt,
643+
image,
644+
generator=generator,
645+
output_type="np",
646+
num_inference_steps=3,
647+
guidance_scale=7.0,
648+
)
649+
650+
image = output.images[0]
651+
assert image.shape == (512, 640, 3)
652+
653+
image_slice = image[-3:, -3:, -1]
654+
expected_slice = np.array([0.1338, 0.1597, 0.1202, 0.1687, 0.1377, 0.1017, 0.2070, 0.1574, 0.1348])
655+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
656+
626657

627658
@slow
628659
@require_torch_gpu

0 commit comments

Comments
 (0)