Skip to content

Commit 5105b5a

Browse files
guiyrthlky
andauthored
MultiControlNetUnionModel on SDXL (#10747)
* SDXL with MultiControlNetUnionModel --------- Co-authored-by: hlky <[email protected]>
1 parent ca6330d commit 5105b5a

File tree

4 files changed

+400
-80
lines changed

4 files changed

+400
-80
lines changed

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
_import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"]
5252
_import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
5353
_import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"]
54+
_import_structure["controlnets.multicontrolnet_union"] = ["MultiControlNetUnionModel"]
5455
_import_structure["embeddings"] = ["ImageProjection"]
5556
_import_structure["modeling_utils"] = ["ModelMixin"]
5657
_import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
@@ -122,6 +123,7 @@
122123
HunyuanDiT2DControlNetModel,
123124
HunyuanDiT2DMultiControlNetModel,
124125
MultiControlNetModel,
126+
MultiControlNetUnionModel,
125127
SD3ControlNetModel,
126128
SD3MultiControlNetModel,
127129
SparseControlNetModel,

src/diffusers/models/controlnets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .controlnet_union import ControlNetUnionModel
1919
from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel
2020
from .multicontrolnet import MultiControlNetModel
21+
from .multicontrolnet_union import MultiControlNetUnionModel
2122

2223
if is_flax_available():
2324
from .controlnet_flax import FlaxControlNetModel
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
import os
2+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3+
4+
import torch
5+
from torch import nn
6+
7+
from ...models.controlnets.controlnet import ControlNetOutput
8+
from ...models.controlnets.controlnet_union import ControlNetUnionModel
9+
from ...models.modeling_utils import ModelMixin
10+
from ...utils import logging
11+
12+
13+
logger = logging.get_logger(__name__)
14+
15+
16+
class MultiControlNetUnionModel(ModelMixin):
17+
r"""
18+
Multiple `ControlNetUnionModel` wrapper class for Multi-ControlNet-Union.
19+
20+
This module is a wrapper for multiple instances of the `ControlNetUnionModel`. The `forward()` API is designed to
21+
be compatible with `ControlNetUnionModel`.
22+
23+
Args:
24+
controlnets (`List[ControlNetUnionModel]`):
25+
Provides additional conditioning to the unet during the denoising process. You must set multiple
26+
`ControlNetUnionModel` as a list.
27+
"""
28+
29+
def __init__(self, controlnets: Union[List[ControlNetUnionModel], Tuple[ControlNetUnionModel]]):
30+
super().__init__()
31+
self.nets = nn.ModuleList(controlnets)
32+
33+
def forward(
34+
self,
35+
sample: torch.Tensor,
36+
timestep: Union[torch.Tensor, float, int],
37+
encoder_hidden_states: torch.Tensor,
38+
controlnet_cond: List[torch.tensor],
39+
control_type: List[torch.Tensor],
40+
control_type_idx: List[List[int]],
41+
conditioning_scale: List[float],
42+
class_labels: Optional[torch.Tensor] = None,
43+
timestep_cond: Optional[torch.Tensor] = None,
44+
attention_mask: Optional[torch.Tensor] = None,
45+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
46+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
47+
guess_mode: bool = False,
48+
return_dict: bool = True,
49+
) -> Union[ControlNetOutput, Tuple]:
50+
for i, (image, ctype, ctype_idx, scale, controlnet) in enumerate(
51+
zip(controlnet_cond, control_type, control_type_idx, conditioning_scale, self.nets)
52+
):
53+
down_samples, mid_sample = controlnet(
54+
sample=sample,
55+
timestep=timestep,
56+
encoder_hidden_states=encoder_hidden_states,
57+
controlnet_cond=image,
58+
control_type=ctype,
59+
control_type_idx=ctype_idx,
60+
conditioning_scale=scale,
61+
class_labels=class_labels,
62+
timestep_cond=timestep_cond,
63+
attention_mask=attention_mask,
64+
added_cond_kwargs=added_cond_kwargs,
65+
cross_attention_kwargs=cross_attention_kwargs,
66+
guess_mode=guess_mode,
67+
return_dict=return_dict,
68+
)
69+
70+
# merge samples
71+
if i == 0:
72+
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
73+
else:
74+
down_block_res_samples = [
75+
samples_prev + samples_curr
76+
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
77+
]
78+
mid_block_res_sample += mid_sample
79+
80+
return down_block_res_samples, mid_block_res_sample
81+
82+
# Copied from diffusers.models.controlnets.multicontrolnet.MultiControlNetModel.save_pretrained with ControlNet->ControlNetUnion
83+
def save_pretrained(
84+
self,
85+
save_directory: Union[str, os.PathLike],
86+
is_main_process: bool = True,
87+
save_function: Callable = None,
88+
safe_serialization: bool = True,
89+
variant: Optional[str] = None,
90+
):
91+
"""
92+
Save a model and its configuration file to a directory, so that it can be re-loaded using the
93+
`[`~models.controlnets.multicontrolnet.MultiControlNetUnionModel.from_pretrained`]` class method.
94+
95+
Arguments:
96+
save_directory (`str` or `os.PathLike`):
97+
Directory to which to save. Will be created if it doesn't exist.
98+
is_main_process (`bool`, *optional*, defaults to `True`):
99+
Whether the process calling this is the main process or not. Useful when in distributed training like
100+
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
101+
the main process to avoid race conditions.
102+
save_function (`Callable`):
103+
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
104+
need to replace `torch.save` by another method. Can be configured with the environment variable
105+
`DIFFUSERS_SAVE_MODE`.
106+
safe_serialization (`bool`, *optional*, defaults to `True`):
107+
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
108+
variant (`str`, *optional*):
109+
If specified, weights are saved in the format pytorch_model.<variant>.bin.
110+
"""
111+
for idx, controlnet in enumerate(self.nets):
112+
suffix = "" if idx == 0 else f"_{idx}"
113+
controlnet.save_pretrained(
114+
save_directory + suffix,
115+
is_main_process=is_main_process,
116+
save_function=save_function,
117+
safe_serialization=safe_serialization,
118+
variant=variant,
119+
)
120+
121+
@classmethod
122+
# Copied from diffusers.models.controlnets.multicontrolnet.MultiControlNetModel.from_pretrained with ControlNet->ControlNetUnion
123+
def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
124+
r"""
125+
Instantiate a pretrained MultiControlNetUnion model from multiple pre-trained controlnet models.
126+
127+
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
128+
the model, you should first set it back in training mode with `model.train()`.
129+
130+
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
131+
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
132+
task.
133+
134+
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
135+
weights are discarded.
136+
137+
Parameters:
138+
pretrained_model_path (`os.PathLike`):
139+
A path to a *directory* containing model weights saved using
140+
[`~models.controlnets.multicontrolnet.MultiControlNetUnionModel.save_pretrained`], e.g.,
141+
`./my_model_directory/controlnet`.
142+
torch_dtype (`str` or `torch.dtype`, *optional*):
143+
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
144+
will be automatically derived from the model's weights.
145+
output_loading_info(`bool`, *optional*, defaults to `False`):
146+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
147+
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
148+
A map that specifies where each submodule should go. It doesn't need to be refined to each
149+
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
150+
same device.
151+
152+
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
153+
more information about each option see [designing a device
154+
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
155+
max_memory (`Dict`, *optional*):
156+
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
157+
GPU and the available CPU RAM if unset.
158+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
159+
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
160+
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
161+
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
162+
setting this argument to `True` will raise an error.
163+
variant (`str`, *optional*):
164+
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
165+
ignored when using `from_flax`.
166+
use_safetensors (`bool`, *optional*, defaults to `None`):
167+
If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
168+
`safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
169+
`safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
170+
"""
171+
idx = 0
172+
controlnets = []
173+
174+
# load controlnet and append to list until no controlnet directory exists anymore
175+
# first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained`
176+
# second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ...
177+
model_path_to_load = pretrained_model_path
178+
while os.path.isdir(model_path_to_load):
179+
controlnet = ControlNetUnionModel.from_pretrained(model_path_to_load, **kwargs)
180+
controlnets.append(controlnet)
181+
182+
idx += 1
183+
model_path_to_load = pretrained_model_path + f"_{idx}"
184+
185+
logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.")
186+
187+
if len(controlnets) == 0:
188+
raise ValueError(
189+
f"No ControlNetUnions found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
190+
)
191+
192+
return cls(controlnets)

0 commit comments

Comments
 (0)