Skip to content

Commit 19bc80b

Browse files
author
Chengzhe Xu
committed
feat: Refactor LLM runner and implemented support for Qwen family
1 parent 095b5cf commit 19bc80b

File tree

13 files changed

+681
-312
lines changed

13 files changed

+681
-312
lines changed
File renamed without changes.

examples/dynamo/llama3_trt.py renamed to examples/dynamo/llm/run_llm.py

Lines changed: 63 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -19,63 +19,29 @@
1919
import torch_tensorrt
2020
from transformers import AutoModelForCausalLM, AutoTokenizer
2121
from contextlib import nullcontext
22-
from utils import export_llm, generate, recordStats, time_generate, generate_with_kv_cache, get_zeroed_kv_cache_inputs
22+
from utils import export_llm, generate, recordStats, time_generate, generate_with_kv_cache
23+
import sys
24+
import os
2325

26+
# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
27+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
28+
from register_sdpa import *
2429

2530
DEVICE = torch.device("cuda:0")
2631

2732
def get_model(args):
2833
with torch.no_grad():
29-
if args.model == "meta-llama/Llama-2-7b-chat-hf":
30-
model = (
31-
AutoModelForCausalLM.from_pretrained(
32-
args.model,
33-
use_cache=False,
34-
attn_implementation="sdpa",
35-
num_hidden_layers=1
36-
)
37-
.eval()
38-
.cuda()
39-
)
40-
elif args.model == "meta-llama/Llama-3.2-1B-Instruct":
41-
model = (
42-
AutoModelForCausalLM.from_pretrained(
43-
args.model,
44-
use_cache=False,
45-
attn_implementation="sdpa",
46-
num_hidden_layers=1
47-
)
48-
.eval()
49-
.cuda()
50-
)
51-
52-
elif args.model == "meta-llama/Llama-3.2-3B-Instruct":
53-
model = (
34+
# Supported list of models:
35+
# - meta-llama/Llama-3.2-1B-Instruct
36+
# - meta-llama/Llama-3.2-3B-Instruct
37+
# - meta-llama/Llama-3.1-8B-Instruct
38+
# - Qwen/Qwen2.5-1.5B-Instruct
39+
model = (
5440
AutoModelForCausalLM.from_pretrained(
5541
args.model,
5642
use_cache=False,
5743
attn_implementation="sdpa",
58-
# num_hidden_layers=2
59-
)
60-
.eval()
61-
.cuda()
62-
)
63-
elif args.model == "meta-llama/Llama-3.1-8B-Instruct":
64-
model = (
65-
AutoModelForCausalLM.from_pretrained(
66-
args.model,
67-
use_cache=False,
68-
attn_implementation="sdpa", # num_hidden_layers=1
69-
)
70-
.eval()
71-
.cuda()
72-
)
73-
elif args.model == "google/gemma-3-1b-it":
74-
model = (
75-
AutoModelForCausalLM.from_pretrained(
76-
"google/gemma-3-1b-it",
77-
use_cache=False,
78-
attn_implementation="sdpa"
44+
# num_hidden_layers=1
7945
)
8046
.eval()
8147
.cuda()
@@ -91,9 +57,9 @@ def get_model(args):
9157

9258

9359
def compile_torchtrt(model, input_ids, args):
94-
max_seq_len = input_ids.shape[1] + args.max_tokens
60+
max_seq_len = input_ids.shape[1] + args.num_tokens
9561
ep = export_llm(model, input_ids, max_seq_len=max_seq_len)
96-
62+
9763
# Set precision specific flags
9864
use_fp32_acc = False
9965
use_explicit_typing = False
@@ -119,6 +85,7 @@ def compile_torchtrt(model, input_ids, args):
11985
disable_tf32=True,
12086
use_python_runtime=True,
12187
debug=args.debug,
88+
offload_module_to_cpu=True,
12289
min_block_size=args.min_block_size,
12390
)
12491

@@ -170,23 +137,29 @@ def measure_perf(trt_model, input_signature, backend_name):
170137
"--model", type=str, default="meta-llama/Llama-3.2-1B-Instruct", help="Name of LLM model"
171138
)
172139
arg_parser.add_argument(
173-
"--tokenizer_path",
140+
"--tokenizer",
174141
type=str,
175-
default="meta-llama/Llama-3.2-1B-Instruct",
142+
default="",
176143
help="Name of LLM model tokenizer",
177144
)
178145
arg_parser.add_argument(
179146
"--prompt", type=str, default="What is parallel programming ?", help="Prompt"
180147
)
181-
arg_parser.add_argument("--precision", type=str, default="FP16", help="Prompt")
148+
arg_parser.add_argument("--precision", type=str, default="FP16", help="Precision to use in the model. Options: FP16, BF16, FP32")
182149
arg_parser.add_argument(
183150
"--iterations", type=int, default=5, help="no. of iterations to run"
184151
)
185152
arg_parser.add_argument(
186153
"--min_block_size", type=int, default=1, help="no. of iterations to run"
187154
)
188155
arg_parser.add_argument(
189-
"--max_tokens", type=int, default=128, help="no. of max tokens to be generated"
156+
"--num_tokens", type=int, default=128, help="no. of output tokens to be generated"
157+
)
158+
arg_parser.add_argument(
159+
"--batch_size", type=int, default=1, help="Batch size used for benchmarking"
160+
)
161+
arg_parser.add_argument(
162+
"--isl", type=int, default=2048, help="Input sequence length used for benchmarking"
190163
)
191164
arg_parser.add_argument(
192165
"--enable_pytorch_run",
@@ -196,8 +169,8 @@ def measure_perf(trt_model, input_signature, backend_name):
196169
arg_parser.add_argument(
197170
"--cache",
198171
type=str,
199-
default="static",
200-
help="Type of KV cache to use",
172+
default="",
173+
help="Type of KV cache to use. Options: static_v1, static_v2, dynamic",
201174
)
202175
arg_parser.add_argument(
203176
"--cudagraph",
@@ -214,22 +187,24 @@ def measure_perf(trt_model, input_signature, backend_name):
214187
action="store_true",
215188
help="Enable benchmark (default: False)"
216189
)
190+
217191
args = arg_parser.parse_args()
218192
with torch.inference_mode():
219193
model = get_model(args)
220194

221-
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
195+
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer or args.model)
222196

223-
prompt = "What is parallel programming ?"
224-
# prompt = "What is the capital of France ?"
225-
model_inputs = tokenizer(prompt, return_tensors="pt")
226-
input_ids = model_inputs["input_ids"].to(DEVICE)
227-
# Prepare input prompt
228-
# word = "What"
229-
# word_ids = tokenizer(word, return_tensors="pt").input_ids[0] # Get the first (and only) sequence
230-
# input_ids = word_ids.repeat(1024).unsqueeze(0).to(model.device) # Add batch dimension and move to device
197+
# Prepare input for benchmarking or evaluation
198+
if args.benchmark:
199+
input_ids = torch.randint(1, 10000, (args.batch_size, args.isl), dtype=torch.int64).to(model.device)
200+
position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE)
201+
else:
202+
model_inputs = tokenizer(args.prompt, return_tensors="pt")
203+
input_ids = model_inputs["input_ids"].to(DEVICE)
204+
position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE)
205+
231206

232-
MAX_OUTPUT_SEQ_LENGTH = input_ids.shape[1] + args.max_tokens
207+
MAX_OUTPUT_SEQ_LENGTH = input_ids.shape[1] + args.num_tokens
233208
# Pyt
234209
pyt_gen_tokens = None
235210
pyt_timings = None
@@ -238,7 +213,6 @@ def measure_perf(trt_model, input_signature, backend_name):
238213
pyt_gen_tokens = generate(
239214
model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id
240215
)
241-
242216
if args.benchmark:
243217
pyt_timings = time_generate(
244218
generate,
@@ -249,71 +223,22 @@ def measure_perf(trt_model, input_signature, backend_name):
249223
iterations=args.iterations,
250224
)
251225
pyt_stats = recordStats(
252-
"PyTorch", pyt_timings, args.precision, batch_size=1, compile_time_s=None
226+
"PyTorch", pyt_timings, args.precision, batch_size=args.batch_size, compile_time_s=None
253227
)
254228

255-
# TRT
256-
pyt_logits_tok1 = model.cuda()(input_ids)
257-
next_tokens = torch.argmax(pyt_logits_tok1.logits[:, -1, :], dim=-1)
258-
input_seq = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
259-
pyt_logits_tok2 = model.cuda()(input_seq)
260-
from lower_sdpa import *
261-
if args.cache == "static":
262-
# This import is required to register static KV cache transformations as lowering passes
263-
from static_cache2 import *
264-
trt_model = compile_torchtrt(model, input_ids, args)
265-
kv_cache = get_zeroed_kv_cache_inputs(trt_model)
266-
267-
# First token generation
268-
pyt_keys = torch.load("key.pt"); pyt_values = torch.load("value.pt")
269-
trt_logits, key_cache, value_cache, trt_keys_1, trt_values_1 = trt_model(input_ids.clone(), True, *kv_cache, 0, input_ids.shape[1])
270-
print(f"Diff between pyt and trt logits: {torch.mean(torch.abs(pyt_logits_tok1.logits - trt_logits))}")
271-
print(f"Diff between pyt and trt keys: {torch.mean(torch.abs(pyt_keys - trt_keys_1))}")
272-
print(f"Diff between pyt and trt keys in cache: {torch.mean(torch.abs(pyt_keys - key_cache[:, :, :-2, :]))}")
273-
print(f"Diff between pyt and trt values: {torch.mean(torch.abs(pyt_values - trt_values_1))}")
274-
print(f"Diff between pyt and trt values in cache: {torch.mean(torch.abs(pyt_values - value_cache[:, :, :-2, :]))}")
275-
next_tokens = torch.argmax(trt_logits[:, -1, :], dim=-1)
276-
277-
# Second token generation
278-
trt_logits_2, key_cache2, value_cache2, trt_keys_2, trt_values_2 = trt_model(next_tokens[:, None], False, key_cache.clone(), value_cache.clone(), input_ids.shape[1], input_ids.shape[1]+1)
279-
pyt_keys2 = torch.load("key2.pt"); pyt_values2 = torch.load("value2.pt")
280-
print(f"Diff between pyt and trt logits: {torch.mean(torch.abs(pyt_logits_tok2.logits[:, -1:, :] - trt_logits_2))}")
281-
print(f"Diff between pyt and trt keys: {torch.mean(torch.abs(pyt_keys2[:, :, -2:-1, :] - trt_keys_2))}")
282-
print(f"Diff between pyt and trt keys in cache: {torch.mean(torch.abs(pyt_keys2 - key_cache2[:, :, :-1, :]))}")
283-
print(f"Diff between pyt and trt values: {torch.mean(torch.abs(pyt_values2[:, :, -2:-1, :] - trt_values_2))}")
284-
print(f"Diff between pyt and trt values in cache: {torch.mean(torch.abs(pyt_values2 - value_cache2[:, :, :-1, :]))}")
285-
breakpoint()
229+
if args.cache == "static_v1":
230+
# This import is required to register static v1 KV cache transformations as lowering passes
231+
import static_cache_v1
232+
if args.cache == "static_v2":
233+
# This import is required to register static v2 KV cache transformations as lowering passes
234+
import static_cache_v2
286235
elif args.cache == "dynamic":
287-
from dynamic_cache import *
288-
trt_model = compile_torchtrt(model, input_ids, args)
289-
breakpoint()
290-
kv_cache = get_zeroed_kv_cache_inputs(trt_model)
291-
else:
292-
# pyt_logits = model.cuda()(input_ids.clone())
293-
trt_model = compile_torchtrt(model, input_ids, args)
294-
# trt_logits = trt_model(input_ids.clone(), True)
295-
# print(f"Diff between pyt and trt: {torch.mean(torch.abs(pyt_logits - trt_logits))}")
296-
# print(f"Diff between pyt and trt logits: {torch.mean(torch.abs(pyt_logits.logits - trt_logits.logits))}")
297-
if args.cache == "static":
298-
if args.cudagraph:
299-
# Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases.
300-
# trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model)
301-
torch_tensorrt.runtime.set_cudagraphs_mode(True)
302-
303-
trt_gen_tokens = generate_with_kv_cache(
304-
trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id,
305-
)
236+
import dynamic_cache
306237

307-
if args.benchmark:
308-
trt_timings = time_generate(
309-
generate_with_kv_cache,
310-
trt_model,
311-
input_ids.clone(),
312-
MAX_OUTPUT_SEQ_LENGTH,
313-
tokenizer.eos_token_id,
314-
iterations=args.iterations,
315-
)
316-
elif args.cache == "dynamic":
238+
239+
trt_model = compile_torchtrt(model, input_ids, args)
240+
241+
if args.cache == "static_v1" or args.cache == "static_v2" or args.cache == "dynamic":
317242
if args.cudagraph:
318243
# Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases.
319244
# trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model)
@@ -332,7 +257,6 @@ def measure_perf(trt_model, input_signature, backend_name):
332257
tokenizer.eos_token_id,
333258
iterations=args.iterations,
334259
)
335-
336260
else:
337261
trt_gen_tokens = generate(
338262
trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id,
@@ -349,14 +273,20 @@ def measure_perf(trt_model, input_signature, backend_name):
349273

350274
if args.benchmark:
351275
trt_stats = recordStats(
352-
"TensorRT", trt_timings, args.precision, batch_size=1, compile_time_s=None
276+
"TensorRT", trt_timings, args.precision, batch_size=args.batch_size, compile_time_s=None
353277
)
354278

355-
if args.enable_pytorch_run:
356-
print_outputs("PyTorch", pyt_gen_tokens, tokenizer)
357-
print_outputs("TensorRT", trt_gen_tokens, tokenizer)
279+
280+
if not args.benchmark:
281+
if args.enable_pytorch_run:
282+
print_outputs("PyTorch", pyt_gen_tokens, tokenizer)
283+
284+
print_outputs("TensorRT", trt_gen_tokens, tokenizer)
358285

359-
if args.benchmark:
286+
if args.enable_pytorch_run:
287+
print(f"PyTorch and TensorRT outputs match: {torch.equal(pyt_gen_tokens, trt_gen_tokens)}")
288+
289+
if args.benchmark:
360290
if args.enable_pytorch_run:
361291
print("=========PyTorch PERFORMANCE============ \n")
362292
print(pyt_stats)

examples/dynamo/static_cache.py renamed to examples/dynamo/llm/static_cache_v1.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,15 @@ def get_static_tensor(tensor: torch.Tensor):
118118
start_idx_input.meta["val"] = start_idx_unbacked_symint
119119
end_idx_input.meta["val"] = end_idx_unbacked_symint
120120

121-
return kv_inputs, start_idx_input, end_idx_input
121+
# Add is_causal as input
122+
is_causal_input = add_graph_input(gm, "is_causal", True)
123+
is_causal_input.meta["val"] = torch.tensor(True)
122124

125+
return kv_inputs, start_idx_input, end_idx_input, is_causal_input
123126

124127

125-
def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node):
128+
129+
def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node, is_causal_input: Node):
126130
"""
127131
Insert slicing operations before each scaled_dot_product_attention operation.
128132
"""
@@ -133,7 +137,8 @@ def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Ten
133137
sdpa_nodes.append(node)
134138
kv_cache_for_graph = []
135139
for idx, sdpa_node in enumerate(sdpa_nodes):
136-
q_node, k_node, v_node = sdpa_node.args[:3]
140+
assert len(sdpa_node.args) == 6, f"SDPA node should have 6 arguments but got {len(sdpa_node.args)} arguments"
141+
q_node, k_node, v_node, attn_mask, dropout_p, is_causal = sdpa_node.args
137142
incoming_key, incoming_value = incoming_keys_values[idx]
138143
kv_cache_for_sdpa_node = []
139144
new_keys_values = []
@@ -231,7 +236,7 @@ def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Ten
231236

232237
kv_cache_for_graph.extend(kv_cache_for_sdpa_node)
233238

234-
sdpa_node.args = (q_node, new_keys_values[0], new_keys_values[1]) + sdpa_node.args[3:]
239+
sdpa_node.args = (q_node, new_keys_values[0], new_keys_values[1]) + (attn_mask, dropout_p, is_causal_input)
235240

236241
return gm, kv_cache_for_graph
237242

@@ -243,11 +248,11 @@ def insert_kv_cache(
243248
"""Insert KV cache ops in the graph"""
244249
"""Perform insertion of kv-caches and attention kernel."""
245250
# Add static key and value as inputs to the graph
246-
kv_inputs, start_idx_input, end_idx_input = add_kv_cache_inputs(gm, fixed_kv=True)
251+
kv_inputs, start_idx_input, end_idx_input, is_causal_input = add_kv_cache_inputs(gm, fixed_kv=True)
247252

248253
# Build and update the KV cache using computed KV inputs for current token and
249254
# incoming keys and values from previous tokens (which were added as inputs)
250-
gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input)
255+
gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input, is_causal_input)
251256

252257
# Call the function to add KV as outputs
253258
logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph)

examples/dynamo/static_cache2.py renamed to examples/dynamo/llm/static_cache_v2.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def get_static_tensor(tensor: torch.Tensor):
9797
start_idx_input = add_graph_input(gm, "start_idx", torch.tensor(0))
9898
end_idx_input = add_graph_input(gm, "end_idx", torch.tensor(1))
9999

100-
# Get the max sequence length from the first key_cache node. The order of nodes is: input_ids, is_causal, key_cache1, value_cache1, key_cache2, value_cache2, ..
100+
# Get the max sequence length from the first key_cache node. The order of input nodes is: input_ids, key_cache1, value_cache1, key_cache2, value_cache2, start_idx, end_idx
101101
input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
102102
# Get the third last input which should be the last value cache node and store the max_seq_len
103103
input_ids_meta = input_nodes[-3].meta["val"]
@@ -232,7 +232,8 @@ def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Ten
232232
sdpa_nodes.append(node)
233233
kv_cache_for_graph = []
234234
for idx, sdpa_node in enumerate(sdpa_nodes):
235-
q_node, k_node, v_node = sdpa_node.args[:3]
235+
assert len(sdpa_node.args) == 6, f"SDPA node should have 6 arguments but got {len(sdpa_node.args)} arguments"
236+
q_node, k_node, v_node, attn_mask, dropout_p, is_causal = sdpa_node.args
236237
incoming_key, incoming_value = incoming_keys_values[idx]
237238
# For keys
238239
new_current_key_node, new_incoming_key_cache_node = create_kv_cache_update_nodes(gm, sdpa_node, k_node, incoming_key, start_idx_input, end_idx_input)
@@ -243,9 +244,9 @@ def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Ten
243244
kv_cache_for_graph.extend([new_incoming_key_cache_node, new_incoming_value_cache_node])
244245

245246
# Update the SDPA node arguments with current key and value nodes
246-
sdpa_node.args = (q_node, new_current_key_node, new_current_value_node) + (None, is_causal_input) # + sdpa_node.args[3:]
247+
sdpa_node.args = (q_node, new_current_key_node, new_current_value_node) + (attn_mask, dropout_p, is_causal_input)
247248

248-
kv_cache_for_graph.extend([k_node, v_node])
249+
# kv_cache_for_graph.extend([k_node, v_node])
249250
return gm, kv_cache_for_graph
250251

251252

0 commit comments

Comments
 (0)