Skip to content

Commit d48777e

Browse files
[mlir][polynomial] remove incorrect canonicalization rule (#110318)
arith.add for tensor does not mod coefficientModulus, and it may overflow; the result could be incorrect It should be rewritten as modular arithmetic instead of arith Revert #93132 Addresses google/heir#749 Cc @j2kun
1 parent 6cbd8a3 commit d48777e

File tree

3 files changed

+2
-111
lines changed

3 files changed

+2
-111
lines changed

mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,9 @@
1111

1212
include "mlir/Dialect/Arith/IR/ArithOps.td"
1313
include "mlir/Dialect/Polynomial/IR/Polynomial.td"
14-
include "mlir/IR/EnumAttr.td"
1514
include "mlir/IR/OpBase.td"
1615
include "mlir/IR/PatternBase.td"
1716

18-
defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
19-
2017
def Equal : Constraint<CPred<"$0 == $1">>;
2118

2219
// Get a -1 integer attribute of the same type as the polynomial SSA value's
@@ -44,40 +41,4 @@ def NTTAfterINTT : Pat<
4441
[(Equal $r1, $r2)]
4542
>;
4643

47-
// NTTs are expensive, and addition in coefficient or NTT domain should be
48-
// equivalently expensive, so reducing the number of NTTs is optimal.
49-
// ntt(a) + ntt(b) -> ntt(a + b)
50-
def NTTOfAdd : Pat<
51-
(Arith_AddIOp
52-
(Polynomial_NTTOp $p1, $r1),
53-
(Polynomial_NTTOp $p2, $r2),
54-
$overflow),
55-
(Polynomial_NTTOp (Polynomial_AddOp $p1, $p2), $r1),
56-
[(Equal $r1, $r2)]
57-
>;
58-
// intt(a) + intt(b) -> intt(a + b)
59-
def INTTOfAdd : Pat<
60-
(Polynomial_AddOp
61-
(Polynomial_INTTOp $t1, $r1),
62-
(Polynomial_INTTOp $t2, $r2)),
63-
(Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow), $r1),
64-
[(Equal $r1, $r2)]
65-
>;
66-
// repeated for sub
67-
def NTTOfSub : Pat<
68-
(Arith_SubIOp
69-
(Polynomial_NTTOp $p1, $r1),
70-
(Polynomial_NTTOp $p2, $r2),
71-
$overflow),
72-
(Polynomial_NTTOp (Polynomial_SubOp $p1, $p2), $r1),
73-
[(Equal $r1, $r2)]
74-
>;
75-
def INTTOfSub : Pat<
76-
(Polynomial_SubOp
77-
(Polynomial_INTTOp $t1, $r1),
78-
(Polynomial_INTTOp $t2, $r2)),
79-
(Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow), $r1),
80-
[(Equal $r1, $r2)]
81-
>;
82-
8344
#endif // POLYNOMIAL_CANONICALIZATION

mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,10 +289,10 @@ void SubOp::getCanonicalizationPatterns(RewritePatternSet &results,
289289

290290
void NTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
291291
MLIRContext *context) {
292-
results.add<NTTAfterINTT, NTTOfAdd, NTTOfSub>(context);
292+
results.add<NTTAfterINTT>(context);
293293
}
294294

295295
void INTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
296296
MLIRContext *context) {
297-
results.add<INTTAfterNTT, INTTOfAdd, INTTOfSub>(context);
297+
results.add<INTTAfterNTT>(context);
298298
}

mlir/test/Dialect/Polynomial/canonicalization.mlir

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -45,73 +45,3 @@ func.func @test_canonicalize_sub(%poly0 : !sub_ty, %poly1 : !sub_ty) -> !sub_ty
4545
return %0 : !sub_ty
4646
}
4747

48-
// CHECK-LABEL: test_canonicalize_fold_add_through_ntt
49-
// CHECK: polynomial.add
50-
// CHECK-NOT: polynomial.ntt
51-
// CHECK-NOT: polynomial.intt
52-
func.func @test_canonicalize_fold_add_through_ntt(
53-
%poly0 : !ntt_poly_ty,
54-
%poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
55-
%0 = polynomial.ntt %poly0 {root=#root} : !ntt_poly_ty -> !tensor_ty
56-
%1 = polynomial.ntt %poly1 {root=#root} : !ntt_poly_ty -> !tensor_ty
57-
%a_plus_b = arith.addi %0, %1 : !tensor_ty
58-
%out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
59-
return %out : !ntt_poly_ty
60-
}
61-
62-
// CHECK-LABEL: test_canonicalize_fold_add_through_intt
63-
// CHECK: arith.addi
64-
// CHECK-NOT: polynomial.intt
65-
// CHECK-NOT: polynomial.iintt
66-
func.func @test_canonicalize_fold_add_through_intt(
67-
%tensor0 : !tensor_ty,
68-
%tensor1 : !tensor_ty) -> !tensor_ty {
69-
%0 = polynomial.intt %tensor0 {root=#root} : !tensor_ty -> !ntt_poly_ty
70-
%1 = polynomial.intt %tensor1 {root=#root} : !tensor_ty -> !ntt_poly_ty
71-
%a_plus_b = polynomial.add %0, %1 : !ntt_poly_ty
72-
%out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
73-
return %out : !tensor_ty
74-
}
75-
76-
// CHECK-LABEL: test_canonicalize_fold_sub_through_ntt
77-
// CHECK: polynomial.mul_scalar
78-
// CHECK: polynomial.add
79-
// CHECK-NOT: polynomial.ntt
80-
// CHECK-NOT: polynomial.intt
81-
func.func @test_canonicalize_fold_sub_through_ntt(
82-
%poly0 : !ntt_poly_ty,
83-
%poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
84-
%0 = polynomial.ntt %poly0 {root=#root} : !ntt_poly_ty -> !tensor_ty
85-
%1 = polynomial.ntt %poly1 {root=#root} : !ntt_poly_ty -> !tensor_ty
86-
%a_plus_b = arith.subi %0, %1 : !tensor_ty
87-
%out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
88-
return %out : !ntt_poly_ty
89-
}
90-
91-
// CHECK-LABEL: test_canonicalize_fold_sub_through_intt
92-
// CHECK: arith.subi
93-
// CHECK-NOT: polynomial.intt
94-
// CHECK-NOT: polynomial.iintt
95-
func.func @test_canonicalize_fold_sub_through_intt(
96-
%tensor0 : !tensor_ty,
97-
%tensor1 : !tensor_ty) -> !tensor_ty {
98-
%0 = polynomial.intt %tensor0 {root=#root} : !tensor_ty -> !ntt_poly_ty
99-
%1 = polynomial.intt %tensor1 {root=#root} : !tensor_ty -> !ntt_poly_ty
100-
%a_plus_b = polynomial.sub %0, %1 : !ntt_poly_ty
101-
%out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
102-
return %out : !tensor_ty
103-
}
104-
105-
106-
// CHECK-LABEL: test_canonicalize_do_not_fold_different_roots
107-
// CHECK: arith.addi
108-
func.func @test_canonicalize_do_not_fold_different_roots(
109-
%poly0 : !ntt_poly_ty,
110-
%poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
111-
%0 = polynomial.ntt %poly0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !ntt_poly_ty -> !tensor_ty
112-
%1 = polynomial.ntt %poly1 {root=#polynomial.primitive_root<value=33:i32, degree=8:index>} : !ntt_poly_ty -> !tensor_ty
113-
%a_plus_b = arith.addi %0, %1 : !tensor_ty
114-
%out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
115-
return %out : !ntt_poly_ty
116-
}
117-

0 commit comments

Comments
 (0)