Skip to content

New embedding quant fusion #10325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions exir/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ python_library(
"//caffe2:torch",
"//executorch/exir/operator:convert",
"//executorch/extension/pytree:pylib",
"//pytorch/ao:torchao",
],
)

Expand Down
300 changes: 297 additions & 3 deletions exir/passes/_quant_patterns_and_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,56 @@
"get_quant_patterns_and_replacements",
]


from torch import Tensor
from torch.library import custom_op


@custom_op("quant_fusion::_pack_embedding_weight", mutates_args=())
def _pack_embedding_weight(weight: Tensor, bitwidth: int) -> Tensor:
num_embeddings, embedding_dim = weight.shape

if bitwidth == 2:
assert embedding_dim % 4 == 0, "embedding_dim must be divisible by 4"
weight_range_shifted = weight.add(2).view(torch.uint8)
weight_view = weight_range_shifted.view(num_embeddings, embedding_dim // 4, 4)
weight_0 = weight_view[:, :, 0]
weight_1 = weight_view[:, :, 1] << 2
weight_2 = weight_view[:, :, 2] << 4
weight_3 = weight_view[:, :, 3] << 6
packed_weight = weight_0 | weight_1 | weight_2 | weight_3
return packed_weight
elif bitwidth == 4:
assert embedding_dim % 2 == 0, "embedding_dim must be divisible by 2"
weight_range_shifted = weight.add(8).view(torch.uint8)
weight_view = weight_range_shifted.view(
weight.shape[0], weight.shape[1] // 2, 2
)
weight_even = weight_view[:, :, 0] << 4
weight_odd = weight_view[:, :, 1]
packed_weight = weight_even | weight_odd
return packed_weight
elif bitwidth == 8:
return weight

raise RuntimeError(f"Unsupported bitwidth {bitwidth}")


# Use register_fake to add a ``FakeTensor`` kernel for the operator
@_pack_embedding_weight.register_fake
def _(weight, bit_width):
assert bit_width in [2, 4, 8]
num_embeddings, embedding_dim = weight.shape
values_per_byte = 8 // bit_width
assert embedding_dim % values_per_byte == 0
return torch.empty(
num_embeddings,
embedding_dim // values_per_byte,
dtype=torch.uint8,
device=weight.device,
)


# TODO: extending an existing library that is defined in OSS might be a bit
# confusing, we can investigate if it is possible to define a new library

Expand Down Expand Up @@ -69,9 +119,10 @@ def embedding_weight_checks(weight, weight_scales, weight_zero_points):
assert (
weight_zero_points is None or weight_zero_points.dtype == weight_scales.dtype
), "Expecting weight_zero_points to be None or have same dtype as weight_scales"
assert (
weight_zero_points is None or weight_zero_points.dim() == 1
), f"Expecting weight_zero_points tensor to be None or have dim()==1, but found {weight_zero_points.dim()}"
assert weight_zero_points is None or weight_zero_points.dim() in [
1,
2,
], f"Expecting weight_zero_points tensor to be None or have dim()==1, but found {weight_zero_points.dim()}"
assert weight_zero_points is None or weight_zero_points.size(0) == weight.size(
0
), f"Expecting weight_zero_points tensor to be None or have same number of rows as weights, but found {weight.size()} and {weight_zero_points.size()}"
Expand Down Expand Up @@ -234,6 +285,21 @@ def embedding_2bit(
return torch.ops.aten.embedding.default(weight, indices)


@register_fake("quantized_decomposed::embedding_2bit")
def _(
weight: torch.Tensor,
weight_scales: torch.Tensor,
weight_zero_points: Optional[torch.Tensor],
weight_quant_min: int,
weight_quant_max: int,
indices: torch.Tensor,
):
num_embeddings, packed_embedding_dim = weight.shape
embedding_dim = packed_embedding_dim * 4
embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device)
return embedding(indices)


@register_fake("quantized_decomposed::embedding_2bit.out")
def embedding_2bit_out_meta(
weight: torch.Tensor,
Expand Down Expand Up @@ -296,6 +362,22 @@ def embedding_2bit_dtype(
return torch.ops.aten.embedding.default(weight, indices)


@register_fake("quantized_decomposed::embedding_2bit.dtype")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so? You cannot dynamo trace code that doesn't have meta kernels registered.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah sorry I think I asked because I wasnt sure how it was working before without it. But yeah you do need meta kernel

def _(
weight: torch.Tensor,
weight_scales: torch.Tensor,
weight_zero_points: Optional[torch.Tensor],
weight_quant_min: int,
weight_quant_max: int,
indices: torch.Tensor,
dtype: Optional[torch.dtype],
) -> torch.Tensor:
num_embeddings, packed_embedding_dim = weight.shape
embedding_dim = packed_embedding_dim * 4
embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device)
return embedding(indices).to(dtype)


@register_fake("quantized_decomposed::embedding_2bit.dtype_out")
def embedding_2bit_dtype_out_meta(
weight: torch.Tensor,
Expand Down Expand Up @@ -378,6 +460,21 @@ def embedding_4bit(
return torch.ops.aten.embedding.default(weight, indices)


@register_fake("quantized_decomposed::embedding_4bit")
def _(
weight: torch.Tensor,
weight_scales: torch.Tensor,
weight_zero_points: Optional[torch.Tensor],
weight_quant_min: int,
weight_quant_max: int,
indices: torch.Tensor,
):
num_embeddings, packed_embedding_dim = weight.shape
embedding_dim = packed_embedding_dim * 2
embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device)
return embedding(indices)


@register_fake("quantized_decomposed::embedding_4bit.out")
def embedding_4bit_out_meta(
weight: torch.Tensor,
Expand Down Expand Up @@ -438,6 +535,22 @@ def embedding_4bit_dtype(
return torch.ops.aten.embedding.default(weight, indices)


@register_fake("quantized_decomposed::embedding_4bit.dtype")
def _(
weight: torch.Tensor,
weight_scales: torch.Tensor,
weight_zero_points: Optional[torch.Tensor],
weight_quant_min: int,
weight_quant_max: int,
indices: torch.Tensor,
dtype: Optional[torch.dtype],
) -> torch.Tensor:
num_embeddings, packed_embedding_dim = weight.shape
embedding_dim = packed_embedding_dim * 2
embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device)
Comment on lines +548 to +550
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very small nit: small refactor can abstract these 3 lines out.

return embedding(indices).to(dtype)


@register_fake("quantized_decomposed::embedding_4bit.dtype_out")
def embedding_4bit_dtype_out_meta(
weight: torch.Tensor,
Expand Down Expand Up @@ -873,6 +986,186 @@ def replacement(x, dim, start, end, x_scale, x_zero_point, x_qmin, x_qmax):
]


def _get_embedding_ops_patterns_and_replacements_torchao() -> ( # noqa C901
List[Tuple[Callable, Callable, List[Callable]]]
):
def embedding_byte_pattern(indices, int_data, group_size, scale, zero_point):
dq = torch.ops.torchao.dequantize_affine.default(
int_data, [1, group_size], scale, zero_point, torch.int8, -128, 127
)
return torch.ops.aten.embedding.default(dq, indices)

def embedding_byte_replacement(indices, int_data, group_size, scale, zero_point):
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

embedding byte ops take float zero point?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return torch.ops.quantized_decomposed.embedding_byte.default(
int_data,
scale,
zero_point_dtype_cast,
-128,
127,
indices,
)

def embedding_byte_dtype_pattern(
indices, int_data, group_size, scale, zero_point, output_dtype
):
dq = torch.ops.torchao.dequantize_affine.default(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does "INT" mean here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the zero point domain. It means the zero points are integers. But this is an arg torchao is going to change in their quant primitives when they clean them up.

int_data,
[1, group_size],
scale,
zero_point,
torch.int8,
-128,
127,
"INT",
output_dtype,
)
return torch.ops.aten.embedding.default(dq, indices)

def embedding_byte_dtype_replacement(
indices, int_data, group_size, scale, zero_point, output_dtype
):
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
return torch.ops.quantized_decomposed.embedding_byte.dtype(
int_data,
scale,
zero_point_dtype_cast,
-128,
127,
indices,
dtype=output_dtype,
)

def embedding_2bit_pattern(indices, int_data, group_size, scale, zero_point):
dq = torch.ops.torchao.dequantize_affine.default(
int_data, [1, group_size], scale, zero_point, torch.int8, -2, 1
)
return torch.ops.aten.embedding.default(dq, indices)

def embedding_2bit_replacement(indices, int_data, group_size, scale, zero_point):
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(
int_data, 2
)
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
return torch.ops.quantized_decomposed.embedding_2bit.default(
packed_int_data, scale, zero_point_dtype_cast, -2, 1, indices
)

def embedding_2bit_dtype_pattern(
indices, int_data, group_size, scale, zero_point, output_dtype
):
dq = torch.ops.torchao.dequantize_affine.default(
int_data,
[1, group_size],
scale,
zero_point,
torch.int8,
-2,
1,
"INT",
output_dtype,
)
return torch.ops.aten.embedding.default(dq, indices)

def embedding_2bit_dtype_replacement(
indices, int_data, group_size, scale, zero_point, output_dtype
):
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(
int_data, 2
)
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
return torch.ops.quantized_decomposed.embedding_2bit.dtype(
packed_int_data,
scale,
zero_point_dtype_cast,
-2,
1,
indices,
dtype=output_dtype,
)

def embedding_4bit_pattern(indices, int_data, group_size, scale, zero_point):
dq = torch.ops.torchao.dequantize_affine.default(
int_data, [1, group_size], scale, zero_point, torch.int8, -8, 7
)
return torch.ops.aten.embedding.default(dq, indices)

def embedding_4bit_replacement(indices, int_data, group_size, scale, zero_point):
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(
int_data, 4
)
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
return torch.ops.quantized_decomposed.embedding_4bit.default(
packed_int_data, scale, zero_point_dtype_cast, -8, 7, indices
)

def embedding_4bit_dtype_pattern(
indices, int_data, group_size, scale, zero_point, output_dtype
):
dq = torch.ops.torchao.dequantize_affine.default(
int_data,
[1, group_size],
scale,
zero_point,
torch.int8,
-8,
7,
"INT",
output_dtype,
)
return torch.ops.aten.embedding.default(dq, indices)

def embedding_4bit_dtype_replacement(
indices, int_data, group_size, scale, zero_point, output_dtype
):
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(
int_data, 4
)
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
packed_int_data,
scale,
zero_point_dtype_cast,
-8,
7,
indices,
dtype=output_dtype,
)

return [
(
_trace_and_lower_to_edge_ops(embedding_byte_pattern),
_trace_and_lower_to_edge_ops(embedding_byte_replacement),
[],
),
(
_trace_and_lower_to_edge_ops(embedding_byte_dtype_pattern),
_trace_and_lower_to_edge_ops(embedding_byte_dtype_replacement),
[],
),
(
_trace_and_lower_to_edge_ops(embedding_2bit_pattern),
_trace_and_lower_to_edge_ops(embedding_2bit_replacement),
[],
),
(
_trace_and_lower_to_edge_ops(embedding_2bit_dtype_pattern),
_trace_and_lower_to_edge_ops(embedding_2bit_dtype_replacement),
[],
),
(
_trace_and_lower_to_edge_ops(embedding_4bit_pattern),
_trace_and_lower_to_edge_ops(embedding_4bit_replacement),
[],
),
(
_trace_and_lower_to_edge_ops(embedding_4bit_dtype_pattern),
_trace_and_lower_to_edge_ops(embedding_4bit_dtype_replacement),
[],
),
]


def _get_embedding_ops_patterns_and_replacements() -> (
List[Tuple[Callable, Callable, List[Callable]]]
):
Expand Down Expand Up @@ -1167,5 +1460,6 @@ def get_quant_patterns_and_replacements() -> (
*_get_slice_patterns_and_replacements(),
# *_get_fixed_qparams_ops_patterns_and_replacements(),
*_get_embedding_ops_patterns_and_replacements(),
*_get_embedding_ops_patterns_and_replacements_torchao(),
]
)
13 changes: 13 additions & 0 deletions exir/passes/quant_fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,18 @@ def _get_qparams(node):
model.graph.erase_node(qnode)


def _remove_dtype_getattr_nodes(model: GraphModule) -> None:
for n in model.graph.nodes:
if n.op == "call_function" and n.target == getattr:
if isinstance(n.args[0], torch.fx.Node) and n.args[1] == "dtype":
dtype = n.args[0].meta["val"].dtype
n.replace_all_uses_with(dtype)
model.graph.erase_node(n)
model.graph.eliminate_dead_code()
model.graph.lint()
model.recompile()


class QuantFusionPass(ExportPass):
def __init__(self, _fix_node_meta_val=False):
super().__init__()
Expand Down Expand Up @@ -123,6 +135,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
torch.fx.Node, lambda x: x.meta["val"], (n.args, n.kwargs)
)
n.meta["val"] = n.target(*args, **kwargs)
_remove_dtype_getattr_nodes(graph_module)
graph_module.graph.lint()
graph_module.graph.eliminate_dead_code()
return PassResult(graph_module, True)
2 changes: 2 additions & 0 deletions exir/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ python_unittest(
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/exir/passes:quant_fusion_pass",
"//pytorch/ao:torchao",
"//executorch/exir/passes:constant_prop_pass",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldnt you call const prop in the QuantFusionPass?

Copy link
Contributor Author

@metascroy metascroy Apr 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately QuantFusionPass works on a graph module and const_prop_pass works on an exported_program. I thought we could enable both in to_executorch by default, rather than have users call the passes separately like is done in the unit test?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think const prop has some nuances which may make it non-trivial for excample for q/dq nodes.

My concern here would be the perf cliff for the uninitiated.

Can we just not const prop manually? instead of even introducing this op in the graph in the first place?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Manual const propagation would still require updating the signature of the exported program because we're changing the weights. So it couldn't be done on the graph module.

In terms of perf cliff, it will not lower to ExecuTorch without const propagation because the pack embedding op has no out-variant.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. ok. but lets plan to have this cost prop done appropriately. I highly doubt that this can be done transparently. For example for quantized models' have dq on weights that might get const propagated

],
)

Expand Down
Loading
Loading