Skip to content

Commit a26530e

Browse files
patrickvonplatensayakpaul
authored andcommitted
Add controlnet and vae from single file (huggingface#4084)
* Add controlnet from single file * Updates * make style * finish * Apply suggestions from code review Co-authored-by: Sayak Paul <[email protected]> --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent ad8467f commit a26530e

14 files changed

+576
-29
lines changed

docs/source/en/api/loaders.mdx

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,11 @@ Adapters (textual inversion, LoRA, hypernetworks) allow you to modify a diffusio
3535
## FromSingleFileMixin
3636

3737
[[autodoc]] loaders.FromSingleFileMixin
38+
39+
## FromOriginalControlnetMixin
40+
41+
[[autodoc]] loaders.FromOriginalControlnetMixin
42+
43+
## FromOriginalVAEMixin
44+
45+
[[autodoc]] loaders.FromOriginalVAEMixin

docs/source/en/api/models/autoencoderkl.mdx

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,18 @@ The abstract from the paper is:
66

77
*How can we perform efficient inference and learning in directed probabilistic models, in the presence of continuous latent variables with intractable posterior distributions, and large datasets? We introduce a stochastic variational inference and learning algorithm that scales to large datasets and, under some mild differentiability conditions, even works in the intractable case. Our contributions are two-fold. First, we show that a reparameterization of the variational lower bound yields a lower bound estimator that can be straightforwardly optimized using standard stochastic gradient methods. Second, we show that for i.i.d. datasets with continuous latent variables per datapoint, posterior inference can be made especially efficient by fitting an approximate inference model (also called a recognition model) to the intractable posterior using the proposed lower bound estimator. Theoretical advantages are reflected in experimental results.*
88

9+
## Loading from the original format
10+
11+
By default the [`AutoencoderKL`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded
12+
from the original format using [`FromOriginalVAEMixin.from_single_file`] as follows:
13+
14+
```py
15+
from diffusers import AutoencoderKL
16+
17+
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file
18+
model = AutoencoderKL.from_single_file(url)
19+
```
20+
921
## AutoencoderKL
1022

1123
[[autodoc]] AutoencoderKL
@@ -28,4 +40,4 @@ The abstract from the paper is:
2840

2941
## FlaxDecoderOutput
3042

31-
[[autodoc]] models.vae_flax.FlaxDecoderOutput
43+
[[autodoc]] models.vae_flax.FlaxDecoderOutput

docs/source/en/api/models/controlnet.mdx

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,21 @@ The abstract from the paper is:
66

77
*We present a neural network structure, ControlNet, to control pretrained large diffusion models to support additional input conditions. The ControlNet learns task-specific conditions in an end-to-end way, and the learning is robust even when the training dataset is small (< 50k). Moreover, training a ControlNet is as fast as fine-tuning a diffusion model, and the model can be trained on a personal devices. Alternatively, if powerful computation clusters are available, the model can scale to large amounts (millions to billions) of data. We report that large diffusion models like Stable Diffusion can be augmented with ControlNets to enable conditional inputs like edge maps, segmentation maps, keypoints, etc. This may enrich the methods to control large diffusion models and further facilitate related applications.*
88

9+
## Loading from the original format
10+
11+
By default the [`ControlNetModel`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded
12+
from the original format using [`FromOriginalControlnetMixin.from_single_file`] as follows:
13+
14+
```py
15+
from diffusers import StableDiffusionControlnetPipeline, ControlNetModel
16+
17+
url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path
18+
controlnet = ControlNetModel.from_single_file(url)
19+
20+
url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path
21+
pipe = StableDiffusionControlnetPipeline.from_single_file(url, controlnet=controlnet)
22+
```
23+
924
## ControlNetModel
1025

1126
[[autodoc]] ControlNetModel
@@ -20,4 +35,4 @@ The abstract from the paper is:
2035

2136
## FlaxControlNetOutput
2237

23-
[[autodoc]] models.controlnet_flax.FlaxControlNetOutput
38+
[[autodoc]] models.controlnet_flax.FlaxControlNetOutput

src/diffusers/loaders.py

Lines changed: 358 additions & 5 deletions
Large diffs are not rendered by default.

src/diffusers/models/autoencoder_kl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch.nn as nn
1919

2020
from ..configuration_utils import ConfigMixin, register_to_config
21+
from ..loaders import FromOriginalVAEMixin
2122
from ..utils import BaseOutput, apply_forward_hook
2223
from .attention_processor import AttentionProcessor, AttnProcessor
2324
from .modeling_utils import ModelMixin
@@ -38,7 +39,7 @@ class AutoencoderKLOutput(BaseOutput):
3839
latent_dist: "DiagonalGaussianDistribution"
3940

4041

41-
class AutoencoderKL(ModelMixin, ConfigMixin):
42+
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
4243
r"""
4344
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
4445

src/diffusers/models/controlnet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch.nn import functional as F
2020

2121
from ..configuration_utils import ConfigMixin, register_to_config
22+
from ..loaders import FromOriginalControlnetMixin
2223
from ..utils import BaseOutput, logging
2324
from .attention_processor import AttentionProcessor, AttnProcessor
2425
from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
@@ -100,7 +101,7 @@ def forward(self, conditioning):
100101
return embedding
101102

102103

103-
class ControlNetModel(ModelMixin, ConfigMixin):
104+
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
104105
"""
105106
A ControlNet model.
106107

src/diffusers/pipelines/controlnet/pipeline_controlnet.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
2525

2626
from ...image_processor import VaeImageProcessor
27-
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
27+
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
2828
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
2929
from ...schedulers import KarrasDiffusionSchedulers
3030
from ...utils import (
@@ -90,7 +90,9 @@
9090
"""
9191

9292

93-
class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
93+
class StableDiffusionControlNetPipeline(
94+
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
95+
):
9496
r"""
9597
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
9698

src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
2525

2626
from ...image_processor import VaeImageProcessor
27-
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
27+
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
2828
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
2929
from ...schedulers import KarrasDiffusionSchedulers
3030
from ...utils import (
@@ -116,7 +116,9 @@ def prepare_image(image):
116116
return image
117117

118118

119-
class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
119+
class StableDiffusionControlNetImg2ImgPipeline(
120+
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
121+
):
120122
r"""
121123
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
122124

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
2626

2727
from ...image_processor import VaeImageProcessor
28-
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
28+
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
2929
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
3030
from ...schedulers import KarrasDiffusionSchedulers
3131
from ...utils import (
@@ -222,7 +222,9 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False
222222
return mask, masked_image
223223

224224

225-
class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
225+
class StableDiffusionControlNetInpaintPipeline(
226+
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
227+
):
226228
r"""
227229
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
228230

src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -621,8 +621,8 @@ def convert_ldm_unet_checkpoint(
621621
def convert_ldm_vae_checkpoint(checkpoint, config):
622622
# extract state dict for VAE
623623
vae_state_dict = {}
624-
vae_key = "first_stage_model."
625624
keys = list(checkpoint.keys())
625+
vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else ""
626626
for key in keys:
627627
if key.startswith(vae_key):
628628
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
@@ -1064,7 +1064,7 @@ def convert_controlnet_checkpoint(
10641064
if cross_attention_dim is not None:
10651065
ctrlnet_config["cross_attention_dim"] = cross_attention_dim
10661066

1067-
controlnet_model = ControlNetModel(**ctrlnet_config)
1067+
controlnet = ControlNetModel(**ctrlnet_config)
10681068

10691069
# Some controlnet ckpt files are distributed independently from the rest of the
10701070
# model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
@@ -1082,9 +1082,9 @@ def convert_controlnet_checkpoint(
10821082
skip_extract_state_dict=skip_extract_state_dict,
10831083
)
10841084

1085-
controlnet_model.load_state_dict(converted_ctrl_checkpoint)
1085+
controlnet.load_state_dict(converted_ctrl_checkpoint)
10861086

1087-
return controlnet_model
1087+
return controlnet
10881088

10891089

10901090
def download_from_original_stable_diffusion_ckpt(
@@ -1181,7 +1181,7 @@ def download_from_original_stable_diffusion_ckpt(
11811181
)
11821182

11831183
if pipeline_class is None:
1184-
pipeline_class = StableDiffusionPipeline
1184+
pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline
11851185

11861186
if prediction_type == "v-prediction":
11871187
prediction_type = "v_prediction"
@@ -1288,8 +1288,7 @@ def download_from_original_stable_diffusion_ckpt(
12881288
if controlnet is None:
12891289
controlnet = "control_stage_config" in original_config.model.params
12901290

1291-
if controlnet:
1292-
controlnet_model = convert_controlnet_checkpoint(
1291+
controlnet = convert_controlnet_checkpoint(
12931292
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
12941293
)
12951294

@@ -1400,13 +1399,13 @@ def download_from_original_stable_diffusion_ckpt(
14001399

14011400
if stable_unclip is None:
14021401
if controlnet:
1403-
pipe = StableDiffusionControlNetPipeline(
1402+
pipe = pipeline_class(
14041403
vae=vae,
14051404
text_encoder=text_model,
14061405
tokenizer=tokenizer,
14071406
unet=unet,
14081407
scheduler=scheduler,
1409-
controlnet=controlnet_model,
1408+
controlnet=controlnet,
14101409
safety_checker=None,
14111410
feature_extractor=None,
14121411
requires_safety_checker=False,
@@ -1503,12 +1502,12 @@ def download_from_original_stable_diffusion_ckpt(
15031502
feature_extractor = None
15041503

15051504
if controlnet:
1506-
pipe = StableDiffusionControlNetPipeline(
1505+
pipe = pipeline_class(
15071506
vae=vae,
15081507
text_encoder=text_model,
15091508
tokenizer=tokenizer,
15101509
unet=unet,
1511-
controlnet=controlnet_model,
1510+
controlnet=controlnet,
15121511
scheduler=scheduler,
15131512
safety_checker=safety_checker,
15141513
feature_extractor=feature_extractor,
@@ -1623,7 +1622,7 @@ def download_controlnet_from_original_ckpt(
16231622
if "control_stage_config" not in original_config.model.params:
16241623
raise ValueError("`control_stage_config` not present in original config")
16251624

1626-
controlnet_model = convert_controlnet_checkpoint(
1625+
controlnet = convert_controlnet_checkpoint(
16271626
checkpoint,
16281627
original_config,
16291628
checkpoint_path,
@@ -1634,4 +1633,4 @@ def download_controlnet_from_original_ckpt(
16341633
cross_attention_dim=cross_attention_dim,
16351634
)
16361635

1637-
return controlnet_model
1636+
return controlnet

tests/models/test_models_vae.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def get_sd_vae_model(self, model_id="CompVis/stable-diffusion-v1-4", fp16=False)
199199
torch_dtype=torch_dtype,
200200
revision=revision,
201201
)
202-
model.to(torch_device).eval()
202+
model.to(torch_device)
203203

204204
return model
205205

@@ -383,3 +383,22 @@ def test_stable_diffusion_encode_sample(self, seed, expected_slice):
383383

384384
tolerance = 3e-3 if torch_device != "mps" else 1e-2
385385
assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)
386+
387+
def test_stable_diffusion_model_local(self):
388+
model_id = "stabilityai/sd-vae-ft-mse"
389+
model_1 = AutoencoderKL.from_pretrained(model_id).to(torch_device)
390+
391+
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"
392+
model_2 = AutoencoderKL.from_single_file(url).to(torch_device)
393+
image = self.get_sd_image(33)
394+
395+
with torch.no_grad():
396+
sample_1 = model_1(image).sample
397+
sample_2 = model_2(image).sample
398+
399+
assert sample_1.shape == sample_2.shape
400+
401+
output_slice_1 = sample_1[-1, -2:, -2:, :2].flatten().float().cpu()
402+
output_slice_2 = sample_2[-1, -2:, -2:, :2].flatten().float().cpu()
403+
404+
assert torch_all_close(output_slice_1, output_slice_2, atol=3e-3)

tests/pipelines/controlnet/test_controlnet.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,42 @@ def test_v11_shuffle_global_pool_conditions(self):
752752
expected_slice = np.array([0.1338, 0.1597, 0.1202, 0.1687, 0.1377, 0.1017, 0.2070, 0.1574, 0.1348])
753753
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
754754

755+
def test_load_local(self):
756+
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
757+
pipe_1 = StableDiffusionControlNetPipeline.from_pretrained(
758+
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
759+
)
760+
761+
controlnet = ControlNetModel.from_single_file(
762+
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
763+
)
764+
pipe_2 = StableDiffusionControlNetPipeline.from_single_file(
765+
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
766+
safety_checker=None,
767+
controlnet=controlnet,
768+
)
769+
pipes = [pipe_1, pipe_2]
770+
images = []
771+
772+
for pipe in pipes:
773+
pipe.enable_model_cpu_offload()
774+
pipe.set_progress_bar_config(disable=None)
775+
776+
generator = torch.Generator(device="cpu").manual_seed(0)
777+
prompt = "bird"
778+
image = load_image(
779+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
780+
)
781+
782+
output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3)
783+
images.append(output.images[0])
784+
785+
del pipe
786+
gc.collect()
787+
torch.cuda.empty_cache()
788+
789+
assert np.abs(images[0] - images[1]).sum() < 1e-3
790+
755791

756792
@slow
757793
@require_torch_gpu

tests/pipelines/controlnet/test_controlnet_img2img.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,3 +401,49 @@ def test_canny(self):
401401
)
402402

403403
assert np.abs(expected_image - image).max() < 9e-2
404+
405+
def test_load_local(self):
406+
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
407+
pipe_1 = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
408+
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
409+
)
410+
411+
controlnet = ControlNetModel.from_single_file(
412+
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
413+
)
414+
pipe_2 = StableDiffusionControlNetImg2ImgPipeline.from_single_file(
415+
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
416+
safety_checker=None,
417+
controlnet=controlnet,
418+
)
419+
control_image = load_image(
420+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
421+
).resize((512, 512))
422+
image = load_image(
423+
"https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png"
424+
).resize((512, 512))
425+
426+
pipes = [pipe_1, pipe_2]
427+
images = []
428+
for pipe in pipes:
429+
pipe.enable_model_cpu_offload()
430+
pipe.set_progress_bar_config(disable=None)
431+
432+
generator = torch.Generator(device="cpu").manual_seed(0)
433+
prompt = "bird"
434+
output = pipe(
435+
prompt,
436+
image=image,
437+
control_image=control_image,
438+
strength=0.9,
439+
generator=generator,
440+
output_type="np",
441+
num_inference_steps=3,
442+
)
443+
images.append(output.images[0])
444+
445+
del pipe
446+
gc.collect()
447+
torch.cuda.empty_cache()
448+
449+
assert np.abs(images[0] - images[1]).sum() < 1e-3

0 commit comments

Comments
 (0)