@@ -43,3 +43,60 @@ func.func @test_canonicalize_sub(%poly0 : !sub_ty, %poly1 : !sub_ty) -> !sub_ty
43
43
// CHECK: [[ADD:%.+]] = polynomial.add %[[p0]], %[[p1neg]]
44
44
return %0 : !sub_ty
45
45
}
46
+
47
+ // CHECK-LABEL: test_canonicalize_fold_add_through_ntt
48
+ // CHECK: polynomial.add
49
+ // CHECK-NOT: polynomial.ntt
50
+ // CHECK-NOT: polynomial.intt
51
+ func.func @test_canonicalize_fold_add_through_ntt (
52
+ %poly0 : !ntt_poly_ty ,
53
+ %poly1 : !ntt_poly_ty ) -> !ntt_poly_ty {
54
+ %0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
55
+ %1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
56
+ %a_plus_b = arith.addi %0 , %1 : !tensor_ty
57
+ %out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
58
+ return %out : !ntt_poly_ty
59
+ }
60
+
61
+ // CHECK-LABEL: test_canonicalize_fold_add_through_intt
62
+ // CHECK: arith.addi
63
+ // CHECK-NOT: polynomial.intt
64
+ // CHECK-NOT: polynomial.iintt
65
+ func.func @test_canonicalize_fold_add_through_intt (
66
+ %tensor0 : !tensor_ty ,
67
+ %tensor1 : !tensor_ty ) -> !tensor_ty {
68
+ %0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
69
+ %1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
70
+ %a_plus_b = polynomial.add %0 , %1 : !ntt_poly_ty
71
+ %out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
72
+ return %out : !tensor_ty
73
+ }
74
+
75
+ // CHECK-LABEL: test_canonicalize_fold_sub_through_ntt
76
+ // CHECK: polynomial.mul_scalar
77
+ // CHECK: polynomial.add
78
+ // CHECK-NOT: polynomial.ntt
79
+ // CHECK-NOT: polynomial.intt
80
+ func.func @test_canonicalize_fold_sub_through_ntt (
81
+ %poly0 : !ntt_poly_ty ,
82
+ %poly1 : !ntt_poly_ty ) -> !ntt_poly_ty {
83
+ %0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
84
+ %1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
85
+ %a_plus_b = arith.subi %0 , %1 : !tensor_ty
86
+ %out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
87
+ return %out : !ntt_poly_ty
88
+ }
89
+
90
+ // CHECK-LABEL: test_canonicalize_fold_sub_through_intt
91
+ // CHECK: arith.subi
92
+ // CHECK-NOT: polynomial.intt
93
+ // CHECK-NOT: polynomial.iintt
94
+ func.func @test_canonicalize_fold_sub_through_intt (
95
+ %tensor0 : !tensor_ty ,
96
+ %tensor1 : !tensor_ty ) -> !tensor_ty {
97
+ %0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
98
+ %1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
99
+ %a_plus_b = polynomial.sub %0 , %1 : !ntt_poly_ty
100
+ %out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
101
+ return %out : !tensor_ty
102
+ }
0 commit comments