|
22 | 22 | "get_quant_patterns_and_replacements",
|
23 | 23 | ]
|
24 | 24 |
|
| 25 | + |
| 26 | +from torch import Tensor |
| 27 | +from torch.library import custom_op |
| 28 | + |
| 29 | + |
| 30 | +@custom_op("quant_fusion::_pack_embedding_weight", mutates_args=()) |
| 31 | +def _pack_embedding_weight(weight: Tensor, bitwidth: int) -> Tensor: |
| 32 | + num_embeddings, embedding_dim = weight.shape |
| 33 | + |
| 34 | + if bitwidth == 2: |
| 35 | + assert embedding_dim % 4 == 0, "embedding_dim must be divisible by 4" |
| 36 | + weight_range_shifted = weight.add(2).view(torch.uint8) |
| 37 | + weight_view = weight_range_shifted.view(num_embeddings, embedding_dim // 4, 4) |
| 38 | + weight_0 = weight_view[:, :, 0] |
| 39 | + weight_1 = weight_view[:, :, 1] << 2 |
| 40 | + weight_2 = weight_view[:, :, 2] << 4 |
| 41 | + weight_3 = weight_view[:, :, 3] << 6 |
| 42 | + packed_weight = weight_0 | weight_1 | weight_2 | weight_3 |
| 43 | + return packed_weight |
| 44 | + elif bitwidth == 4: |
| 45 | + assert embedding_dim % 2 == 0, "embedding_dim must be divisible by 2" |
| 46 | + weight_range_shifted = weight.add(8).view(torch.uint8) |
| 47 | + weight_view = weight_range_shifted.view( |
| 48 | + weight.shape[0], weight.shape[1] // 2, 2 |
| 49 | + ) |
| 50 | + weight_even = weight_view[:, :, 0] << 4 |
| 51 | + weight_odd = weight_view[:, :, 1] |
| 52 | + packed_weight = weight_even | weight_odd |
| 53 | + return packed_weight |
| 54 | + elif bitwidth == 8: |
| 55 | + return weight |
| 56 | + |
| 57 | + raise RuntimeError(f"Unsupported bitwidth {bitwidth}") |
| 58 | + |
| 59 | + |
| 60 | +# Use register_fake to add a ``FakeTensor`` kernel for the operator |
| 61 | +@_pack_embedding_weight.register_fake |
| 62 | +def _(weight, bit_width): |
| 63 | + assert bit_width in [2, 4, 8] |
| 64 | + num_embeddings, embedding_dim = weight.shape |
| 65 | + values_per_byte = 8 // bit_width |
| 66 | + assert embedding_dim % values_per_byte == 0 |
| 67 | + return torch.empty( |
| 68 | + num_embeddings, |
| 69 | + embedding_dim // values_per_byte, |
| 70 | + dtype=torch.uint8, |
| 71 | + device=weight.device, |
| 72 | + ) |
| 73 | + |
| 74 | + |
25 | 75 | # TODO: extending an existing library that is defined in OSS might be a bit
|
26 | 76 | # confusing, we can investigate if it is possible to define a new library
|
27 | 77 |
|
@@ -69,9 +119,10 @@ def embedding_weight_checks(weight, weight_scales, weight_zero_points):
|
69 | 119 | assert (
|
70 | 120 | weight_zero_points is None or weight_zero_points.dtype == weight_scales.dtype
|
71 | 121 | ), "Expecting weight_zero_points to be None or have same dtype as weight_scales"
|
72 |
| - assert ( |
73 |
| - weight_zero_points is None or weight_zero_points.dim() == 1 |
74 |
| - ), f"Expecting weight_zero_points tensor to be None or have dim()==1, but found {weight_zero_points.dim()}" |
| 122 | + assert weight_zero_points is None or weight_zero_points.dim() in [ |
| 123 | + 1, |
| 124 | + 2, |
| 125 | + ], f"Expecting weight_zero_points tensor to be None or have dim()==1, but found {weight_zero_points.dim()}" |
75 | 126 | assert weight_zero_points is None or weight_zero_points.size(0) == weight.size(
|
76 | 127 | 0
|
77 | 128 | ), f"Expecting weight_zero_points tensor to be None or have same number of rows as weights, but found {weight.size()} and {weight_zero_points.size()}"
|
@@ -234,6 +285,21 @@ def embedding_2bit(
|
234 | 285 | return torch.ops.aten.embedding.default(weight, indices)
|
235 | 286 |
|
236 | 287 |
|
| 288 | +@register_fake("quantized_decomposed::embedding_2bit") |
| 289 | +def _( |
| 290 | + weight: torch.Tensor, |
| 291 | + weight_scales: torch.Tensor, |
| 292 | + weight_zero_points: Optional[torch.Tensor], |
| 293 | + weight_quant_min: int, |
| 294 | + weight_quant_max: int, |
| 295 | + indices: torch.Tensor, |
| 296 | +): |
| 297 | + num_embeddings, packed_embedding_dim = weight.shape |
| 298 | + embedding_dim = packed_embedding_dim * 4 |
| 299 | + embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device) |
| 300 | + return embedding(indices) |
| 301 | + |
| 302 | + |
237 | 303 | @register_fake("quantized_decomposed::embedding_2bit.out")
|
238 | 304 | def embedding_2bit_out_meta(
|
239 | 305 | weight: torch.Tensor,
|
@@ -296,6 +362,22 @@ def embedding_2bit_dtype(
|
296 | 362 | return torch.ops.aten.embedding.default(weight, indices)
|
297 | 363 |
|
298 | 364 |
|
| 365 | +@register_fake("quantized_decomposed::embedding_2bit.dtype") |
| 366 | +def _( |
| 367 | + weight: torch.Tensor, |
| 368 | + weight_scales: torch.Tensor, |
| 369 | + weight_zero_points: Optional[torch.Tensor], |
| 370 | + weight_quant_min: int, |
| 371 | + weight_quant_max: int, |
| 372 | + indices: torch.Tensor, |
| 373 | + dtype: Optional[torch.dtype], |
| 374 | +) -> torch.Tensor: |
| 375 | + num_embeddings, packed_embedding_dim = weight.shape |
| 376 | + embedding_dim = packed_embedding_dim * 4 |
| 377 | + embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device) |
| 378 | + return embedding(indices).to(dtype) |
| 379 | + |
| 380 | + |
299 | 381 | @register_fake("quantized_decomposed::embedding_2bit.dtype_out")
|
300 | 382 | def embedding_2bit_dtype_out_meta(
|
301 | 383 | weight: torch.Tensor,
|
@@ -378,6 +460,21 @@ def embedding_4bit(
|
378 | 460 | return torch.ops.aten.embedding.default(weight, indices)
|
379 | 461 |
|
380 | 462 |
|
| 463 | +@register_fake("quantized_decomposed::embedding_4bit") |
| 464 | +def _( |
| 465 | + weight: torch.Tensor, |
| 466 | + weight_scales: torch.Tensor, |
| 467 | + weight_zero_points: Optional[torch.Tensor], |
| 468 | + weight_quant_min: int, |
| 469 | + weight_quant_max: int, |
| 470 | + indices: torch.Tensor, |
| 471 | +): |
| 472 | + num_embeddings, packed_embedding_dim = weight.shape |
| 473 | + embedding_dim = packed_embedding_dim * 2 |
| 474 | + embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device) |
| 475 | + return embedding(indices) |
| 476 | + |
| 477 | + |
381 | 478 | @register_fake("quantized_decomposed::embedding_4bit.out")
|
382 | 479 | def embedding_4bit_out_meta(
|
383 | 480 | weight: torch.Tensor,
|
@@ -438,6 +535,22 @@ def embedding_4bit_dtype(
|
438 | 535 | return torch.ops.aten.embedding.default(weight, indices)
|
439 | 536 |
|
440 | 537 |
|
| 538 | +@register_fake("quantized_decomposed::embedding_4bit.dtype") |
| 539 | +def _( |
| 540 | + weight: torch.Tensor, |
| 541 | + weight_scales: torch.Tensor, |
| 542 | + weight_zero_points: Optional[torch.Tensor], |
| 543 | + weight_quant_min: int, |
| 544 | + weight_quant_max: int, |
| 545 | + indices: torch.Tensor, |
| 546 | + dtype: Optional[torch.dtype], |
| 547 | +) -> torch.Tensor: |
| 548 | + num_embeddings, packed_embedding_dim = weight.shape |
| 549 | + embedding_dim = packed_embedding_dim * 2 |
| 550 | + embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device) |
| 551 | + return embedding(indices).to(dtype) |
| 552 | + |
| 553 | + |
441 | 554 | @register_fake("quantized_decomposed::embedding_4bit.dtype_out")
|
442 | 555 | def embedding_4bit_dtype_out_meta(
|
443 | 556 | weight: torch.Tensor,
|
@@ -873,6 +986,186 @@ def replacement(x, dim, start, end, x_scale, x_zero_point, x_qmin, x_qmax):
|
873 | 986 | ]
|
874 | 987 |
|
875 | 988 |
|
| 989 | +def _get_embedding_ops_patterns_and_replacements_torchao() -> ( # noqa C901 |
| 990 | + List[Tuple[Callable, Callable, List[Callable]]] |
| 991 | +): |
| 992 | + def embedding_byte_pattern(indices, int_data, group_size, scale, zero_point): |
| 993 | + dq = torch.ops.torchao.dequantize_affine.default( |
| 994 | + int_data, [1, group_size], scale, zero_point, torch.int8, -128, 127 |
| 995 | + ) |
| 996 | + return torch.ops.aten.embedding.default(dq, indices) |
| 997 | + |
| 998 | + def embedding_byte_replacement(indices, int_data, group_size, scale, zero_point): |
| 999 | + zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) |
| 1000 | + return torch.ops.quantized_decomposed.embedding_byte.default( |
| 1001 | + int_data, |
| 1002 | + scale, |
| 1003 | + zero_point_dtype_cast, |
| 1004 | + -128, |
| 1005 | + 127, |
| 1006 | + indices, |
| 1007 | + ) |
| 1008 | + |
| 1009 | + def embedding_byte_dtype_pattern( |
| 1010 | + indices, int_data, group_size, scale, zero_point, output_dtype |
| 1011 | + ): |
| 1012 | + dq = torch.ops.torchao.dequantize_affine.default( |
| 1013 | + int_data, |
| 1014 | + [1, group_size], |
| 1015 | + scale, |
| 1016 | + zero_point, |
| 1017 | + torch.int8, |
| 1018 | + -128, |
| 1019 | + 127, |
| 1020 | + "INT", |
| 1021 | + output_dtype, |
| 1022 | + ) |
| 1023 | + return torch.ops.aten.embedding.default(dq, indices) |
| 1024 | + |
| 1025 | + def embedding_byte_dtype_replacement( |
| 1026 | + indices, int_data, group_size, scale, zero_point, output_dtype |
| 1027 | + ): |
| 1028 | + zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) |
| 1029 | + return torch.ops.quantized_decomposed.embedding_byte.dtype( |
| 1030 | + int_data, |
| 1031 | + scale, |
| 1032 | + zero_point_dtype_cast, |
| 1033 | + -128, |
| 1034 | + 127, |
| 1035 | + indices, |
| 1036 | + dtype=output_dtype, |
| 1037 | + ) |
| 1038 | + |
| 1039 | + def embedding_2bit_pattern(indices, int_data, group_size, scale, zero_point): |
| 1040 | + dq = torch.ops.torchao.dequantize_affine.default( |
| 1041 | + int_data, [1, group_size], scale, zero_point, torch.int8, -2, 1 |
| 1042 | + ) |
| 1043 | + return torch.ops.aten.embedding.default(dq, indices) |
| 1044 | + |
| 1045 | + def embedding_2bit_replacement(indices, int_data, group_size, scale, zero_point): |
| 1046 | + packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default( |
| 1047 | + int_data, 2 |
| 1048 | + ) |
| 1049 | + zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) |
| 1050 | + return torch.ops.quantized_decomposed.embedding_2bit.default( |
| 1051 | + packed_int_data, scale, zero_point_dtype_cast, -2, 1, indices |
| 1052 | + ) |
| 1053 | + |
| 1054 | + def embedding_2bit_dtype_pattern( |
| 1055 | + indices, int_data, group_size, scale, zero_point, output_dtype |
| 1056 | + ): |
| 1057 | + dq = torch.ops.torchao.dequantize_affine.default( |
| 1058 | + int_data, |
| 1059 | + [1, group_size], |
| 1060 | + scale, |
| 1061 | + zero_point, |
| 1062 | + torch.int8, |
| 1063 | + -2, |
| 1064 | + 1, |
| 1065 | + "INT", |
| 1066 | + output_dtype, |
| 1067 | + ) |
| 1068 | + return torch.ops.aten.embedding.default(dq, indices) |
| 1069 | + |
| 1070 | + def embedding_2bit_dtype_replacement( |
| 1071 | + indices, int_data, group_size, scale, zero_point, output_dtype |
| 1072 | + ): |
| 1073 | + packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default( |
| 1074 | + int_data, 2 |
| 1075 | + ) |
| 1076 | + zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) |
| 1077 | + return torch.ops.quantized_decomposed.embedding_2bit.dtype( |
| 1078 | + packed_int_data, |
| 1079 | + scale, |
| 1080 | + zero_point_dtype_cast, |
| 1081 | + -2, |
| 1082 | + 1, |
| 1083 | + indices, |
| 1084 | + dtype=output_dtype, |
| 1085 | + ) |
| 1086 | + |
| 1087 | + def embedding_4bit_pattern(indices, int_data, group_size, scale, zero_point): |
| 1088 | + dq = torch.ops.torchao.dequantize_affine.default( |
| 1089 | + int_data, [1, group_size], scale, zero_point, torch.int8, -8, 7 |
| 1090 | + ) |
| 1091 | + return torch.ops.aten.embedding.default(dq, indices) |
| 1092 | + |
| 1093 | + def embedding_4bit_replacement(indices, int_data, group_size, scale, zero_point): |
| 1094 | + packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default( |
| 1095 | + int_data, 4 |
| 1096 | + ) |
| 1097 | + zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) |
| 1098 | + return torch.ops.quantized_decomposed.embedding_4bit.default( |
| 1099 | + packed_int_data, scale, zero_point_dtype_cast, -8, 7, indices |
| 1100 | + ) |
| 1101 | + |
| 1102 | + def embedding_4bit_dtype_pattern( |
| 1103 | + indices, int_data, group_size, scale, zero_point, output_dtype |
| 1104 | + ): |
| 1105 | + dq = torch.ops.torchao.dequantize_affine.default( |
| 1106 | + int_data, |
| 1107 | + [1, group_size], |
| 1108 | + scale, |
| 1109 | + zero_point, |
| 1110 | + torch.int8, |
| 1111 | + -8, |
| 1112 | + 7, |
| 1113 | + "INT", |
| 1114 | + output_dtype, |
| 1115 | + ) |
| 1116 | + return torch.ops.aten.embedding.default(dq, indices) |
| 1117 | + |
| 1118 | + def embedding_4bit_dtype_replacement( |
| 1119 | + indices, int_data, group_size, scale, zero_point, output_dtype |
| 1120 | + ): |
| 1121 | + packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default( |
| 1122 | + int_data, 4 |
| 1123 | + ) |
| 1124 | + zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) |
| 1125 | + return torch.ops.quantized_decomposed.embedding_4bit.dtype( |
| 1126 | + packed_int_data, |
| 1127 | + scale, |
| 1128 | + zero_point_dtype_cast, |
| 1129 | + -8, |
| 1130 | + 7, |
| 1131 | + indices, |
| 1132 | + dtype=output_dtype, |
| 1133 | + ) |
| 1134 | + |
| 1135 | + return [ |
| 1136 | + ( |
| 1137 | + _trace_and_lower_to_edge_ops(embedding_byte_pattern), |
| 1138 | + _trace_and_lower_to_edge_ops(embedding_byte_replacement), |
| 1139 | + [], |
| 1140 | + ), |
| 1141 | + ( |
| 1142 | + _trace_and_lower_to_edge_ops(embedding_byte_dtype_pattern), |
| 1143 | + _trace_and_lower_to_edge_ops(embedding_byte_dtype_replacement), |
| 1144 | + [], |
| 1145 | + ), |
| 1146 | + ( |
| 1147 | + _trace_and_lower_to_edge_ops(embedding_2bit_pattern), |
| 1148 | + _trace_and_lower_to_edge_ops(embedding_2bit_replacement), |
| 1149 | + [], |
| 1150 | + ), |
| 1151 | + ( |
| 1152 | + _trace_and_lower_to_edge_ops(embedding_2bit_dtype_pattern), |
| 1153 | + _trace_and_lower_to_edge_ops(embedding_2bit_dtype_replacement), |
| 1154 | + [], |
| 1155 | + ), |
| 1156 | + ( |
| 1157 | + _trace_and_lower_to_edge_ops(embedding_4bit_pattern), |
| 1158 | + _trace_and_lower_to_edge_ops(embedding_4bit_replacement), |
| 1159 | + [], |
| 1160 | + ), |
| 1161 | + ( |
| 1162 | + _trace_and_lower_to_edge_ops(embedding_4bit_dtype_pattern), |
| 1163 | + _trace_and_lower_to_edge_ops(embedding_4bit_dtype_replacement), |
| 1164 | + [], |
| 1165 | + ), |
| 1166 | + ] |
| 1167 | + |
| 1168 | + |
876 | 1169 | def _get_embedding_ops_patterns_and_replacements() -> (
|
877 | 1170 | List[Tuple[Callable, Callable, List[Callable]]]
|
878 | 1171 | ):
|
@@ -1167,5 +1460,6 @@ def get_quant_patterns_and_replacements() -> (
|
1167 | 1460 | *_get_slice_patterns_and_replacements(),
|
1168 | 1461 | # *_get_fixed_qparams_ops_patterns_and_replacements(),
|
1169 | 1462 | *_get_embedding_ops_patterns_and_replacements(),
|
| 1463 | + *_get_embedding_ops_patterns_and_replacements_torchao(), |
1170 | 1464 | ]
|
1171 | 1465 | )
|
0 commit comments