Skip to content

Commit acd6d2c

Browse files
Fix the bug that joint_attention_kwargs is not passed to the FLUX's transformer attention processors (#9517)
* Update transformer_flux.py
1 parent 86bd991 commit acd6d2c

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

src/diffusers/models/transformers/transformer_flux.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,16 @@ def forward(
8383
hidden_states: torch.FloatTensor,
8484
temb: torch.FloatTensor,
8585
image_rotary_emb=None,
86+
joint_attention_kwargs=None,
8687
):
8788
residual = hidden_states
8889
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
8990
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
90-
91+
joint_attention_kwargs = joint_attention_kwargs or {}
9192
attn_output = self.attn(
9293
hidden_states=norm_hidden_states,
9394
image_rotary_emb=image_rotary_emb,
95+
**joint_attention_kwargs,
9496
)
9597

9698
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
@@ -161,18 +163,20 @@ def forward(
161163
encoder_hidden_states: torch.FloatTensor,
162164
temb: torch.FloatTensor,
163165
image_rotary_emb=None,
166+
joint_attention_kwargs=None,
164167
):
165168
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
166169

167170
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
168171
encoder_hidden_states, emb=temb
169172
)
170-
173+
joint_attention_kwargs = joint_attention_kwargs or {}
171174
# Attention.
172175
attn_output, context_attn_output = self.attn(
173176
hidden_states=norm_hidden_states,
174177
encoder_hidden_states=norm_encoder_hidden_states,
175178
image_rotary_emb=image_rotary_emb,
179+
**joint_attention_kwargs,
176180
)
177181

178182
# Process attention outputs for the `hidden_states`.
@@ -497,6 +501,7 @@ def custom_forward(*inputs):
497501
encoder_hidden_states=encoder_hidden_states,
498502
temb=temb,
499503
image_rotary_emb=image_rotary_emb,
504+
joint_attention_kwargs=joint_attention_kwargs,
500505
)
501506

502507
# controlnet residual
@@ -533,6 +538,7 @@ def custom_forward(*inputs):
533538
hidden_states=hidden_states,
534539
temb=temb,
535540
image_rotary_emb=image_rotary_emb,
541+
joint_attention_kwargs=joint_attention_kwargs,
536542
)
537543

538544
# controlnet residual

0 commit comments

Comments
 (0)