Skip to content

Add ControlNetUnion #10131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Dec 11, 2024
Merged

Add ControlNetUnion #10131

merged 33 commits into from
Dec 11, 2024

Conversation

hlky
Copy link
Contributor

@hlky hlky commented Dec 5, 2024

What does this PR do?

Model

Original code
Weights

Changes include:

  • Fixup imports.
    • ControlNetConditioningEmbedding, ControlNetOutput, zero_module imported from controlnet.
  • ResidualAttentionMlp instead of nn.Sequential/OrderedDict.
  • Remove subclassed nn.LayerNorm.
  • Add num_trans_channel etc to __init__/config.
  • transformer_layers as nn.ModuleList instead of nn.Sequential.
  • Fix copies and other comparisons against ControlNetModel (type hints, etc)

Pipelines

StableDiffusionXLControlNetUnionPipeline

  • Fixup compared to StableDiffusionXLControlNetPipeline

StableDiffusionXLControlNetUnionImg2ImgPipeline

  • Fixup compared to StableDiffusionXLControlNetImg2ImgPipeline

StableDiffusionXLControlNetUnionInpaintPipeline

  • Fixup compared to StableDiffusionXLControlNetInpaintPipeline

ControlNetUnionInput/ControlNetUnionInputProMax

The original requires pipeline input where image_list is a list like [0, 0, 0, controlnet_img, 0, 0] with the ControlNet conditioning image in the correct index, and union_control_type as a tensor like torch.Tensor([0, 0, 0, 1, 0, 0]) with the same indexing.

This has been replaced by ControlNetUnionInput/ControlNetUnionInputProMax, these allow the user to provide a single input like:

ControlNetUnionInput(
openpose=...,
canny=...,
)

Appropriate control_type tensor is produced in the pipeline.

Example

Note: a better example prompt/image would be great

pip install controlnet_aux
import torch
import random
from diffusers import AutoencoderKL
from controlnet_aux import LineartAnimeDetector
from diffusers import (
  EulerAncestralDiscreteScheduler,
  ControlNetUnionModel,
  StableDiffusionXLControlNetUnionPipeline,
)
from diffusers.models.controlnets import ControlNetUnionInput
from diffusers.utils import load_image


device = torch.device("cuda:0")

scheduler = EulerAncestralDiscreteScheduler.from_pretrained(
  "gsdf/CounterfeitXL", subfolder="scheduler"
)
vae = AutoencoderKL.from_pretrained(
  "gsdf/CounterfeitXL", subfolder="vae", torch_dtype=torch.float16
)
controlnet_model = ControlNetUnionModel.from_pretrained(
  "xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16, use_safetensors=True
)

pipe: StableDiffusionXLControlNetUnionPipeline = (
  StableDiffusionXLControlNetUnionPipeline.from_pretrained(
      "gsdf/CounterfeitXL",
      controlnet=controlnet_model,
      vae=vae,
      torch_dtype=torch.float16,
      variant="fp16",
      scheduler=scheduler,
  )
)
pipe = pipe.to(device)

processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators").to(device)

prompt = "A cat"

controlnet_img = load_image(
  "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
).resize((1024, 1024))
controlnet_img = processor(controlnet_img, output_type="pil")

seed = random.randint(0, 2147483647)
generator = torch.Generator("cuda").manual_seed(seed)

union_input = ControlNetUnionInput(
  canny=controlnet_img,
)
image = pipe(
  prompt=prompt,
  image=union_input,
  generator=generator,
  height=1024,
  width=1024,
  num_inference_steps=30,
).images[0]

image.save("anime_lineart.png")
Example output

anime_lineart

StableDiffusionXLControlNetUnionImg2ImgPipeline example (`tile`)
from diffusers import (
 StableDiffusionXLControlNetUnionImg2ImgPipeline,
 ControlNetUnionModel,
 AutoencoderKL,
)
from diffusers.models.controlnets import ControlNetUnionInputProMax
from diffusers.utils import load_image
import torch
from PIL import Image
import numpy as np

prompt = "A cat"
# download an image
image = load_image(
 "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
)
# initialize the models and pipeline
controlnet = ControlNetUnionModel.from_pretrained(
 "brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
)
vae = AutoencoderKL.from_pretrained(
 "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
)
pipe = StableDiffusionXLControlNetUnionImg2ImgPipeline.from_pretrained(
 "stabilityai/stable-diffusion-xl-base-1.0",
 controlnet=controlnet,
 vae=vae,
 torch_dtype=torch.float16,
).to("cuda")
height = image.height
width = image.width
ratio = np.sqrt(1024.0 * 1024.0 / (width * height))
# 3 * 3 upscale correspond to 16 * 3 multiply, 2 * 2 correspond to 16 * 2 multiply and so on.
scale_image_factor = 3
base_factor = 16
factor = scale_image_factor * base_factor
W, H = int(width * ratio) // factor * factor, int(height * ratio) // factor * factor
image = image.resize((W, H))
target_width = W // scale_image_factor
target_height = H // scale_image_factor
images = []
crops_coords_list = [
 (0, 0),
 (0, width // 2),
 (height // 2, 0),
 (width // 2, height // 2),
 0,
 0,
 0,
 0,
 0,
]
for i in range(scale_image_factor):
 for j in range(scale_image_factor):
     left = j * target_width
     top = i * target_height
     right = left + target_width
     bottom = top + target_height
     cropped_image = image.crop((left, top, right, bottom))
     cropped_image = cropped_image.resize((W, H))
     images.append(cropped_image)
# set ControlNetUnion input
result_images = []
for sub_img, crops_coords in zip(images, crops_coords_list):
 union_input = ControlNetUnionInputProMax(
     tile=sub_img,
 )
 new_width, new_height = W, H
 out = pipe(
     prompt=[prompt] * 1,
     image=sub_img,
     control_image_list=union_input,
     width=new_width,
     height=new_height,
     num_inference_steps=30,
     crops_coords_top_left=(W, H),
     target_size=(W, H),
     original_size=(W * 2, H * 2),
 )
 result_images.append(out.images[0])
new_im = Image.new(
 "RGB", (new_width * scale_image_factor, new_height * scale_image_factor)
)
new_im.paste(result_images[0], (0, 0))
new_im.paste(result_images[1], (new_width, 0))
new_im.paste(result_images[2], (new_width * 2, 0))
new_im.paste(result_images[3], (0, new_height))
new_im.paste(result_images[4], (new_width, new_height))
new_im.paste(result_images[5], (new_width * 2, new_height))
new_im.paste(result_images[6], (0, new_height * 2))
new_im.paste(result_images[7], (new_width, new_height * 2))
new_im.paste(result_images[8], (new_width * 2, new_height * 2))
new_im.save("upscaled.png")
Tile example output

upscaled

StableDiffusionXLControlNetUnionInpaintPipeline example (`inpaint`, `outpaint`)

Note: a better example prompt/image would be great, non-optimal results in this example due to resized image+mask

from diffusers import StableDiffusionXLControlNetUnionInpaintPipeline, ControlNetUnionModel, AutoencoderKL
from diffusers.models.controlnets import ControlNetUnionInputProMax
from diffusers.utils import load_image
import torch
import numpy as np
from PIL import Image
prompt = "A cat"
# download an image
image = load_image(
  "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint/overture-creations-5sI6fQgYIuo.png"
).resize((1024, 1024))
mask = load_image(
  "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
).resize((1024, 1024))
# initialize the models and pipeline
controlnet = ControlNetUnionModel.from_pretrained(
  "brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
)
vae = AutoencoderKL.from_pretrained(
  "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
)
pipe = StableDiffusionXLControlNetUnionInpaintPipeline.from_pretrained(
  "stabilityai/stable-diffusion-xl-base-1.0",
  controlnet=controlnet,
  vae=vae,
  torch_dtype=torch.float16,
  variant="fp16",
)
pipe.enable_model_cpu_offload()
controlnet_img = image.copy()
controlnet_img_np = np.array(controlnet_img)
mask_np = np.array(mask)
controlnet_img_np[mask_np > 0] = 0
controlnet_img = Image.fromarray(controlnet_img_np)
union_input = ControlNetUnionInputProMax(
  repaint=controlnet_img,
)
# generate image
image = pipe(prompt, image=image, mask_image=mask, control_image_list=union_input).images[0]
image.save("inpaint.png")
Inpaint example output

inpaint

Fixes #9772

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@@ -1,80 +1,86 @@
from typing import TYPE_CHECKING
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why the whitespace changed here 🤷‍♂️

@hlky
Copy link
Contributor Author

hlky commented Dec 5, 2024

Ready for review.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!!

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Dec 5, 2024

cc @asomoza
can you help test these controlnet union pipelines a bit?

@hlky
Copy link
Contributor Author

hlky commented Dec 5, 2024

@asomoza There are more test examples in the original repo. Note that the inpaint pipeline is only used with promax version for inpaint/outpaint, and img2img pipeline is only used with promax version for tile superresolution.

@yiyixuxu yiyixuxu added the roadmap Add to current release roadmap label Dec 5, 2024
@asomoza
Copy link
Member

asomoza commented Dec 6, 2024

I'll test as soon as I can, this is the controlnet that I use the most on a daily basis so I'll love the idea to have it in the core.

Just a bit more information on the official repository, first I tried using it without the need of a different repository, which lead me to use the code in this space and I asked the model author to do a different model repository for the promax version without any response, so at the end I also had to make my own to avoid the custom code to just load it.

There's quite a bit of discussion about changing the model naming but the author never answers any of them.

Also the promax version came after the original one, but once we have it, there's not much incentive on using the older when the newer one does the same and more.

@asomoza
Copy link
Member

asomoza commented Dec 11, 2024

I did the tests and it works as expected, really nice!

an example of the diffusers image fill with this:

source mask result
source image fill

Code

import random

import torch
from PIL import Image, ImageChops

from diffusers import AutoencoderKL, ControlNetUnionModel, StableDiffusionXLControlNetUnionPipeline
from diffusers.models.controlnets import ControlNetUnionInputProMax
from diffusers.utils import load_image


device = torch.device("cuda:0")

source_image = load_image(
    "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/diffusers_fill/jefferson-sees-OCQjiB4tG5c-unsplash.jpg"
)

width, height = source_image.size
min_dimension = min(width, height)

left = (width - min_dimension) / 2
top = (height - min_dimension) / 2
right = (width + min_dimension) / 2
bottom = (height + min_dimension) / 2

final_source = source_image.crop((left, top, right, bottom))
final_source = final_source.resize((1024, 1024), Image.LANCZOS).convert("RGBA")

final_source.save("source.png")

mask = load_image(
    "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/diffusers_fill/car_mask_good.png"
).convert("L")

binary_mask = mask.point(lambda p: 255 if p > 0 else 0)
inverted_mask = ImageChops.invert(binary_mask)

alpha_image = Image.new("RGBA", final_source.size, (0, 0, 0, 0))
cnet_image = Image.composite(final_source, alpha_image, inverted_mask)

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
controlnet_model = ControlNetUnionModel.from_pretrained(
    "OzzyGT/controlnet-union-promax-sdxl-1.0",
    torch_dtype=torch.float16,
    variant="fp16",
)

pipe: StableDiffusionXLControlNetUnionPipeline = StableDiffusionXLControlNetUnionPipeline.from_pretrained(
    "SG161222/RealVisXL_V5.0",
    controlnet=controlnet_model,
    vae=vae,
    torch_dtype=torch.float16,
    variant="fp16",
)
pipe = pipe.to(device)

prompt = "high quality"

seed = random.randint(0, 2147483647)
generator = torch.Generator("cuda").manual_seed(seed)

union_input = ControlNetUnionInputProMax(
    repaint=cnet_image,
)
image = pipe(
    prompt=prompt,
    image=union_input,
    generator=generator,
    height=1024,
    width=1024,
    num_inference_steps=30,
).images[0]

image.save("fill.png")

@yiyixuxu yiyixuxu merged commit 914a585 into huggingface:main Dec 11, 2024
15 checks passed
@yiyixuxu
Copy link
Collaborator

ok merged! thank you both
still need test, but can be added in a follow-up PR

@vladmandic
Copy link
Contributor

vladmandic commented Dec 11, 2024

@hlky @yiyixuxu @asomoza

the design here makes it a nightmare to use and very different than any other controlnet model.
i'm talking specifically about concept:

image: Union[ControlNetUnionInput, ControlNetUnionInputProMax]

so image is some new class and need to set image via its properties instead? how does that make sense?
its not even a new argument, its still called image, so no way to differentiate.
and everywhere else in diffusers, image is exactly that - an image!

i understand that you need to signal the union controlnet whats the input type, but this is not a unique scenario.
the FluxControlNetPipeline introduced concept of control_mode, why not use that here?
if not, at the very least, do not call it image (its not!) so it can be differentiated what to set and how.

@hlky
Copy link
Contributor Author

hlky commented Dec 12, 2024

This ControlNet works differently than Flux ControlNet Union. The design is explained in the opening comment, also see the original code, we avoid sync on every step.

https://github.com/xinsir6/ControlNetPlus/blob/b48420576eac63c04388cb65fb74513cbd17405a/models/controlnet_union.py#L842-L872

https://github.com/xinsir6/ControlNetPlus/blob/b48420576eac63c04388cb65fb74513cbd17405a/models/controlnet_union.py#L854

@vladmandic
Copy link
Contributor

I understand it works differently, but you can't override image param to become something completely else. at least create a new property then.

@asomoza
Copy link
Member

asomoza commented Dec 12, 2024

just to take note on this, we also have some inconsistencies with the naming of the normal controlnet conditioning images across some pipelines, IMO it should never be image but this is a breaking change so at some point we will need to fix this for the 1.0 version release.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Dec 12, 2024

the controlnet union pipelines introduced a new parameter called control_image_list (not control_image/image like our regular ones) so I don't think we override the image param, no?

we can change the argument for check_input here to something else if that's what you want

image: Union[ControlNetUnionInput, ControlNetUnionInputProMax],

@yiyixuxu
Copy link
Collaborator

agree with @asomoza. Here, we have some inconsistency with the parameter names; control_image and image are used interchangeably sometimes. and we should fix it

But it is irrelevant to this PR, this PR does not have this issue at all

@yiyixuxu
Copy link
Collaborator

ok I just saw the refactor
that works too:)

sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
* ControlNetUnion model
@john09282922
Copy link

Hi, Can I use promax version for now?
And I would like to check the promax code, But I can't find it.
Also, in the promax, is there controlnet option, like canny, pose, depth?

Thanks,

@vladmandic
Copy link
Contributor

Hi, Can I use promax version for now? And I would like to check the promax code, But I can't find it. Also, in the promax, is there controlnet option, like canny, pose, depth?

there is no separate promax code, its the same architecture so same limitations.
only difference is control_mode identifiers since union and promax don't have same conditioning models.

@john09282922
Copy link

Hi, Can I use promax version for now? And I would like to check the promax code, But I can't find it. Also, in the promax, is there controlnet option, like canny, pose, depth?

there is no separate promax code, its the same architecture so same limitations. only difference is control_mode identifiers since union and promax don't have same conditioning models.

Thanks for replying my question. I can't find the code on promax input and union input. where those are?

@hlky
Copy link
Contributor Author

hlky commented Jan 30, 2025

@john09282922 There is no separate promax code.

@guiyrt guiyrt mentioned this pull request Feb 8, 2025
6 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
need-test roadmap Add to current release roadmap
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support ControlNetPlus Union if not already supported
6 participants