Skip to content

Experimental per control type scale for ControlNet Union #10723

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions src/diffusers/models/controlnets/controlnet_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,12 +605,13 @@ def forward(
controlnet_cond: List[torch.Tensor],
control_type: torch.Tensor,
control_type_idx: List[int],
conditioning_scale: float = 1.0,
conditioning_scale: Union[float, List[float]] = 1.0,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
from_multi: bool = False,
guess_mode: bool = False,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
Expand Down Expand Up @@ -647,6 +648,8 @@ def forward(
Additional conditions for the Stable Diffusion XL UNet.
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
from_multi (`bool`, defaults to `False`):
Use standard scaling when called from `MultiControlNetUnionModel`.
guess_mode (`bool`, defaults to `False`):
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
Expand All @@ -658,6 +661,9 @@ def forward(
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
returned where the first element is the sample tensor.
"""
if isinstance(conditioning_scale, float):
conditioning_scale = [conditioning_scale] * len(controlnet_cond)

# check channel order
channel_order = self.config.controlnet_conditioning_channel_order

Expand Down Expand Up @@ -742,12 +748,16 @@ def forward(
inputs = []
condition_list = []

for cond, control_idx in zip(controlnet_cond, control_type_idx):
for cond, control_idx, scale in zip(controlnet_cond, control_type_idx, conditioning_scale):
condition = self.controlnet_cond_embedding(cond)
feat_seq = torch.mean(condition, dim=(2, 3))
feat_seq = feat_seq + self.task_embedding[control_idx]
inputs.append(feat_seq.unsqueeze(1))
condition_list.append(condition)
if from_multi:
inputs.append(feat_seq.unsqueeze(1))
condition_list.append(condition)
else:
inputs.append(feat_seq.unsqueeze(1) * scale)
condition_list.append(condition * scale)

condition = sample
feat_seq = torch.mean(condition, dim=(2, 3))
Expand All @@ -759,10 +769,13 @@ def forward(
x = layer(x)

controlnet_cond_fuser = sample * 0.0
for idx, condition in enumerate(condition_list[:-1]):
for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale):
alpha = self.spatial_ch_projs(x[:, idx])
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
controlnet_cond_fuser += condition + alpha
if from_multi:
controlnet_cond_fuser += condition + alpha
else:
controlnet_cond_fuser += condition + alpha * scale

sample = sample + controlnet_cond_fuser

Expand Down Expand Up @@ -806,12 +819,13 @@ def forward(
# 6. scaling
if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
scales = scales * conditioning_scale
if from_multi:
scales = scales * conditioning_scale[0]
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
else:
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
mid_block_res_sample = mid_block_res_sample * conditioning_scale
elif from_multi:
down_block_res_samples = [sample * conditioning_scale[0] for sample in down_block_res_samples]
mid_block_res_sample = mid_block_res_sample * conditioning_scale[0]

if self.config.global_pool_conditions:
down_block_res_samples = [
Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/models/controlnets/multicontrolnet_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,12 @@ def forward(
guess_mode: bool = False,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]:
down_block_res_samples, mid_block_res_sample = None, None
for i, (image, ctype, ctype_idx, scale, controlnet) in enumerate(
zip(controlnet_cond, control_type, control_type_idx, conditioning_scale, self.nets)
):
if scale == 0.0:
continue
down_samples, mid_sample = controlnet(
sample=sample,
timestep=timestep,
Expand All @@ -63,12 +66,13 @@ def forward(
attention_mask=attention_mask,
added_cond_kwargs=added_cond_kwargs,
cross_attention_kwargs=cross_attention_kwargs,
from_multi=True,
guess_mode=guess_mode,
return_dict=return_dict,
)

# merge samples
if i == 0:
if down_block_res_samples is None and mid_block_res_sample is None:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
down_block_res_samples = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -757,15 +757,9 @@ def check_inputs(
for images_ in image:
for image_ in images_:
self.check_image(image_, prompt, prompt_embeds)
else:
assert False

# Check `controlnet_conditioning_scale`
# TODO Update for https://github.com/huggingface/diffusers/pull/10723
if isinstance(controlnet, ControlNetUnionModel):
if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
elif isinstance(controlnet, MultiControlNetUnionModel):
if isinstance(controlnet, MultiControlNetUnionModel):
if isinstance(controlnet_conditioning_scale, list):
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
raise ValueError("A single batch of multiple conditionings is not supported at the moment.")
Expand All @@ -776,8 +770,6 @@ def check_inputs(
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
" the same length as the number of controlnets"
)
else:
assert False

if len(control_guidance_start) != len(control_guidance_end):
raise ValueError(
Expand Down Expand Up @@ -808,8 +800,6 @@ def check_inputs(
for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets):
if max(_control_mode) >= _controlnet.config.num_control_type:
raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.")
else:
assert False

# Equal number of `image` and `control_mode` elements
if isinstance(controlnet, ControlNetUnionModel):
Expand All @@ -823,8 +813,6 @@ def check_inputs(

elif sum(len(x) for x in image) != sum(len(x) for x in control_mode):
raise ValueError("Expected len(control_image) == len(control_mode)")
else:
assert False

if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
Expand Down Expand Up @@ -1201,28 +1189,33 @@ def __call__(

controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet

if not isinstance(control_image, list):
control_image = [control_image]
else:
control_image = control_image.copy()

if not isinstance(control_mode, list):
control_mode = [control_mode]

if isinstance(controlnet, MultiControlNetUnionModel):
control_image = [[item] for item in control_image]
control_mode = [[item] for item in control_mode]

# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else 1
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
control_guidance_start, control_guidance_end = (
mult * [control_guidance_start],
mult * [control_guidance_end],
)

if not isinstance(control_image, list):
control_image = [control_image]
else:
control_image = control_image.copy()

if not isinstance(control_mode, list):
control_mode = [control_mode]

if isinstance(controlnet, MultiControlNetUnionModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
if isinstance(controlnet_conditioning_scale, float):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult

# 1. Check inputs
self.check_inputs(
Expand Down Expand Up @@ -1357,9 +1350,6 @@ def __call__(
control_image = control_images
height, width = control_image[0][0].shape[-2:]

else:
assert False

# 5. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
Expand Down Expand Up @@ -1397,7 +1387,7 @@ def __call__(
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
for s, e in zip(control_guidance_start, control_guidance_end)
]
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetUnionModel) else keeps)
controlnet_keep.append(keeps)

# 7.2 Prepare added time ids & embeddings
original_size = original_size or (height, width)
Expand Down
Loading