File tree 2 files changed +10
-6
lines changed
src/transformers/models/gemma3 2 files changed +10
-6
lines changed Original file line number Diff line number Diff line change @@ -361,13 +361,15 @@ def forward(
361
361
)
362
362
else :
363
363
attention_interface = ALL_ATTENTION_FUNCTIONS [self .config ._attn_implementation ]
364
-
364
+ if attention_mask is not None :
365
+ # backwards compatibility
366
+ attention_mask = attention_mask .to (query_states )
365
367
attn_output , attn_weights = attention_interface (
366
368
self ,
367
369
query_states ,
368
370
key_states ,
369
371
value_states ,
370
- attention_mask . to ( query_states ) ,
372
+ attention_mask ,
371
373
dropout = self .attention_dropout if self .training else 0.0 ,
372
374
scaling = self .scaling ,
373
375
sliding_window = self .sliding_window ,
@@ -1360,7 +1362,7 @@ def forward(
1360
1362
** lm_kwargs ,
1361
1363
)
1362
1364
1363
- logits = outputs . logits
1365
+ logits = outputs [ 0 ]
1364
1366
loss = None
1365
1367
if labels is not None :
1366
1368
# Upcast to float if we need to compute the loss to avoid potential precision issues
Original file line number Diff line number Diff line change @@ -418,13 +418,15 @@ def forward(
418
418
)
419
419
else :
420
420
attention_interface = ALL_ATTENTION_FUNCTIONS [self .config ._attn_implementation ]
421
-
421
+ if attention_mask is not None :
422
+ # backwards compatibility
423
+ attention_mask = attention_mask .to (query_states )
422
424
attn_output , attn_weights = attention_interface (
423
425
self ,
424
426
query_states ,
425
427
key_states ,
426
428
value_states ,
427
- attention_mask . to ( query_states ) ,
429
+ attention_mask ,
428
430
dropout = self .attention_dropout if self .training else 0.0 ,
429
431
scaling = self .scaling ,
430
432
sliding_window = self .sliding_window ,
@@ -974,7 +976,7 @@ def forward(
974
976
** lm_kwargs ,
975
977
)
976
978
977
- logits = outputs . logits
979
+ logits = outputs [ 0 ]
978
980
loss = None
979
981
if labels is not None :
980
982
# Upcast to float if we need to compute the loss to avoid potential precision issues
You can’t perform that action at this time.
0 commit comments