Skip to content

New embedding quant fusion #10325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 23, 2025
Merged

Conversation

metascroy
Copy link
Contributor

Summary:
The diff adds new quant fusion passes to recognize 2, 4, and 8 bit quantized embeedings (per group and per channel) and fuses them to ExecuTorch kernels. This makes torchao's quantize_ integrate with ExecuTorch:

 quantize_(
    model,
    IntxWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerGroup(32)),
    lambda m, fqn: isinstance(m, torch.nn.Embedding)
)

# lower model to executorch

For the model to lower, we need to run QuantFusionPass. For subbyte, we also need to run constant_prop_pass. (See new unit tests for examples). In follow-up diffs, we will enable these passes by default in to_executorch before the memory passing and out-variant passes.

Differential Revision: D73381542

Copy link

pytorch-bot bot commented Apr 21, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/10325

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit d88601f with merge base ad1b154 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 21, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73381542

@facebook-github-bot
Copy link
Contributor

@metascroy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

weight_1 = weight_view[:, :, 1] << 2
weight_2 = weight_view[:, :, 2] << 4
weight_3 = weight_view[:, :, 3] << 6
packed_weight = weight_0 + weight_1 + weight_2 + weight_3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Just do bitwise OR.

weight_view = weight_range_shifted.view(
weight.shape[0], weight.shape[1] // 2, 2
)
weight_even = weight_view[:, :, 0] * 16 # left shift 4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why * 16 here but shift by 4 above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't * 16 the same as shift by 4?

I just copied the code from here: https://github.com/pytorch/executorch/blob/main/examples/models/llama/source_transformation/quantize.py#L659-L683

But I can clean it up and change to using left shift.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah was just highlighting for consistency

)
weight_even = weight_view[:, :, 0] * 16 # left shift 4
weight_odd = weight_view[:, :, 1]
packed_weight = weight_even + weight_odd
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bitwise OR

@@ -296,6 +362,22 @@ def embedding_2bit_dtype(
return torch.ops.aten.embedding.default(weight, indices)


@register_fake("quantized_decomposed::embedding_2bit.dtype")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so? You cannot dynamo trace code that doesn't have meta kernels registered.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah sorry I think I asked because I wasnt sure how it was working before without it. But yeah you do need meta kernel

Comment on lines +548 to +550
num_embeddings, packed_embedding_dim = weight.shape
embedding_dim = packed_embedding_dim * 2
embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very small nit: small refactor can abstract these 3 lines out.

def embedding_byte_dtype_pattern(
indices, int_data, group_size, scale, zero_point, output_dtype
):
dq = torch.ops.torchao.dequantize_affine.default(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does "INT" mean here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the zero point domain. It means the zero points are integers. But this is an arg torchao is going to change in their quant primitives when they clean them up.

return torch.ops.aten.embedding.default(dq, indices)

def embedding_byte_replacement(indices, int_data, group_size, scale, zero_point):
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

embedding byte ops take float zero point?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -298,6 +298,8 @@ python_unittest(
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/exir/passes:quant_fusion_pass",
"//pytorch/ao:torchao",
"//executorch/exir/passes:constant_prop_pass",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldnt you call const prop in the QuantFusionPass?

Copy link
Contributor Author

@metascroy metascroy Apr 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately QuantFusionPass works on a graph module and const_prop_pass works on an exported_program. I thought we could enable both in to_executorch by default, rather than have users call the passes separately like is done in the unit test?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think const prop has some nuances which may make it non-trivial for excample for q/dq nodes.

My concern here would be the perf cliff for the uninitiated.

Can we just not const prop manually? instead of even introducing this op in the graph in the first place?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Manual const propagation would still require updating the signature of the exported program because we're changing the weights. So it couldn't be done on the graph module.

In terms of perf cliff, it will not lower to ExecuTorch without const propagation because the pack embedding op has no out-variant.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. ok. but lets plan to have this cost prop done appropriately. I highly doubt that this can be done transparently. For example for quantized models' have dq on weights that might get const propagated

self._test_embedding_torchao(bit_width, test_dtype_variant, test_per_group)

def _test_embedding_torchao(
self, bit_width: int, test_dtype_variant: bool, test_per_group: bool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_dtype_variant as bool feels like a bit of misnomer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll change to use_dtype_variant?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think i misunderstood this probably because it was a bit late when i was reviewin git. You just meant to check .dtype variant of the op and I thought you were checking fp16 vs fp32 vs other dtype variants

Copy link
Contributor

@kimishpatel kimishpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sending back for question on const prop

metascroy and others added 3 commits April 22, 2025 19:41
Summary:
The diff adds new quant fusion passes to recognize 2, 4, and 8 bit quantized embeedings (per group and per channel) and fuses them to ExecuTorch kernels.  This makes torchao's quantize_ integrate with ExecuTorch:


```
 quantize_(
    model,
    IntxWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerGroup(32)),
    lambda m, fqn: isinstance(m, torch.nn.Embedding)
)

# lower model to executorch
```

For the model to lower, we need to run QuantFusionPass.  For subbyte, we also need to run constant_prop_pass.  (See new unit tests for examples).  In follow-up diffs, we will enable these passes by default in to_executorch before the memory passing and out-variant passes.

Differential Revision: D73381542
@facebook-github-bot
Copy link
Contributor

@metascroy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Contributor

@kimishpatel kimishpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On const prop question, lets follow up in a separate pr

@facebook-github-bot facebook-github-bot merged commit 28ee6f5 into pytorch:main Apr 23, 2025
82 of 86 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported topic: not user facing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants