Skip to content

[polynomial] distribute add/sub through ntt to reduce ntts #93132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 23, 2024

Conversation

j2kun
Copy link
Contributor

@j2kun j2kun commented May 23, 2024

@j2kun j2kun requested a review from makslevental May 23, 2024 04:16
@llvmbot llvmbot added the mlir label May 23, 2024
@llvmbot
Copy link
Member

llvmbot commented May 23, 2024

@llvm/pr-subscribers-mlir

Author: Jeremy Kun (j2kun)

Changes

Addresses google/heir#542 (comment)


Full diff: https://github.com/llvm/llvm-project/pull/93132.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td (+40-1)
  • (modified) mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp (+2-2)
  • (modified) mlir/test/Dialect/Polynomial/canonicalization.mlir (+57)
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
index 9d09799c1763a..e37bcf76a20f2 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
@@ -9,11 +9,14 @@
 #ifndef POLYNOMIAL_CANONICALIZATION
 #define POLYNOMIAL_CANONICALIZATION
 
-include "mlir/Dialect/Polynomial/IR/Polynomial.td"
 include "mlir/Dialect/Arith/IR/ArithOps.td"
+include "mlir/Dialect/Polynomial/IR/Polynomial.td"
+include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpBase.td"
 include "mlir/IR/PatternBase.td"
 
+defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
+
 // Get a -1 integer attribute of the same type as the polynomial SSA value's
 // ring coefficient type.
 def getMinusOne
@@ -39,4 +42,40 @@ def NTTAfterINTT : Pat<
   []
 >;
 
+// NTTs are expensive, and addition in coefficient or NTT domain should be
+// equivalently expensive, so reducing the number of NTTs is optimal.
+// ntt(a) + ntt(b) -> ntt(a + b)
+def NTTOfAdd : Pat<
+  (Arith_AddIOp
+    (Polynomial_NTTOp $p1),
+    (Polynomial_NTTOp $p2),
+    $overflow),
+  (Polynomial_NTTOp (Polynomial_AddOp $p1, $p2)),
+  []
+>;
+// intt(a) + intt(b) -> intt(a + b)
+def INTTOfAdd : Pat<
+  (Polynomial_AddOp
+    (Polynomial_INTTOp $t1),
+    (Polynomial_INTTOp $t2)),
+  (Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow)),
+  []
+>;
+// repeated for sub
+def NTTOfSub : Pat<
+  (Arith_SubIOp
+    (Polynomial_NTTOp $p1),
+    (Polynomial_NTTOp $p2),
+    $overflow),
+  (Polynomial_NTTOp (Polynomial_SubOp $p1, $p2)),
+  []
+>;
+def INTTOfSub : Pat<
+  (Polynomial_SubOp
+    (Polynomial_INTTOp $t1),
+    (Polynomial_INTTOp $t2)),
+  (Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow)),
+  []
+>;
+
 #endif  // POLYNOMIAL_CANONICALIZATION
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 1a2439fe810b5..98263732da8a9 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -201,10 +201,10 @@ void SubOp::getCanonicalizationPatterns(RewritePatternSet &results,
 
 void NTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  results.add<NTTAfterINTT>(context);
+  results.add<NTTAfterINTT, NTTOfAdd, NTTOfSub>(context);
 }
 
 void INTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                          MLIRContext *context) {
-  results.add<INTTAfterNTT>(context);
+  results.add<INTTAfterNTT, INTTOfAdd, INTTOfSub>(context);
 }
diff --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir
index dbfbf2d93f111..489d9ec2720d6 100644
--- a/mlir/test/Dialect/Polynomial/canonicalization.mlir
+++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir
@@ -43,3 +43,60 @@ func.func @test_canonicalize_sub(%poly0 : !sub_ty, %poly1 : !sub_ty) -> !sub_ty
   // CHECK: [[ADD:%.+]] = polynomial.add %[[p0]], %[[p1neg]]
   return %0 : !sub_ty
 }
+
+// CHECK-LABEL: test_canonicalize_fold_add_through_ntt
+// CHECK: polynomial.add
+// CHECK-NOT: polynomial.ntt
+// CHECK-NOT: polynomial.intt
+func.func @test_canonicalize_fold_add_through_ntt(
+    %poly0 : !ntt_poly_ty,
+    %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
+  %0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
+  %1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
+  %a_plus_b = arith.addi %0, %1 : !tensor_ty
+  %out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
+  return %out : !ntt_poly_ty
+}
+
+// CHECK-LABEL: test_canonicalize_fold_add_through_intt
+// CHECK: arith.addi
+// CHECK-NOT: polynomial.intt
+// CHECK-NOT: polynomial.iintt
+func.func @test_canonicalize_fold_add_through_intt(
+    %tensor0 : !tensor_ty,
+    %tensor1 : !tensor_ty) -> !tensor_ty {
+  %0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
+  %1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
+  %a_plus_b = polynomial.add %0, %1 : !ntt_poly_ty
+  %out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
+  return %out : !tensor_ty
+}
+
+// CHECK-LABEL: test_canonicalize_fold_sub_through_ntt
+// CHECK: polynomial.mul_scalar
+// CHECK: polynomial.add
+// CHECK-NOT: polynomial.ntt
+// CHECK-NOT: polynomial.intt
+func.func @test_canonicalize_fold_sub_through_ntt(
+    %poly0 : !ntt_poly_ty,
+    %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
+  %0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
+  %1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
+  %a_plus_b = arith.subi %0, %1 : !tensor_ty
+  %out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
+  return %out : !ntt_poly_ty
+}
+
+// CHECK-LABEL: test_canonicalize_fold_sub_through_intt
+// CHECK: arith.subi
+// CHECK-NOT: polynomial.intt
+// CHECK-NOT: polynomial.iintt
+func.func @test_canonicalize_fold_sub_through_intt(
+    %tensor0 : !tensor_ty,
+    %tensor1 : !tensor_ty) -> !tensor_ty {
+  %0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
+  %1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
+  %a_plus_b = polynomial.sub %0, %1 : !ntt_poly_ty
+  %out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
+  return %out : !tensor_ty
+}

Copy link

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Discourse for more information.

@j2kun j2kun requested a review from joker-eph May 23, 2024 17:09
Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks good from an MLIR point of view, I don't understand NTT enough to know if the patterns are correct though.
So LG if you're owning the maths correctness :)

@j2kun
Copy link
Contributor Author

j2kun commented May 23, 2024

That looks good from an MLIR point of view, I don't understand NTT enough to know if the patterns are correct though. So LG if you're owning the maths correctness :)

Speaking of which, should I solicit someone who can review the math aspects of these PRs? For now they are mainly coming from out-of-tree ports, but some of the larger parts (e.g., lowering ntt to an optimized CPU implementation) are more sophisticated.

If you know of folks who already have commit access who would be interested in this, please let me know, otherwise I can ask some of my colleagues to request membership.

@j2kun j2kun merged commit 1a28f26 into llvm:main May 23, 2024
5 checks passed
@joker-eph
Copy link
Collaborator

should I solicit someone who can review the math aspects of these PRs?

It's up to you on this and how confident you feel about it!
(It's not like we'd be in an environment where these would just be continuously shipped to production either)

ZenithalHourlyRate added a commit to ZenithalHourlyRate/llvm-project that referenced this pull request Sep 27, 2024
arith.add for tensor does not mod coefficientModulus, and it may
overflow; the result could be incorrect

It should be rewritten as modular arithmetic instead of arith

Revert llvm#93132
Addresses google/heir#749
j2kun pushed a commit that referenced this pull request Sep 28, 2024
arith.add for tensor does not mod coefficientModulus, and it may
overflow; the result could be incorrect

It should be rewritten as modular arithmetic instead of arith

Revert #93132
Addresses google/heir#749

Cc @j2kun
puja2196 pushed a commit to puja2196/LLVM-tutorial that referenced this pull request Sep 30, 2024
arith.add for tensor does not mod coefficientModulus, and it may
overflow; the result could be incorrect

It should be rewritten as modular arithmetic instead of arith

Revert llvm/llvm-project#93132
Addresses google/heir#749

Cc @j2kun
puja2196 pushed a commit to puja2196/LLVM-tutorial that referenced this pull request Oct 2, 2024
arith.add for tensor does not mod coefficientModulus, and it may
overflow; the result could be incorrect

It should be rewritten as modular arithmetic instead of arith

Revert llvm/llvm-project#93132
Addresses google/heir#749

Cc @j2kun
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants