-
Notifications
You must be signed in to change notification settings - Fork 569
New embedding quant fusion #10325
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New embedding quant fusion #10325
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,56 @@ | |
"get_quant_patterns_and_replacements", | ||
] | ||
|
||
|
||
from torch import Tensor | ||
from torch.library import custom_op | ||
|
||
|
||
@custom_op("quant_fusion::_pack_embedding_weight", mutates_args=()) | ||
def _pack_embedding_weight(weight: Tensor, bitwidth: int) -> Tensor: | ||
num_embeddings, embedding_dim = weight.shape | ||
|
||
if bitwidth == 2: | ||
assert embedding_dim % 4 == 0, "embedding_dim must be divisible by 4" | ||
weight_range_shifted = weight.add(2).view(torch.uint8) | ||
weight_view = weight_range_shifted.view(num_embeddings, embedding_dim // 4, 4) | ||
weight_0 = weight_view[:, :, 0] | ||
weight_1 = weight_view[:, :, 1] << 2 | ||
weight_2 = weight_view[:, :, 2] << 4 | ||
weight_3 = weight_view[:, :, 3] << 6 | ||
packed_weight = weight_0 | weight_1 | weight_2 | weight_3 | ||
return packed_weight | ||
elif bitwidth == 4: | ||
assert embedding_dim % 2 == 0, "embedding_dim must be divisible by 2" | ||
weight_range_shifted = weight.add(8).view(torch.uint8) | ||
weight_view = weight_range_shifted.view( | ||
weight.shape[0], weight.shape[1] // 2, 2 | ||
) | ||
weight_even = weight_view[:, :, 0] << 4 | ||
weight_odd = weight_view[:, :, 1] | ||
packed_weight = weight_even | weight_odd | ||
return packed_weight | ||
elif bitwidth == 8: | ||
return weight | ||
|
||
raise RuntimeError(f"Unsupported bitwidth {bitwidth}") | ||
|
||
|
||
# Use register_fake to add a ``FakeTensor`` kernel for the operator | ||
@_pack_embedding_weight.register_fake | ||
def _(weight, bit_width): | ||
assert bit_width in [2, 4, 8] | ||
num_embeddings, embedding_dim = weight.shape | ||
values_per_byte = 8 // bit_width | ||
assert embedding_dim % values_per_byte == 0 | ||
return torch.empty( | ||
num_embeddings, | ||
embedding_dim // values_per_byte, | ||
dtype=torch.uint8, | ||
device=weight.device, | ||
) | ||
|
||
|
||
# TODO: extending an existing library that is defined in OSS might be a bit | ||
# confusing, we can investigate if it is possible to define a new library | ||
|
||
|
@@ -69,9 +119,10 @@ def embedding_weight_checks(weight, weight_scales, weight_zero_points): | |
assert ( | ||
weight_zero_points is None or weight_zero_points.dtype == weight_scales.dtype | ||
), "Expecting weight_zero_points to be None or have same dtype as weight_scales" | ||
assert ( | ||
weight_zero_points is None or weight_zero_points.dim() == 1 | ||
), f"Expecting weight_zero_points tensor to be None or have dim()==1, but found {weight_zero_points.dim()}" | ||
assert weight_zero_points is None or weight_zero_points.dim() in [ | ||
1, | ||
2, | ||
], f"Expecting weight_zero_points tensor to be None or have dim()==1, but found {weight_zero_points.dim()}" | ||
assert weight_zero_points is None or weight_zero_points.size(0) == weight.size( | ||
0 | ||
), f"Expecting weight_zero_points tensor to be None or have same number of rows as weights, but found {weight.size()} and {weight_zero_points.size()}" | ||
|
@@ -234,6 +285,21 @@ def embedding_2bit( | |
return torch.ops.aten.embedding.default(weight, indices) | ||
|
||
|
||
@register_fake("quantized_decomposed::embedding_2bit") | ||
def _( | ||
weight: torch.Tensor, | ||
weight_scales: torch.Tensor, | ||
weight_zero_points: Optional[torch.Tensor], | ||
weight_quant_min: int, | ||
weight_quant_max: int, | ||
indices: torch.Tensor, | ||
): | ||
num_embeddings, packed_embedding_dim = weight.shape | ||
embedding_dim = packed_embedding_dim * 4 | ||
embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device) | ||
return embedding(indices) | ||
|
||
|
||
@register_fake("quantized_decomposed::embedding_2bit.out") | ||
def embedding_2bit_out_meta( | ||
weight: torch.Tensor, | ||
|
@@ -296,6 +362,22 @@ def embedding_2bit_dtype( | |
return torch.ops.aten.embedding.default(weight, indices) | ||
|
||
|
||
@register_fake("quantized_decomposed::embedding_2bit.dtype") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think so? You cannot dynamo trace code that doesn't have meta kernels registered. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah sorry I think I asked because I wasnt sure how it was working before without it. But yeah you do need meta kernel |
||
def _( | ||
weight: torch.Tensor, | ||
weight_scales: torch.Tensor, | ||
weight_zero_points: Optional[torch.Tensor], | ||
weight_quant_min: int, | ||
weight_quant_max: int, | ||
indices: torch.Tensor, | ||
dtype: Optional[torch.dtype], | ||
) -> torch.Tensor: | ||
num_embeddings, packed_embedding_dim = weight.shape | ||
embedding_dim = packed_embedding_dim * 4 | ||
embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device) | ||
return embedding(indices).to(dtype) | ||
|
||
|
||
@register_fake("quantized_decomposed::embedding_2bit.dtype_out") | ||
def embedding_2bit_dtype_out_meta( | ||
weight: torch.Tensor, | ||
|
@@ -378,6 +460,21 @@ def embedding_4bit( | |
return torch.ops.aten.embedding.default(weight, indices) | ||
|
||
|
||
@register_fake("quantized_decomposed::embedding_4bit") | ||
def _( | ||
weight: torch.Tensor, | ||
weight_scales: torch.Tensor, | ||
weight_zero_points: Optional[torch.Tensor], | ||
weight_quant_min: int, | ||
weight_quant_max: int, | ||
indices: torch.Tensor, | ||
): | ||
num_embeddings, packed_embedding_dim = weight.shape | ||
embedding_dim = packed_embedding_dim * 2 | ||
embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device) | ||
return embedding(indices) | ||
|
||
|
||
@register_fake("quantized_decomposed::embedding_4bit.out") | ||
def embedding_4bit_out_meta( | ||
weight: torch.Tensor, | ||
|
@@ -438,6 +535,22 @@ def embedding_4bit_dtype( | |
return torch.ops.aten.embedding.default(weight, indices) | ||
|
||
|
||
@register_fake("quantized_decomposed::embedding_4bit.dtype") | ||
def _( | ||
weight: torch.Tensor, | ||
weight_scales: torch.Tensor, | ||
weight_zero_points: Optional[torch.Tensor], | ||
weight_quant_min: int, | ||
weight_quant_max: int, | ||
indices: torch.Tensor, | ||
dtype: Optional[torch.dtype], | ||
) -> torch.Tensor: | ||
num_embeddings, packed_embedding_dim = weight.shape | ||
embedding_dim = packed_embedding_dim * 2 | ||
embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device) | ||
Comment on lines
+548
to
+550
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. very small nit: small refactor can abstract these 3 lines out. |
||
return embedding(indices).to(dtype) | ||
|
||
|
||
@register_fake("quantized_decomposed::embedding_4bit.dtype_out") | ||
def embedding_4bit_dtype_out_meta( | ||
weight: torch.Tensor, | ||
|
@@ -873,6 +986,186 @@ def replacement(x, dim, start, end, x_scale, x_zero_point, x_qmin, x_qmax): | |
] | ||
|
||
|
||
def _get_embedding_ops_patterns_and_replacements_torchao() -> ( # noqa C901 | ||
List[Tuple[Callable, Callable, List[Callable]]] | ||
): | ||
def embedding_byte_pattern(indices, int_data, group_size, scale, zero_point): | ||
dq = torch.ops.torchao.dequantize_affine.default( | ||
int_data, [1, group_size], scale, zero_point, torch.int8, -128, 127 | ||
) | ||
return torch.ops.aten.embedding.default(dq, indices) | ||
|
||
def embedding_byte_replacement(indices, int_data, group_size, scale, zero_point): | ||
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. embedding byte ops take float zero point? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The ones in the ET kernels do: https://github.com/pytorch/executorch/blob/main/kernels/quantized/cpu/embeddingxb.cpp#L200 |
||
return torch.ops.quantized_decomposed.embedding_byte.default( | ||
int_data, | ||
scale, | ||
zero_point_dtype_cast, | ||
-128, | ||
127, | ||
indices, | ||
) | ||
|
||
def embedding_byte_dtype_pattern( | ||
indices, int_data, group_size, scale, zero_point, output_dtype | ||
): | ||
dq = torch.ops.torchao.dequantize_affine.default( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what does "INT" mean here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's the zero point domain. It means the zero points are integers. But this is an arg torchao is going to change in their quant primitives when they clean them up. |
||
int_data, | ||
[1, group_size], | ||
scale, | ||
zero_point, | ||
torch.int8, | ||
-128, | ||
127, | ||
"INT", | ||
output_dtype, | ||
) | ||
return torch.ops.aten.embedding.default(dq, indices) | ||
|
||
def embedding_byte_dtype_replacement( | ||
indices, int_data, group_size, scale, zero_point, output_dtype | ||
): | ||
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) | ||
return torch.ops.quantized_decomposed.embedding_byte.dtype( | ||
int_data, | ||
scale, | ||
zero_point_dtype_cast, | ||
-128, | ||
127, | ||
indices, | ||
dtype=output_dtype, | ||
) | ||
|
||
def embedding_2bit_pattern(indices, int_data, group_size, scale, zero_point): | ||
dq = torch.ops.torchao.dequantize_affine.default( | ||
int_data, [1, group_size], scale, zero_point, torch.int8, -2, 1 | ||
) | ||
return torch.ops.aten.embedding.default(dq, indices) | ||
|
||
def embedding_2bit_replacement(indices, int_data, group_size, scale, zero_point): | ||
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default( | ||
int_data, 2 | ||
) | ||
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) | ||
return torch.ops.quantized_decomposed.embedding_2bit.default( | ||
packed_int_data, scale, zero_point_dtype_cast, -2, 1, indices | ||
) | ||
|
||
def embedding_2bit_dtype_pattern( | ||
indices, int_data, group_size, scale, zero_point, output_dtype | ||
): | ||
dq = torch.ops.torchao.dequantize_affine.default( | ||
int_data, | ||
[1, group_size], | ||
scale, | ||
zero_point, | ||
torch.int8, | ||
-2, | ||
1, | ||
"INT", | ||
output_dtype, | ||
) | ||
return torch.ops.aten.embedding.default(dq, indices) | ||
|
||
def embedding_2bit_dtype_replacement( | ||
indices, int_data, group_size, scale, zero_point, output_dtype | ||
): | ||
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default( | ||
int_data, 2 | ||
) | ||
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) | ||
return torch.ops.quantized_decomposed.embedding_2bit.dtype( | ||
packed_int_data, | ||
scale, | ||
zero_point_dtype_cast, | ||
-2, | ||
1, | ||
indices, | ||
dtype=output_dtype, | ||
) | ||
|
||
def embedding_4bit_pattern(indices, int_data, group_size, scale, zero_point): | ||
dq = torch.ops.torchao.dequantize_affine.default( | ||
int_data, [1, group_size], scale, zero_point, torch.int8, -8, 7 | ||
) | ||
return torch.ops.aten.embedding.default(dq, indices) | ||
|
||
def embedding_4bit_replacement(indices, int_data, group_size, scale, zero_point): | ||
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default( | ||
int_data, 4 | ||
) | ||
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) | ||
return torch.ops.quantized_decomposed.embedding_4bit.default( | ||
packed_int_data, scale, zero_point_dtype_cast, -8, 7, indices | ||
) | ||
|
||
def embedding_4bit_dtype_pattern( | ||
indices, int_data, group_size, scale, zero_point, output_dtype | ||
): | ||
dq = torch.ops.torchao.dequantize_affine.default( | ||
int_data, | ||
[1, group_size], | ||
scale, | ||
zero_point, | ||
torch.int8, | ||
-8, | ||
7, | ||
"INT", | ||
output_dtype, | ||
) | ||
return torch.ops.aten.embedding.default(dq, indices) | ||
|
||
def embedding_4bit_dtype_replacement( | ||
indices, int_data, group_size, scale, zero_point, output_dtype | ||
): | ||
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default( | ||
int_data, 4 | ||
) | ||
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) | ||
return torch.ops.quantized_decomposed.embedding_4bit.dtype( | ||
packed_int_data, | ||
scale, | ||
zero_point_dtype_cast, | ||
-8, | ||
7, | ||
indices, | ||
dtype=output_dtype, | ||
) | ||
|
||
return [ | ||
( | ||
_trace_and_lower_to_edge_ops(embedding_byte_pattern), | ||
_trace_and_lower_to_edge_ops(embedding_byte_replacement), | ||
[], | ||
), | ||
( | ||
_trace_and_lower_to_edge_ops(embedding_byte_dtype_pattern), | ||
_trace_and_lower_to_edge_ops(embedding_byte_dtype_replacement), | ||
[], | ||
), | ||
( | ||
_trace_and_lower_to_edge_ops(embedding_2bit_pattern), | ||
_trace_and_lower_to_edge_ops(embedding_2bit_replacement), | ||
[], | ||
), | ||
( | ||
_trace_and_lower_to_edge_ops(embedding_2bit_dtype_pattern), | ||
_trace_and_lower_to_edge_ops(embedding_2bit_dtype_replacement), | ||
[], | ||
), | ||
( | ||
_trace_and_lower_to_edge_ops(embedding_4bit_pattern), | ||
_trace_and_lower_to_edge_ops(embedding_4bit_replacement), | ||
[], | ||
), | ||
( | ||
_trace_and_lower_to_edge_ops(embedding_4bit_dtype_pattern), | ||
_trace_and_lower_to_edge_ops(embedding_4bit_dtype_replacement), | ||
[], | ||
), | ||
] | ||
|
||
|
||
def _get_embedding_ops_patterns_and_replacements() -> ( | ||
List[Tuple[Callable, Callable, List[Callable]]] | ||
): | ||
|
@@ -1167,5 +1460,6 @@ def get_quant_patterns_and_replacements() -> ( | |
*_get_slice_patterns_and_replacements(), | ||
# *_get_fixed_qparams_ops_patterns_and_replacements(), | ||
*_get_embedding_ops_patterns_and_replacements(), | ||
*_get_embedding_ops_patterns_and_replacements_torchao(), | ||
] | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -298,6 +298,8 @@ python_unittest( | |
"//caffe2:torch", | ||
"//executorch/exir:lib", | ||
"//executorch/exir/passes:quant_fusion_pass", | ||
"//pytorch/ao:torchao", | ||
"//executorch/exir/passes:constant_prop_pass", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldnt you call const prop in the QuantFusionPass? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately QuantFusionPass works on a graph module and const_prop_pass works on an exported_program. I thought we could enable both in to_executorch by default, rather than have users call the passes separately like is done in the unit test? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think const prop has some nuances which may make it non-trivial for excample for q/dq nodes. My concern here would be the perf cliff for the uninitiated. Can we just not const prop manually? instead of even introducing this op in the graph in the first place? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Manual const propagation would still require updating the signature of the exported program because we're changing the weights. So it couldn't be done on the graph module. In terms of perf cliff, it will not lower to ExecuTorch without const propagation because the pack embedding op has no out-variant. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. ok. but lets plan to have this cost prop done appropriately. I highly doubt that this can be done transparently. For example for quantized models' have dq on weights that might get const propagated |
||
], | ||
) | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.