Skip to content

Commit 0fb7068

Browse files
authored
[tests] use proper gemma class and config in lumina2 tests. (#10828)
use proper gemma class and config in lumina2 tests.
1 parent f8b54cf commit 0fb7068

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

tests/pipelines/lumina2/test_pipeline_lumina2.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy as np
44
import torch
5-
from transformers import AutoTokenizer, GemmaConfig, GemmaForCausalLM
5+
from transformers import AutoTokenizer, Gemma2Config, Gemma2Model
66

77
from diffusers import (
88
AutoencoderKL,
@@ -81,15 +81,16 @@ def get_dummy_components(self):
8181
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
8282

8383
torch.manual_seed(0)
84-
config = GemmaConfig(
85-
head_dim=2,
84+
config = Gemma2Config(
85+
head_dim=4,
8686
hidden_size=8,
87-
intermediate_size=37,
88-
num_attention_heads=4,
87+
intermediate_size=8,
88+
num_attention_heads=2,
8989
num_hidden_layers=2,
90-
num_key_value_heads=4,
90+
num_key_value_heads=2,
91+
sliding_window=2,
9192
)
92-
text_encoder = GemmaForCausalLM(config)
93+
text_encoder = Gemma2Model(config)
9394

9495
components = {
9596
"transformer": transformer.eval(),

0 commit comments

Comments
 (0)