19
19
import torch_tensorrt
20
20
from transformers import AutoModelForCausalLM , AutoTokenizer
21
21
from contextlib import nullcontext
22
- from utils import export_llm , generate , recordStats , time_generate , generate_with_kv_cache , get_zeroed_kv_cache_inputs
22
+ from utils import export_llm , generate , recordStats , time_generate , generate_with_kv_cache
23
+ import sys
24
+ import os
23
25
26
+ # Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
27
+ sys .path .append (os .path .join (os .path .dirname (__file__ ), '..' ))
28
+ from register_sdpa import *
24
29
25
30
DEVICE = torch .device ("cuda:0" )
26
31
27
32
def get_model (args ):
28
33
with torch .no_grad ():
29
- if args .model == "meta-llama/Llama-2-7b-chat-hf" :
30
- model = (
31
- AutoModelForCausalLM .from_pretrained (
32
- args .model ,
33
- use_cache = False ,
34
- attn_implementation = "sdpa" ,
35
- num_hidden_layers = 1
36
- )
37
- .eval ()
38
- .cuda ()
39
- )
40
- elif args .model == "meta-llama/Llama-3.2-1B-Instruct" :
41
- model = (
42
- AutoModelForCausalLM .from_pretrained (
43
- args .model ,
44
- use_cache = False ,
45
- attn_implementation = "sdpa" ,
46
- num_hidden_layers = 1
47
- )
48
- .eval ()
49
- .cuda ()
50
- )
51
-
52
- elif args .model == "meta-llama/Llama-3.2-3B-Instruct" :
53
- model = (
34
+ # Supported list of models:
35
+ # - meta-llama/Llama-3.2-1B-Instruct
36
+ # - meta-llama/Llama-3.2-3B-Instruct
37
+ # - meta-llama/Llama-3.1-8B-Instruct
38
+ # - Qwen/Qwen2.5-1.5B-Instruct
39
+ model = (
54
40
AutoModelForCausalLM .from_pretrained (
55
41
args .model ,
56
42
use_cache = False ,
57
43
attn_implementation = "sdpa" ,
58
- # num_hidden_layers=2
59
- )
60
- .eval ()
61
- .cuda ()
62
- )
63
- elif args .model == "meta-llama/Llama-3.1-8B-Instruct" :
64
- model = (
65
- AutoModelForCausalLM .from_pretrained (
66
- args .model ,
67
- use_cache = False ,
68
- attn_implementation = "sdpa" , # num_hidden_layers=1
69
- )
70
- .eval ()
71
- .cuda ()
72
- )
73
- elif args .model == "google/gemma-3-1b-it" :
74
- model = (
75
- AutoModelForCausalLM .from_pretrained (
76
- "google/gemma-3-1b-it" ,
77
- use_cache = False ,
78
- attn_implementation = "sdpa"
44
+ # num_hidden_layers=1
79
45
)
80
46
.eval ()
81
47
.cuda ()
@@ -91,9 +57,9 @@ def get_model(args):
91
57
92
58
93
59
def compile_torchtrt (model , input_ids , args ):
94
- max_seq_len = input_ids .shape [1 ] + args .max_tokens
60
+ max_seq_len = input_ids .shape [1 ] + args .num_tokens
95
61
ep = export_llm (model , input_ids , max_seq_len = max_seq_len )
96
-
62
+
97
63
# Set precision specific flags
98
64
use_fp32_acc = False
99
65
use_explicit_typing = False
@@ -119,6 +85,7 @@ def compile_torchtrt(model, input_ids, args):
119
85
disable_tf32 = True ,
120
86
use_python_runtime = True ,
121
87
debug = args .debug ,
88
+ offload_module_to_cpu = True ,
122
89
min_block_size = args .min_block_size ,
123
90
)
124
91
@@ -170,23 +137,29 @@ def measure_perf(trt_model, input_signature, backend_name):
170
137
"--model" , type = str , default = "meta-llama/Llama-3.2-1B-Instruct" , help = "Name of LLM model"
171
138
)
172
139
arg_parser .add_argument (
173
- "--tokenizer_path " ,
140
+ "--tokenizer " ,
174
141
type = str ,
175
- default = "meta-llama/Llama-3.2-1B-Instruct " ,
142
+ default = "" ,
176
143
help = "Name of LLM model tokenizer" ,
177
144
)
178
145
arg_parser .add_argument (
179
146
"--prompt" , type = str , default = "What is parallel programming ?" , help = "Prompt"
180
147
)
181
- arg_parser .add_argument ("--precision" , type = str , default = "FP16" , help = "Prompt " )
148
+ arg_parser .add_argument ("--precision" , type = str , default = "FP16" , help = "Precision to use in the model. Options: FP16, BF16, FP32 " )
182
149
arg_parser .add_argument (
183
150
"--iterations" , type = int , default = 5 , help = "no. of iterations to run"
184
151
)
185
152
arg_parser .add_argument (
186
153
"--min_block_size" , type = int , default = 1 , help = "no. of iterations to run"
187
154
)
188
155
arg_parser .add_argument (
189
- "--max_tokens" , type = int , default = 128 , help = "no. of max tokens to be generated"
156
+ "--num_tokens" , type = int , default = 128 , help = "no. of output tokens to be generated"
157
+ )
158
+ arg_parser .add_argument (
159
+ "--batch_size" , type = int , default = 1 , help = "Batch size used for benchmarking"
160
+ )
161
+ arg_parser .add_argument (
162
+ "--isl" , type = int , default = 2048 , help = "Input sequence length used for benchmarking"
190
163
)
191
164
arg_parser .add_argument (
192
165
"--enable_pytorch_run" ,
@@ -196,8 +169,8 @@ def measure_perf(trt_model, input_signature, backend_name):
196
169
arg_parser .add_argument (
197
170
"--cache" ,
198
171
type = str ,
199
- default = "static " ,
200
- help = "Type of KV cache to use" ,
172
+ default = "" ,
173
+ help = "Type of KV cache to use. Options: static_v1, static_v2, dynamic " ,
201
174
)
202
175
arg_parser .add_argument (
203
176
"--cudagraph" ,
@@ -214,22 +187,24 @@ def measure_perf(trt_model, input_signature, backend_name):
214
187
action = "store_true" ,
215
188
help = "Enable benchmark (default: False)"
216
189
)
190
+
217
191
args = arg_parser .parse_args ()
218
192
with torch .inference_mode ():
219
193
model = get_model (args )
220
194
221
- tokenizer = AutoTokenizer .from_pretrained (args .tokenizer_path )
195
+ tokenizer = AutoTokenizer .from_pretrained (args .tokenizer or args . model )
222
196
223
- prompt = "What is parallel programming ?"
224
- # prompt = "What is the capital of France ?"
225
- model_inputs = tokenizer (prompt , return_tensors = "pt" )
226
- input_ids = model_inputs ["input_ids" ].to (DEVICE )
227
- # Prepare input prompt
228
- # word = "What"
229
- # word_ids = tokenizer(word, return_tensors="pt").input_ids[0] # Get the first (and only) sequence
230
- # input_ids = word_ids.repeat(1024).unsqueeze(0).to(model.device) # Add batch dimension and move to device
197
+ # Prepare input for benchmarking or evaluation
198
+ if args .benchmark :
199
+ input_ids = torch .randint (1 , 10000 , (args .batch_size , args .isl ), dtype = torch .int64 ).to (model .device )
200
+ position_ids = torch .arange (input_ids .shape [1 ]).unsqueeze (0 ).to (DEVICE )
201
+ else :
202
+ model_inputs = tokenizer (args .prompt , return_tensors = "pt" )
203
+ input_ids = model_inputs ["input_ids" ].to (DEVICE )
204
+ position_ids = torch .arange (input_ids .shape [1 ]).unsqueeze (0 ).to (DEVICE )
205
+
231
206
232
- MAX_OUTPUT_SEQ_LENGTH = input_ids .shape [1 ] + args .max_tokens
207
+ MAX_OUTPUT_SEQ_LENGTH = input_ids .shape [1 ] + args .num_tokens
233
208
# Pyt
234
209
pyt_gen_tokens = None
235
210
pyt_timings = None
@@ -238,7 +213,6 @@ def measure_perf(trt_model, input_signature, backend_name):
238
213
pyt_gen_tokens = generate (
239
214
model , input_ids .clone (), MAX_OUTPUT_SEQ_LENGTH , tokenizer .eos_token_id
240
215
)
241
-
242
216
if args .benchmark :
243
217
pyt_timings = time_generate (
244
218
generate ,
@@ -249,71 +223,22 @@ def measure_perf(trt_model, input_signature, backend_name):
249
223
iterations = args .iterations ,
250
224
)
251
225
pyt_stats = recordStats (
252
- "PyTorch" , pyt_timings , args .precision , batch_size = 1 , compile_time_s = None
226
+ "PyTorch" , pyt_timings , args .precision , batch_size = args . batch_size , compile_time_s = None
253
227
)
254
228
255
- # TRT
256
- pyt_logits_tok1 = model .cuda ()(input_ids )
257
- next_tokens = torch .argmax (pyt_logits_tok1 .logits [:, - 1 , :], dim = - 1 )
258
- input_seq = torch .cat ([input_ids , next_tokens [:, None ]], dim = - 1 )
259
- pyt_logits_tok2 = model .cuda ()(input_seq )
260
- from lower_sdpa import *
261
- if args .cache == "static" :
262
- # This import is required to register static KV cache transformations as lowering passes
263
- from static_cache2 import *
264
- trt_model = compile_torchtrt (model , input_ids , args )
265
- kv_cache = get_zeroed_kv_cache_inputs (trt_model )
266
-
267
- # First token generation
268
- pyt_keys = torch .load ("key.pt" ); pyt_values = torch .load ("value.pt" )
269
- trt_logits , key_cache , value_cache , trt_keys_1 , trt_values_1 = trt_model (input_ids .clone (), True , * kv_cache , 0 , input_ids .shape [1 ])
270
- print (f"Diff between pyt and trt logits: { torch .mean (torch .abs (pyt_logits_tok1 .logits - trt_logits ))} " )
271
- print (f"Diff between pyt and trt keys: { torch .mean (torch .abs (pyt_keys - trt_keys_1 ))} " )
272
- print (f"Diff between pyt and trt keys in cache: { torch .mean (torch .abs (pyt_keys - key_cache [:, :, :- 2 , :]))} " )
273
- print (f"Diff between pyt and trt values: { torch .mean (torch .abs (pyt_values - trt_values_1 ))} " )
274
- print (f"Diff between pyt and trt values in cache: { torch .mean (torch .abs (pyt_values - value_cache [:, :, :- 2 , :]))} " )
275
- next_tokens = torch .argmax (trt_logits [:, - 1 , :], dim = - 1 )
276
-
277
- # Second token generation
278
- trt_logits_2 , key_cache2 , value_cache2 , trt_keys_2 , trt_values_2 = trt_model (next_tokens [:, None ], False , key_cache .clone (), value_cache .clone (), input_ids .shape [1 ], input_ids .shape [1 ]+ 1 )
279
- pyt_keys2 = torch .load ("key2.pt" ); pyt_values2 = torch .load ("value2.pt" )
280
- print (f"Diff between pyt and trt logits: { torch .mean (torch .abs (pyt_logits_tok2 .logits [:, - 1 :, :] - trt_logits_2 ))} " )
281
- print (f"Diff between pyt and trt keys: { torch .mean (torch .abs (pyt_keys2 [:, :, - 2 :- 1 , :] - trt_keys_2 ))} " )
282
- print (f"Diff between pyt and trt keys in cache: { torch .mean (torch .abs (pyt_keys2 - key_cache2 [:, :, :- 1 , :]))} " )
283
- print (f"Diff between pyt and trt values: { torch .mean (torch .abs (pyt_values2 [:, :, - 2 :- 1 , :] - trt_values_2 ))} " )
284
- print (f"Diff between pyt and trt values in cache: { torch .mean (torch .abs (pyt_values2 - value_cache2 [:, :, :- 1 , :]))} " )
285
- breakpoint ()
229
+ if args .cache == "static_v1" :
230
+ # This import is required to register static v1 KV cache transformations as lowering passes
231
+ import static_cache_v1
232
+ if args .cache == "static_v2" :
233
+ # This import is required to register static v2 KV cache transformations as lowering passes
234
+ import static_cache_v2
286
235
elif args .cache == "dynamic" :
287
- from dynamic_cache import *
288
- trt_model = compile_torchtrt (model , input_ids , args )
289
- breakpoint ()
290
- kv_cache = get_zeroed_kv_cache_inputs (trt_model )
291
- else :
292
- # pyt_logits = model.cuda()(input_ids.clone())
293
- trt_model = compile_torchtrt (model , input_ids , args )
294
- # trt_logits = trt_model(input_ids.clone(), True)
295
- # print(f"Diff between pyt and trt: {torch.mean(torch.abs(pyt_logits - trt_logits))}")
296
- # print(f"Diff between pyt and trt logits: {torch.mean(torch.abs(pyt_logits.logits - trt_logits.logits))}")
297
- if args .cache == "static" :
298
- if args .cudagraph :
299
- # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases.
300
- # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model)
301
- torch_tensorrt .runtime .set_cudagraphs_mode (True )
302
-
303
- trt_gen_tokens = generate_with_kv_cache (
304
- trt_model , input_ids .clone (), MAX_OUTPUT_SEQ_LENGTH , tokenizer .eos_token_id ,
305
- )
236
+ import dynamic_cache
306
237
307
- if args .benchmark :
308
- trt_timings = time_generate (
309
- generate_with_kv_cache ,
310
- trt_model ,
311
- input_ids .clone (),
312
- MAX_OUTPUT_SEQ_LENGTH ,
313
- tokenizer .eos_token_id ,
314
- iterations = args .iterations ,
315
- )
316
- elif args .cache == "dynamic" :
238
+
239
+ trt_model = compile_torchtrt (model , input_ids , args )
240
+
241
+ if args .cache == "static_v1" or args .cache == "static_v2" or args .cache == "dynamic" :
317
242
if args .cudagraph :
318
243
# Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases.
319
244
# trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model)
@@ -332,7 +257,6 @@ def measure_perf(trt_model, input_signature, backend_name):
332
257
tokenizer .eos_token_id ,
333
258
iterations = args .iterations ,
334
259
)
335
-
336
260
else :
337
261
trt_gen_tokens = generate (
338
262
trt_model , input_ids .clone (), MAX_OUTPUT_SEQ_LENGTH , tokenizer .eos_token_id ,
@@ -349,14 +273,20 @@ def measure_perf(trt_model, input_signature, backend_name):
349
273
350
274
if args .benchmark :
351
275
trt_stats = recordStats (
352
- "TensorRT" , trt_timings , args .precision , batch_size = 1 , compile_time_s = None
276
+ "TensorRT" , trt_timings , args .precision , batch_size = args . batch_size , compile_time_s = None
353
277
)
354
278
355
- if args .enable_pytorch_run :
356
- print_outputs ("PyTorch" , pyt_gen_tokens , tokenizer )
357
- print_outputs ("TensorRT" , trt_gen_tokens , tokenizer )
279
+
280
+ if not args .benchmark :
281
+ if args .enable_pytorch_run :
282
+ print_outputs ("PyTorch" , pyt_gen_tokens , tokenizer )
283
+
284
+ print_outputs ("TensorRT" , trt_gen_tokens , tokenizer )
358
285
359
- if args .benchmark :
286
+ if args .enable_pytorch_run :
287
+ print (f"PyTorch and TensorRT outputs match: { torch .equal (pyt_gen_tokens , trt_gen_tokens )} " )
288
+
289
+ if args .benchmark :
360
290
if args .enable_pytorch_run :
361
291
print ("=========PyTorch PERFORMANCE============ \n " )
362
292
print (pyt_stats )
0 commit comments