@@ -17,7 +17,8 @@ def __init__(
17
17
self ,
18
18
prefix_length : int ,
19
19
prefix_hidden_dim : Optional [int ] = None ,
20
- n_positions : int = 1024 , # Start of GPT2 config args
20
+ vocab_size : int = 50257 , # Start of GPT2 config args
21
+ n_positions : int = 1024 ,
21
22
n_embd : int = 768 ,
22
23
n_layer : int = 12 ,
23
24
n_head : int = 12 ,
@@ -28,6 +29,10 @@ def __init__(
28
29
attn_pdrop : float = 0.1 ,
29
30
layer_norm_epsilon : float = 1e-5 ,
30
31
initializer_range : float = 0.02 ,
32
+ scale_attn_weights : bool = True ,
33
+ use_cache : bool = True ,
34
+ scale_attn_by_inverse_layer_idx : bool = False ,
35
+ reorder_and_upcast_attn : bool = False ,
31
36
):
32
37
"""
33
38
Text decoder model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is used to
@@ -52,6 +57,7 @@ def __init__(
52
57
)
53
58
54
59
gpt_config = GPT2Config (
60
+ vocab_size = vocab_size ,
55
61
n_positions = n_positions ,
56
62
n_embd = n_embd ,
57
63
n_layer = n_layer ,
@@ -63,6 +69,10 @@ def __init__(
63
69
attn_pdrop = attn_pdrop ,
64
70
layer_norm_epsilon = layer_norm_epsilon ,
65
71
initializer_range = initializer_range ,
72
+ scale_attn_weights = scale_attn_weights ,
73
+ use_cache = use_cache ,
74
+ scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx ,
75
+ reorder_and_upcast_attn = reorder_and_upcast_attn ,
66
76
)
67
77
self .transformer = GPT2LMHeadModel (gpt_config )
68
78
@@ -143,8 +153,7 @@ def generate_beam(
143
153
TODO: args
144
154
"""
145
155
# Generates text until stop_token is reached using beam search with the desired beam size.
146
- # TODO: get the stop token index directly from tokenizer rather than manually specifying the EOS token?
147
- stop_token_index = tokenizer .encode (stop_token )[0 ]
156
+ stop_token_index = tokenizer .eos_token_id
148
157
tokens = None
149
158
scores = None
150
159
seq_lengths = torch .ones (beam_size , device = device )
@@ -159,7 +168,7 @@ def generate_beam(
159
168
generated = self .transformer .transformer .wte (tokens )
160
169
161
170
for i in range (entry_length ):
162
- outputs = self .transformer (input_embeds = generated )
171
+ outputs = self .transformer (inputs_embeds = generated )
163
172
logits = outputs .logits
164
173
logits = logits [:, - 1 , :] / (temperature if temperature > 0 else 1.0 )
165
174
logits = logits .softmax (- 1 ).log ()
@@ -198,8 +207,12 @@ def generate_beam(
198
207
199
208
scores = scores / seq_lengths
200
209
output_list = tokens .cpu ().numpy ()
210
+ # print(f"Output list: {output_list}")
211
+ # print(f"Output list length: {len(output_list)}")
212
+ # print(f"Seq lengths: {seq_lengths}")
213
+ # print(f"Seq lengths length: {len(seq_lengths)}")
201
214
output_texts = [
202
- self . tokenizer .decode (output [: int (length )], skip_special_tokens = True )
215
+ tokenizer .decode (output [: int (length )], skip_special_tokens = True )
203
216
for output , length in zip (output_list , seq_lengths )
204
217
]
205
218
order = scores .argsort (descending = True )
0 commit comments