Skip to content

Commit 3cc066c

Browse files
committed
fold add via constBinaryFold
1 parent 616f918 commit 3cc066c

File tree

6 files changed

+144
-3
lines changed

6 files changed

+144
-3
lines changed

mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,20 @@ class PolynomialBase {
175175
continue;
176176
}
177177

178+
while (it1->getExponent().ult(it2->getExponent())) {
179+
newTerms.emplace_back(*it1);
180+
it1++;
181+
if (it1 == terms.end())
182+
break;
183+
}
184+
185+
while (it2->getExponent().ult(it1->getExponent())) {
186+
newTerms.emplace_back(*it2);
187+
it2++;
188+
if (it2 == terms.end())
189+
break;
190+
}
191+
178192
newTerms.emplace_back(it1->add(*it2));
179193
it1++;
180194
it2++;

mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def Polynomial_Dialect : Dialect {
5353

5454
let useDefaultTypePrinterParser = 1;
5555
let useDefaultAttributePrinterParser = 1;
56+
let hasConstantMaterializer = 1;
5657
}
5758

5859
class Polynomial_Attr<string name, string attrMnemonic, list<Trait> traits = []>
@@ -83,6 +84,30 @@ def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynom
8384
let hasCustomAssemblyFormat = 1;
8485
}
8586

87+
def Polynomial_TypedIntPolynomialAttr : Polynomial_Attr<
88+
"TypedIntPolynomial", "typed_int_polynomial", [TypedAttrInterface]> {
89+
let summary = "A typed variant of int_polynomial for constant folding.";
90+
let parameters = (ins "::mlir::Type":$type, "::mlir::polynomial::IntPolynomialAttr":$value);
91+
let assemblyFormat = "`<` struct(params) `>`";
92+
let builders = [
93+
AttrBuilderWithInferredContext<(ins "Type":$type,
94+
"const IntPolynomial &":$value), [{
95+
return $_get(
96+
type.getContext(),
97+
type,
98+
IntPolynomialAttr::get(type.getContext(), value));
99+
}]>,
100+
AttrBuilderWithInferredContext<(ins "Type":$type,
101+
"const Attribute &":$value), [{
102+
return $_get(type.getContext(), type, ::llvm::cast<IntPolynomialAttr>(value));
103+
}]>
104+
];
105+
let extraClassDeclaration = [{
106+
// used for constFoldBinaryOp
107+
using ValueType = ::mlir::Attribute;
108+
}];
109+
}
110+
86111
def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_polynomial"> {
87112
let summary = "An attribute containing a single-variable polynomial with double precision floating point coefficients.";
88113
let description = [{
@@ -105,6 +130,30 @@ def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_p
105130
let hasCustomAssemblyFormat = 1;
106131
}
107132

133+
def Polynomial_TypedFloatPolynomialAttr : Polynomial_Attr<
134+
"TypedFloatPolynomial", "typed_float_polynomial", [TypedAttrInterface]> {
135+
let summary = "A typed variant of float_polynomial for constant folding.";
136+
let parameters = (ins "::mlir::Type":$type, "::mlir::polynomial::FloatPolynomialAttr":$value);
137+
let assemblyFormat = "`<` struct(params) `>`";
138+
let builders = [
139+
AttrBuilderWithInferredContext<(ins "Type":$type,
140+
"const FloatPolynomial &":$value), [{
141+
return $_get(
142+
type.getContext(),
143+
type,
144+
FloatPolynomialAttr::get(type.getContext(), value));
145+
}]>,
146+
AttrBuilderWithInferredContext<(ins "Type":$type,
147+
"const Attribute &":$value), [{
148+
return $_get(type.getContext(), type, ::llvm::cast<FloatPolynomialAttr>(value));
149+
}]>
150+
];
151+
let extraClassDeclaration = [{
152+
// used for constFoldBinaryOp
153+
using ValueType = ::mlir::Attribute;
154+
}];
155+
}
156+
108157
def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
109158
let summary = "An attribute specifying a polynomial ring.";
110159
let description = [{
@@ -221,6 +270,7 @@ def Polynomial_AddOp : Polynomial_BinaryOp<"add", [Commutative]> {
221270
%2 = polynomial.add %0, %1 : !polynomial.polynomial<#ring>
222271
```
223272
}];
273+
let hasFolder = 1;
224274
}
225275

226276
def Polynomial_SubOp : Polynomial_BinaryOp<"sub"> {
@@ -441,7 +491,7 @@ def Polynomial_AnyPolynomialAttr : AnyAttrOf<[
441491
]>;
442492

443493
// Not deriving from Polynomial_Op due to need for custom assembly format
444-
def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure]> {
494+
def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure, ConstantLike]> {
445495
let summary = "Define a constant polynomial via an attribute.";
446496
let description = [{
447497
Example:
@@ -458,6 +508,7 @@ def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure]> {
458508
let arguments = (ins Polynomial_AnyPolynomialAttr:$value);
459509
let results = (outs Polynomial_PolynomialType:$output);
460510
let assemblyFormat = "attr-dict `:` type($output)";
511+
let hasFolder = 1;
461512
}
462513

463514
def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {

mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,17 @@ void PolynomialDialect::initialize() {
4848
#include "mlir/Dialect/Polynomial/IR/Polynomial.cpp.inc"
4949
>();
5050
}
51+
52+
Operation *PolynomialDialect::materializeConstant(OpBuilder &builder,
53+
Attribute value, Type type,
54+
Location loc) {
55+
auto intPoly = dyn_cast<TypedIntPolynomialAttr>(value);
56+
auto floatPoly = dyn_cast<TypedFloatPolynomialAttr>(value);
57+
if (!intPoly && !floatPoly)
58+
return nullptr;
59+
60+
Type ty = intPoly ? intPoly.getType() : floatPoly.getType();
61+
Attribute valueAttr =
62+
intPoly ? (Attribute)intPoly.getValue() : (Attribute)floatPoly.getValue();
63+
return builder.create<ConstantOp>(loc, ty, valueAttr);
64+
}

mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Polynomial/IR/PolynomialOps.h"
10+
#include "mlir/Dialect/Arith/IR/Arith.h"
11+
#include "mlir/Dialect/CommonFolders.h"
1012
#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
1113
#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
1214
#include "mlir/Dialect/Polynomial/IR/PolynomialTypes.h"
1315
#include "mlir/IR/Builders.h"
16+
#include "mlir/IR/BuiltinAttributes.h"
1417
#include "mlir/IR/BuiltinTypes.h"
1518
#include "mlir/IR/Dialect.h"
1619
#include "mlir/Support/LogicalResult.h"
@@ -19,6 +22,41 @@
1922
using namespace mlir;
2023
using namespace mlir::polynomial;
2124

25+
OpFoldResult ConstantOp::fold(ConstantOp::FoldAdaptor adaptor) {
26+
PolynomialType ty = dyn_cast<PolynomialType>(getOutput().getType());
27+
28+
if (isa<FloatPolynomialAttr>(ty.getRing().getPolynomialModulus()))
29+
return TypedFloatPolynomialAttr::get(
30+
ty, cast<FloatPolynomialAttr>(getValue()).getPolynomial());
31+
32+
assert(isa<IntPolynomialAttr>(ty.getRing().getPolynomialModulus()) &&
33+
"expected float or integer polynomial");
34+
return TypedIntPolynomialAttr::get(
35+
ty, cast<IntPolynomialAttr>(getValue()).getPolynomial());
36+
}
37+
38+
OpFoldResult AddOp::fold(AddOp::FoldAdaptor adaptor) {
39+
auto lhsElements = dyn_cast<ShapedType>(getLhs().getType());
40+
PolynomialType elementType = cast<PolynomialType>(
41+
lhsElements ? lhsElements.getElementType() : getLhs().getType());
42+
MLIRContext *context = getContext();
43+
44+
if (isa<FloatType>(elementType.getRing().getCoefficientType()))
45+
return constFoldBinaryOp<TypedFloatPolynomialAttr>(
46+
adaptor.getOperands(), elementType, [&](Attribute a, const Attribute &b) {
47+
return FloatPolynomialAttr::get(
48+
context, cast<FloatPolynomialAttr>(a).getPolynomial().add(
49+
cast<FloatPolynomialAttr>(b).getPolynomial()));
50+
});
51+
52+
return constFoldBinaryOp<TypedIntPolynomialAttr>(
53+
adaptor.getOperands(), elementType, [&](Attribute a, const Attribute &b) {
54+
return IntPolynomialAttr::get(
55+
context, cast<IntPolynomialAttr>(a).getPolynomial().add(
56+
cast<IntPolynomialAttr>(b).getPolynomial()));
57+
});
58+
}
59+
2260
void FromTensorOp::build(OpBuilder &builder, OperationState &result,
2361
Value input, RingAttr ring) {
2462
TensorType tensorType = dyn_cast<TensorType>(input.getType());
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: mlir-opt --sccp --canonicalize %s | FileCheck %s
2+
3+
// Tests for folding
4+
5+
#my_poly = #polynomial.int_polynomial<1 + x**1024>
6+
#poly_3t = #polynomial.int_polynomial<3t>
7+
#poly_t3_plus_4t_plus_2 = #polynomial.int_polynomial<t**3 + 4t + 2>
8+
#modulus = #polynomial.int_polynomial<-1 + x**1024>
9+
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#modulus, primitiveRoot=193>
10+
!poly_ty = !polynomial.polynomial<ring=#ring>
11+
12+
// CHECK-LABEL: test_fold_add
13+
// CHECK-NEXT: polynomial.constant {value = #polynomial.int_polynomial<2 + 7x + x**3>}
14+
// CHECK-NEXT: return
15+
func.func @test_fold_add() -> !poly_ty {
16+
%0 = polynomial.constant {value=#poly_3t} : !poly_ty
17+
%1 = polynomial.constant {value=#poly_t3_plus_4t_plus_2} : !poly_ty
18+
%2 = polynomial.add %0, %1 : !poly_ty
19+
return %2 : !poly_ty
20+
}
21+
22+
// Test elementwise folding of add
23+
// Test float folding of add

mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ TEST(AddTest, checkSameDegreeAdditionOfIntPolynomial) {
2020
}
2121

2222
TEST(AddTest, checkDifferentDegreeAdditionOfIntPolynomial) {
23-
IntPolynomial x = IntPolynomial::fromCoefficients({1, 2});
23+
IntMonomial term2t = IntMonomial(2, 1);
24+
IntPolynomial x = IntPolynomial::fromMonomials({term2t}).value();
2425
IntPolynomial y = IntPolynomial::fromCoefficients({2, 3, 4});
25-
IntPolynomial expected = IntPolynomial::fromCoefficients({3, 5, 4});
26+
IntPolynomial expected = IntPolynomial::fromCoefficients({2, 5, 4});
2627
EXPECT_EQ(expected, x.add(y));
2728
EXPECT_EQ(expected, y.add(x));
2829
}

0 commit comments

Comments
 (0)