@@ -83,14 +83,16 @@ def forward(
83
83
hidden_states : torch .FloatTensor ,
84
84
temb : torch .FloatTensor ,
85
85
image_rotary_emb = None ,
86
+ joint_attention_kwargs = None ,
86
87
):
87
88
residual = hidden_states
88
89
norm_hidden_states , gate = self .norm (hidden_states , emb = temb )
89
90
mlp_hidden_states = self .act_mlp (self .proj_mlp (norm_hidden_states ))
90
-
91
+ joint_attention_kwargs = joint_attention_kwargs or {}
91
92
attn_output = self .attn (
92
93
hidden_states = norm_hidden_states ,
93
94
image_rotary_emb = image_rotary_emb ,
95
+ ** joint_attention_kwargs ,
94
96
)
95
97
96
98
hidden_states = torch .cat ([attn_output , mlp_hidden_states ], dim = 2 )
@@ -161,18 +163,20 @@ def forward(
161
163
encoder_hidden_states : torch .FloatTensor ,
162
164
temb : torch .FloatTensor ,
163
165
image_rotary_emb = None ,
166
+ joint_attention_kwargs = None ,
164
167
):
165
168
norm_hidden_states , gate_msa , shift_mlp , scale_mlp , gate_mlp = self .norm1 (hidden_states , emb = temb )
166
169
167
170
norm_encoder_hidden_states , c_gate_msa , c_shift_mlp , c_scale_mlp , c_gate_mlp = self .norm1_context (
168
171
encoder_hidden_states , emb = temb
169
172
)
170
-
173
+ joint_attention_kwargs = joint_attention_kwargs or {}
171
174
# Attention.
172
175
attn_output , context_attn_output = self .attn (
173
176
hidden_states = norm_hidden_states ,
174
177
encoder_hidden_states = norm_encoder_hidden_states ,
175
178
image_rotary_emb = image_rotary_emb ,
179
+ ** joint_attention_kwargs ,
176
180
)
177
181
178
182
# Process attention outputs for the `hidden_states`.
@@ -497,6 +501,7 @@ def custom_forward(*inputs):
497
501
encoder_hidden_states = encoder_hidden_states ,
498
502
temb = temb ,
499
503
image_rotary_emb = image_rotary_emb ,
504
+ joint_attention_kwargs = joint_attention_kwargs ,
500
505
)
501
506
502
507
# controlnet residual
@@ -533,6 +538,7 @@ def custom_forward(*inputs):
533
538
hidden_states = hidden_states ,
534
539
temb = temb ,
535
540
image_rotary_emb = image_rotary_emb ,
541
+ joint_attention_kwargs = joint_attention_kwargs ,
536
542
)
537
543
538
544
# controlnet residual
0 commit comments