Skip to content

Commit 363127b

Browse files
Fix typo and format BasicTransformerBlock attributes (huggingface#2953)
* ⚙️chore(train_controlnet) fix typo in logger message * ⚙️chore(models) refactor modules order; make them the same as calling order When printing the BasicTransformerBlock to stdout, I think it's crucial that the attributes order are shown in proper order. And also previously the "3. Feed Forward" comment was not making sense. It should have been close to self.ff but it's instead next to self.norm3 * correct many tests * remove bogus file * make style * correct more tests * finish tests * fix one more * make style * make unclip deterministic * ⚙️chore(models/attention) reorganize comments in BasicTransformerBlock class --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 6fc3f72 commit 363127b

File tree

1 file changed

+20
-23
lines changed

1 file changed

+20
-23
lines changed

models/attention.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,14 @@ def __init__(
224224
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
225225
)
226226

227+
# Define 3 blocks. Each block has its own normalization layer.
227228
# 1. Self-Attn
229+
if self.use_ada_layer_norm:
230+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
231+
elif self.use_ada_layer_norm_zero:
232+
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
233+
else:
234+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
228235
self.attn1 = Attention(
229236
query_dim=dim,
230237
heads=num_attention_heads,
@@ -235,10 +242,16 @@ def __init__(
235242
upcast_attention=upcast_attention,
236243
)
237244

238-
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
239-
240245
# 2. Cross-Attn
241246
if cross_attention_dim is not None or double_self_attention:
247+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
248+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
249+
# the second cross attention block.
250+
self.norm2 = (
251+
AdaLayerNorm(dim, num_embeds_ada_norm)
252+
if self.use_ada_layer_norm
253+
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
254+
)
242255
self.attn2 = Attention(
243256
query_dim=dim,
244257
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
@@ -248,30 +261,13 @@ def __init__(
248261
bias=attention_bias,
249262
upcast_attention=upcast_attention,
250263
) # is self-attn if encoder_hidden_states is none
251-
else:
252-
self.attn2 = None
253-
254-
if self.use_ada_layer_norm:
255-
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
256-
elif self.use_ada_layer_norm_zero:
257-
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
258-
else:
259-
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
260-
261-
if cross_attention_dim is not None or double_self_attention:
262-
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
263-
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
264-
# the second cross attention block.
265-
self.norm2 = (
266-
AdaLayerNorm(dim, num_embeds_ada_norm)
267-
if self.use_ada_layer_norm
268-
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
269-
)
270264
else:
271265
self.norm2 = None
266+
self.attn2 = None
272267

273268
# 3. Feed-forward
274269
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
270+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
275271

276272
def forward(
277273
self,
@@ -283,6 +279,8 @@ def forward(
283279
cross_attention_kwargs=None,
284280
class_labels=None,
285281
):
282+
# Notice that normalization is always applied before the real computation in the following blocks.
283+
# 1. Self-Attention
286284
if self.use_ada_layer_norm:
287285
norm_hidden_states = self.norm1(hidden_states, timestep)
288286
elif self.use_ada_layer_norm_zero:
@@ -292,7 +290,6 @@ def forward(
292290
else:
293291
norm_hidden_states = self.norm1(hidden_states)
294292

295-
# 1. Self-Attention
296293
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
297294
attn_output = self.attn1(
298295
norm_hidden_states,
@@ -304,14 +301,14 @@ def forward(
304301
attn_output = gate_msa.unsqueeze(1) * attn_output
305302
hidden_states = attn_output + hidden_states
306303

304+
# 2. Cross-Attention
307305
if self.attn2 is not None:
308306
norm_hidden_states = (
309307
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
310308
)
311309
# TODO (Birch-San): Here we should prepare the encoder_attention mask correctly
312310
# prepare attention mask here
313311

314-
# 2. Cross-Attention
315312
attn_output = self.attn2(
316313
norm_hidden_states,
317314
encoder_hidden_states=encoder_hidden_states,

0 commit comments

Comments
 (0)