Skip to content

[polynomial] Move primitive root attribute to ntt/intt ops. #93227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,6 @@ def Polynomial_AnyTypedPolynomialAttr : AnyAttrOf<[
Polynomial_TypedIntPolynomialAttr
]>;

// Not deriving from Polynomial_Op due to need for custom assembly format
def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant",
[Pure, InferTypeOpAdaptor]> {
let summary = "Define a constant polynomial via an attribute.";
Expand Down Expand Up @@ -312,9 +311,12 @@ def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {

`f[k] = F(omega[n]^k) ; k = {0, ..., n-1}`

The choice of primitive root is determined by subsequent lowerings.
The choice of primitive root may be optionally specified.
}];
let arguments = (ins Polynomial_PolynomialType:$input);
let arguments = (ins
Polynomial_PolynomialType:$input,
OptionalAttr<Polynomial_PrimitiveRootAttr>:$root
);
let results = (outs RankedTensorOf<[AnyInteger]>:$output);
let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
let hasCanonicalizer = 1;
Expand All @@ -332,8 +334,13 @@ def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> {
output polynomial at powers of a primitive `n`-th root of unity (see
`polynomial.ntt`). The ring of the polynomial is taken from the required
encoding attribute of the tensor.

The choice of primitive root may be optionally specified.
}];
let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
let arguments = (
ins RankedTensorOf<[AnyInteger]>:$input,
OptionalAttr<Polynomial_PrimitiveRootAttr>:$root
);
let results = (outs Polynomial_PolynomialType:$output);
let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
let hasCanonicalizer = 1;
Expand Down
33 changes: 27 additions & 6 deletions mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -166,24 +166,45 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
let parameters = (ins
"Type": $coefficientType,
OptionalParameter<"::mlir::IntegerAttr">: $coefficientModulus,
OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus,
OptionalParameter<"::mlir::IntegerAttr">: $primitiveRoot
OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus
);
let assemblyFormat = "`<` struct(params) `>`";
let builders = [
AttrBuilderWithInferredContext<
(ins "::mlir::Type":$coefficientTy,
CArg<"::mlir::IntegerAttr", "nullptr"> :$coefficientModulusAttr,
CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr,
CArg<"::mlir::IntegerAttr", "nullptr"> :$primitiveRootAttr), [{
CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr), [{
return $_get(
coefficientTy.getContext(),
coefficientTy,
coefficientModulusAttr,
polynomialModulusAttr,
primitiveRootAttr);
polynomialModulusAttr);
}]>,
];
}

def Polynomial_PrimitiveRootAttr: Polynomial_Attr<"PrimitiveRoot", "primitive_root"> {
let summary = "an attribute containing an integer and its degree as a root of unity";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if mlir docs can render latex

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like yes:
image

let description = [{
A primitive root attribute stores an integer root `value` and an integer
`degree`, corresponding to a primitive root of unity of the given degree in
an unspecified ring.

This is used as an attribute on `polynomial.ntt` and `polynomial.intt` ops
to specify the root of unity used in lowering the transform.

Example:

```mlir
#poly = #polynomial.primitive_root<value=123 : i32, degree : 7 index>
```
}];
let parameters = (ins
"::mlir::IntegerAttr":$value,
"::mlir::IntegerAttr":$degree
);
let assemblyFormat = "`<` struct(params) `>`";
}


#endif // POLYNOMIAL_ATTRIBUTES
42 changes: 22 additions & 20 deletions mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ include "mlir/IR/PatternBase.td"

defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;

def Equal : Constraint<CPred<"$0 == $1">>;

// Get a -1 integer attribute of the same type as the polynomial SSA value's
// ring coefficient type.
def getMinusOne
Expand All @@ -31,51 +33,51 @@ def SubAsAdd : Pat<
(Arith_ConstantOp (getMinusOne $g))))>;

def INTTAfterNTT : Pat<
(Polynomial_INTTOp (Polynomial_NTTOp $poly)),
(Polynomial_INTTOp (Polynomial_NTTOp $poly, $r1), $r2),
(replaceWithValue $poly),
[]
[(Equal $r1, $r2)]
>;

def NTTAfterINTT : Pat<
(Polynomial_NTTOp (Polynomial_INTTOp $tensor)),
(Polynomial_NTTOp (Polynomial_INTTOp $tensor, $r1), $r2),
(replaceWithValue $tensor),
[]
[(Equal $r1, $r2)]
>;

// NTTs are expensive, and addition in coefficient or NTT domain should be
// equivalently expensive, so reducing the number of NTTs is optimal.
// ntt(a) + ntt(b) -> ntt(a + b)
def NTTOfAdd : Pat<
(Arith_AddIOp
(Polynomial_NTTOp $p1),
(Polynomial_NTTOp $p2),
(Polynomial_NTTOp $p1, $r1),
(Polynomial_NTTOp $p2, $r2),
$overflow),
(Polynomial_NTTOp (Polynomial_AddOp $p1, $p2)),
[]
(Polynomial_NTTOp (Polynomial_AddOp $p1, $p2), $r1),
[(Equal $r1, $r2)]
>;
// intt(a) + intt(b) -> intt(a + b)
def INTTOfAdd : Pat<
(Polynomial_AddOp
(Polynomial_INTTOp $t1),
(Polynomial_INTTOp $t2)),
(Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow)),
[]
(Polynomial_INTTOp $t1, $r1),
(Polynomial_INTTOp $t2, $r2)),
(Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow), $r1),
[(Equal $r1, $r2)]
>;
// repeated for sub
def NTTOfSub : Pat<
(Arith_SubIOp
(Polynomial_NTTOp $p1),
(Polynomial_NTTOp $p2),
(Polynomial_NTTOp $p1, $r1),
(Polynomial_NTTOp $p2, $r2),
$overflow),
(Polynomial_NTTOp (Polynomial_SubOp $p1, $p2)),
[]
(Polynomial_NTTOp (Polynomial_SubOp $p1, $p2), $r1),
[(Equal $r1, $r2)]
>;
def INTTOfSub : Pat<
(Polynomial_SubOp
(Polynomial_INTTOp $t1),
(Polynomial_INTTOp $t2)),
(Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow)),
[]
(Polynomial_INTTOp $t1, $r1),
(Polynomial_INTTOp $t2, $r2)),
(Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow), $r1),
[(Equal $r1, $r2)]
>;

#endif // POLYNOMIAL_CANONICALIZATION
41 changes: 20 additions & 21 deletions mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,15 @@ LogicalResult MulScalarOp::verify() {
}

/// Test if a value is a primitive nth root of unity modulo cmod.
bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n,
const APInt &cmod) {
// Root bitwidth may be 1 less then cmod.
APInt r = APInt(root).zext(cmod.getBitWidth());
assert(r.ule(cmod) && "root must be less than cmod");
unsigned upperBound = n.getZExtValue();

APInt a = r;
for (size_t k = 1; k < n; k++) {
for (size_t k = 1; k < upperBound; k++) {
if (a.isOne())
return false;
a = (a * r).urem(cmod);
Expand All @@ -126,7 +127,8 @@ bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
/// Verify that the types involved in an NTT or INTT operation are
/// compatible.
static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
RankedTensorType tensorType) {
RankedTensorType tensorType,
std::optional<PrimitiveRootAttr> root) {
Attribute encoding = tensorType.getEncoding();
if (!encoding) {
return op->emitOpError()
Expand Down Expand Up @@ -157,33 +159,30 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
return diag;
}

if (!ring.getPrimitiveRoot()) {
return op->emitOpError()
<< "ring type " << ring << " does not provide a primitive root "
<< "of unity, which is required to express an NTT";
}

if (!isPrimitiveNthRootOfUnity(ring.getPrimitiveRoot().getValue(), polyDegree,
ring.getCoefficientModulus().getValue())) {
return op->emitOpError()
<< "ring type " << ring << " has a primitiveRoot attribute '"
<< ring.getPrimitiveRoot()
<< "' that is not a primitive root of the coefficient ring";
if (root.has_value()) {
APInt rootValue = root.value().getValue().getValue();
APInt rootDegree = root.value().getDegree().getValue();
APInt cmod = ring.getCoefficientModulus().getValue();
if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove {, } for single statement if/fors

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it's questionable convention in the case of nested conditionals? up to you.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is indeed questionable, I usually lean towards having braces for multi-line bodies even if there is a single statement.

return op->emitOpError()
<< "provided root " << rootValue.getZExtValue()
<< " is not a primitive root "
<< "of unity mod " << cmod.getZExtValue()
<< ", with the specified degree " << rootDegree.getZExtValue();
}
}

return success();
}

LogicalResult NTTOp::verify() {
auto ring = getInput().getType().getRing();
auto tensorType = getOutput().getType();
return verifyNTTOp(this->getOperation(), ring, tensorType);
return verifyNTTOp(this->getOperation(), getInput().getType().getRing(),
getOutput().getType(), getRoot());
}

LogicalResult INTTOp::verify() {
auto tensorType = getInput().getType();
auto ring = getOutput().getType().getRing();
return verifyNTTOp(this->getOperation(), ring, tensorType);
return verifyNTTOp(this->getOperation(), getOutput().getType().getRing(),
getInput().getType(), getRoot());
}

ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
Expand Down
49 changes: 32 additions & 17 deletions mlir/test/Dialect/Polynomial/canonicalization.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// RUN: mlir-opt -canonicalize %s | FileCheck %s
#ntt_poly = #polynomial.int_polynomial<-1 + x**8>
#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
#root = #polynomial.primitive_root<value=31:i32, degree=8:index>
!ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
!tensor_ty = tensor<8xi32, #ntt_ring>

Expand All @@ -10,8 +11,8 @@ func.func @test_canonicalize_intt_after_ntt(%p0 : !ntt_poly_ty) -> !ntt_poly_ty
// CHECK-NOT: polynomial.ntt
// CHECK-NOT: polynomial.intt
// CHECK: %[[RESULT:.+]] = polynomial.add %[[P]], %[[P]] : [[T]]
%t0 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
%p1 = polynomial.intt %t0: !tensor_ty -> !ntt_poly_ty
%t0 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty
%p1 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty
%p2 = polynomial.add %p1, %p1 : !ntt_poly_ty
// CHECK: return %[[RESULT]] : [[T]]
return %p2 : !ntt_poly_ty
Expand All @@ -23,8 +24,8 @@ func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty {
// CHECK-NOT: polynomial.intt
// CHECK-NOT: polynomial.ntt
// CHECK: %[[RESULT:.+]] = arith.addi %[[X]], %[[X]] : [[T]]
%p0 = polynomial.intt %t0 : !tensor_ty -> !ntt_poly_ty
%t1 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
%p0 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty
%t1 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty
%t2 = arith.addi %t1, %t1 : !tensor_ty
// CHECK: return %[[RESULT]] : [[T]]
return %t2 : !tensor_ty
Expand All @@ -51,10 +52,10 @@ func.func @test_canonicalize_sub(%poly0 : !sub_ty, %poly1 : !sub_ty) -> !sub_ty
func.func @test_canonicalize_fold_add_through_ntt(
%poly0 : !ntt_poly_ty,
%poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
%0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
%1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
%0 = polynomial.ntt %poly0 {root=#root} : !ntt_poly_ty -> !tensor_ty
%1 = polynomial.ntt %poly1 {root=#root} : !ntt_poly_ty -> !tensor_ty
%a_plus_b = arith.addi %0, %1 : !tensor_ty
%out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
%out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
return %out : !ntt_poly_ty
}

Expand All @@ -65,10 +66,10 @@ func.func @test_canonicalize_fold_add_through_ntt(
func.func @test_canonicalize_fold_add_through_intt(
%tensor0 : !tensor_ty,
%tensor1 : !tensor_ty) -> !tensor_ty {
%0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
%1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
%0 = polynomial.intt %tensor0 {root=#root} : !tensor_ty -> !ntt_poly_ty
%1 = polynomial.intt %tensor1 {root=#root} : !tensor_ty -> !ntt_poly_ty
%a_plus_b = polynomial.add %0, %1 : !ntt_poly_ty
%out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
%out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
return %out : !tensor_ty
}

Expand All @@ -80,10 +81,10 @@ func.func @test_canonicalize_fold_add_through_intt(
func.func @test_canonicalize_fold_sub_through_ntt(
%poly0 : !ntt_poly_ty,
%poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
%0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
%1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
%0 = polynomial.ntt %poly0 {root=#root} : !ntt_poly_ty -> !tensor_ty
%1 = polynomial.ntt %poly1 {root=#root} : !ntt_poly_ty -> !tensor_ty
%a_plus_b = arith.subi %0, %1 : !tensor_ty
%out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
%out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
return %out : !ntt_poly_ty
}

Expand All @@ -94,9 +95,23 @@ func.func @test_canonicalize_fold_sub_through_ntt(
func.func @test_canonicalize_fold_sub_through_intt(
%tensor0 : !tensor_ty,
%tensor1 : !tensor_ty) -> !tensor_ty {
%0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
%1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
%0 = polynomial.intt %tensor0 {root=#root} : !tensor_ty -> !ntt_poly_ty
%1 = polynomial.intt %tensor1 {root=#root} : !tensor_ty -> !ntt_poly_ty
%a_plus_b = polynomial.sub %0, %1 : !ntt_poly_ty
%out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
%out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
return %out : !tensor_ty
}


// CHECK-LABEL: test_canonicalize_do_not_fold_different_roots
// CHECK: arith.addi
func.func @test_canonicalize_do_not_fold_different_roots(
%poly0 : !ntt_poly_ty,
%poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
%0 = polynomial.ntt %poly0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !ntt_poly_ty -> !tensor_ty
%1 = polynomial.ntt %poly1 {root=#polynomial.primitive_root<value=33:i32, degree=8:index>} : !ntt_poly_ty -> !tensor_ty
%a_plus_b = arith.addi %0, %1 : !tensor_ty
%out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
return %out : !ntt_poly_ty
}

8 changes: 4 additions & 4 deletions mlir/test/Dialect/Polynomial/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
#one_plus_x_squared = #polynomial.int_polynomial<1 + x**2>

#ideal = #polynomial.int_polynomial<-1 + x**1024>
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal, primitiveRoot=193>
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal>
!poly_ty = !polynomial.polynomial<ring=#ring>

#ntt_poly = #polynomial.int_polynomial<-1 + x**8>
#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
!ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>

module {
Expand Down Expand Up @@ -91,12 +91,12 @@ module {
}

func.func @test_ntt(%0 : !ntt_poly_ty) {
%1 = polynomial.ntt %0 : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
%1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
return
}

func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) {
%1 = polynomial.intt %0 : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
%1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
return
}
}
Loading
Loading