Skip to content

Commit 8fb2b6b

Browse files
yiyixuxusayakpaul
authored andcommitted
pass attn mask arg for flux (#10122)
1 parent 053a48d commit 8fb2b6b

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1908,7 +1908,9 @@ def __call__(
19081908
query = apply_rotary_emb(query, image_rotary_emb)
19091909
key = apply_rotary_emb(key, image_rotary_emb)
19101910

1911-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1911+
hidden_states = F.scaled_dot_product_attention(
1912+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1913+
)
19121914
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
19131915
hidden_states = hidden_states.to(query.dtype)
19141916

0 commit comments

Comments
 (0)