Skip to content

Commit 2638d54

Browse files
authored
Gemma 3 tests expect greedy decoding (#36882)
tests expect greedy decoding
1 parent b8aadc3 commit 2638d54

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

tests/models/gemma3/test_modeling_gemma3.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ def test_generation_beyond_sliding_window(self, attn_implementation: str):
567567
input_size = inputs.input_ids.shape[-1]
568568
self.assertTrue(input_size > model.config.sliding_window)
569569

570-
out = model.generate(**inputs, max_new_tokens=20)[:, input_size:]
570+
out = model.generate(**inputs, max_new_tokens=20, do_sample=False)[:, input_size:]
571571
output_text = tokenizer.batch_decode(out)
572572

573573
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
@@ -599,6 +599,11 @@ def test_generation_beyond_sliding_window_with_generation_config(self):
599599
generation_config = GenerationConfig(max_new_tokens=5, min_new_tokens=5)
600600
out = model.generate(**inputs, generation_config=generation_config)
601601

602+
out = model.generate(**inputs, generation_config=generation_config, do_sample=False)[:, input_size:]
603+
output_text = tokenizer.batch_decode(out)
604+
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
605+
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
606+
602607
# Generation works beyond sliding window
603608
self.assertGreater(out.shape[1], model.config.sliding_window)
604609
self.assertEqual(out.shape[1], input_size + 5)

0 commit comments

Comments
 (0)