Skip to content

Commit c71bdb7

Browse files
sayakpaulDN6
andcommitted
[tests] refactor vae tests (#9808)
* add: autoencoderkl tests * autoencodertiny. * fix * asymmetric autoencoder. * more * integration tests for stable audio decoder. * consistency decoder vae tests * remove grad check from consistency decoder. * cog * bye test_models_vae.py * fix * fix * remove allegro * fixes * fixes * fixes --------- Co-authored-by: Dhruv Nair <[email protected]>
1 parent c33fd8b commit c71bdb7

16 files changed

+1863
-1277
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ def create_forward(*inputs):
433433
hidden_states,
434434
temb,
435435
zq,
436-
conv_cache=conv_cache.get(conv_cache_key),
436+
conv_cache.get(conv_cache_key),
437437
)
438438
else:
439439
hidden_states, new_conv_cache[conv_cache_key] = resnet(
@@ -531,7 +531,7 @@ def create_forward(*inputs):
531531
return create_forward
532532

533533
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
534-
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
534+
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key)
535535
)
536536
else:
537537
hidden_states, new_conv_cache[conv_cache_key] = resnet(
@@ -649,7 +649,7 @@ def create_forward(*inputs):
649649
hidden_states,
650650
temb,
651651
zq,
652-
conv_cache=conv_cache.get(conv_cache_key),
652+
conv_cache.get(conv_cache_key),
653653
)
654654
else:
655655
hidden_states, new_conv_cache[conv_cache_key] = resnet(
@@ -789,7 +789,7 @@ def custom_forward(*inputs):
789789
hidden_states,
790790
temb,
791791
None,
792-
conv_cache=conv_cache.get(conv_cache_key),
792+
conv_cache.get(conv_cache_key),
793793
)
794794

795795
# 2. Mid
@@ -798,14 +798,14 @@ def custom_forward(*inputs):
798798
hidden_states,
799799
temb,
800800
None,
801-
conv_cache=conv_cache.get("mid_block"),
801+
conv_cache.get("mid_block"),
802802
)
803803
else:
804804
# 1. Down
805805
for i, down_block in enumerate(self.down_blocks):
806806
conv_cache_key = f"down_block_{i}"
807807
hidden_states, new_conv_cache[conv_cache_key] = down_block(
808-
hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key)
808+
hidden_states, temb, None, conv_cache.get(conv_cache_key)
809809
)
810810

811811
# 2. Mid
@@ -953,7 +953,7 @@ def custom_forward(*inputs):
953953
hidden_states,
954954
temb,
955955
sample,
956-
conv_cache=conv_cache.get("mid_block"),
956+
conv_cache.get("mid_block"),
957957
)
958958

959959
# 2. Up
@@ -964,7 +964,7 @@ def custom_forward(*inputs):
964964
hidden_states,
965965
temb,
966966
sample,
967-
conv_cache=conv_cache.get(conv_cache_key),
967+
conv_cache.get(conv_cache_key),
968968
)
969969
else:
970970
# 1. Mid
@@ -1476,7 +1476,7 @@ def forward(
14761476
z = posterior.sample(generator=generator)
14771477
else:
14781478
z = posterior.mode()
1479-
dec = self.decode(z)
1479+
dec = self.decode(z).sample
14801480
if not return_dict:
14811481
return (dec,)
1482-
return dec
1482+
return DecoderOutput(sample=dec)

src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -229,14 +229,6 @@ def __init__(
229229

230230
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
231231

232-
sample_size = (
233-
self.config.sample_size[0]
234-
if isinstance(self.config.sample_size, (list, tuple))
235-
else self.config.sample_size
236-
)
237-
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
238-
self.tile_overlap_factor = 0.25
239-
240232
def _set_gradient_checkpointing(self, module, value=False):
241233
if isinstance(module, (Encoder, TemporalDecoder)):
242234
module.gradient_checkpointing = value

src/diffusers/models/autoencoders/autoencoder_tiny.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,9 @@ def decode(
310310
self, x: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True
311311
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
312312
if self.use_slicing and x.shape[0] > 1:
313-
output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
313+
output = [
314+
self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x_slice) for x_slice in x.split(1)
315+
]
314316
output = torch.cat(output)
315317
else:
316318
output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
@@ -341,7 +343,7 @@ def forward(
341343
# as if we were loading the latents from an RGBA uint8 image.
342344
unscaled_enc = self.unscale_latents(scaled_enc / 255.0)
343345

344-
dec = self.decode(unscaled_enc)
346+
dec = self.decode(unscaled_enc).sample
345347

346348
if not return_dict:
347349
return (dec,)
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import gc
17+
import unittest
18+
19+
import torch
20+
from parameterized import parameterized
21+
22+
from diffusers import AsymmetricAutoencoderKL
23+
from diffusers.utils.import_utils import is_xformers_available
24+
from diffusers.utils.testing_utils import (
25+
backend_empty_cache,
26+
enable_full_determinism,
27+
floats_tensor,
28+
load_hf_numpy,
29+
require_torch_accelerator,
30+
require_torch_gpu,
31+
skip_mps,
32+
slow,
33+
torch_all_close,
34+
torch_device,
35+
)
36+
37+
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
38+
39+
40+
enable_full_determinism()
41+
42+
43+
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
44+
model_class = AsymmetricAutoencoderKL
45+
main_input_name = "sample"
46+
base_precision = 1e-2
47+
48+
def get_asym_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None):
49+
block_out_channels = block_out_channels or [2, 4]
50+
norm_num_groups = norm_num_groups or 2
51+
init_dict = {
52+
"in_channels": 3,
53+
"out_channels": 3,
54+
"down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
55+
"down_block_out_channels": block_out_channels,
56+
"layers_per_down_block": 1,
57+
"up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels),
58+
"up_block_out_channels": block_out_channels,
59+
"layers_per_up_block": 1,
60+
"act_fn": "silu",
61+
"latent_channels": 4,
62+
"norm_num_groups": norm_num_groups,
63+
"sample_size": 32,
64+
"scaling_factor": 0.18215,
65+
}
66+
return init_dict
67+
68+
@property
69+
def dummy_input(self):
70+
batch_size = 4
71+
num_channels = 3
72+
sizes = (32, 32)
73+
74+
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
75+
mask = torch.ones((batch_size, 1) + sizes).to(torch_device)
76+
77+
return {"sample": image, "mask": mask}
78+
79+
@property
80+
def input_shape(self):
81+
return (3, 32, 32)
82+
83+
@property
84+
def output_shape(self):
85+
return (3, 32, 32)
86+
87+
def prepare_init_args_and_inputs_for_common(self):
88+
init_dict = self.get_asym_autoencoder_kl_config()
89+
inputs_dict = self.dummy_input
90+
return init_dict, inputs_dict
91+
92+
@unittest.skip("Unsupported test.")
93+
def test_forward_with_norm_groups(self):
94+
pass
95+
96+
97+
@slow
98+
class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
99+
def get_file_format(self, seed, shape):
100+
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
101+
102+
def tearDown(self):
103+
# clean up the VRAM after each test
104+
super().tearDown()
105+
gc.collect()
106+
backend_empty_cache(torch_device)
107+
108+
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
109+
dtype = torch.float16 if fp16 else torch.float32
110+
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
111+
return image
112+
113+
def get_sd_vae_model(self, model_id="cross-attention/asymmetric-autoencoder-kl-x-1-5", fp16=False):
114+
revision = "main"
115+
torch_dtype = torch.float32
116+
117+
model = AsymmetricAutoencoderKL.from_pretrained(
118+
model_id,
119+
torch_dtype=torch_dtype,
120+
revision=revision,
121+
)
122+
model.to(torch_device).eval()
123+
124+
return model
125+
126+
def get_generator(self, seed=0):
127+
generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
128+
if torch_device != "mps":
129+
return torch.Generator(device=generator_device).manual_seed(seed)
130+
return torch.manual_seed(seed)
131+
132+
@parameterized.expand(
133+
[
134+
# fmt: off
135+
[
136+
33,
137+
[-0.0336, 0.3011, 0.1764, 0.0087, -0.3401, 0.3645, -0.1247, 0.1205],
138+
[-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824],
139+
],
140+
[
141+
47,
142+
[0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529],
143+
[-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089],
144+
],
145+
# fmt: on
146+
]
147+
)
148+
def test_stable_diffusion(self, seed, expected_slice, expected_slice_mps):
149+
model = self.get_sd_vae_model()
150+
image = self.get_sd_image(seed)
151+
generator = self.get_generator(seed)
152+
153+
with torch.no_grad():
154+
sample = model(image, generator=generator, sample_posterior=True).sample
155+
156+
assert sample.shape == image.shape
157+
158+
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
159+
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
160+
161+
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
162+
163+
@parameterized.expand(
164+
[
165+
# fmt: off
166+
[
167+
33,
168+
[-0.0340, 0.2870, 0.1698, -0.0105, -0.3448, 0.3529, -0.1321, 0.1097],
169+
[-0.0344, 0.2912, 0.1687, -0.0137, -0.3462, 0.3552, -0.1337, 0.1078],
170+
],
171+
[
172+
47,
173+
[0.4397, 0.0550, 0.2873, 0.2946, 0.0567, 0.0855, -0.1580, 0.2531],
174+
[0.4397, 0.0550, 0.2873, 0.2946, 0.0567, 0.0855, -0.1580, 0.2531],
175+
],
176+
# fmt: on
177+
]
178+
)
179+
def test_stable_diffusion_mode(self, seed, expected_slice, expected_slice_mps):
180+
model = self.get_sd_vae_model()
181+
image = self.get_sd_image(seed)
182+
183+
with torch.no_grad():
184+
sample = model(image).sample
185+
186+
assert sample.shape == image.shape
187+
188+
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
189+
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
190+
191+
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
192+
193+
@parameterized.expand(
194+
[
195+
# fmt: off
196+
[13, [-0.0521, -0.2939, 0.1540, -0.1855, -0.5936, -0.3138, -0.4579, -0.2275]],
197+
[37, [-0.1820, -0.4345, -0.0455, -0.2923, -0.8035, -0.5089, -0.4795, -0.3106]],
198+
# fmt: on
199+
]
200+
)
201+
@require_torch_accelerator
202+
@skip_mps
203+
def test_stable_diffusion_decode(self, seed, expected_slice):
204+
model = self.get_sd_vae_model()
205+
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
206+
207+
with torch.no_grad():
208+
sample = model.decode(encoding).sample
209+
210+
assert list(sample.shape) == [3, 3, 512, 512]
211+
212+
output_slice = sample[-1, -2:, :2, -2:].flatten().cpu()
213+
expected_output_slice = torch.tensor(expected_slice)
214+
215+
assert torch_all_close(output_slice, expected_output_slice, atol=2e-3)
216+
217+
@parameterized.expand([(13,), (16,), (37,)])
218+
@require_torch_gpu
219+
@unittest.skipIf(
220+
not is_xformers_available(),
221+
reason="xformers is not required when using PyTorch 2.0.",
222+
)
223+
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
224+
model = self.get_sd_vae_model()
225+
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
226+
227+
with torch.no_grad():
228+
sample = model.decode(encoding).sample
229+
230+
model.enable_xformers_memory_efficient_attention()
231+
with torch.no_grad():
232+
sample_2 = model.decode(encoding).sample
233+
234+
assert list(sample.shape) == [3, 3, 512, 512]
235+
236+
assert torch_all_close(sample, sample_2, atol=5e-2)
237+
238+
@parameterized.expand(
239+
[
240+
# fmt: off
241+
[33, [-0.3001, 0.0918, -2.6984, -3.9720, -3.2099, -5.0353, 1.7338, -0.2065, 3.4267]],
242+
[47, [-1.5030, -4.3871, -6.0355, -9.1157, -1.6661, -2.7853, 2.1607, -5.0823, 2.5633]],
243+
# fmt: on
244+
]
245+
)
246+
def test_stable_diffusion_encode_sample(self, seed, expected_slice):
247+
model = self.get_sd_vae_model()
248+
image = self.get_sd_image(seed)
249+
generator = self.get_generator(seed)
250+
251+
with torch.no_grad():
252+
dist = model.encode(image).latent_dist
253+
sample = dist.sample(generator=generator)
254+
255+
assert list(sample.shape) == [image.shape[0], 4] + [i // 8 for i in image.shape[2:]]
256+
257+
output_slice = sample[0, -1, -3:, -3:].flatten().cpu()
258+
expected_output_slice = torch.tensor(expected_slice)
259+
260+
tolerance = 3e-3 if torch_device != "mps" else 1e-2
261+
assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)

0 commit comments

Comments
 (0)