Skip to content

Commit ca918e4

Browse files
committed
fix: missing AutoencoderKL lora adapter
1 parent 0d1d267 commit ca918e4

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

src/diffusers/models/autoencoders/autoencoder_kl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch.nn as nn
1818

1919
from ...configuration_utils import ConfigMixin, register_to_config
20+
from ...loaders import PeftAdapterMixin
2021
from ...loaders.single_file_model import FromOriginalModelMixin
2122
from ...utils import deprecate
2223
from ...utils.accelerate_utils import apply_forward_hook
@@ -34,7 +35,7 @@
3435
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
3536

3637

37-
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
38+
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
3839
r"""
3940
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
4041

tests/models/autoencoders/test_models_vae.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch
2121
from datasets import load_dataset
2222
from parameterized import parameterized
23+
from peft import LoraConfig
2324

2425
from diffusers import (
2526
AsymmetricAutoencoderKL,
@@ -299,6 +300,37 @@ def test_output_pretrained(self):
299300

300301
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
301302

303+
def test_lora_adapter(self):
304+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
305+
vae = self.model_class(**init_dict)
306+
307+
target_modules_vae = [
308+
"conv1",
309+
"conv2",
310+
"conv_in",
311+
"conv_shortcut",
312+
"conv",
313+
"conv_out",
314+
"skip_conv_1",
315+
"skip_conv_2",
316+
"skip_conv_3",
317+
"skip_conv_4",
318+
"to_k",
319+
"to_q",
320+
"to_v",
321+
"to_out.0",
322+
]
323+
vae_lora_config = LoraConfig(
324+
r=16,
325+
init_lora_weights="gaussian",
326+
target_modules=target_modules_vae,
327+
)
328+
329+
vae.add_adapter(vae_lora_config, adapter_name="vae_lora")
330+
active_lora = vae.active_adapters()
331+
self.assertTrue(len(active_lora) == 1)
332+
self.assertTrue(active_lora[0] == "vae_lora")
333+
302334

303335
class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
304336
model_class = AsymmetricAutoencoderKL

0 commit comments

Comments
 (0)