-
Notifications
You must be signed in to change notification settings - Fork 13.6k
polynomial: Add basic ops #89525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
polynomial: Add basic ops #89525
Conversation
@llvm/pr-subscribers-mlir Author: Jeremy Kun (j2kun) ChangesAdds a few basic polynomial ops.
Full diff: https://github.com/llvm/llvm-project/pull/89525.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
index 39b05b9d3ad14b..fa767649f649b6 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
@@ -102,6 +102,8 @@ class Polynomial {
unsigned getDegree() const;
+ ArrayRef<Monomial> getTerms() const { return terms; }
+
friend ::llvm::hash_code hash_value(const Polynomial &arg);
private:
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index 5d8da8399b01b5..89a1bd8a5bb30f 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -131,23 +131,137 @@ def Polynomial_PolynomialType : Polynomial_Type<"Polynomial", "polynomial"> {
let assemblyFormat = "`<` $ring `>`";
}
+def PolynomialLike: TypeOrContainer<Polynomial_PolynomialType, "polynomial-like">;
+
class Polynomial_Op<string mnemonic, list<Trait> traits = []> :
- Op<Polynomial_Dialect, mnemonic, traits # [Pure]>;
+ Op<Polynomial_Dialect, mnemonic, traits # [Pure]> {
+ let assemblyFormat = [{
+ operands attr-dict `:` `(` qualified(type(operands)) `)` `->` qualified(type(results))
+ }];
+}
class Polynomial_UnaryOp<string mnemonic, list<Trait> traits = []> :
Polynomial_Op<mnemonic, traits # [SameOperandsAndResultType]> {
let arguments = (ins Polynomial_PolynomialType:$operand);
let results = (outs Polynomial_PolynomialType:$result);
-
- let assemblyFormat = "$operand attr-dict `:` qualified(type($result))";
}
class Polynomial_BinaryOp<string mnemonic, list<Trait> traits = []> :
- Polynomial_Op<mnemonic, traits # [SameOperandsAndResultType]> {
- let arguments = (ins Polynomial_PolynomialType:$lhs, Polynomial_PolynomialType:$rhs);
- let results = (outs Polynomial_PolynomialType:$result);
+ Polynomial_Op<mnemonic, !listconcat(traits, [Pure, SameOperandsAndResultType, ElementwiseMappable])> {
+ let arguments = (ins PolynomialLike:$lhs, PolynomialLike:$rhs);
+ let results = (outs PolynomialLike:$result);
+ let assemblyFormat = "operands attr-dict `:` qualified(type($result))";
+}
+
+def Polynomial_AddOp : Polynomial_BinaryOp<"add", [Commutative]> {
+ let summary = "Addition operation between polynomials.";
+}
+
+def Polynomial_SubOp : Polynomial_BinaryOp<"sub"> {
+ let summary = "Subtraction operation between polynomials.";
+}
+
+def Polynomial_MulOp : Polynomial_BinaryOp<"mul", [Commutative]> {
+ let summary = "Multiplication operation between polynomials.";
+}
+
+def Polynomial_MulScalarOp : Polynomial_Op<"mul_scalar", [
+ ElementwiseMappable, AllTypesMatch<["polynomial", "output"]>]> {
+ let summary = "Multiplication by a scalar of the field.";
+
+ let arguments = (ins
+ PolynomialLike:$polynomial,
+ AnyInteger:$scalar
+ );
+
+ let results = (outs
+ PolynomialLike:$output
+ );
+
+ let assemblyFormat = "operands attr-dict `:` qualified(type($polynomial)) `,` type($scalar)";
+}
+
+def Polynomial_LeadingTermOp: Polynomial_Op<"leading_term"> {
+ let summary = "Compute the leading term of the polynomial.";
+ let description = [{
+ The degree of a polynomial is the largest $k$ for which the coefficient
+ $a_k$ of $x^k$ is nonzero. The leading term is the term $a_k x^k$, which
+ this op represents as a pair of results.
+ }];
+ let arguments = (ins Polynomial_PolynomialType:$input);
+ let results = (outs Index:$degree, AnyInteger:$coefficient);
+ let assemblyFormat = "operands attr-dict `:` qualified(type($input)) `->` `(` type($degree) `,` type($coefficient) `)`";
+}
+
+def Polynomial_MonomialOp: Polynomial_Op<"monomial"> {
+ let summary = "Create a polynomial that consists of a single monomial.";
+ let arguments = (ins AnyInteger:$coefficient, Index:$degree);
+ let results = (outs Polynomial_PolynomialType:$output);
+}
+
+def Polynomial_MonomialMulOp: Polynomial_Op<"monomial_mul", [AllTypesMatch<["input", "output"]>]> {
+ let summary = "Multiply a polynomial by a monic monomial.";
+ let description = [{
+ In the ring of polynomials mod $x^n - 1$, `monomial_mul` can be interpreted
+ as a cyclic shift of the coefficients of the polynomial. For some rings,
+ this results in optimized lowerings that involve rotations and rescaling
+ of the coefficients of the input.
+ }];
+ let arguments = (ins Polynomial_PolynomialType:$input, Index:$monomialDegree);
+ let results = (outs Polynomial_PolynomialType:$output);
+ let hasVerifier = 1;
+}
+
+def Polynomial_FromTensorOp : Polynomial_Op<"from_tensor", [Pure]> {
+ let summary = "Creates a polynomial from integer coefficients stored in a tensor.";
+ let description = [{
+ `polynomial.from_tensor` creates a polynomial value from a tensor of coefficients.
+ The input tensor must list the coefficients in degree-increasing order.
+
+ The input one-dimensional tensor may have size at most the degree of the
+ ring's ideal generator polynomial, with smaller dimension implying that
+ all higher-degree terms have coefficient zero.
+ }];
+ let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
+ let results = (outs Polynomial_PolynomialType:$output);
+
+ let assemblyFormat = "$input attr-dict `:` type($input) `->` qualified(type($output))";
+
+ let builders = [
+ // Builder that infers coefficient modulus from tensor bit width,
+ // and uses whatever input ring is provided by the caller.
+ OpBuilder<(ins "::mlir::Value":$input, "RingAttr":$ring)>
+ ];
+ let hasVerifier = 1;
+}
+
+def Polynomial_ToTensorOp : Polynomial_Op<"to_tensor", [Pure]> {
+ let summary = "Creates a tensor containing the coefficients of a polynomial.";
+ let description = [{
+ `polynomial.to_tensor` creates a tensor value containing the coefficients of the
+ input polynomial. The output tensor contains the coefficients in
+ degree-increasing order.
+
+ Operations that act on the coefficients of a polynomial, such as extracting
+ a specific coefficient or extracting a range of coefficients, should be
+ implemented by composing `to_tensor` with the relevant `tensor` dialect
+ ops.
+
+ The output tensor has shape equal to the degree of the ring's ideal
+ generator polynomial, including zeroes.
+ }];
+ let arguments = (ins Polynomial_PolynomialType:$input);
+ let results = (outs RankedTensorOf<[AnyInteger]>:$output);
+ let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
+
+ let hasVerifier = 1;
+}
- let assemblyFormat = "$lhs `,` $rhs attr-dict `:` qualified(type($result))";
+def Polynomial_ConstantOp : Polynomial_Op<"constant", [Pure]> {
+ let summary = "Define a constant polynomial via an attribute.";
+ let arguments = (ins Polynomial_PolynomialAttr:$input);
+ let results = (outs Polynomial_PolynomialType:$output);
+ let assemblyFormat = "$input attr-dict `:` qualified(type($output))";
}
#endif // POLYNOMIAL_OPS
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp
index a672a59b8a465d..04a56ed8812280 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp
@@ -8,10 +8,19 @@
#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/ADT/APInt.h"
#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
#include "mlir/Dialect/Polynomial/IR/PolynomialOps.h"
#include "mlir/Dialect/Polynomial/IR/PolynomialTypes.h"
-#include "llvm/ADT/TypeSwitch.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
using namespace mlir;
using namespace mlir::polynomial;
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 96c59a28b8fdce..2ad375a2081ebf 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -6,10 +6,90 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Polynomial/IR/PolynomialOps.h"
#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
+#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
+#include "mlir/Dialect/Polynomial/IR/PolynomialTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/APInt.h"
using namespace mlir;
using namespace mlir::polynomial;
-#define GET_OP_CLASSES
-#include "mlir/Dialect/Polynomial/IR/Polynomial.cpp.inc"
+void FromTensorOp::build(OpBuilder &builder, OperationState &result,
+ Value input, RingAttr ring) {
+ TensorType tensorType = dyn_cast<TensorType>(input.getType());
+ auto bitWidth = tensorType.getElementTypeBitWidth();
+ APInt cmod(1 + bitWidth, 1);
+ cmod = cmod << bitWidth;
+ Type resultType = PolynomialType::get(builder.getContext(), ring);
+ build(builder, result, resultType, input);
+}
+
+LogicalResult FromTensorOp::verify() {
+ auto tensorShape = getInput().getType().getShape();
+ auto ring = getOutput().getType().getRing();
+ auto polyDegree = ring.getPolynomialModulus().getPolynomial().getDegree();
+ bool compatible = tensorShape.size() == 1 && tensorShape[0] <= polyDegree;
+ if (!compatible) {
+ return emitOpError()
+ << "input type " << getInput().getType()
+ << " does not match output type " << getOutput().getType()
+ << ". The input type must be a tensor of shape [d] where d "
+ "is at most the degree of the polynomial generator of "
+ "the output ring's ideal.";
+ }
+
+ APInt coefficientModulus = ring.getCoefficientModulus().getValue();
+ unsigned cmodBitWidth = coefficientModulus.ceilLogBase2();
+ unsigned inputBitWidth = getInput().getType().getElementTypeBitWidth();
+
+ if (inputBitWidth > cmodBitWidth) {
+ return emitOpError() << "input tensor element type "
+ << getInput().getType().getElementType()
+ << " is too large to fit in the coefficients of "
+ << getOutput().getType()
+ << ". The input tensor's elements must be rescaled"
+ " to fit before using from_tensor.";
+ }
+
+ return success();
+}
+
+LogicalResult ToTensorOp::verify() {
+ auto tensorShape = getOutput().getType().getShape();
+ auto polyDegree = getInput()
+ .getType()
+ .getRing()
+ .getPolynomialModulus()
+ .getPolynomial()
+ .getDegree();
+ bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
+
+ return compatible
+ ? success()
+ : emitOpError()
+ << "input type " << getInput().getType()
+ << " does not match output type " << getOutput().getType()
+ << ". The input type must be a tensor of shape [d] where d "
+ "is exactly the degree of the polynomial generator of "
+ "the output ring's ideal.";
+}
+
+LogicalResult MonomialMulOp::verify() {
+ auto ring = getInput().getType().getRing();
+ auto idealTerms = ring.getPolynomialModulus().getPolynomial().getTerms();
+ bool compatible =
+ idealTerms.size() == 2 &&
+ (idealTerms[0].coefficient == -1 && idealTerms[0].exponent == 0) &&
+ (idealTerms[1].coefficient == 1);
+
+ return compatible ? success()
+ : emitOpError()
+ << "ring type " << ring
+ << " is not supported yet. The ring "
+ "must be of the form (x^n - 1) for some n";
+}
diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir
new file mode 100644
index 00000000000000..dae14471344ec3
--- /dev/null
+++ b/mlir/test/Dialect/Polynomial/ops.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+// This simply tests for syntax.
+
+#my_poly = #polynomial.polynomial<1 + x**1024>
+#my_poly_2 = #polynomial.polynomial<2>
+#my_poly_3 = #polynomial.polynomial<3x>
+#my_poly_4 = #polynomial.polynomial<t**3 + 4t + 2>
+#ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
+#one_plus_x_squared = #polynomial.polynomial<1 + x**2>
+
+#ideal = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=18, polynomialModulus=#ideal>
+!poly_ty = !polynomial.polynomial<#ring>
+
+module {
+ func.func @test_multiply() -> !polynomial.polynomial<#ring1> {
+ %c0 = arith.constant 0 : index
+ %two = arith.constant 2 : i16
+ %five = arith.constant 5 : i16
+ %coeffs1 = tensor.from_elements %two, %two, %five : tensor<3xi16>
+ %coeffs2 = tensor.from_elements %five, %five, %two : tensor<3xi16>
+
+ %poly1 = polynomial.from_tensor %coeffs1 : tensor<3xi16> -> !polynomial.polynomial<#ring1>
+ %poly2 = polynomial.from_tensor %coeffs2 : tensor<3xi16> -> !polynomial.polynomial<#ring1>
+
+ %3 = polynomial.mul %poly1, %poly2 : !polynomial.polynomial<#ring1>
+
+ return %3 : !polynomial.polynomial<#ring1>
+ }
+
+ func.func @test_elementwise(%p0 : !polynomial.polynomial<#ring1>, %p1: !polynomial.polynomial<#ring1>) {
+ %tp0 = tensor.from_elements %p0, %p1 : tensor<2x!polynomial.polynomial<#ring1>>
+ %tp1 = tensor.from_elements %p1, %p0 : tensor<2x!polynomial.polynomial<#ring1>>
+
+ %c = arith.constant 2 : i32
+ %mul_const_sclr = polynomial.mul_scalar %tp0, %c : tensor<2x!polynomial.polynomial<#ring1>>, i32
+
+ %add = polynomial.add %tp0, %tp1 : tensor<2x!polynomial.polynomial<#ring1>>
+ %sub = polynomial.sub %tp0, %tp1 : tensor<2x!polynomial.polynomial<#ring1>>
+ %mul = polynomial.mul %tp0, %tp1 : tensor<2x!polynomial.polynomial<#ring1>>
+
+ return
+ }
+
+ func.func @test_to_from_tensor(%p0 : !polynomial.polynomial<#ring1>) {
+ %c0 = arith.constant 0 : index
+ %two = arith.constant 2 : i16
+ %coeffs1 = tensor.from_elements %two, %two : tensor<2xi16>
+ // CHECK: from_tensor
+ %poly = polynomial.from_tensor %coeffs1 : tensor<2xi16> -> !polynomial.polynomial<#ring1>
+ // CHECK: to_tensor
+ %tensor = polynomial.to_tensor %poly : !polynomial.polynomial<#ring1> -> tensor<1024xi16>
+
+ return
+ }
+
+ func.func @test_degree(%p0 : !polynomial.polynomial<#ring1>) {
+ %0, %1 = polynomial.leading_term %p0 : !polynomial.polynomial<#ring1> -> (index, i32)
+ return
+ }
+
+ func.func @test_monomial() {
+ %deg = arith.constant 1023 : index
+ %five = arith.constant 5 : i16
+ %0 = polynomial.monomial %five, %deg : (i16, index) -> !polynomial.polynomial<#ring1>
+ return
+ }
+
+ func.func @test_constant() {
+ %0 = polynomial.constant #one_plus_x_squared : !polynomial.polynomial<#ring1>
+ %1 = polynomial.constant <1 + x**2> : !polynomial.polynomial<#ring1>
+ return
+ }
+}
diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
new file mode 100644
index 00000000000000..e536571eb8f3ee
--- /dev/null
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt --verify-diagnostics %s
+
+#my_poly = #polynomial.polynomial<1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
+module {
+ func.func @test_from_tensor_too_large_coeffs() {
+ %two = arith.constant 2 : i32
+ %coeffs1 = tensor.from_elements %two, %two : tensor<2xi32>
+ // expected-error@below {{is too large to fit in the coefficients}}
+ %poly = polynomial.from_tensor %coeffs1 : tensor<2xi32> -> !polynomial.polynomial<#ring>
+ return
+ }
+}
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
9a205a6
to
d7aceb1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with a bit more tests added.
Co-authored-by: Oleksandr "Alex" Zinenko <[email protected]>
Merging because the failing tests are accounted for by tests reverted (and un-reverted) in #90232 |
Adds a few basic polynomial ops.
mul
to standard MLIR)