Skip to content

Commit 3f9ff19

Browse files
authored
Minor Gemma 3 fixes (#36884)
fix attention mask dtype + outputs type
1 parent f94b0c5 commit 3f9ff19

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

src/transformers/models/gemma3/modeling_gemma3.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -361,13 +361,15 @@ def forward(
361361
)
362362
else:
363363
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
364-
364+
if attention_mask is not None:
365+
# backwards compatibility
366+
attention_mask = attention_mask.to(query_states)
365367
attn_output, attn_weights = attention_interface(
366368
self,
367369
query_states,
368370
key_states,
369371
value_states,
370-
attention_mask.to(query_states),
372+
attention_mask,
371373
dropout=self.attention_dropout if self.training else 0.0,
372374
scaling=self.scaling,
373375
sliding_window=self.sliding_window,
@@ -1360,7 +1362,7 @@ def forward(
13601362
**lm_kwargs,
13611363
)
13621364

1363-
logits = outputs.logits
1365+
logits = outputs[0]
13641366
loss = None
13651367
if labels is not None:
13661368
# Upcast to float if we need to compute the loss to avoid potential precision issues

src/transformers/models/gemma3/modular_gemma3.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -418,13 +418,15 @@ def forward(
418418
)
419419
else:
420420
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
421-
421+
if attention_mask is not None:
422+
# backwards compatibility
423+
attention_mask = attention_mask.to(query_states)
422424
attn_output, attn_weights = attention_interface(
423425
self,
424426
query_states,
425427
key_states,
426428
value_states,
427-
attention_mask.to(query_states),
429+
attention_mask,
428430
dropout=self.attention_dropout if self.training else 0.0,
429431
scaling=self.scaling,
430432
sliding_window=self.sliding_window,
@@ -974,7 +976,7 @@ def forward(
974976
**lm_kwargs,
975977
)
976978

977-
logits = outputs.logits
979+
logits = outputs[0]
978980
loss = None
979981
if labels is not None:
980982
# Upcast to float if we need to compute the loss to avoid potential precision issues

0 commit comments

Comments
 (0)