Skip to content

[mlir][polynomial] verify from_tensor coeff type #93243

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 30, 2024

Conversation

j2kun
Copy link
Contributor

@j2kun j2kun commented May 23, 2024

Rebased over #93227

@j2kun j2kun marked this pull request as draft May 23, 2024 21:29
@llvmbot llvmbot added the mlir label May 23, 2024
@llvmbot
Copy link
Member

llvmbot commented May 23, 2024

@llvm/pr-subscribers-mlir

Author: Jeremy Kun (j2kun)

Changes

Rebased over #93227


Patch is 22.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/93243.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td (+11-4)
  • (modified) mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td (+27-6)
  • (modified) mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td (+22-20)
  • (modified) mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp (+21-25)
  • (modified) mlir/test/Dialect/Polynomial/canonicalization.mlir (+32-17)
  • (modified) mlir/test/Dialect/Polynomial/ops.mlir (+4-4)
  • (modified) mlir/test/Dialect/Polynomial/ops_errors.mlir (+12-26)
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index f99cbccd243ec..755396c8b9023 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -277,7 +277,6 @@ def Polynomial_AnyTypedPolynomialAttr : AnyAttrOf<[
   Polynomial_TypedIntPolynomialAttr
 ]>;
 
-// Not deriving from Polynomial_Op due to need for custom assembly format
 def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant",
     [Pure, InferTypeOpAdaptor]> {
   let summary = "Define a constant polynomial via an attribute.";
@@ -312,9 +311,12 @@ def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
 
       `f[k] = F(omega[n]^k) ; k = {0, ..., n-1}`
 
-    The choice of primitive root is determined by subsequent lowerings.
+    The choice of primitive root may be optionally specified.
   }];
-  let arguments = (ins Polynomial_PolynomialType:$input);
+  let arguments = (ins
+    Polynomial_PolynomialType:$input,
+    OptionalAttr<Polynomial_PrimitiveRootAttr>:$root
+  );
   let results = (outs RankedTensorOf<[AnyInteger]>:$output);
   let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
   let hasCanonicalizer = 1;
@@ -332,8 +334,13 @@ def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> {
     output polynomial at powers of a primitive `n`-th root of unity (see
     `polynomial.ntt`). The ring of the polynomial is taken from the required
     encoding attribute of the tensor.
+
+    The choice of primitive root may be optionally specified.
   }];
-  let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
+  let arguments = (
+    ins RankedTensorOf<[AnyInteger]>:$input,
+    OptionalAttr<Polynomial_PrimitiveRootAttr>:$root
+  );
   let results = (outs Polynomial_PolynomialType:$output);
   let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
   let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
index 655020adf808b..2d3ed60a35fd9 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
@@ -166,24 +166,45 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
   let parameters = (ins
     "Type": $coefficientType,
     OptionalParameter<"::mlir::IntegerAttr">: $coefficientModulus,
-    OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus,
-    OptionalParameter<"::mlir::IntegerAttr">: $primitiveRoot
+    OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus
   );
   let assemblyFormat = "`<` struct(params) `>`";
   let builders = [
     AttrBuilderWithInferredContext<
         (ins "::mlir::Type":$coefficientTy,
               CArg<"::mlir::IntegerAttr", "nullptr"> :$coefficientModulusAttr,
-              CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr,
-              CArg<"::mlir::IntegerAttr", "nullptr"> :$primitiveRootAttr), [{
+              CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr), [{
       return $_get(
         coefficientTy.getContext(),
         coefficientTy,
         coefficientModulusAttr,
-        polynomialModulusAttr,
-        primitiveRootAttr);
+        polynomialModulusAttr);
     }]>,
   ];
 }
 
+def Polynomial_PrimitiveRootAttr: Polynomial_Attr<"PrimitiveRoot", "primitive_root"> {
+  let summary = "an attribute containing an integer and its degree as a root of unity";
+  let description = [{
+    A primitive root attribute stores an integer root `value` and an integer
+    `degree`, corresponding to a primitive root of unity of the given degree in
+    an unspecified ring.
+
+    This is used as an attribute on `polynomial.ntt` and `polynomial.intt` ops
+    to specify the root of unity used in lowering the transform.
+
+    Example:
+
+    ```mlir
+    #poly = #polynomial.primitive_root<value=123 : i32, degree : 7 index>
+    ```
+  }];
+  let parameters = (ins
+    "::mlir::IntegerAttr":$value,
+    "::mlir::IntegerAttr":$degree
+  );
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
+
 #endif // POLYNOMIAL_ATTRIBUTES
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
index e37bcf76a20f2..93ea6e4e43698 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
@@ -17,6 +17,8 @@ include "mlir/IR/PatternBase.td"
 
 defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
 
+def Equal : Constraint<CPred<"$0 == $1">>;
+
 // Get a -1 integer attribute of the same type as the polynomial SSA value's
 // ring coefficient type.
 def getMinusOne
@@ -31,15 +33,15 @@ def SubAsAdd : Pat<
       (Arith_ConstantOp (getMinusOne $g))))>;
 
 def INTTAfterNTT : Pat<
-  (Polynomial_INTTOp (Polynomial_NTTOp $poly)),
+  (Polynomial_INTTOp (Polynomial_NTTOp $poly, $r1), $r2),
   (replaceWithValue $poly),
-  []
+  [(Equal $r1, $r2)]
 >;
 
 def NTTAfterINTT : Pat<
-  (Polynomial_NTTOp (Polynomial_INTTOp $tensor)),
+  (Polynomial_NTTOp (Polynomial_INTTOp $tensor, $r1), $r2),
   (replaceWithValue $tensor),
-  []
+  [(Equal $r1, $r2)]
 >;
 
 // NTTs are expensive, and addition in coefficient or NTT domain should be
@@ -47,35 +49,35 @@ def NTTAfterINTT : Pat<
 // ntt(a) + ntt(b) -> ntt(a + b)
 def NTTOfAdd : Pat<
   (Arith_AddIOp
-    (Polynomial_NTTOp $p1),
-    (Polynomial_NTTOp $p2),
+    (Polynomial_NTTOp $p1, $r1),
+    (Polynomial_NTTOp $p2, $r2),
     $overflow),
-  (Polynomial_NTTOp (Polynomial_AddOp $p1, $p2)),
-  []
+  (Polynomial_NTTOp (Polynomial_AddOp $p1, $p2), $r1),
+  [(Equal $r1, $r2)]
 >;
 // intt(a) + intt(b) -> intt(a + b)
 def INTTOfAdd : Pat<
   (Polynomial_AddOp
-    (Polynomial_INTTOp $t1),
-    (Polynomial_INTTOp $t2)),
-  (Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow)),
-  []
+    (Polynomial_INTTOp $t1, $r1),
+    (Polynomial_INTTOp $t2, $r2)),
+  (Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow), $r1),
+  [(Equal $r1, $r2)]
 >;
 // repeated for sub
 def NTTOfSub : Pat<
   (Arith_SubIOp
-    (Polynomial_NTTOp $p1),
-    (Polynomial_NTTOp $p2),
+    (Polynomial_NTTOp $p1, $r1),
+    (Polynomial_NTTOp $p2, $r2),
     $overflow),
-  (Polynomial_NTTOp (Polynomial_SubOp $p1, $p2)),
-  []
+  (Polynomial_NTTOp (Polynomial_SubOp $p1, $p2), $r1),
+  [(Equal $r1, $r2)]
 >;
 def INTTOfSub : Pat<
   (Polynomial_SubOp
-    (Polynomial_INTTOp $t1),
-    (Polynomial_INTTOp $t2)),
-  (Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow)),
-  []
+    (Polynomial_INTTOp $t1, $r1),
+    (Polynomial_INTTOp $t2, $r2)),
+  (Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow), $r1),
+  [(Equal $r1, $r2)]
 >;
 
 #endif  // POLYNOMIAL_CANONICALIZATION
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 3d302797ce513..3719979177215 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -47,11 +47,8 @@ LogicalResult FromTensorOp::verify() {
     return diag;
   }
 
-  APInt coefficientModulus = ring.getCoefficientModulus().getValue();
-  unsigned cmodBitWidth = coefficientModulus.ceilLogBase2();
   unsigned inputBitWidth = getInput().getType().getElementTypeBitWidth();
-
-  if (inputBitWidth > cmodBitWidth) {
+  if (inputBitWidth > ring.getCoefficientType().getIntOrFloatBitWidth()) {
     InFlightDiagnostic diag = emitOpError()
                               << "input tensor element type "
                               << getInput().getType().getElementType()
@@ -108,14 +105,15 @@ LogicalResult MulScalarOp::verify() {
 }
 
 /// Test if a value is a primitive nth root of unity modulo cmod.
-bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
+bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n,
                                const APInt &cmod) {
   // Root bitwidth may be 1 less then cmod.
   APInt r = APInt(root).zext(cmod.getBitWidth());
   assert(r.ule(cmod) && "root must be less than cmod");
+  unsigned upperBound = n.getZExtValue();
 
   APInt a = r;
-  for (size_t k = 1; k < n; k++) {
+  for (size_t k = 1; k < upperBound; k++) {
     if (a.isOne())
       return false;
     a = (a * r).urem(cmod);
@@ -126,7 +124,8 @@ bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
 /// Verify that the types involved in an NTT or INTT operation are
 /// compatible.
 static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
-                                 RankedTensorType tensorType) {
+                                 RankedTensorType tensorType,
+                                 std::optional<PrimitiveRootAttr> root) {
   Attribute encoding = tensorType.getEncoding();
   if (!encoding) {
     return op->emitOpError()
@@ -157,33 +156,30 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
     return diag;
   }
 
-  if (!ring.getPrimitiveRoot()) {
-    return op->emitOpError()
-           << "ring type " << ring << " does not provide a primitive root "
-           << "of unity, which is required to express an NTT";
-  }
-
-  if (!isPrimitiveNthRootOfUnity(ring.getPrimitiveRoot().getValue(), polyDegree,
-                                 ring.getCoefficientModulus().getValue())) {
-    return op->emitOpError()
-           << "ring type " << ring << " has a primitiveRoot attribute '"
-           << ring.getPrimitiveRoot()
-           << "' that is not a primitive root of the coefficient ring";
+  if (root.has_value()) {
+    APInt rootValue = root.value().getValue().getValue();
+    APInt rootDegree = root.value().getDegree().getValue();
+    APInt cmod = ring.getCoefficientModulus().getValue();
+    if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) {
+      return op->emitOpError()
+             << "provided root " << rootValue.getZExtValue()
+             << " is not a primitive root "
+             << "of unity mod " << cmod.getZExtValue()
+             << ", with the specified degree " << rootDegree.getZExtValue();
+    }
   }
 
   return success();
 }
 
 LogicalResult NTTOp::verify() {
-  auto ring = getInput().getType().getRing();
-  auto tensorType = getOutput().getType();
-  return verifyNTTOp(this->getOperation(), ring, tensorType);
+  return verifyNTTOp(this->getOperation(), getInput().getType().getRing(),
+                     getOutput().getType(), getRoot());
 }
 
 LogicalResult INTTOp::verify() {
-  auto tensorType = getInput().getType();
-  auto ring = getOutput().getType().getRing();
-  return verifyNTTOp(this->getOperation(), ring, tensorType);
+  return verifyNTTOp(this->getOperation(), getOutput().getType().getRing(),
+                     getInput().getType(), getRoot());
 }
 
 ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
diff --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir
index 489d9ec2720d6..b79938627e415 100644
--- a/mlir/test/Dialect/Polynomial/canonicalization.mlir
+++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir
@@ -1,6 +1,7 @@
 // RUN: mlir-opt -canonicalize %s | FileCheck %s
 #ntt_poly = #polynomial.int_polynomial<-1 + x**8>
-#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
+#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
+#root = #polynomial.primitive_root<value=31:i32, degree=8:index>
 !ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
 !tensor_ty = tensor<8xi32, #ntt_ring>
 
@@ -10,8 +11,8 @@ func.func @test_canonicalize_intt_after_ntt(%p0 : !ntt_poly_ty) -> !ntt_poly_ty
   // CHECK-NOT: polynomial.ntt
   // CHECK-NOT: polynomial.intt
   // CHECK: %[[RESULT:.+]] = polynomial.add %[[P]], %[[P]]  : [[T]]
-  %t0 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
-  %p1 = polynomial.intt %t0: !tensor_ty -> !ntt_poly_ty
+  %t0 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty
+  %p1 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty
   %p2 = polynomial.add %p1, %p1 : !ntt_poly_ty
   // CHECK: return %[[RESULT]] : [[T]]
   return %p2 : !ntt_poly_ty
@@ -23,8 +24,8 @@ func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty {
   // CHECK-NOT: polynomial.intt
   // CHECK-NOT: polynomial.ntt
   // CHECK: %[[RESULT:.+]] = arith.addi %[[X]], %[[X]] : [[T]]
-  %p0 = polynomial.intt %t0 : !tensor_ty -> !ntt_poly_ty
-  %t1 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
+  %p0 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty
+  %t1 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty
   %t2 = arith.addi %t1, %t1 : !tensor_ty
   // CHECK: return %[[RESULT]] : [[T]]
   return %t2 : !tensor_ty
@@ -51,10 +52,10 @@ func.func @test_canonicalize_sub(%poly0 : !sub_ty, %poly1 : !sub_ty) -> !sub_ty
 func.func @test_canonicalize_fold_add_through_ntt(
     %poly0 : !ntt_poly_ty,
     %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
-  %0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
-  %1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
+  %0 = polynomial.ntt %poly0 {root=#root} : !ntt_poly_ty -> !tensor_ty
+  %1 = polynomial.ntt %poly1 {root=#root} : !ntt_poly_ty -> !tensor_ty
   %a_plus_b = arith.addi %0, %1 : !tensor_ty
-  %out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
+  %out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
   return %out : !ntt_poly_ty
 }
 
@@ -65,10 +66,10 @@ func.func @test_canonicalize_fold_add_through_ntt(
 func.func @test_canonicalize_fold_add_through_intt(
     %tensor0 : !tensor_ty,
     %tensor1 : !tensor_ty) -> !tensor_ty {
-  %0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
-  %1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
+  %0 = polynomial.intt %tensor0 {root=#root} : !tensor_ty -> !ntt_poly_ty
+  %1 = polynomial.intt %tensor1 {root=#root} : !tensor_ty -> !ntt_poly_ty
   %a_plus_b = polynomial.add %0, %1 : !ntt_poly_ty
-  %out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
+  %out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
   return %out : !tensor_ty
 }
 
@@ -80,10 +81,10 @@ func.func @test_canonicalize_fold_add_through_intt(
 func.func @test_canonicalize_fold_sub_through_ntt(
     %poly0 : !ntt_poly_ty,
     %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
-  %0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
-  %1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
+  %0 = polynomial.ntt %poly0 {root=#root} : !ntt_poly_ty -> !tensor_ty
+  %1 = polynomial.ntt %poly1 {root=#root} : !ntt_poly_ty -> !tensor_ty
   %a_plus_b = arith.subi %0, %1 : !tensor_ty
-  %out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
+  %out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
   return %out : !ntt_poly_ty
 }
 
@@ -94,9 +95,23 @@ func.func @test_canonicalize_fold_sub_through_ntt(
 func.func @test_canonicalize_fold_sub_through_intt(
     %tensor0 : !tensor_ty,
     %tensor1 : !tensor_ty) -> !tensor_ty {
-  %0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
-  %1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
+  %0 = polynomial.intt %tensor0 {root=#root} : !tensor_ty -> !ntt_poly_ty
+  %1 = polynomial.intt %tensor1 {root=#root} : !tensor_ty -> !ntt_poly_ty
   %a_plus_b = polynomial.sub %0, %1 : !ntt_poly_ty
-  %out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
+  %out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
   return %out : !tensor_ty
 }
+
+
+// CHECK-LABEL: test_canonicalize_do_not_fold_different_roots
+// CHECK: arith.addi
+func.func @test_canonicalize_do_not_fold_different_roots(
+    %poly0 : !ntt_poly_ty,
+    %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
+  %0 = polynomial.ntt %poly0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !ntt_poly_ty -> !tensor_ty
+  %1 = polynomial.ntt %poly1 {root=#polynomial.primitive_root<value=33:i32, degree=8:index>} : !ntt_poly_ty -> !tensor_ty
+  %a_plus_b = arith.addi %0, %1 : !tensor_ty
+  %out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
+  return %out : !ntt_poly_ty
+}
+
diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir
index 4716e37ff8852..8c134ab789d60 100644
--- a/mlir/test/Dialect/Polynomial/ops.mlir
+++ b/mlir/test/Dialect/Polynomial/ops.mlir
@@ -11,11 +11,11 @@
 #one_plus_x_squared = #polynomial.int_polynomial<1 + x**2>
 
 #ideal = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal, primitiveRoot=193>
+#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 #ntt_poly = #polynomial.int_polynomial<-1 + x**8>
-#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
+#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
 !ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
 
 module {
@@ -91,12 +91,12 @@ module {
   }
 
   func.func @test_ntt(%0 : !ntt_poly_ty) {
-    %1 = polynomial.ntt %0 : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
+    %1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
     return
   }
 
   func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) {
-    %1 = polynomial.intt %0 : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
+    %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
     return
   }
 }
diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
index af8e4aa5da862..4937e17027afa 100644
--- a/mlir/test/Dialect/Polynomial/ops_errors.mlir
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt --split-input-file --verify-diagnostics %s
 
 #my_poly = #polynomial.int_polynomial<1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i32, polynomialModulus=#my_poly>
+#ring = #polynomial.ring<coefficientType=i16>
 !ty = !polynomial.polynomial<ring=#ring>
 
 func.func @test_from_tensor_too_large_coeffs() {
@@ -55,28 +55,28 @@ func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
 // -----
 
 #my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_ntt
 // CHECK-NOT: polynomial.ntt
 func.func @test_invalid_ntt(%0 : !poly_ty) {
   // expected-error@below {{expects a ring encoding to be provided to the tensor}}
-  %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32>
+  %1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !poly_ty -> tensor<1024xi32>
   return
 }
 
 // -----
 
 #my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_ntt
 // CHECK-NOT: polynomial.ntt
 func.func @test_invalid_ntt(%0 : !poly_ty) {
   // expected-error@below {{tensor encoding is not a ring attribute}}
-  %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #my_poly>
+  %1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !poly_ty -> tensor<1024xi32, #my_poly>
   return
 }
 
@@ -84,21 +84,21 @@ func.func @test_invalid_ntt(%0 : !poly_ty) {
 
 #my_poly = #polynomial.int_polynomial<-1 + x**1024>
 #ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
-#ring1 = #polynomial.ring<coefficientType=i16, coefficientModu...
[truncated]

Copy link

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Discourse for more information.

@j2kun j2kun force-pushed the verify-cmod-bitwidth branch from 7c44485 to af5567b Compare May 30, 2024 14:18
@j2kun j2kun marked this pull request as ready for review May 30, 2024 14:18
Use the coefficient type to verify if a tensor fits in a polynomial
ring. Downstream we had originally not specified the coefficient type
and just used the implied bit width of the coefficient modulus.  Now
that we specify the type, this is simpler.
@j2kun j2kun force-pushed the verify-cmod-bitwidth branch from af5567b to 3cf04e4 Compare May 30, 2024 17:57
Copy link

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff cc2fafa1788908f69366821a04407083f770483e 3cf04e4cc22debffa4e83e0e7427dd46c64ff54c -- mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
View the diff from clang-format here.
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index ea83db4fdd..3117721a94 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -170,9 +170,9 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
     if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) {
       return op->emitOpError()
              << "provided root " << rootValue.getZExtValue()
-             << " is not a primitive root " << "of unity mod "
-             << cmod.getZExtValue() << ", with the specified degree "
-             << rootDegree.getZExtValue();
+             << " is not a primitive root "
+             << "of unity mod " << cmod.getZExtValue()
+             << ", with the specified degree " << rootDegree.getZExtValue();
     }
   }
 

@j2kun j2kun force-pushed the verify-cmod-bitwidth branch from 3cf04e4 to 24d2b29 Compare May 30, 2024 18:02
@j2kun j2kun merged commit 692ae54 into llvm:main May 30, 2024
4 of 6 checks passed
j2kun added a commit that referenced this pull request May 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants