Skip to content

Commit 28ee6f5

Browse files
authored
New embedding quant fusion
Differential Revision: D73381542 Pull Request resolved: #10325
1 parent ad1b154 commit 28ee6f5

File tree

6 files changed

+431
-6
lines changed

6 files changed

+431
-6
lines changed

exir/TARGETS

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ python_library(
1616
"//caffe2:torch",
1717
"//executorch/exir/operator:convert",
1818
"//executorch/extension/pytree:pylib",
19+
"//pytorch/ao:torchao",
1920
],
2021
)
2122

exir/passes/_quant_patterns_and_replacements.py

+297-3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,56 @@
2222
"get_quant_patterns_and_replacements",
2323
]
2424

25+
26+
from torch import Tensor
27+
from torch.library import custom_op
28+
29+
30+
@custom_op("quant_fusion::_pack_embedding_weight", mutates_args=())
31+
def _pack_embedding_weight(weight: Tensor, bitwidth: int) -> Tensor:
32+
num_embeddings, embedding_dim = weight.shape
33+
34+
if bitwidth == 2:
35+
assert embedding_dim % 4 == 0, "embedding_dim must be divisible by 4"
36+
weight_range_shifted = weight.add(2).view(torch.uint8)
37+
weight_view = weight_range_shifted.view(num_embeddings, embedding_dim // 4, 4)
38+
weight_0 = weight_view[:, :, 0]
39+
weight_1 = weight_view[:, :, 1] << 2
40+
weight_2 = weight_view[:, :, 2] << 4
41+
weight_3 = weight_view[:, :, 3] << 6
42+
packed_weight = weight_0 | weight_1 | weight_2 | weight_3
43+
return packed_weight
44+
elif bitwidth == 4:
45+
assert embedding_dim % 2 == 0, "embedding_dim must be divisible by 2"
46+
weight_range_shifted = weight.add(8).view(torch.uint8)
47+
weight_view = weight_range_shifted.view(
48+
weight.shape[0], weight.shape[1] // 2, 2
49+
)
50+
weight_even = weight_view[:, :, 0] << 4
51+
weight_odd = weight_view[:, :, 1]
52+
packed_weight = weight_even | weight_odd
53+
return packed_weight
54+
elif bitwidth == 8:
55+
return weight
56+
57+
raise RuntimeError(f"Unsupported bitwidth {bitwidth}")
58+
59+
60+
# Use register_fake to add a ``FakeTensor`` kernel for the operator
61+
@_pack_embedding_weight.register_fake
62+
def _(weight, bit_width):
63+
assert bit_width in [2, 4, 8]
64+
num_embeddings, embedding_dim = weight.shape
65+
values_per_byte = 8 // bit_width
66+
assert embedding_dim % values_per_byte == 0
67+
return torch.empty(
68+
num_embeddings,
69+
embedding_dim // values_per_byte,
70+
dtype=torch.uint8,
71+
device=weight.device,
72+
)
73+
74+
2575
# TODO: extending an existing library that is defined in OSS might be a bit
2676
# confusing, we can investigate if it is possible to define a new library
2777

@@ -69,9 +119,10 @@ def embedding_weight_checks(weight, weight_scales, weight_zero_points):
69119
assert (
70120
weight_zero_points is None or weight_zero_points.dtype == weight_scales.dtype
71121
), "Expecting weight_zero_points to be None or have same dtype as weight_scales"
72-
assert (
73-
weight_zero_points is None or weight_zero_points.dim() == 1
74-
), f"Expecting weight_zero_points tensor to be None or have dim()==1, but found {weight_zero_points.dim()}"
122+
assert weight_zero_points is None or weight_zero_points.dim() in [
123+
1,
124+
2,
125+
], f"Expecting weight_zero_points tensor to be None or have dim()==1, but found {weight_zero_points.dim()}"
75126
assert weight_zero_points is None or weight_zero_points.size(0) == weight.size(
76127
0
77128
), f"Expecting weight_zero_points tensor to be None or have same number of rows as weights, but found {weight.size()} and {weight_zero_points.size()}"
@@ -234,6 +285,21 @@ def embedding_2bit(
234285
return torch.ops.aten.embedding.default(weight, indices)
235286

236287

288+
@register_fake("quantized_decomposed::embedding_2bit")
289+
def _(
290+
weight: torch.Tensor,
291+
weight_scales: torch.Tensor,
292+
weight_zero_points: Optional[torch.Tensor],
293+
weight_quant_min: int,
294+
weight_quant_max: int,
295+
indices: torch.Tensor,
296+
):
297+
num_embeddings, packed_embedding_dim = weight.shape
298+
embedding_dim = packed_embedding_dim * 4
299+
embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device)
300+
return embedding(indices)
301+
302+
237303
@register_fake("quantized_decomposed::embedding_2bit.out")
238304
def embedding_2bit_out_meta(
239305
weight: torch.Tensor,
@@ -296,6 +362,22 @@ def embedding_2bit_dtype(
296362
return torch.ops.aten.embedding.default(weight, indices)
297363

298364

365+
@register_fake("quantized_decomposed::embedding_2bit.dtype")
366+
def _(
367+
weight: torch.Tensor,
368+
weight_scales: torch.Tensor,
369+
weight_zero_points: Optional[torch.Tensor],
370+
weight_quant_min: int,
371+
weight_quant_max: int,
372+
indices: torch.Tensor,
373+
dtype: Optional[torch.dtype],
374+
) -> torch.Tensor:
375+
num_embeddings, packed_embedding_dim = weight.shape
376+
embedding_dim = packed_embedding_dim * 4
377+
embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device)
378+
return embedding(indices).to(dtype)
379+
380+
299381
@register_fake("quantized_decomposed::embedding_2bit.dtype_out")
300382
def embedding_2bit_dtype_out_meta(
301383
weight: torch.Tensor,
@@ -378,6 +460,21 @@ def embedding_4bit(
378460
return torch.ops.aten.embedding.default(weight, indices)
379461

380462

463+
@register_fake("quantized_decomposed::embedding_4bit")
464+
def _(
465+
weight: torch.Tensor,
466+
weight_scales: torch.Tensor,
467+
weight_zero_points: Optional[torch.Tensor],
468+
weight_quant_min: int,
469+
weight_quant_max: int,
470+
indices: torch.Tensor,
471+
):
472+
num_embeddings, packed_embedding_dim = weight.shape
473+
embedding_dim = packed_embedding_dim * 2
474+
embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device)
475+
return embedding(indices)
476+
477+
381478
@register_fake("quantized_decomposed::embedding_4bit.out")
382479
def embedding_4bit_out_meta(
383480
weight: torch.Tensor,
@@ -438,6 +535,22 @@ def embedding_4bit_dtype(
438535
return torch.ops.aten.embedding.default(weight, indices)
439536

440537

538+
@register_fake("quantized_decomposed::embedding_4bit.dtype")
539+
def _(
540+
weight: torch.Tensor,
541+
weight_scales: torch.Tensor,
542+
weight_zero_points: Optional[torch.Tensor],
543+
weight_quant_min: int,
544+
weight_quant_max: int,
545+
indices: torch.Tensor,
546+
dtype: Optional[torch.dtype],
547+
) -> torch.Tensor:
548+
num_embeddings, packed_embedding_dim = weight.shape
549+
embedding_dim = packed_embedding_dim * 2
550+
embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device)
551+
return embedding(indices).to(dtype)
552+
553+
441554
@register_fake("quantized_decomposed::embedding_4bit.dtype_out")
442555
def embedding_4bit_dtype_out_meta(
443556
weight: torch.Tensor,
@@ -873,6 +986,186 @@ def replacement(x, dim, start, end, x_scale, x_zero_point, x_qmin, x_qmax):
873986
]
874987

875988

989+
def _get_embedding_ops_patterns_and_replacements_torchao() -> ( # noqa C901
990+
List[Tuple[Callable, Callable, List[Callable]]]
991+
):
992+
def embedding_byte_pattern(indices, int_data, group_size, scale, zero_point):
993+
dq = torch.ops.torchao.dequantize_affine.default(
994+
int_data, [1, group_size], scale, zero_point, torch.int8, -128, 127
995+
)
996+
return torch.ops.aten.embedding.default(dq, indices)
997+
998+
def embedding_byte_replacement(indices, int_data, group_size, scale, zero_point):
999+
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1000+
return torch.ops.quantized_decomposed.embedding_byte.default(
1001+
int_data,
1002+
scale,
1003+
zero_point_dtype_cast,
1004+
-128,
1005+
127,
1006+
indices,
1007+
)
1008+
1009+
def embedding_byte_dtype_pattern(
1010+
indices, int_data, group_size, scale, zero_point, output_dtype
1011+
):
1012+
dq = torch.ops.torchao.dequantize_affine.default(
1013+
int_data,
1014+
[1, group_size],
1015+
scale,
1016+
zero_point,
1017+
torch.int8,
1018+
-128,
1019+
127,
1020+
"INT",
1021+
output_dtype,
1022+
)
1023+
return torch.ops.aten.embedding.default(dq, indices)
1024+
1025+
def embedding_byte_dtype_replacement(
1026+
indices, int_data, group_size, scale, zero_point, output_dtype
1027+
):
1028+
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1029+
return torch.ops.quantized_decomposed.embedding_byte.dtype(
1030+
int_data,
1031+
scale,
1032+
zero_point_dtype_cast,
1033+
-128,
1034+
127,
1035+
indices,
1036+
dtype=output_dtype,
1037+
)
1038+
1039+
def embedding_2bit_pattern(indices, int_data, group_size, scale, zero_point):
1040+
dq = torch.ops.torchao.dequantize_affine.default(
1041+
int_data, [1, group_size], scale, zero_point, torch.int8, -2, 1
1042+
)
1043+
return torch.ops.aten.embedding.default(dq, indices)
1044+
1045+
def embedding_2bit_replacement(indices, int_data, group_size, scale, zero_point):
1046+
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(
1047+
int_data, 2
1048+
)
1049+
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1050+
return torch.ops.quantized_decomposed.embedding_2bit.default(
1051+
packed_int_data, scale, zero_point_dtype_cast, -2, 1, indices
1052+
)
1053+
1054+
def embedding_2bit_dtype_pattern(
1055+
indices, int_data, group_size, scale, zero_point, output_dtype
1056+
):
1057+
dq = torch.ops.torchao.dequantize_affine.default(
1058+
int_data,
1059+
[1, group_size],
1060+
scale,
1061+
zero_point,
1062+
torch.int8,
1063+
-2,
1064+
1,
1065+
"INT",
1066+
output_dtype,
1067+
)
1068+
return torch.ops.aten.embedding.default(dq, indices)
1069+
1070+
def embedding_2bit_dtype_replacement(
1071+
indices, int_data, group_size, scale, zero_point, output_dtype
1072+
):
1073+
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(
1074+
int_data, 2
1075+
)
1076+
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1077+
return torch.ops.quantized_decomposed.embedding_2bit.dtype(
1078+
packed_int_data,
1079+
scale,
1080+
zero_point_dtype_cast,
1081+
-2,
1082+
1,
1083+
indices,
1084+
dtype=output_dtype,
1085+
)
1086+
1087+
def embedding_4bit_pattern(indices, int_data, group_size, scale, zero_point):
1088+
dq = torch.ops.torchao.dequantize_affine.default(
1089+
int_data, [1, group_size], scale, zero_point, torch.int8, -8, 7
1090+
)
1091+
return torch.ops.aten.embedding.default(dq, indices)
1092+
1093+
def embedding_4bit_replacement(indices, int_data, group_size, scale, zero_point):
1094+
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(
1095+
int_data, 4
1096+
)
1097+
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1098+
return torch.ops.quantized_decomposed.embedding_4bit.default(
1099+
packed_int_data, scale, zero_point_dtype_cast, -8, 7, indices
1100+
)
1101+
1102+
def embedding_4bit_dtype_pattern(
1103+
indices, int_data, group_size, scale, zero_point, output_dtype
1104+
):
1105+
dq = torch.ops.torchao.dequantize_affine.default(
1106+
int_data,
1107+
[1, group_size],
1108+
scale,
1109+
zero_point,
1110+
torch.int8,
1111+
-8,
1112+
7,
1113+
"INT",
1114+
output_dtype,
1115+
)
1116+
return torch.ops.aten.embedding.default(dq, indices)
1117+
1118+
def embedding_4bit_dtype_replacement(
1119+
indices, int_data, group_size, scale, zero_point, output_dtype
1120+
):
1121+
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(
1122+
int_data, 4
1123+
)
1124+
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1125+
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
1126+
packed_int_data,
1127+
scale,
1128+
zero_point_dtype_cast,
1129+
-8,
1130+
7,
1131+
indices,
1132+
dtype=output_dtype,
1133+
)
1134+
1135+
return [
1136+
(
1137+
_trace_and_lower_to_edge_ops(embedding_byte_pattern),
1138+
_trace_and_lower_to_edge_ops(embedding_byte_replacement),
1139+
[],
1140+
),
1141+
(
1142+
_trace_and_lower_to_edge_ops(embedding_byte_dtype_pattern),
1143+
_trace_and_lower_to_edge_ops(embedding_byte_dtype_replacement),
1144+
[],
1145+
),
1146+
(
1147+
_trace_and_lower_to_edge_ops(embedding_2bit_pattern),
1148+
_trace_and_lower_to_edge_ops(embedding_2bit_replacement),
1149+
[],
1150+
),
1151+
(
1152+
_trace_and_lower_to_edge_ops(embedding_2bit_dtype_pattern),
1153+
_trace_and_lower_to_edge_ops(embedding_2bit_dtype_replacement),
1154+
[],
1155+
),
1156+
(
1157+
_trace_and_lower_to_edge_ops(embedding_4bit_pattern),
1158+
_trace_and_lower_to_edge_ops(embedding_4bit_replacement),
1159+
[],
1160+
),
1161+
(
1162+
_trace_and_lower_to_edge_ops(embedding_4bit_dtype_pattern),
1163+
_trace_and_lower_to_edge_ops(embedding_4bit_dtype_replacement),
1164+
[],
1165+
),
1166+
]
1167+
1168+
8761169
def _get_embedding_ops_patterns_and_replacements() -> (
8771170
List[Tuple[Callable, Callable, List[Callable]]]
8781171
):
@@ -1167,5 +1460,6 @@ def get_quant_patterns_and_replacements() -> (
11671460
*_get_slice_patterns_and_replacements(),
11681461
# *_get_fixed_qparams_ops_patterns_and_replacements(),
11691462
*_get_embedding_ops_patterns_and_replacements(),
1463+
*_get_embedding_ops_patterns_and_replacements_torchao(),
11701464
]
11711465
)

exir/passes/quant_fusion_pass.py

+13
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,18 @@ def _get_qparams(node):
9090
model.graph.erase_node(qnode)
9191

9292

93+
def _remove_dtype_getattr_nodes(model: GraphModule) -> None:
94+
for n in model.graph.nodes:
95+
if n.op == "call_function" and n.target == getattr:
96+
if isinstance(n.args[0], torch.fx.Node) and n.args[1] == "dtype":
97+
dtype = n.args[0].meta["val"].dtype
98+
n.replace_all_uses_with(dtype)
99+
model.graph.erase_node(n)
100+
model.graph.eliminate_dead_code()
101+
model.graph.lint()
102+
model.recompile()
103+
104+
93105
class QuantFusionPass(ExportPass):
94106
def __init__(self, _fix_node_meta_val=False):
95107
super().__init__()
@@ -123,6 +135,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
123135
torch.fx.Node, lambda x: x.meta["val"], (n.args, n.kwargs)
124136
)
125137
n.meta["val"] = n.target(*args, **kwargs)
138+
_remove_dtype_getattr_nodes(graph_module)
126139
graph_module.graph.lint()
127140
graph_module.graph.eliminate_dead_code()
128141
return PassResult(graph_module, True)

exir/tests/TARGETS

+2
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,8 @@ python_unittest(
298298
"//caffe2:torch",
299299
"//executorch/exir:lib",
300300
"//executorch/exir/passes:quant_fusion_pass",
301+
"//pytorch/ao:torchao",
302+
"//executorch/exir/passes:constant_prop_pass",
301303
],
302304
)
303305

0 commit comments

Comments
 (0)