@@ -85,10 +85,12 @@ class MODEL_ARCH(IntEnum):
85
85
GPTNEOX : int = auto ()
86
86
MPT : int = auto ()
87
87
STARCODER : int = auto ()
88
+ BERT : int = auto ()
88
89
89
90
90
91
class MODEL_TENSOR (IntEnum ):
91
92
TOKEN_EMBD : int = auto ()
93
+ TOKEN_TYPES : int = auto ()
92
94
POS_EMBD : int = auto ()
93
95
OUTPUT : int = auto ()
94
96
OUTPUT_NORM : int = auto ()
@@ -116,10 +118,12 @@ class MODEL_TENSOR(IntEnum):
116
118
MODEL_ARCH .GPTNEOX : "gptneox" ,
117
119
MODEL_ARCH .MPT : "mpt" ,
118
120
MODEL_ARCH .STARCODER : "starcoder" ,
121
+ MODEL_ARCH .BERT : "bert" ,
119
122
}
120
123
121
124
TENSOR_NAMES : dict [MODEL_TENSOR , str ] = {
122
125
MODEL_TENSOR .TOKEN_EMBD : "token_embd" ,
126
+ MODEL_TENSOR .TOKEN_TYPES : "token_types" ,
123
127
MODEL_TENSOR .POS_EMBD : "position_embd" ,
124
128
MODEL_TENSOR .OUTPUT_NORM : "output_norm" ,
125
129
MODEL_TENSOR .OUTPUT : "output" ,
@@ -206,6 +210,43 @@ class MODEL_TENSOR(IntEnum):
206
210
MODEL_TENSOR .FFN_DOWN ,
207
211
MODEL_TENSOR .FFN_UP ,
208
212
],
213
+ MODEL_ARCH .BERT : [
214
+ MODEL_TENSOR .TOKEN_EMBD ,
215
+ MODEL_TENSOR .TOKEN_TYPES ,
216
+ MODEL_TENSOR .POS_EMBD ,
217
+ MODEL_TENSOR .OUTPUT_NORM ,
218
+ MODEL_TENSOR .ATTN_NORM ,
219
+ MODEL_TENSOR .ATTN_Q ,
220
+ MODEL_TENSOR .ATTN_K ,
221
+ MODEL_TENSOR .ATTN_V ,
222
+ MODEL_TENSOR .ATTN_OUT ,
223
+ MODEL_TENSOR .FFN_NORM ,
224
+ MODEL_TENSOR .FFN_DOWN ,
225
+ MODEL_TENSOR .FFN_UP ,
226
+ ],
227
+ MODEL_ARCH .MPT : [
228
+ MODEL_TENSOR .TOKEN_EMBD ,
229
+ MODEL_TENSOR .OUTPUT_NORM ,
230
+ MODEL_TENSOR .OUTPUT ,
231
+ MODEL_TENSOR .ATTN_NORM ,
232
+ MODEL_TENSOR .ATTN_QKV ,
233
+ MODEL_TENSOR .ATTN_OUT ,
234
+ MODEL_TENSOR .FFN_NORM ,
235
+ MODEL_TENSOR .FFN_DOWN ,
236
+ MODEL_TENSOR .FFN_UP ,
237
+ ],
238
+ MODEL_ARCH .GPTJ : [
239
+ MODEL_TENSOR .TOKEN_EMBD ,
240
+ MODEL_TENSOR .OUTPUT_NORM ,
241
+ MODEL_TENSOR .OUTPUT ,
242
+ MODEL_TENSOR .ATTN_NORM ,
243
+ MODEL_TENSOR .ATTN_Q ,
244
+ MODEL_TENSOR .ATTN_K ,
245
+ MODEL_TENSOR .ATTN_V ,
246
+ MODEL_TENSOR .ATTN_OUT ,
247
+ MODEL_TENSOR .FFN_DOWN ,
248
+ MODEL_TENSOR .FFN_UP ,
249
+ ],
209
250
MODEL_ARCH .GPT2 : [
210
251
# TODO
211
252
],
@@ -229,31 +270,40 @@ class TensorNameMap:
229
270
mappings_cfg : dict [MODEL_TENSOR , tuple [str , ...]] = {
230
271
# Token embeddings
231
272
MODEL_TENSOR .TOKEN_EMBD : (
232
- "gpt_neox.embed_in" , # gptneox
233
- "transformer.wte" , # gpt2 mpt
234
- "transformer.word_embeddings" , # falcon
235
- "model.embed_tokens" , # llama-hf
236
- "tok_embeddings" , # llama-pth
273
+ "gpt_neox.embed_in" , # gptneox
274
+ "transformer.wte" , # gpt2 gpt-j mpt
275
+ "transformer.word_embeddings" , # falcon
276
+ "model.embed_tokens" , # llama-hf
277
+ "tok_embeddings" , # llama-pth
278
+ "embeddings.word_embeddings" , # bert
279
+ ),
280
+
281
+ # Token type embeddings
282
+ MODEL_TENSOR .TOKEN_TYPES : (
283
+ "embeddings.token_type_embeddings" , # bert
237
284
),
238
285
239
286
# Position embeddings
240
287
MODEL_TENSOR .POS_EMBD : (
241
- "transformer.wpe" , # gpt2
288
+ "transformer.wpe" , # gpt2
289
+ "embeddings.position_embeddings" , # bert
242
290
),
243
291
244
292
# Output
245
293
MODEL_TENSOR .OUTPUT : (
246
- "embed_out" , # gptneox
247
- "lm_head" , # gpt2 mpt falcon llama-hf baichuan
248
- "output" , # llama-pth
294
+ "embed_out" , # gptneox
295
+ "lm_head" , # gpt2 gpt-j mpt falcon llama-hf baichuan
296
+ "output" , # llama-pth
249
297
),
250
298
251
299
# Output norm
252
300
MODEL_TENSOR .OUTPUT_NORM : (
253
- "gpt_neox.final_layer_norm" , # gptneox
254
- "transformer.ln_f" , # gpt2 falcon
255
- "model.norm" , # llama-hf baichuan
256
- "norm" , # llama-pth
301
+ "gpt_neox.final_layer_norm" , # gptneox
302
+ "transformer.ln_f" , # gpt2 gpt-j falcon
303
+ "model.norm" , # llama-hf baichuan
304
+ "norm" , # llama-pth
305
+ "embeddings.LayerNorm" , # bert
306
+ "transformer.norm_f" , # mpt
257
307
),
258
308
259
309
# Rope frequencies
@@ -265,13 +315,14 @@ class TensorNameMap:
265
315
block_mappings_cfg : dict [MODEL_TENSOR , tuple [str , ...]] = {
266
316
# Attention norm
267
317
MODEL_TENSOR .ATTN_NORM : (
268
- "gpt_neox.layers.{bid}.input_layernorm" , # gptneox
269
- "transformer.h.{bid}.ln_1" , # gpt2
270
- "transformer.blocks.{bid}.norm_1" , # mpt
271
- "transformer.h.{bid}.input_layernorm" , # falcon7b
272
- "transformer.h.{bid}.ln_mlp" , # falcon40b
273
- "model.layers.{bid}.input_layernorm" , # llama-hf
274
- "layers.{bid}.attention_norm" , # llama-pth
318
+ "gpt_neox.layers.{bid}.input_layernorm" , # gptneox
319
+ "transformer.h.{bid}.ln_1" , # gpt2 gpt-j
320
+ "transformer.blocks.{bid}.norm_1" , # mpt
321
+ "transformer.h.{bid}.input_layernorm" , # falcon7b
322
+ "transformer.h.{bid}.ln_mlp" , # falcon40b
323
+ "model.layers.{bid}.input_layernorm" , # llama-hf
324
+ "layers.{bid}.attention_norm" , # llama-pth
325
+ "encoder.layer.{bid}.attention.output.LayerNorm" , # bert
275
326
),
276
327
277
328
# Attention norm 2
@@ -281,38 +332,46 @@ class TensorNameMap:
281
332
282
333
# Attention query-key-value
283
334
MODEL_TENSOR .ATTN_QKV : (
284
- "gpt_neox.layers.{bid}.attention.query_key_value" , # gptneox
285
- "transformer.h.{bid}.attn.c_attn" , # gpt2
286
- "transformer.blocks.{bid}.attn.Wqkv" , # mpt
287
- "transformer.h.{bid}.self_attention.query_key_value" , # falcon
335
+ "gpt_neox.layers.{bid}.attention.query_key_value" , # gptneox
336
+ "transformer.h.{bid}.attn.c_attn" , # gpt2
337
+ "transformer.blocks.{bid}.attn.Wqkv" , # mpt
338
+ "transformer.h.{bid}.self_attention.query_key_value" , # falcon
288
339
),
289
340
290
341
# Attention query
291
342
MODEL_TENSOR .ATTN_Q : (
292
- "model.layers.{bid}.self_attn.q_proj" , # llama-hf
293
- "layers.{bid}.attention.wq" , # llama-pth
343
+ "model.layers.{bid}.self_attn.q_proj" , # llama-hf
344
+ "layers.{bid}.attention.wq" , # llama-pth
345
+ "encoder.layer.{bid}.attention.self.query" , # bert
346
+ "transformer.h.{bid}.attn.q_proj" , # gpt-j
294
347
),
295
348
296
349
# Attention key
297
350
MODEL_TENSOR .ATTN_K : (
298
- "model.layers.{bid}.self_attn.k_proj" , # llama-hf
299
- "layers.{bid}.attention.wk" , # llama-pth
351
+ "model.layers.{bid}.self_attn.k_proj" , # llama-hf
352
+ "layers.{bid}.attention.wk" , # llama-pth
353
+ "encoder.layer.{bid}.attention.self.key" , # bert
354
+ "transformer.h.{bid}.attn.k_proj" , # gpt-j
300
355
),
301
356
302
357
# Attention value
303
358
MODEL_TENSOR .ATTN_V : (
304
- "model.layers.{bid}.self_attn.v_proj" , # llama-hf
305
- "layers.{bid}.attention.wv" , # llama-pth
359
+ "model.layers.{bid}.self_attn.v_proj" , # llama-hf
360
+ "layers.{bid}.attention.wv" , # llama-pth
361
+ "encoder.layer.{bid}.attention.self.value" , # bert
362
+ "transformer.h.{bid}.attn.v_proj" , # gpt-j
306
363
),
307
364
308
365
# Attention output
309
366
MODEL_TENSOR .ATTN_OUT : (
310
- "gpt_neox.layers.{bid}.attention.dense" , # gptneox
311
- "transformer.h.{bid}.attn.c_proj" , # gpt2
312
- "transformer.blocks.{bid}.attn.out_proj" , # mpt
313
- "transformer.h.{bid}.self_attention.dense" , # falcon
314
- "model.layers.{bid}.self_attn.o_proj" , # llama-hf
315
- "layers.{bid}.attention.wo" , # llama-pth
367
+ "gpt_neox.layers.{bid}.attention.dense" , # gptneox
368
+ "transformer.h.{bid}.attn.c_proj" , # gpt2
369
+ "transformer.blocks.{bid}.attn.out_proj" , # mpt
370
+ "transformer.h.{bid}.self_attention.dense" , # falcon
371
+ "model.layers.{bid}.self_attn.o_proj" , # llama-hf
372
+ "layers.{bid}.attention.wo" , # llama-pth
373
+ "encoder.layer.{bid}.attention.output.dense" , # bert
374
+ "transformer.h.{bid}.attn.out_proj" , # gpt-j
316
375
),
317
376
318
377
# Rotary embeddings
@@ -323,21 +382,24 @@ class TensorNameMap:
323
382
324
383
# Feed-forward norm
325
384
MODEL_TENSOR .FFN_NORM : (
326
- "gpt_neox.layers.{bid}.post_attention_layernorm" , # gptneox
327
- "transformer.h.{bid}.ln_2" , # gpt2
328
- "transformer.blocks.{bid}.norm_2" , # mpt
329
- "model.layers.{bid}.post_attention_layernorm" , # llama-hf
330
- "layers.{bid}.ffn_norm" , # llama-pth
385
+ "gpt_neox.layers.{bid}.post_attention_layernorm" , # gptneox
386
+ "transformer.h.{bid}.ln_2" , # gpt2
387
+ "transformer.blocks.{bid}.norm_2" , # mpt
388
+ "model.layers.{bid}.post_attention_layernorm" , # llama-hf
389
+ "layers.{bid}.ffn_norm" , # llama-pth
390
+ "encoder.layer.{bid}.output.LayerNorm" , # bert
331
391
),
332
392
333
393
# Feed-forward up
334
394
MODEL_TENSOR .FFN_UP : (
335
- "gpt_neox.layers.{bid}.mlp.dense_h_to_4h" , # gptneox
336
- "transformer.h.{bid}.mlp.c_fc" , # gpt2
337
- "transformer.blocks.{bid}.ffn.up_proj" , # mpt
338
- "transformer.h.{bid}.mlp.dense_h_to_4h" , # falcon
339
- "model.layers.{bid}.mlp.up_proj" , # llama-hf
340
- "layers.{bid}.feed_forward.w3" , # llama-pth
395
+ "gpt_neox.layers.{bid}.mlp.dense_h_to_4h" , # gptneox
396
+ "transformer.h.{bid}.mlp.c_fc" , # gpt2
397
+ "transformer.blocks.{bid}.ffn.up_proj" , # mpt
398
+ "transformer.h.{bid}.mlp.dense_h_to_4h" , # falcon
399
+ "model.layers.{bid}.mlp.up_proj" , # llama-hf
400
+ "layers.{bid}.feed_forward.w3" , # llama-pth
401
+ "encoder.layer.{bid}.intermediate.dense" , # bert
402
+ "transformer.h.{bid}.mlp.fc_in" , # gpt-j
341
403
),
342
404
343
405
# Feed-forward gate
@@ -348,12 +410,14 @@ class TensorNameMap:
348
410
349
411
# Feed-forward down
350
412
MODEL_TENSOR .FFN_DOWN : (
351
- "gpt_neox.layers.{bid}.mlp.dense_4h_to_h" , # gptneox
352
- "transformer.h.{bid}.mlp.c_proj" , # gpt2
353
- "transformer.blocks.{bid}.ffn.down_proj" , # mpt
354
- "transformer.h.{bid}.mlp.dense_4h_to_h" , # falcon
355
- "model.layers.{bid}.mlp.down_proj" , # llama-hf
356
- "layers.{bid}.feed_forward.w2" , # llama-pth
413
+ "gpt_neox.layers.{bid}.mlp.dense_4h_to_h" , # gptneox
414
+ "transformer.h.{bid}.mlp.c_proj" , # gpt2
415
+ "transformer.blocks.{bid}.ffn.down_proj" , # mpt
416
+ "transformer.h.{bid}.mlp.dense_4h_to_h" , # falcon
417
+ "model.layers.{bid}.mlp.down_proj" , # llama-hf
418
+ "layers.{bid}.feed_forward.w2" , # llama-pth
419
+ "encoder.layer.{bid}.output.dense" , # bert
420
+ "transformer.h.{bid}.mlp.fc_out" , # gpt-j
357
421
),
358
422
}
359
423
0 commit comments