Skip to content

Commit 8ae8008

Browse files
authored
speedup hunyuan encoder causal mask generation (#10764)
* speedup causal mask generation * fixing hunyuan attn mask test case
1 parent c80eda9 commit 8ae8008

File tree

2 files changed

+31
-5
lines changed

2 files changed

+31
-5
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@
3636
def prepare_causal_attention_mask(
3737
num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None
3838
) -> torch.Tensor:
39-
seq_len = num_frames * height_width
40-
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
41-
for i in range(seq_len):
42-
i_frame = i // height_width
43-
mask[i, : (i_frame + 1) * height_width] = 0
39+
indices = torch.arange(1, num_frames + 1, dtype=torch.int32, device=device)
40+
indices_blocks = indices.repeat_interleave(height_width)
41+
x, y = torch.meshgrid(indices_blocks, indices_blocks, indexing="xy")
42+
mask = torch.where(x <= y, 0, -float("inf")).to(dtype=dtype)
43+
4444
if batch_size is not None:
4545
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
4646
return mask

tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919

2020
from diffusers import AutoencoderKLHunyuanVideo
21+
from diffusers.models.autoencoders.autoencoder_kl_hunyuan_video import prepare_causal_attention_mask
2122
from diffusers.utils.testing_utils import (
2223
enable_full_determinism,
2324
floats_tensor,
@@ -182,3 +183,28 @@ def test_forward_with_norm_groups(self):
182183
@unittest.skip("Unsupported test.")
183184
def test_outputs_equivalence(self):
184185
pass
186+
187+
def test_prepare_causal_attention_mask(self):
188+
def prepare_causal_attention_mask_orig(
189+
num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None
190+
) -> torch.Tensor:
191+
seq_len = num_frames * height_width
192+
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
193+
for i in range(seq_len):
194+
i_frame = i // height_width
195+
mask[i, : (i_frame + 1) * height_width] = 0
196+
if batch_size is not None:
197+
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
198+
return mask
199+
200+
# test with some odd shapes
201+
original_mask = prepare_causal_attention_mask_orig(
202+
num_frames=31, height_width=111, dtype=torch.float32, device=torch_device
203+
)
204+
new_mask = prepare_causal_attention_mask(
205+
num_frames=31, height_width=111, dtype=torch.float32, device=torch_device
206+
)
207+
self.assertTrue(
208+
torch.allclose(original_mask, new_mask),
209+
"Causal attention mask should be the same",
210+
)

0 commit comments

Comments
 (0)