-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[polynomial] Move primitive root attribute to ntt/intt ops. #93227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Jeremy Kun (j2kun) ChangesPatch is 20.83 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/93227.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index 3ef899d3376b1..7035219238dd5 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -277,8 +277,7 @@ def Polynomial_AnyPolynomialAttr : AnyAttrOf<[
Polynomial_IntPolynomialAttr
]>;
-// Not deriving from Polynomial_Op due to need for custom assembly format
-def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure]> {
+def Polynomial_ConstantOp : Polynomial_Op<"constant", [Pure]> {
let summary = "Define a constant polynomial via an attribute.";
let description = [{
Example:
@@ -312,9 +311,12 @@ def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
`f[k] = F(omega[n]^k) ; k = {0, ..., n-1}`
- The choice of primitive root is determined by subsequent lowerings.
+ The choice of primitive root may be optionally specified.
}];
- let arguments = (ins Polynomial_PolynomialType:$input);
+ let arguments = (ins
+ Polynomial_PolynomialType:$input,
+ OptionalAttr<Polynomial_PrimitiveRootAttr>:$root
+ );
let results = (outs RankedTensorOf<[AnyInteger]>:$output);
let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
let hasCanonicalizer = 1;
@@ -332,8 +334,13 @@ def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> {
output polynomial at powers of a primitive `n`-th root of unity (see
`polynomial.ntt`). The ring of the polynomial is taken from the required
encoding attribute of the tensor.
+
+ The choice of primitive root may be optionally specified.
}];
- let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
+ let arguments = (
+ ins RankedTensorOf<[AnyInteger]>:$input,
+ OptionalAttr<Polynomial_PrimitiveRootAttr>:$root
+ );
let results = (outs Polynomial_PolynomialType:$output);
let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
index e5dbfa7fa21ee..2bfdc15ddd034 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
@@ -102,24 +102,45 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
let parameters = (ins
"Type": $coefficientType,
OptionalParameter<"::mlir::IntegerAttr">: $coefficientModulus,
- OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus,
- OptionalParameter<"::mlir::IntegerAttr">: $primitiveRoot
+ OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus
);
let assemblyFormat = "`<` struct(params) `>`";
let builders = [
AttrBuilderWithInferredContext<
(ins "::mlir::Type":$coefficientTy,
CArg<"::mlir::IntegerAttr", "nullptr"> :$coefficientModulusAttr,
- CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr,
- CArg<"::mlir::IntegerAttr", "nullptr"> :$primitiveRootAttr), [{
+ CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr), [{
return $_get(
coefficientTy.getContext(),
coefficientTy,
coefficientModulusAttr,
- polynomialModulusAttr,
- primitiveRootAttr);
+ polynomialModulusAttr);
}]>,
];
}
+def Polynomial_PrimitiveRootAttr: Polynomial_Attr<"PrimitiveRoot", "primitive_root"> {
+ let summary = "an attribute containing an integer and its degree as a root of unity";
+ let description = [{
+ A primitive root attribute stores an integer root `value` and an integer
+ `degree`, corresponding to a primitive root of unity of the given degree in
+ an unspecified ring.
+
+ This is used as an attribute on `polynomial.ntt` and `polynomial.intt` ops
+ to specify the root of unity used in lowering the transform.
+
+ Example:
+
+ ```mlir
+ #poly = #polynomial.primitive_root<value=123 : i32, degree : 7 index>
+ ```
+ }];
+ let parameters = (ins
+ "::mlir::IntegerAttr":$value,
+ "::mlir::IntegerAttr":$degree
+ );
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
+
#endif // POLYNOMIAL_ATTRIBUTES
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
index 9d09799c1763a..f8216e1b2307d 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
@@ -14,6 +14,10 @@ include "mlir/Dialect/Arith/IR/ArithOps.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/PatternBase.td"
+defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
+
+def Equal : Constraint<CPred<"$0 == $1">>;
+
// Get a -1 integer attribute of the same type as the polynomial SSA value's
// ring coefficient type.
def getMinusOne
@@ -28,15 +32,51 @@ def SubAsAdd : Pat<
(Arith_ConstantOp (getMinusOne $g))))>;
def INTTAfterNTT : Pat<
- (Polynomial_INTTOp (Polynomial_NTTOp $poly)),
+ (Polynomial_INTTOp (Polynomial_NTTOp $poly, $r1), $r2),
(replaceWithValue $poly),
- []
+ [(Equal $r1, $r2)]
>;
def NTTAfterINTT : Pat<
- (Polynomial_NTTOp (Polynomial_INTTOp $tensor)),
+ (Polynomial_NTTOp (Polynomial_INTTOp $tensor, $r1), $r2),
(replaceWithValue $tensor),
- []
+ [(Equal $r1, $r2)]
+>;
+
+// NTTs are expensive, and addition in coefficient or NTT domain should be
+// equivalently expensive, so reducing the number of NTTs is optimal.
+// ntt(a) + ntt(b) -> ntt(a + b)
+def NTTOfAdd : Pat<
+ (Arith_AddIOp
+ (Polynomial_NTTOp $p1, $r1),
+ (Polynomial_NTTOp $p2, $r2),
+ $overflow),
+ (Polynomial_NTTOp (Polynomial_AddOp $p1, $p2), $r1),
+ [(Equal $r1, $r2)]
+>;
+// intt(a) + intt(b) -> intt(a + b)
+def INTTOfAdd : Pat<
+ (Polynomial_AddOp
+ (Polynomial_INTTOp $t1, $r1),
+ (Polynomial_INTTOp $t2, $r2)),
+ (Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow), $r1),
+ [(Equal $r1, $r2)]
+>;
+// repeated for sub
+def NTTOfSub : Pat<
+ (Arith_SubIOp
+ (Polynomial_NTTOp $p1, $r1),
+ (Polynomial_NTTOp $p2, $r2),
+ $overflow),
+ (Polynomial_NTTOp (Polynomial_SubOp $p1, $p2), $r1),
+ [(Equal $r1, $r2)]
+>;
+def INTTOfSub : Pat<
+ (Polynomial_SubOp
+ (Polynomial_INTTOp $t1, $r1),
+ (Polynomial_INTTOp $t2, $r2)),
+ (Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow), $r1),
+ [(Equal $r1, $r2)]
>;
#endif // POLYNOMIAL_CANONICALIZATION
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 1a2439fe810b5..58773a24181df 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -108,14 +108,15 @@ LogicalResult MulScalarOp::verify() {
}
/// Test if a value is a primitive nth root of unity modulo cmod.
-bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
+bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n,
const APInt &cmod) {
// Root bitwidth may be 1 less then cmod.
APInt r = APInt(root).zext(cmod.getBitWidth());
assert(r.ule(cmod) && "root must be less than cmod");
+ unsigned upperBound = n.getZExtValue();
APInt a = r;
- for (size_t k = 1; k < n; k++) {
+ for (size_t k = 1; k < upperBound; k++) {
if (a.isOne())
return false;
a = (a * r).urem(cmod);
@@ -126,7 +127,8 @@ bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
/// Verify that the types involved in an NTT or INTT operation are
/// compatible.
static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
- RankedTensorType tensorType) {
+ RankedTensorType tensorType,
+ std::optional<PrimitiveRootAttr> root) {
Attribute encoding = tensorType.getEncoding();
if (!encoding) {
return op->emitOpError()
@@ -157,33 +159,29 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
return diag;
}
- if (!ring.getPrimitiveRoot()) {
- return op->emitOpError()
- << "ring type " << ring << " does not provide a primitive root "
- << "of unity, which is required to express an NTT";
- }
-
- if (!isPrimitiveNthRootOfUnity(ring.getPrimitiveRoot().getValue(), polyDegree,
- ring.getCoefficientModulus().getValue())) {
- return op->emitOpError()
- << "ring type " << ring << " has a primitiveRoot attribute '"
- << ring.getPrimitiveRoot()
- << "' that is not a primitive root of the coefficient ring";
+ if (root.has_value()) {
+ APInt rootValue = root.value().getValue().getValue();
+ APInt rootDegree = root.value().getDegree().getValue();
+ APInt cmod = ring.getCoefficientModulus().getValue();
+ if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) {
+ return op->emitOpError()
+ << "provided root " << rootValue.getZExtValue() << " is not a primitive root "
+ << "of unity mod " << cmod.getZExtValue() << ", with the specified degree "
+ << rootDegree.getZExtValue();
+ }
}
return success();
}
LogicalResult NTTOp::verify() {
- auto ring = getInput().getType().getRing();
- auto tensorType = getOutput().getType();
- return verifyNTTOp(this->getOperation(), ring, tensorType);
+ return verifyNTTOp(this->getOperation(), getInput().getType().getRing(),
+ getOutput().getType(), getRoot());
}
LogicalResult INTTOp::verify() {
- auto tensorType = getInput().getType();
- auto ring = getOutput().getType().getRing();
- return verifyNTTOp(this->getOperation(), ring, tensorType);
+ return verifyNTTOp(this->getOperation(), getOutput().getType().getRing(),
+ getInput().getType(), getRoot());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir
index dbfbf2d93f111..354b76e3d9669 100644
--- a/mlir/test/Dialect/Polynomial/canonicalization.mlir
+++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt -canonicalize %s | FileCheck %s
#ntt_poly = #polynomial.int_polynomial<-1 + x**8>
-#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
+#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
+#root = #polynomial.primitive_root<value=31:i32, degree=8:index>
!ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
!tensor_ty = tensor<8xi32, #ntt_ring>
@@ -10,8 +11,8 @@ func.func @test_canonicalize_intt_after_ntt(%p0 : !ntt_poly_ty) -> !ntt_poly_ty
// CHECK-NOT: polynomial.ntt
// CHECK-NOT: polynomial.intt
// CHECK: %[[RESULT:.+]] = polynomial.add %[[P]], %[[P]] : [[T]]
- %t0 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
- %p1 = polynomial.intt %t0: !tensor_ty -> !ntt_poly_ty
+ %t0 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty
+ %p1 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty
%p2 = polynomial.add %p1, %p1 : !ntt_poly_ty
// CHECK: return %[[RESULT]] : [[T]]
return %p2 : !ntt_poly_ty
@@ -23,8 +24,8 @@ func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty {
// CHECK-NOT: polynomial.intt
// CHECK-NOT: polynomial.ntt
// CHECK: %[[RESULT:.+]] = arith.addi %[[X]], %[[X]] : [[T]]
- %p0 = polynomial.intt %t0 : !tensor_ty -> !ntt_poly_ty
- %t1 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
+ %p0 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty
+ %t1 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty
%t2 = arith.addi %t1, %t1 : !tensor_ty
// CHECK: return %[[RESULT]] : [[T]]
return %t2 : !tensor_ty
@@ -43,3 +44,60 @@ func.func @test_canonicalize_sub(%poly0 : !sub_ty, %poly1 : !sub_ty) -> !sub_ty
// CHECK: [[ADD:%.+]] = polynomial.add %[[p0]], %[[p1neg]]
return %0 : !sub_ty
}
+
+// CHECK-LABEL: test_canonicalize_fold_add_through_ntt
+// CHECK: polynomial.add
+// CHECK-NOT: polynomial.ntt
+// CHECK-NOT: polynomial.intt
+func.func @test_canonicalize_fold_add_through_ntt(
+ %poly0 : !ntt_poly_ty,
+ %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
+ %0 = polynomial.ntt %poly0 {root=#root} : !ntt_poly_ty -> !tensor_ty
+ %1 = polynomial.ntt %poly1 {root=#root} : !ntt_poly_ty -> !tensor_ty
+ %a_plus_b = arith.addi %0, %1 : !tensor_ty
+ %out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
+ return %out : !ntt_poly_ty
+}
+
+// CHECK-LABEL: test_canonicalize_fold_add_through_intt
+// CHECK: arith.addi
+// CHECK-NOT: polynomial.intt
+// CHECK-NOT: polynomial.iintt
+func.func @test_canonicalize_fold_add_through_intt(
+ %tensor0 : !tensor_ty,
+ %tensor1 : !tensor_ty) -> !tensor_ty {
+ %0 = polynomial.intt %tensor0 {root=#root} : !tensor_ty -> !ntt_poly_ty
+ %1 = polynomial.intt %tensor1 {root=#root} : !tensor_ty -> !ntt_poly_ty
+ %a_plus_b = polynomial.add %0, %1 : !ntt_poly_ty
+ %out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
+ return %out : !tensor_ty
+}
+
+// CHECK-LABEL: test_canonicalize_fold_sub_through_ntt
+// CHECK: polynomial.mul_scalar
+// CHECK: polynomial.add
+// CHECK-NOT: polynomial.ntt
+// CHECK-NOT: polynomial.intt
+func.func @test_canonicalize_fold_sub_through_ntt(
+ %poly0 : !ntt_poly_ty,
+ %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
+ %0 = polynomial.ntt %poly0 {root=#root} : !ntt_poly_ty -> !tensor_ty
+ %1 = polynomial.ntt %poly1 {root=#root} : !ntt_poly_ty -> !tensor_ty
+ %a_plus_b = arith.subi %0, %1 : !tensor_ty
+ %out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
+ return %out : !ntt_poly_ty
+}
+
+// CHECK-LABEL: test_canonicalize_fold_sub_through_intt
+// CHECK: arith.subi
+// CHECK-NOT: polynomial.intt
+// CHECK-NOT: polynomial.iintt
+func.func @test_canonicalize_fold_sub_through_intt(
+ %tensor0 : !tensor_ty,
+ %tensor1 : !tensor_ty) -> !tensor_ty {
+ %0 = polynomial.intt %tensor0 {root=#root} : !tensor_ty -> !ntt_poly_ty
+ %1 = polynomial.intt %tensor1 {root=#root} : !tensor_ty -> !ntt_poly_ty
+ %a_plus_b = polynomial.sub %0, %1 : !ntt_poly_ty
+ %out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
+ return %out : !tensor_ty
+}
diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir
index ff709960c50e9..c9bcb7f95b7d3 100644
--- a/mlir/test/Dialect/Polynomial/ops.mlir
+++ b/mlir/test/Dialect/Polynomial/ops.mlir
@@ -11,11 +11,11 @@
#one_plus_x_squared = #polynomial.int_polynomial<1 + x**2>
#ideal = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal, primitiveRoot=193>
+#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal>
!poly_ty = !polynomial.polynomial<ring=#ring>
#ntt_poly = #polynomial.int_polynomial<-1 + x**8>
-#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
+#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
!ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
module {
@@ -87,12 +87,12 @@ module {
}
func.func @test_ntt(%0 : !ntt_poly_ty) {
- %1 = polynomial.ntt %0 : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
+ %1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
return
}
func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) {
- %1 = polynomial.intt %0 : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
+ %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
return
}
}
diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
index af8e4aa5da862..f22b14897e98a 100644
--- a/mlir/test/Dialect/Polynomial/ops_errors.mlir
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -55,28 +55,28 @@ func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
// -----
#my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
!poly_ty = !polynomial.polynomial<ring=#ring>
// CHECK-NOT: @test_invalid_ntt
// CHECK-NOT: polynomial.ntt
func.func @test_invalid_ntt(%0 : !poly_ty) {
// expected-error@below {{expects a ring encoding to be provided to the tensor}}
- %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32>
+ %1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !poly_ty -> tensor<1024xi32>
return
}
// -----
#my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
!poly_ty = !polynomial.polynomial<ring=#ring>
// CHECK-NOT: @test_invalid_ntt
// CHECK-NOT: polynomial.ntt
func.func @test_invalid_ntt(%0 : !poly_ty) {
// expected-error@below {{tensor encoding is not a ring attribute}}
- %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #my_poly>
+ %1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !poly_ty -> tensor<1024xi32, #my_poly>
return
}
@@ -84,21 +84,21 @@ func.func @test_invalid_ntt(%0 : !poly_ty) {
#my_poly = #polynomial.int_polynomial<-1 + x**1024>
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
-#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
+#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257:i16, polynomialModulus=#my_poly>
!poly_ty = !polynomial.polynomial<ring=#ring>
// CHECK-NOT: @test_invalid_intt
// CHECK-NOT: polynomial.intt
func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
// expected-error@below {{not equivalent to the polynomial ring}}
- %1 = polynomial.intt %0 : tensor<1024xi32, #ring1> -> !poly_ty
+ %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<1024xi32, #ring1> -> !poly_ty
return
}
// -----
#my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
!poly_ty = !polynomial.polynomial<ring=#ring>
// CHECK-NOT: @test_invalid_intt
@@ -106,21 +106,7 @@ func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
func.func @test_invalid_intt(%0 : tensor<1025xi32, #ring>) {
// expected-error@below {{does not match output type}}
// expected-note@below {{exactly the degree of the polynomialModulus of the polynomial type's ring attribute}}
- %1 = polynomial.intt %0 : tensor<1025xi32, #ring> -> !poly_ty
- return
-}
-
-// -----
-
-#my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
-!poly_ty = !polynomial.polynomial<ring=#ring>
-
-// CHECK-NOT: @test_invalid_ntt
-// CHECK-NOT: polynomial.ntt
-func.func @test_invalid_ntt(%0 : !poly_ty) {
- // expected-error@below {{does not provide a primitive root of unity, which is required to express an NTT}}
- %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #ring>
+ %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<1025xi32, #ring> -> !poly_ty
return
}
@@ -128,13 +114,13 @@ func.func @test_invalid_ntt(%0 : !poly_ty) {
#my_poly = #polynomial.int_polynomial<-1 + x**8>
// A valid root i...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
9f289ea
to
d3d79d9
Compare
}]>, | ||
]; | ||
} | ||
|
||
def Polynomial_PrimitiveRootAttr: Polynomial_Attr<"PrimitiveRoot", "primitive_root"> { | ||
let summary = "an attribute containing an integer and its degree as a root of unity"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if mlir docs can render latex
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
APInt rootValue = root.value().getValue().getValue(); | ||
APInt rootDegree = root.value().getDegree().getValue(); | ||
APInt cmod = ring.getCoefficientModulus().getValue(); | ||
if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove {
, }
for single statement if/fors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe it's questionable convention in the case of nested conditionals? up to you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is indeed questionable, I usually lean towards having braces for multi-line bodies even if there is a single statement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM modulo my general distate for algebra (see what I did there...)
APInt rootValue = root.value().getValue().getValue(); | ||
APInt rootDegree = root.value().getDegree().getValue(); | ||
APInt cmod = ring.getCoefficientModulus().getValue(); | ||
if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is indeed questionable, I usually lean towards having braces for multi-line bodies even if there is a single statement.
Rebased over #93227 --------- Co-authored-by: Jeremy Kun <[email protected]>
Related to llvm#93227 and google/heir#993 When ntt/intt ops are emitted as a result of pattern rewrite, the primitive root attr must be provided in some way, and it is convenient for it to be provided in ring attr. As for using different primitive root for the same polynomial, to_tensor/tensor.cast/from_tensor should be enough for changing primitiveRoot attribute in RingAttr.
Better design to put semantics on the ops, and in this case the ntt/intt op can lower in multiple ways depending on the polynomial ring modulus (it can need an nth root of unity for cyclic polymul -> ntt, or a 2nth root for negacyclic polymul -> ntt)