Skip to content

Commit 9a4f32d

Browse files
authored
Qualcomm AI Engine Direct - Add source transform for kv cache and sdpa
Differential Revision: D60913232 Pull Request resolved: #4555
1 parent 5c27045 commit 9a4f32d

File tree

8 files changed

+279
-28
lines changed

8 files changed

+279
-28
lines changed

backends/qualcomm/passes/replace_inf_buffer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,18 @@
88

99

1010
class ReplaceInfBuffer(ExportPass):
11+
"""
12+
Due to limitation in Qnn, we need to change inf or -inf to arbitrary value in quantization.
13+
"""
14+
1115
def __init__(self):
1216
super(ReplaceInfBuffer, self).__init__()
1317

1418
def call(self, graph_module: torch.fx.GraphModule):
1519
for buf_name, tensor in graph_module.named_buffers():
1620
if tensor.is_floating_point():
17-
tensor[tensor == float("inf")] = torch.finfo(torch.float32).max
18-
tensor[tensor == float("-inf")] = torch.finfo(torch.float32).min
21+
tensor[tensor == float("inf")] = 255
22+
tensor[tensor == float("-inf")] = -255
1923
setattr(graph_module, buf_name, tensor)
2024

2125
graph_module.recompile()
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Sequence
7+
8+
import torch
9+
from executorch.backends.qualcomm.quantizer.quantizer import (
10+
get_16a8w_qnn_ptq_config,
11+
get_default_8bit_qnn_ptq_config,
12+
QuantizationConfig,
13+
)
14+
from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY
15+
from torch.ao.quantization.quantizer import (
16+
QuantizationAnnotation,
17+
SharedQuantizationSpec,
18+
)
19+
from torch.fx import Node
20+
21+
22+
def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901
23+
"""
24+
This function is specific for llama matmul op 16a8w.
25+
"""
26+
27+
def annotate_matmul(node: Node, quantization_config: QuantizationConfig):
28+
input_qspec_map = {}
29+
input_act = node.args[0]
30+
input_spec = quantization_config.input_activation
31+
input_qspec_map[input_act] = input_spec
32+
input_act1 = node.args[1]
33+
input_spec1 = quantization_config.weight
34+
input_qspec_map[input_act1] = input_spec1
35+
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
36+
input_qspec_map=input_qspec_map,
37+
output_qspec=quantization_config.output_activation,
38+
_annotated=True,
39+
)
40+
41+
def annotate_index_put(node: Node, quantization_config: QuantizationConfig) -> None:
42+
input = node.args[0]
43+
value = node.args[2]
44+
input_qspec_map = {}
45+
input_qspec_map[input] = quantization_config.input_activation
46+
input_qspec_map[value] = SharedQuantizationSpec((input, node))
47+
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
48+
input_qspec_map=input_qspec_map,
49+
output_qspec=SharedQuantizationSpec((input, node)),
50+
_annotated=True,
51+
)
52+
53+
def annotate_single_in_single_out(
54+
node: Node, quantization_config: QuantizationConfig
55+
) -> None:
56+
input_qspec_map = {}
57+
input_act = node.args[0]
58+
input_qspec_map[input_act] = quantization_config.input_activation
59+
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
60+
input_qspec_map=input_qspec_map,
61+
output_qspec=quantization_config.output_activation,
62+
_annotated=True,
63+
)
64+
65+
def annotate_cat(node: Node, quantization_config: QuantizationConfig):
66+
input_nodes = node.args[0]
67+
assert isinstance(input_nodes, Sequence)
68+
first_input_node = input_nodes[0]
69+
input_qspec_map = {}
70+
assert isinstance(first_input_node, Node)
71+
assert isinstance(node, Node)
72+
input_qspec_map[first_input_node] = quantization_config.input_activation
73+
share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
74+
(first_input_node, node)
75+
)
76+
for input_node in input_nodes[1:]:
77+
if input_node not in input_qspec_map:
78+
assert isinstance(input_node, Node)
79+
input_qspec_map[input_node] = share_qparams_with_input_act0_qspec
80+
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
81+
input_qspec_map=input_qspec_map,
82+
output_qspec=share_qparams_with_input_act0_qspec,
83+
_annotated=True,
84+
)
85+
86+
def is_edge_condition(node: Node):
87+
if not isinstance(node, Node) or node.op != "call_function":
88+
return True
89+
return False
90+
91+
def annotate_matmul_input1(node: Node, quantization_config: QuantizationConfig):
92+
if is_edge_condition(node):
93+
return
94+
if node.target == torch.ops.aten.index_put_.default:
95+
annotate_index_put(node, quantization_config)
96+
annotate_matmul_input1(node.args[0], quantization_config)
97+
elif node.target == torch.ops.aten.cat.default:
98+
annotate_cat(node, quantization_config)
99+
# Expect that the inputs of the cat op are select ops
100+
for arg in node.args[0][1:]:
101+
annotate_single_in_single_out(arg, quantization_config)
102+
annotate_matmul_input1(node.args[0][0], quantization_config)
103+
else:
104+
annotate_single_in_single_out(node, quantization_config)
105+
annotate_matmul_input1(node.args[0], quantization_config)
106+
107+
# Annotate 16a8w for matmul op to get better performance
108+
quantization_config_16a8w = get_16a8w_qnn_ptq_config()
109+
# Annotate 8a8w for second input of matmul until past_kv_cache
110+
quantization_config_8a8w = get_default_8bit_qnn_ptq_config(act_symmetric=True)
111+
for node in gm.graph.nodes:
112+
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
113+
if "nn_module_stack" in node.meta:
114+
module_values_list = list(node.meta["nn_module_stack"].values())
115+
full_qualified_name = module_values_list[-1][0]
116+
if "SDPA" in full_qualified_name:
117+
annotate_matmul(node, quantization_config_16a8w)
118+
annotate_matmul_input1(node.args[1], quantization_config_8a8w)

backends/qualcomm/utils/utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,13 @@ def _transform(edge_program: ExportedProgram) -> None:
206206
FoldQDQ()(graph_module)
207207
LayoutTransform(edge_program)(graph_module)
208208

209+
# Since QDQ nodes are stripped, update graph signature again to validate program
210+
edge_program._graph_signature = _get_updated_graph_signature(
211+
edge_program.graph_signature,
212+
edge_program.graph_module,
213+
)
214+
edge_program._validate()
215+
209216

210217
def capture_program(
211218
module: torch.nn.Module,
@@ -222,12 +229,6 @@ def capture_program(
222229
core_ep.transform(ConvertBinaryOpsWithScalar())
223230
edge_ep = core_ep.to_edge(qnn_edge_config())
224231
_transform(edge_ep.exported_program)
225-
# Since QDQ nodes are stripped, update graph signature again to validate program
226-
edge_ep.exported_program._graph_signature = _get_updated_graph_signature(
227-
edge_ep.exported_program.graph_signature,
228-
edge_ep.exported_program.graph_module,
229-
)
230-
edge_ep.exported_program._validate()
231232
return edge_ep
232233

233234

examples/models/llama2/export_llama_lib.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@
5252
from .source_transformation.rope import materialze_broadcast_of_rope_freq_cis
5353
from .source_transformation.sdpa import (
5454
replace_causal_mask,
55+
replace_kv_cache_with_simple_kv_cache,
5556
replace_sdpa_with_custom_op,
57+
replace_sdpa_with_flex_sdpa,
5658
replace_sdpa_with_simple_sdpa,
5759
)
5860

@@ -385,7 +387,12 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
385387
transforms.append(replace_sdpa_with_custom_op)
386388

387389
if args.use_kv_cache:
388-
if args.qnn or args.coreml or args.mps:
390+
if args.qnn:
391+
transforms.append(replace_kv_cache_with_simple_kv_cache)
392+
transforms.append(replace_sdpa_with_flex_sdpa)
393+
transforms.append(replace_causal_mask)
394+
395+
elif args.coreml or args.mps:
389396
# Currently qnn/coreml/mps doesn't support sdpa op, use the simpler decomposition
390397
# to get free perf gain.
391398
transforms.append(replace_sdpa_with_simple_sdpa)

examples/models/llama2/llama_transformer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,9 @@ def __init__(
161161
else:
162162
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
163163

164+
self.max_batch_size = max_batch_size
165+
self.n_heads = n_heads
166+
self.head_dim = head_dim
164167
self.transpose_cache = transpose_cache
165168
self.enable_dynamic_shape = enable_dynamic_shape
166169
self.register_buffer(

examples/models/llama2/source_transformation/sdpa.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# Example script for exporting Llama2 to flatbuffer
1010

1111
import math
12+
from typing import Tuple
1213

1314
import torch
1415

@@ -112,6 +113,61 @@ def forward(
112113
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
113114

114115

116+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
117+
"""
118+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
119+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
120+
"""
121+
if n_rep == 1:
122+
return hidden_states
123+
124+
new_kv = []
125+
batch, n_heads, seqlen, head_dim = hidden_states.shape
126+
n_heads *= n_rep
127+
for h in hidden_states[0]:
128+
new_kv += [h] * n_rep
129+
return torch.cat(new_kv, 0).reshape(batch, n_heads, seqlen, head_dim)
130+
131+
132+
class SDPAFlex(torch.nn.Module):
133+
134+
def __init__(
135+
self,
136+
kv_cache: KVCache,
137+
dim: int,
138+
n_rep: int,
139+
):
140+
super().__init__()
141+
self.kv_cache = kv_cache
142+
self.dim = dim
143+
self.n_rep = n_rep
144+
145+
def forward(
146+
self,
147+
input_pos: torch.Tensor,
148+
q: torch.Tensor,
149+
k: torch.Tensor,
150+
v: torch.Tensor,
151+
bsz,
152+
seqlen,
153+
mask,
154+
):
155+
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
156+
157+
k, v = self.kv_cache.update(input_pos, k, v)
158+
k = repeat_kv(k, self.n_rep)
159+
v = repeat_kv(v, self.n_rep)
160+
attn_mask = mask[input_pos]
161+
162+
scale_factor = 1 / math.sqrt(q.size(-1))
163+
attn_weight = q @ k.transpose(-2, -1) * scale_factor
164+
attn_weight += attn_mask
165+
attn_weight = torch.softmax(attn_weight, dim=-1)
166+
y = attn_weight @ v
167+
168+
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
169+
170+
115171
def replace_sdpa_with_simple_sdpa(module: torch.nn.Module):
116172
for name, child in module.named_children():
117173
if isinstance(child, SDPA):
@@ -125,6 +181,71 @@ def replace_sdpa_with_simple_sdpa(module: torch.nn.Module):
125181
return module
126182

127183

184+
def replace_sdpa_with_flex_sdpa(module: torch.nn.Module):
185+
for name, child in module.named_children():
186+
if isinstance(child, SDPA):
187+
setattr(
188+
module,
189+
name,
190+
SDPAFlex(child.kv_cache, child.dim, child.n_rep),
191+
)
192+
else:
193+
replace_sdpa_with_flex_sdpa(child)
194+
return module
195+
196+
197+
class KVCacheSimple(torch.nn.Module):
198+
def __init__(
199+
self,
200+
max_batch_size: int,
201+
max_seq_length: int,
202+
n_heads: int,
203+
head_dim: int,
204+
dtype=torch.float32,
205+
):
206+
super().__init__()
207+
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
208+
self.register_buffer(
209+
"past_k_caches",
210+
torch.zeros(cache_shape, dtype=dtype, device="cpu"),
211+
persistent=False,
212+
)
213+
self.register_buffer(
214+
"past_v_caches",
215+
torch.zeros(cache_shape, dtype=dtype, device="cpu"),
216+
persistent=False,
217+
)
218+
219+
def update(
220+
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
221+
) -> Tuple[torch.Tensor, torch.Tensor]:
222+
k_out = torch.ops.aten.index_put_(self.past_k_caches, [None, input_pos], k_val)
223+
v_out = torch.ops.aten.index_put_(self.past_v_caches, [None, input_pos], v_val)
224+
225+
k_out = k_out.transpose(1, 2)
226+
v_out = v_out.transpose(1, 2)
227+
return k_out, v_out
228+
229+
230+
def replace_kv_cache_with_simple_kv_cache(module: torch.nn.Module):
231+
for name, child in module.named_children():
232+
if isinstance(child, KVCache):
233+
setattr(
234+
module,
235+
name,
236+
KVCacheSimple(
237+
child.max_batch_size,
238+
child.max_seq_length,
239+
child.n_heads,
240+
child.head_dim,
241+
child.k_cache.dtype,
242+
),
243+
)
244+
else:
245+
replace_kv_cache_with_simple_kv_cache(child)
246+
return module
247+
248+
128249
def replace_causal_mask(module: torch.nn.Module):
129250
for buffer_fqn_name, buffer in module.named_buffers():
130251
buffer_name = buffer_fqn_name.split(".")[-1]

extension/llm/export/partitioner_lib.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,6 @@ def get_qnn_partitioner(
116116
QnnPartitioner,
117117
)
118118

119-
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.quantizer.quantizer`
120-
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
121-
122119
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.serialization.qnn_compile_spec_schema`
123120
from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import (
124121
QcomChipset,
@@ -138,16 +135,6 @@ def get_qnn_partitioner(
138135
skip_node_op_set = {}
139136
if pt2e_quantize is not None:
140137
use_fp16 = False
141-
# TODO: fix the lowering error without skipping nodes
142-
143-
if quant_dtype == QuantDtype.use_8a8w:
144-
raise NotImplementedError("8a8w for llama is still under development")
145-
146-
elif quant_dtype == QuantDtype.use_16a16w:
147-
raise NotImplementedError("16a16w for llama is still under development")
148-
149-
elif quant_dtype == QuantDtype.use_16a4w:
150-
raise NotImplementedError("16a4w for llama is still under development")
151138

152139
return QnnPartitioner(
153140
generate_qnn_executorch_compiler_spec(

0 commit comments

Comments
 (0)