|
18 | 18 | import torch
|
19 | 19 |
|
20 | 20 | from diffusers import AutoencoderKLHunyuanVideo
|
| 21 | +from diffusers.models.autoencoders.autoencoder_kl_hunyuan_video import prepare_causal_attention_mask |
21 | 22 | from diffusers.utils.testing_utils import (
|
22 | 23 | enable_full_determinism,
|
23 | 24 | floats_tensor,
|
@@ -182,3 +183,28 @@ def test_forward_with_norm_groups(self):
|
182 | 183 | @unittest.skip("Unsupported test.")
|
183 | 184 | def test_outputs_equivalence(self):
|
184 | 185 | pass
|
| 186 | + |
| 187 | + def test_prepare_causal_attention_mask(self): |
| 188 | + def prepare_causal_attention_mask_orig( |
| 189 | + num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None |
| 190 | + ) -> torch.Tensor: |
| 191 | + seq_len = num_frames * height_width |
| 192 | + mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device) |
| 193 | + for i in range(seq_len): |
| 194 | + i_frame = i // height_width |
| 195 | + mask[i, : (i_frame + 1) * height_width] = 0 |
| 196 | + if batch_size is not None: |
| 197 | + mask = mask.unsqueeze(0).expand(batch_size, -1, -1) |
| 198 | + return mask |
| 199 | + |
| 200 | + # test with some odd shapes |
| 201 | + original_mask = prepare_causal_attention_mask_orig( |
| 202 | + num_frames=31, height_width=111, dtype=torch.float32, device=torch_device |
| 203 | + ) |
| 204 | + new_mask = prepare_causal_attention_mask( |
| 205 | + num_frames=31, height_width=111, dtype=torch.float32, device=torch_device |
| 206 | + ) |
| 207 | + self.assertTrue( |
| 208 | + torch.allclose(original_mask, new_mask), |
| 209 | + "Causal attention mask should be the same", |
| 210 | + ) |
0 commit comments