Skip to content

Commit 1d033a9

Browse files
mikegartsmishka
and
mishka
authored
img2img.multiple.controlnets.pipeline (#2833)
* img2img.multiple.controlnets.pipeline * remove comments --------- Co-authored-by: mishka <[email protected]>
1 parent 4960976 commit 1d033a9

File tree

1 file changed

+122
-64
lines changed

1 file changed

+122
-64
lines changed

examples/community/stable_diffusion_controlnet_img2img.py

Lines changed: 122 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
22

33
import inspect
4-
from typing import Any, Callable, Dict, List, Optional, Union
4+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
55

66
import numpy as np
77
import PIL.Image
@@ -10,6 +10,7 @@
1010

1111
from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging
1212
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
13+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
1314
from diffusers.schedulers import KarrasDiffusionSchedulers
1415
from diffusers.utils import (
1516
PIL_INTERPOLATION,
@@ -86,7 +87,14 @@ def prepare_image(image):
8687

8788

8889
def prepare_controlnet_conditioning_image(
89-
controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype
90+
controlnet_conditioning_image,
91+
width,
92+
height,
93+
batch_size,
94+
num_images_per_prompt,
95+
device,
96+
dtype,
97+
do_classifier_free_guidance,
9098
):
9199
if not isinstance(controlnet_conditioning_image, torch.Tensor):
92100
if isinstance(controlnet_conditioning_image, PIL.Image.Image):
@@ -116,6 +124,9 @@ def prepare_controlnet_conditioning_image(
116124

117125
controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)
118126

127+
if do_classifier_free_guidance:
128+
controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
129+
119130
return controlnet_conditioning_image
120131

121132

@@ -132,7 +143,7 @@ def __init__(
132143
text_encoder: CLIPTextModel,
133144
tokenizer: CLIPTokenizer,
134145
unet: UNet2DConditionModel,
135-
controlnet: ControlNetModel,
146+
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
136147
scheduler: KarrasDiffusionSchedulers,
137148
safety_checker: StableDiffusionSafetyChecker,
138149
feature_extractor: CLIPImageProcessor,
@@ -156,6 +167,9 @@ def __init__(
156167
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
157168
)
158169

170+
if isinstance(controlnet, (list, tuple)):
171+
controlnet = MultiControlNetModel(controlnet)
172+
159173
self.register_modules(
160174
vae=vae,
161175
text_encoder=text_encoder,
@@ -424,6 +438,42 @@ def prepare_extra_step_kwargs(self, generator, eta):
424438
extra_step_kwargs["generator"] = generator
425439
return extra_step_kwargs
426440

441+
def check_controlnet_conditioning_image(self, image, prompt, prompt_embeds):
442+
image_is_pil = isinstance(image, PIL.Image.Image)
443+
image_is_tensor = isinstance(image, torch.Tensor)
444+
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
445+
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
446+
447+
if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
448+
raise TypeError(
449+
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
450+
)
451+
452+
if image_is_pil:
453+
image_batch_size = 1
454+
elif image_is_tensor:
455+
image_batch_size = image.shape[0]
456+
elif image_is_pil_list:
457+
image_batch_size = len(image)
458+
elif image_is_tensor_list:
459+
image_batch_size = len(image)
460+
else:
461+
raise ValueError("controlnet condition image is not valid")
462+
463+
if prompt is not None and isinstance(prompt, str):
464+
prompt_batch_size = 1
465+
elif prompt is not None and isinstance(prompt, list):
466+
prompt_batch_size = len(prompt)
467+
elif prompt_embeds is not None:
468+
prompt_batch_size = prompt_embeds.shape[0]
469+
else:
470+
raise ValueError("prompt or prompt_embeds are not valid")
471+
472+
if image_batch_size != 1 and image_batch_size != prompt_batch_size:
473+
raise ValueError(
474+
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
475+
)
476+
427477
def check_inputs(
428478
self,
429479
prompt,
@@ -438,6 +488,7 @@ def check_inputs(
438488
strength=None,
439489
controlnet_guidance_start=None,
440490
controlnet_guidance_end=None,
491+
controlnet_conditioning_scale=None,
441492
):
442493
if height % 8 != 0 or width % 8 != 0:
443494
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -476,58 +527,51 @@ def check_inputs(
476527
f" {negative_prompt_embeds.shape}."
477528
)
478529

479-
controlnet_cond_image_is_pil = isinstance(controlnet_conditioning_image, PIL.Image.Image)
480-
controlnet_cond_image_is_tensor = isinstance(controlnet_conditioning_image, torch.Tensor)
481-
controlnet_cond_image_is_pil_list = isinstance(controlnet_conditioning_image, list) and isinstance(
482-
controlnet_conditioning_image[0], PIL.Image.Image
483-
)
484-
controlnet_cond_image_is_tensor_list = isinstance(controlnet_conditioning_image, list) and isinstance(
485-
controlnet_conditioning_image[0], torch.Tensor
486-
)
530+
# check controlnet condition image
487531

488-
if (
489-
not controlnet_cond_image_is_pil
490-
and not controlnet_cond_image_is_tensor
491-
and not controlnet_cond_image_is_pil_list
492-
and not controlnet_cond_image_is_tensor_list
493-
):
494-
raise TypeError(
495-
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
496-
)
532+
if isinstance(self.controlnet, ControlNetModel):
533+
self.check_controlnet_conditioning_image(controlnet_conditioning_image, prompt, prompt_embeds)
534+
elif isinstance(self.controlnet, MultiControlNetModel):
535+
if not isinstance(controlnet_conditioning_image, list):
536+
raise TypeError("For multiple controlnets: `image` must be type `list`")
497537

498-
if controlnet_cond_image_is_pil:
499-
controlnet_cond_image_batch_size = 1
500-
elif controlnet_cond_image_is_tensor:
501-
controlnet_cond_image_batch_size = controlnet_conditioning_image.shape[0]
502-
elif controlnet_cond_image_is_pil_list:
503-
controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
504-
elif controlnet_cond_image_is_tensor_list:
505-
controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
538+
if len(controlnet_conditioning_image) != len(self.controlnet.nets):
539+
raise ValueError(
540+
"For multiple controlnets: `image` must have the same length as the number of controlnets."
541+
)
506542

507-
if prompt is not None and isinstance(prompt, str):
508-
prompt_batch_size = 1
509-
elif prompt is not None and isinstance(prompt, list):
510-
prompt_batch_size = len(prompt)
511-
elif prompt_embeds is not None:
512-
prompt_batch_size = prompt_embeds.shape[0]
543+
for image_ in controlnet_conditioning_image:
544+
self.check_controlnet_conditioning_image(image_, prompt, prompt_embeds)
545+
else:
546+
assert False
513547

514-
if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size:
515-
raise ValueError(
516-
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {controlnet_cond_image_batch_size}, prompt batch size: {prompt_batch_size}"
517-
)
548+
# Check `controlnet_conditioning_scale`
549+
550+
if isinstance(self.controlnet, ControlNetModel):
551+
if not isinstance(controlnet_conditioning_scale, float):
552+
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
553+
elif isinstance(self.controlnet, MultiControlNetModel):
554+
if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
555+
self.controlnet.nets
556+
):
557+
raise ValueError(
558+
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
559+
" the same length as the number of controlnets"
560+
)
561+
else:
562+
assert False
518563

519564
if isinstance(image, torch.Tensor):
520565
if image.ndim != 3 and image.ndim != 4:
521566
raise ValueError("`image` must have 3 or 4 dimensions")
522567

523-
# if mask_image.ndim != 2 and mask_image.ndim != 3 and mask_image.ndim != 4:
524-
# raise ValueError("`mask_image` must have 2, 3, or 4 dimensions")
525-
526568
if image.ndim == 3:
527569
image_batch_size = 1
528570
image_channels, image_height, image_width = image.shape
529571
elif image.ndim == 4:
530572
image_batch_size, image_channels, image_height, image_width = image.shape
573+
else:
574+
assert False
531575

532576
if image_channels != 3:
533577
raise ValueError("`image` must have 3 channels")
@@ -659,7 +703,7 @@ def __call__(
659703
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
660704
callback_steps: int = 1,
661705
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
662-
controlnet_conditioning_scale: float = 1.0,
706+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
663707
controlnet_guidance_start: float = 0.0,
664708
controlnet_guidance_end: float = 1.0,
665709
):
@@ -759,7 +803,6 @@ def __call__(
759803
self.check_inputs(
760804
prompt,
761805
image,
762-
# mask_image,
763806
controlnet_conditioning_image,
764807
height,
765808
width,
@@ -770,6 +813,7 @@ def __call__(
770813
strength,
771814
controlnet_guidance_start,
772815
controlnet_guidance_end,
816+
controlnet_conditioning_scale,
773817
)
774818

775819
# 2. Define call parameters
@@ -786,6 +830,9 @@ def __call__(
786830
# corresponds to doing no classifier free guidance.
787831
do_classifier_free_guidance = guidance_scale > 1.0
788832

833+
if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
834+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)
835+
789836
# 3. Encode input prompt
790837
prompt_embeds = self._encode_prompt(
791838
prompt,
@@ -797,22 +844,41 @@ def __call__(
797844
negative_prompt_embeds=negative_prompt_embeds,
798845
)
799846

800-
# 4. Prepare mask, image, and controlnet_conditioning_image
847+
# 4. Prepare image, and controlnet_conditioning_image
801848
image = prepare_image(image)
802849

803-
# mask_image = prepare_mask_image(mask_image)
850+
# condition image(s)
851+
if isinstance(self.controlnet, ControlNetModel):
852+
controlnet_conditioning_image = prepare_controlnet_conditioning_image(
853+
controlnet_conditioning_image=controlnet_conditioning_image,
854+
width=width,
855+
height=height,
856+
batch_size=batch_size * num_images_per_prompt,
857+
num_images_per_prompt=num_images_per_prompt,
858+
device=device,
859+
dtype=self.controlnet.dtype,
860+
do_classifier_free_guidance=do_classifier_free_guidance,
861+
)
862+
elif isinstance(self.controlnet, MultiControlNetModel):
863+
controlnet_conditioning_images = []
864+
865+
for image_ in controlnet_conditioning_image:
866+
image_ = prepare_controlnet_conditioning_image(
867+
controlnet_conditioning_image=image_,
868+
width=width,
869+
height=height,
870+
batch_size=batch_size * num_images_per_prompt,
871+
num_images_per_prompt=num_images_per_prompt,
872+
device=device,
873+
dtype=self.controlnet.dtype,
874+
do_classifier_free_guidance=do_classifier_free_guidance,
875+
)
804876

805-
controlnet_conditioning_image = prepare_controlnet_conditioning_image(
806-
controlnet_conditioning_image,
807-
width,
808-
height,
809-
batch_size * num_images_per_prompt,
810-
num_images_per_prompt,
811-
device,
812-
self.controlnet.dtype,
813-
)
877+
controlnet_conditioning_images.append(image_)
814878

815-
# masked_image = image * (mask_image < 0.5)
879+
controlnet_conditioning_image = controlnet_conditioning_images
880+
else:
881+
assert False
816882

817883
# 5. Prepare timesteps
818884
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -830,9 +896,6 @@ def __call__(
830896
generator,
831897
)
832898

833-
if do_classifier_free_guidance:
834-
controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
835-
836899
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
837900
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
838901

@@ -862,15 +925,10 @@ def __call__(
862925
t,
863926
encoder_hidden_states=prompt_embeds,
864927
controlnet_cond=controlnet_conditioning_image,
928+
conditioning_scale=controlnet_conditioning_scale,
865929
return_dict=False,
866930
)
867931

868-
down_block_res_samples = [
869-
down_block_res_sample * controlnet_conditioning_scale
870-
for down_block_res_sample in down_block_res_samples
871-
]
872-
mid_block_res_sample *= controlnet_conditioning_scale
873-
874932
# predict the noise residual
875933
noise_pred = self.unet(
876934
latent_model_input,

0 commit comments

Comments
 (0)