-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[polynomial] distribute add/sub through ntt to reduce ntts #93132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Jeremy Kun (j2kun) ChangesAddresses google/heir#542 (comment) Full diff: https://github.com/llvm/llvm-project/pull/93132.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
index 9d09799c1763a..e37bcf76a20f2 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
@@ -9,11 +9,14 @@
#ifndef POLYNOMIAL_CANONICALIZATION
#define POLYNOMIAL_CANONICALIZATION
-include "mlir/Dialect/Polynomial/IR/Polynomial.td"
include "mlir/Dialect/Arith/IR/ArithOps.td"
+include "mlir/Dialect/Polynomial/IR/Polynomial.td"
+include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/PatternBase.td"
+defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
+
// Get a -1 integer attribute of the same type as the polynomial SSA value's
// ring coefficient type.
def getMinusOne
@@ -39,4 +42,40 @@ def NTTAfterINTT : Pat<
[]
>;
+// NTTs are expensive, and addition in coefficient or NTT domain should be
+// equivalently expensive, so reducing the number of NTTs is optimal.
+// ntt(a) + ntt(b) -> ntt(a + b)
+def NTTOfAdd : Pat<
+ (Arith_AddIOp
+ (Polynomial_NTTOp $p1),
+ (Polynomial_NTTOp $p2),
+ $overflow),
+ (Polynomial_NTTOp (Polynomial_AddOp $p1, $p2)),
+ []
+>;
+// intt(a) + intt(b) -> intt(a + b)
+def INTTOfAdd : Pat<
+ (Polynomial_AddOp
+ (Polynomial_INTTOp $t1),
+ (Polynomial_INTTOp $t2)),
+ (Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow)),
+ []
+>;
+// repeated for sub
+def NTTOfSub : Pat<
+ (Arith_SubIOp
+ (Polynomial_NTTOp $p1),
+ (Polynomial_NTTOp $p2),
+ $overflow),
+ (Polynomial_NTTOp (Polynomial_SubOp $p1, $p2)),
+ []
+>;
+def INTTOfSub : Pat<
+ (Polynomial_SubOp
+ (Polynomial_INTTOp $t1),
+ (Polynomial_INTTOp $t2)),
+ (Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow)),
+ []
+>;
+
#endif // POLYNOMIAL_CANONICALIZATION
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 1a2439fe810b5..98263732da8a9 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -201,10 +201,10 @@ void SubOp::getCanonicalizationPatterns(RewritePatternSet &results,
void NTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<NTTAfterINTT>(context);
+ results.add<NTTAfterINTT, NTTOfAdd, NTTOfSub>(context);
}
void INTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<INTTAfterNTT>(context);
+ results.add<INTTAfterNTT, INTTOfAdd, INTTOfSub>(context);
}
diff --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir
index dbfbf2d93f111..489d9ec2720d6 100644
--- a/mlir/test/Dialect/Polynomial/canonicalization.mlir
+++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir
@@ -43,3 +43,60 @@ func.func @test_canonicalize_sub(%poly0 : !sub_ty, %poly1 : !sub_ty) -> !sub_ty
// CHECK: [[ADD:%.+]] = polynomial.add %[[p0]], %[[p1neg]]
return %0 : !sub_ty
}
+
+// CHECK-LABEL: test_canonicalize_fold_add_through_ntt
+// CHECK: polynomial.add
+// CHECK-NOT: polynomial.ntt
+// CHECK-NOT: polynomial.intt
+func.func @test_canonicalize_fold_add_through_ntt(
+ %poly0 : !ntt_poly_ty,
+ %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
+ %0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
+ %1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
+ %a_plus_b = arith.addi %0, %1 : !tensor_ty
+ %out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
+ return %out : !ntt_poly_ty
+}
+
+// CHECK-LABEL: test_canonicalize_fold_add_through_intt
+// CHECK: arith.addi
+// CHECK-NOT: polynomial.intt
+// CHECK-NOT: polynomial.iintt
+func.func @test_canonicalize_fold_add_through_intt(
+ %tensor0 : !tensor_ty,
+ %tensor1 : !tensor_ty) -> !tensor_ty {
+ %0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
+ %1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
+ %a_plus_b = polynomial.add %0, %1 : !ntt_poly_ty
+ %out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
+ return %out : !tensor_ty
+}
+
+// CHECK-LABEL: test_canonicalize_fold_sub_through_ntt
+// CHECK: polynomial.mul_scalar
+// CHECK: polynomial.add
+// CHECK-NOT: polynomial.ntt
+// CHECK-NOT: polynomial.intt
+func.func @test_canonicalize_fold_sub_through_ntt(
+ %poly0 : !ntt_poly_ty,
+ %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
+ %0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
+ %1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
+ %a_plus_b = arith.subi %0, %1 : !tensor_ty
+ %out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
+ return %out : !ntt_poly_ty
+}
+
+// CHECK-LABEL: test_canonicalize_fold_sub_through_intt
+// CHECK: arith.subi
+// CHECK-NOT: polynomial.intt
+// CHECK-NOT: polynomial.iintt
+func.func @test_canonicalize_fold_sub_through_intt(
+ %tensor0 : !tensor_ty,
+ %tensor1 : !tensor_ty) -> !tensor_ty {
+ %0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
+ %1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
+ %a_plus_b = polynomial.sub %0, %1 : !ntt_poly_ty
+ %out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
+ return %out : !tensor_ty
+}
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That looks good from an MLIR point of view, I don't understand NTT enough to know if the patterns are correct though.
So LG if you're owning the maths correctness :)
Speaking of which, should I solicit someone who can review the math aspects of these PRs? For now they are mainly coming from out-of-tree ports, but some of the larger parts (e.g., lowering If you know of folks who already have commit access who would be interested in this, please let me know, otherwise I can ask some of my colleagues to request membership. |
It's up to you on this and how confident you feel about it! |
arith.add for tensor does not mod coefficientModulus, and it may overflow; the result could be incorrect It should be rewritten as modular arithmetic instead of arith Revert llvm#93132 Addresses google/heir#749
arith.add for tensor does not mod coefficientModulus, and it may overflow; the result could be incorrect It should be rewritten as modular arithmetic instead of arith Revert #93132 Addresses google/heir#749 Cc @j2kun
arith.add for tensor does not mod coefficientModulus, and it may overflow; the result could be incorrect It should be rewritten as modular arithmetic instead of arith Revert llvm/llvm-project#93132 Addresses google/heir#749 Cc @j2kun
arith.add for tensor does not mod coefficientModulus, and it may overflow; the result could be incorrect It should be rewritten as modular arithmetic instead of arith Revert llvm/llvm-project#93132 Addresses google/heir#749 Cc @j2kun
Addresses google/heir#542 (comment)