Skip to content

Commit 1a28f26

Browse files
authored
[polynomial] distribute add/sub through ntt to reduce ntts (#93132)
Addresses google/heir#542 (comment) Co-authored-by: Jeremy Kun <[email protected]>
1 parent 779be6f commit 1a28f26

File tree

3 files changed

+99
-3
lines changed

3 files changed

+99
-3
lines changed

mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@
99
#ifndef POLYNOMIAL_CANONICALIZATION
1010
#define POLYNOMIAL_CANONICALIZATION
1111

12-
include "mlir/Dialect/Polynomial/IR/Polynomial.td"
1312
include "mlir/Dialect/Arith/IR/ArithOps.td"
13+
include "mlir/Dialect/Polynomial/IR/Polynomial.td"
14+
include "mlir/IR/EnumAttr.td"
1415
include "mlir/IR/OpBase.td"
1516
include "mlir/IR/PatternBase.td"
1617

18+
defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
19+
1720
// Get a -1 integer attribute of the same type as the polynomial SSA value's
1821
// ring coefficient type.
1922
def getMinusOne
@@ -39,4 +42,40 @@ def NTTAfterINTT : Pat<
3942
[]
4043
>;
4144

45+
// NTTs are expensive, and addition in coefficient or NTT domain should be
46+
// equivalently expensive, so reducing the number of NTTs is optimal.
47+
// ntt(a) + ntt(b) -> ntt(a + b)
48+
def NTTOfAdd : Pat<
49+
(Arith_AddIOp
50+
(Polynomial_NTTOp $p1),
51+
(Polynomial_NTTOp $p2),
52+
$overflow),
53+
(Polynomial_NTTOp (Polynomial_AddOp $p1, $p2)),
54+
[]
55+
>;
56+
// intt(a) + intt(b) -> intt(a + b)
57+
def INTTOfAdd : Pat<
58+
(Polynomial_AddOp
59+
(Polynomial_INTTOp $t1),
60+
(Polynomial_INTTOp $t2)),
61+
(Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow)),
62+
[]
63+
>;
64+
// repeated for sub
65+
def NTTOfSub : Pat<
66+
(Arith_SubIOp
67+
(Polynomial_NTTOp $p1),
68+
(Polynomial_NTTOp $p2),
69+
$overflow),
70+
(Polynomial_NTTOp (Polynomial_SubOp $p1, $p2)),
71+
[]
72+
>;
73+
def INTTOfSub : Pat<
74+
(Polynomial_SubOp
75+
(Polynomial_INTTOp $t1),
76+
(Polynomial_INTTOp $t2)),
77+
(Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow)),
78+
[]
79+
>;
80+
4281
#endif // POLYNOMIAL_CANONICALIZATION

mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,10 +283,10 @@ void SubOp::getCanonicalizationPatterns(RewritePatternSet &results,
283283

284284
void NTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
285285
MLIRContext *context) {
286-
results.add<NTTAfterINTT>(context);
286+
results.add<NTTAfterINTT, NTTOfAdd, NTTOfSub>(context);
287287
}
288288

289289
void INTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
290290
MLIRContext *context) {
291-
results.add<INTTAfterNTT>(context);
291+
results.add<INTTAfterNTT, INTTOfAdd, INTTOfSub>(context);
292292
}

mlir/test/Dialect/Polynomial/canonicalization.mlir

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,60 @@ func.func @test_canonicalize_sub(%poly0 : !sub_ty, %poly1 : !sub_ty) -> !sub_ty
4343
// CHECK: [[ADD:%.+]] = polynomial.add %[[p0]], %[[p1neg]]
4444
return %0 : !sub_ty
4545
}
46+
47+
// CHECK-LABEL: test_canonicalize_fold_add_through_ntt
48+
// CHECK: polynomial.add
49+
// CHECK-NOT: polynomial.ntt
50+
// CHECK-NOT: polynomial.intt
51+
func.func @test_canonicalize_fold_add_through_ntt(
52+
%poly0 : !ntt_poly_ty,
53+
%poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
54+
%0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
55+
%1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
56+
%a_plus_b = arith.addi %0, %1 : !tensor_ty
57+
%out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
58+
return %out : !ntt_poly_ty
59+
}
60+
61+
// CHECK-LABEL: test_canonicalize_fold_add_through_intt
62+
// CHECK: arith.addi
63+
// CHECK-NOT: polynomial.intt
64+
// CHECK-NOT: polynomial.iintt
65+
func.func @test_canonicalize_fold_add_through_intt(
66+
%tensor0 : !tensor_ty,
67+
%tensor1 : !tensor_ty) -> !tensor_ty {
68+
%0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
69+
%1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
70+
%a_plus_b = polynomial.add %0, %1 : !ntt_poly_ty
71+
%out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
72+
return %out : !tensor_ty
73+
}
74+
75+
// CHECK-LABEL: test_canonicalize_fold_sub_through_ntt
76+
// CHECK: polynomial.mul_scalar
77+
// CHECK: polynomial.add
78+
// CHECK-NOT: polynomial.ntt
79+
// CHECK-NOT: polynomial.intt
80+
func.func @test_canonicalize_fold_sub_through_ntt(
81+
%poly0 : !ntt_poly_ty,
82+
%poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
83+
%0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
84+
%1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
85+
%a_plus_b = arith.subi %0, %1 : !tensor_ty
86+
%out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
87+
return %out : !ntt_poly_ty
88+
}
89+
90+
// CHECK-LABEL: test_canonicalize_fold_sub_through_intt
91+
// CHECK: arith.subi
92+
// CHECK-NOT: polynomial.intt
93+
// CHECK-NOT: polynomial.iintt
94+
func.func @test_canonicalize_fold_sub_through_intt(
95+
%tensor0 : !tensor_ty,
96+
%tensor1 : !tensor_ty) -> !tensor_ty {
97+
%0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
98+
%1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
99+
%a_plus_b = polynomial.sub %0, %1 : !ntt_poly_ty
100+
%out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
101+
return %out : !tensor_ty
102+
}

0 commit comments

Comments
 (0)