Skip to content

Commit e368675

Browse files
authored
[mlir][polynomial] implement add for polynomial data structure (#92169)
A change extracted from #91655, where I'm still trying to get the attributes working for elementwise constant folding of polynomial ops. This piece is self-contained. - use CRTP for base classes - Add unit test --------- Co-authored-by: Jeremy Kun <[email protected]>
1 parent 45daa4f commit e368675

File tree

4 files changed

+118
-18
lines changed

4 files changed

+118
-18
lines changed

mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h

Lines changed: 65 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ namespace polynomial {
3030
/// would want to specify 128-bit polynomials statically in the source code.
3131
constexpr unsigned apintBitWidth = 64;
3232

33-
template <typename CoefficientType>
33+
template <class Derived, typename CoefficientType>
3434
class MonomialBase {
3535
public:
3636
MonomialBase(const CoefficientType &coeff, const APInt &expo)
@@ -55,12 +55,21 @@ class MonomialBase {
5555
return (exponent.ult(other.exponent));
5656
}
5757

58+
Derived add(const Derived &other) {
59+
assert(exponent == other.exponent);
60+
CoefficientType newCoeff = coefficient + other.coefficient;
61+
Derived result;
62+
result.setCoefficient(newCoeff);
63+
result.setExponent(exponent);
64+
return result;
65+
}
66+
5867
virtual bool isMonic() const = 0;
5968
virtual void
6069
coefficientToString(llvm::SmallString<16> &coeffString) const = 0;
6170

62-
template <typename T>
63-
friend ::llvm::hash_code hash_value(const MonomialBase<T> &arg);
71+
template <class D, typename T>
72+
friend ::llvm::hash_code hash_value(const MonomialBase<D, T> &arg);
6473

6574
protected:
6675
CoefficientType coefficient;
@@ -69,15 +78,15 @@ class MonomialBase {
6978

7079
/// A class representing a monomial of a single-variable polynomial with integer
7180
/// coefficients.
72-
class IntMonomial : public MonomialBase<APInt> {
81+
class IntMonomial : public MonomialBase<IntMonomial, APInt> {
7382
public:
7483
IntMonomial(int64_t coeff, uint64_t expo)
7584
: MonomialBase(APInt(apintBitWidth, coeff), APInt(apintBitWidth, expo)) {}
7685

7786
IntMonomial()
7887
: MonomialBase(APInt(apintBitWidth, 0), APInt(apintBitWidth, 0)) {}
7988

80-
~IntMonomial() = default;
89+
~IntMonomial() override = default;
8190

8291
bool isMonic() const override { return coefficient == 1; }
8392

@@ -88,14 +97,14 @@ class IntMonomial : public MonomialBase<APInt> {
8897

8998
/// A class representing a monomial of a single-variable polynomial with integer
9099
/// coefficients.
91-
class FloatMonomial : public MonomialBase<APFloat> {
100+
class FloatMonomial : public MonomialBase<FloatMonomial, APFloat> {
92101
public:
93102
FloatMonomial(double coeff, uint64_t expo)
94103
: MonomialBase(APFloat(coeff), APInt(apintBitWidth, expo)) {}
95104

96105
FloatMonomial() : MonomialBase(APFloat((double)0), APInt(apintBitWidth, 0)) {}
97106

98-
~FloatMonomial() = default;
107+
~FloatMonomial() override = default;
99108

100109
bool isMonic() const override { return coefficient == APFloat(1.0); }
101110

@@ -104,7 +113,7 @@ class FloatMonomial : public MonomialBase<APFloat> {
104113
}
105114
};
106115

107-
template <typename Monomial>
116+
template <class Derived, typename Monomial>
108117
class PolynomialBase {
109118
public:
110119
PolynomialBase() = delete;
@@ -149,6 +158,44 @@ class PolynomialBase {
149158
}
150159
}
151160

161+
Derived add(const Derived &other) {
162+
SmallVector<Monomial> newTerms;
163+
auto it1 = terms.begin();
164+
auto it2 = other.terms.begin();
165+
while (it1 != terms.end() || it2 != other.terms.end()) {
166+
if (it1 == terms.end()) {
167+
newTerms.emplace_back(*it2);
168+
it2++;
169+
continue;
170+
}
171+
172+
if (it2 == other.terms.end()) {
173+
newTerms.emplace_back(*it1);
174+
it1++;
175+
continue;
176+
}
177+
178+
while (it1->getExponent().ult(it2->getExponent())) {
179+
newTerms.emplace_back(*it1);
180+
it1++;
181+
if (it1 == terms.end())
182+
break;
183+
}
184+
185+
while (it2->getExponent().ult(it1->getExponent())) {
186+
newTerms.emplace_back(*it2);
187+
it2++;
188+
if (it2 == terms.end())
189+
break;
190+
}
191+
192+
newTerms.emplace_back(it1->add(*it2));
193+
it1++;
194+
it2++;
195+
}
196+
return Derived(newTerms);
197+
}
198+
152199
// Prints polynomial to 'os'.
153200
void print(raw_ostream &os) const { print(os, " + ", "**"); }
154201

@@ -168,8 +215,8 @@ class PolynomialBase {
168215

169216
ArrayRef<Monomial> getTerms() const { return terms; }
170217

171-
template <typename T>
172-
friend ::llvm::hash_code hash_value(const PolynomialBase<T> &arg);
218+
template <class D, typename T>
219+
friend ::llvm::hash_code hash_value(const PolynomialBase<D, T> &arg);
173220

174221
private:
175222
// The monomial terms for this polynomial.
@@ -179,7 +226,7 @@ class PolynomialBase {
179226
/// A single-variable polynomial with integer coefficients.
180227
///
181228
/// Eg: x^1024 + x + 1
182-
class IntPolynomial : public PolynomialBase<IntMonomial> {
229+
class IntPolynomial : public PolynomialBase<IntPolynomial, IntMonomial> {
183230
public:
184231
explicit IntPolynomial(ArrayRef<IntMonomial> terms) : PolynomialBase(terms) {}
185232

@@ -196,7 +243,7 @@ class IntPolynomial : public PolynomialBase<IntMonomial> {
196243
/// A single-variable polynomial with double coefficients.
197244
///
198245
/// Eg: 1.0 x^1024 + 3.5 x + 1e-05
199-
class FloatPolynomial : public PolynomialBase<FloatMonomial> {
246+
class FloatPolynomial : public PolynomialBase<FloatPolynomial, FloatMonomial> {
200247
public:
201248
explicit FloatPolynomial(ArrayRef<FloatMonomial> terms)
202249
: PolynomialBase(terms) {}
@@ -212,20 +259,20 @@ class FloatPolynomial : public PolynomialBase<FloatMonomial> {
212259
};
213260

214261
// Make Polynomials hashable.
215-
template <typename T>
216-
inline ::llvm::hash_code hash_value(const PolynomialBase<T> &arg) {
262+
template <class D, typename T>
263+
inline ::llvm::hash_code hash_value(const PolynomialBase<D, T> &arg) {
217264
return ::llvm::hash_combine_range(arg.terms.begin(), arg.terms.end());
218265
}
219266

220-
template <typename T>
221-
inline ::llvm::hash_code hash_value(const MonomialBase<T> &arg) {
267+
template <class D, typename T>
268+
inline ::llvm::hash_code hash_value(const MonomialBase<D, T> &arg) {
222269
return llvm::hash_combine(::llvm::hash_value(arg.coefficient),
223270
::llvm::hash_value(arg.exponent));
224271
}
225272

226-
template <typename T>
273+
template <class D, typename T>
227274
inline raw_ostream &operator<<(raw_ostream &os,
228-
const PolynomialBase<T> &polynomial) {
275+
const PolynomialBase<D, T> &polynomial) {
229276
polynomial.print(os);
230277
return os;
231278
}

mlir/unittests/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_subdirectory(Index)
1111
add_subdirectory(LLVMIR)
1212
add_subdirectory(MemRef)
1313
add_subdirectory(OpenACC)
14+
add_subdirectory(Polynomial)
1415
add_subdirectory(SCF)
1516
add_subdirectory(SparseTensor)
1617
add_subdirectory(SPIRV)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
add_mlir_unittest(MLIRPolynomialTests
2+
PolynomialMathTest.cpp
3+
)
4+
target_link_libraries(MLIRPolynomialTests
5+
PRIVATE
6+
MLIRIR
7+
MLIRPolynomialDialect
8+
)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
//===- PolynomialMathTest.cpp - Polynomial math Tests ---------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
10+
#include "gtest/gtest.h"
11+
12+
using namespace mlir;
13+
using namespace mlir::polynomial;
14+
15+
TEST(AddTest, checkSameDegreeAdditionOfIntPolynomial) {
16+
IntPolynomial x = IntPolynomial::fromCoefficients({1, 2, 3});
17+
IntPolynomial y = IntPolynomial::fromCoefficients({2, 3, 4});
18+
IntPolynomial expected = IntPolynomial::fromCoefficients({3, 5, 7});
19+
EXPECT_EQ(expected, x.add(y));
20+
}
21+
22+
TEST(AddTest, checkDifferentDegreeAdditionOfIntPolynomial) {
23+
IntMonomial term2t = IntMonomial(2, 1);
24+
IntPolynomial x = IntPolynomial::fromMonomials({term2t}).value();
25+
IntPolynomial y = IntPolynomial::fromCoefficients({2, 3, 4});
26+
IntPolynomial expected = IntPolynomial::fromCoefficients({2, 5, 4});
27+
EXPECT_EQ(expected, x.add(y));
28+
EXPECT_EQ(expected, y.add(x));
29+
}
30+
31+
TEST(AddTest, checkSameDegreeAdditionOfFloatPolynomial) {
32+
FloatPolynomial x = FloatPolynomial::fromCoefficients({1.5, 2.5, 3.5});
33+
FloatPolynomial y = FloatPolynomial::fromCoefficients({2.5, 3.5, 4.5});
34+
FloatPolynomial expected = FloatPolynomial::fromCoefficients({4, 6, 8});
35+
EXPECT_EQ(expected, x.add(y));
36+
}
37+
38+
TEST(AddTest, checkDifferentDegreeAdditionOfFloatPolynomial) {
39+
FloatPolynomial x = FloatPolynomial::fromCoefficients({1.5, 2.5});
40+
FloatPolynomial y = FloatPolynomial::fromCoefficients({2.5, 3.5, 4.5});
41+
FloatPolynomial expected = FloatPolynomial::fromCoefficients({4, 6, 4.5});
42+
EXPECT_EQ(expected, x.add(y));
43+
EXPECT_EQ(expected, y.add(x));
44+
}

0 commit comments

Comments
 (0)