Skip to content

Add controlnet and vae from single file #4084

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/en/api/loaders.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,11 @@ Adapters (textual inversion, LoRA, hypernetworks) allow you to modify a diffusio
## FromSingleFileMixin

[[autodoc]] loaders.FromSingleFileMixin

## FromOriginalControlnetMixin

[[autodoc]] loaders.FromOriginalControlnetMixin

## FromOriginalVAEMixin

[[autodoc]] loaders. FromOriginalVAEMixin
14 changes: 13 additions & 1 deletion docs/source/en/api/models/autoencoderkl.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,18 @@ The abstract from the paper is:

*How can we perform efficient inference and learning in directed probabilistic models, in the presence of continuous latent variables with intractable posterior distributions, and large datasets? We introduce a stochastic variational inference and learning algorithm that scales to large datasets and, under some mild differentiability conditions, even works in the intractable case. Our contributions are two-fold. First, we show that a reparameterization of the variational lower bound yields a lower bound estimator that can be straightforwardly optimized using standard stochastic gradient methods. Second, we show that for i.i.d. datasets with continuous latent variables per datapoint, posterior inference can be made especially efficient by fitting an approximate inference model (also called a recognition model) to the intractable posterior using the proposed lower bound estimator. Theoretical advantages are reflected in experimental results.*

## Loading from the original format

By default the [`AutoencoderKL`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded
from the original format using [`FromOriginalVAEMixin.from_single_file`] as follows:

```py
from diffusers import AutoencoderKL

url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file
model = AutoencoderKL.from_single_file(url)
```

## AutoencoderKL

[[autodoc]] AutoencoderKL
Expand All @@ -28,4 +40,4 @@ The abstract from the paper is:

## FlaxDecoderOutput

[[autodoc]] models.vae_flax.FlaxDecoderOutput
[[autodoc]] models.vae_flax.FlaxDecoderOutput
17 changes: 16 additions & 1 deletion docs/source/en/api/models/controlnet.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,21 @@ The abstract from the paper is:

*We present a neural network structure, ControlNet, to control pretrained large diffusion models to support additional input conditions. The ControlNet learns task-specific conditions in an end-to-end way, and the learning is robust even when the training dataset is small (< 50k). Moreover, training a ControlNet is as fast as fine-tuning a diffusion model, and the model can be trained on a personal devices. Alternatively, if powerful computation clusters are available, the model can scale to large amounts (millions to billions) of data. We report that large diffusion models like Stable Diffusion can be augmented with ControlNets to enable conditional inputs like edge maps, segmentation maps, keypoints, etc. This may enrich the methods to control large diffusion models and further facilitate related applications.*

## Loading from the original format

By default the [`ControlNetModel`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded
from the original format using [`FromOriginalControlnetMixin.from_single_file`] as follows:

```py
from diffusers import StableDiffusionControlnetPipeline, ControlNetModel

url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path
model = ControlNetModel.from_single_file(url)

url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path
pipe = StableDiffusionControlnetPipeline.from_single_file(url, controlnet=controlnet)
```

## ControlNetModel

[[autodoc]] ControlNetModel
Expand All @@ -20,4 +35,4 @@ The abstract from the paper is:

## FlaxControlNetOutput

[[autodoc]] models.controlnet_flax.FlaxControlNetOutput
[[autodoc]] models.controlnet_flax.FlaxControlNetOutput
363 changes: 358 additions & 5 deletions src/diffusers/loaders.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion src/diffusers/models/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch.nn as nn

from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalVAEMixin
from ..utils import BaseOutput, apply_forward_hook
from .attention_processor import AttentionProcessor, AttnProcessor
from .modeling_utils import ModelMixin
Expand All @@ -38,7 +39,7 @@ class AutoencoderKLOutput(BaseOutput):
latent_dist: "DiagonalGaussianDistribution"


class AutoencoderKL(ModelMixin, ConfigMixin):
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.

Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/models/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch.nn import functional as F

from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalControlnetMixin
from ..utils import BaseOutput, logging
from .attention_processor import AttentionProcessor, AttnProcessor
from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
Expand Down Expand Up @@ -100,7 +101,7 @@ def forward(self, conditioning):
return embedding


class ControlNetModel(ModelMixin, ConfigMixin):
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
"""
A ControlNet model.

Expand Down
6 changes: 4 additions & 2 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer

from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
Expand Down Expand Up @@ -90,7 +90,9 @@
"""


class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
class StableDiffusionControlNetPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer

from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
Expand Down Expand Up @@ -116,7 +116,9 @@ def prepare_image(image):
return image


class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
class StableDiffusionControlNetImg2ImgPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer

from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
Expand Down Expand Up @@ -222,7 +222,9 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False
return mask, masked_image


class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
class StableDiffusionControlNetInpaintPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.

Expand Down
25 changes: 12 additions & 13 deletions src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,8 +621,8 @@ def convert_ldm_unet_checkpoint(
def convert_ldm_vae_checkpoint(checkpoint, config):
# extract state dict for VAE
vae_state_dict = {}
vae_key = "first_stage_model."
keys = list(checkpoint.keys())
vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else ""
for key in keys:
if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
Expand Down Expand Up @@ -1064,7 +1064,7 @@ def convert_controlnet_checkpoint(
if cross_attention_dim is not None:
ctrlnet_config["cross_attention_dim"] = cross_attention_dim

controlnet_model = ControlNetModel(**ctrlnet_config)
controlnet = ControlNetModel(**ctrlnet_config)

# Some controlnet ckpt files are distributed independently from the rest of the
# model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
Expand All @@ -1082,9 +1082,9 @@ def convert_controlnet_checkpoint(
skip_extract_state_dict=skip_extract_state_dict,
)

controlnet_model.load_state_dict(converted_ctrl_checkpoint)
controlnet.load_state_dict(converted_ctrl_checkpoint)

return controlnet_model
return controlnet


def download_from_original_stable_diffusion_ckpt(
Expand Down Expand Up @@ -1182,7 +1182,7 @@ def download_from_original_stable_diffusion_ckpt(
)

if pipeline_class is None:
pipeline_class = StableDiffusionPipeline
pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline

if prediction_type == "v-prediction":
prediction_type = "v_prediction"
Expand Down Expand Up @@ -1289,8 +1289,7 @@ def download_from_original_stable_diffusion_ckpt(
if controlnet is None:
controlnet = "control_stage_config" in original_config.model.params

if controlnet:
controlnet_model = convert_controlnet_checkpoint(
controlnet = convert_controlnet_checkpoint(
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
)
Comment on lines -1292 to 1294
Copy link
Contributor

@williamberman williamberman Jul 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iamwavecut#1 this might have introduced a bug @patrickvonplaten

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It did indeed haha. Solved as explained in iamwavecut#1


Expand Down Expand Up @@ -1401,13 +1400,13 @@ def download_from_original_stable_diffusion_ckpt(

if stable_unclip is None:
if controlnet:
pipe = StableDiffusionControlNetPipeline(
pipe = pipeline_class(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
controlnet=controlnet_model,
controlnet=controlnet,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
Expand Down Expand Up @@ -1504,12 +1503,12 @@ def download_from_original_stable_diffusion_ckpt(
feature_extractor = None

if controlnet:
pipe = StableDiffusionControlNetPipeline(
pipe = pipeline_class(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
controlnet=controlnet_model,
controlnet=controlnet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
Expand Down Expand Up @@ -1624,7 +1623,7 @@ def download_controlnet_from_original_ckpt(
if "control_stage_config" not in original_config.model.params:
raise ValueError("`control_stage_config` not present in original config")

controlnet_model = convert_controlnet_checkpoint(
controlnet = convert_controlnet_checkpoint(
checkpoint,
original_config,
checkpoint_path,
Expand All @@ -1635,4 +1634,4 @@ def download_controlnet_from_original_ckpt(
cross_attention_dim=cross_attention_dim,
)

return controlnet_model
return controlnet
21 changes: 20 additions & 1 deletion tests/models/test_models_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def get_sd_vae_model(self, model_id="CompVis/stable-diffusion-v1-4", fp16=False)
torch_dtype=torch_dtype,
revision=revision,
)
model.to(torch_device).eval()
model.to(torch_device)

return model

Expand Down Expand Up @@ -383,3 +383,22 @@ def test_stable_diffusion_encode_sample(self, seed, expected_slice):

tolerance = 3e-3 if torch_device != "mps" else 1e-2
assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)

def test_stable_diffusion_model_local(self):
model_id = "stabilityai/sd-vae-ft-mse"
model_1 = AutoencoderKL.from_pretrained(model_id).to(torch_device)

url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"
model_2 = AutoencoderKL.from_single_file(url).to(torch_device)
image = self.get_sd_image(33)

with torch.no_grad():
sample_1 = model_1(image).sample
sample_2 = model_2(image).sample

assert sample_1.shape == sample_2.shape

output_slice_1 = sample_1[-1, -2:, -2:, :2].flatten().float().cpu()
output_slice_2 = sample_2[-1, -2:, -2:, :2].flatten().float().cpu()

assert torch_all_close(output_slice_1, output_slice_2, atol=3e-3)
36 changes: 36 additions & 0 deletions tests/pipelines/controlnet/test_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,42 @@ def test_v11_shuffle_global_pool_conditions(self):
expected_slice = np.array([0.1338, 0.1597, 0.1202, 0.1687, 0.1377, 0.1017, 0.2070, 0.1574, 0.1348])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

def test_load_local(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
pipe_1 = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)

controlnet = ControlNetModel.from_single_file(
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
)
pipe_2 = StableDiffusionControlNetPipeline.from_single_file(
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
safety_checker=None,
controlnet=controlnet,
)
pipes = [pipe_1, pipe_2]
images = []

for pipe in pipes:
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)

generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "bird"
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
)

output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3)
images.append(output.images[0])

del pipe
gc.collect()
torch.cuda.empty_cache()

assert np.abs(images[0] - images[1]).sum() < 1e-3


@slow
@require_torch_gpu
Expand Down
46 changes: 46 additions & 0 deletions tests/pipelines/controlnet/test_controlnet_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,3 +401,49 @@ def test_canny(self):
)

assert np.abs(expected_image - image).max() < 9e-2

def test_load_local(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
pipe_1 = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)

controlnet = ControlNetModel.from_single_file(
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
)
pipe_2 = StableDiffusionControlNetImg2ImgPipeline.from_single_file(
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
safety_checker=None,
controlnet=controlnet,
)
control_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
).resize((512, 512))
image = load_image(
"https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png"
).resize((512, 512))

pipes = [pipe_1, pipe_2]
images = []
for pipe in pipes:
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)

generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "bird"
output = pipe(
prompt,
image=image,
control_image=control_image,
strength=0.9,
generator=generator,
output_type="np",
num_inference_steps=3,
)
images.append(output.images[0])

del pipe
gc.collect()
torch.cuda.empty_cache()

assert np.abs(images[0] - images[1]).sum() < 1e-3
Loading