Skip to content

Commit 963ffca

Browse files
benizsayakpaul
andauthored
fix: missing AutoencoderKL lora adapter (#9807)
* fix: missing AutoencoderKL lora adapter * fix --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 30f2e9b commit 963ffca

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

src/diffusers/models/autoencoders/autoencoder_kl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch.nn as nn
1818

1919
from ...configuration_utils import ConfigMixin, register_to_config
20+
from ...loaders import PeftAdapterMixin
2021
from ...loaders.single_file_model import FromOriginalModelMixin
2122
from ...utils import deprecate
2223
from ...utils.accelerate_utils import apply_forward_hook
@@ -34,7 +35,7 @@
3435
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
3536

3637

37-
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
38+
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
3839
r"""
3940
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
4041

tests/models/autoencoders/test_models_vae.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@
3636
backend_empty_cache,
3737
enable_full_determinism,
3838
floats_tensor,
39+
is_peft_available,
3940
load_hf_numpy,
41+
require_peft_backend,
4042
require_torch_accelerator,
4143
require_torch_accelerator_with_fp16,
4244
require_torch_gpu,
@@ -50,6 +52,10 @@
5052
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
5153

5254

55+
if is_peft_available():
56+
from peft import LoraConfig
57+
58+
5359
enable_full_determinism()
5460

5561

@@ -263,6 +269,38 @@ def test_output_pretrained(self):
263269

264270
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
265271

272+
@require_peft_backend
273+
def test_lora_adapter(self):
274+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
275+
vae = self.model_class(**init_dict)
276+
277+
target_modules_vae = [
278+
"conv1",
279+
"conv2",
280+
"conv_in",
281+
"conv_shortcut",
282+
"conv",
283+
"conv_out",
284+
"skip_conv_1",
285+
"skip_conv_2",
286+
"skip_conv_3",
287+
"skip_conv_4",
288+
"to_k",
289+
"to_q",
290+
"to_v",
291+
"to_out.0",
292+
]
293+
vae_lora_config = LoraConfig(
294+
r=16,
295+
init_lora_weights="gaussian",
296+
target_modules=target_modules_vae,
297+
)
298+
299+
vae.add_adapter(vae_lora_config, adapter_name="vae_lora")
300+
active_lora = vae.active_adapters()
301+
self.assertTrue(len(active_lora) == 1)
302+
self.assertTrue(active_lora[0] == "vae_lora")
303+
266304

267305
class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
268306
model_class = AsymmetricAutoencoderKL

0 commit comments

Comments
 (0)