Skip to content

Commit 616f918

Browse files
committed
implement add for polynomial data structure
- use CRTP for base classes - Add unit test
1 parent 65cbc36 commit 616f918

File tree

4 files changed

+103
-18
lines changed

4 files changed

+103
-18
lines changed

mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ namespace polynomial {
3030
/// would want to specify 128-bit polynomials statically in the source code.
3131
constexpr unsigned apintBitWidth = 64;
3232

33-
template <typename CoefficientType>
33+
template <class Derived, typename CoefficientType>
3434
class MonomialBase {
3535
public:
3636
MonomialBase(const CoefficientType &coeff, const APInt &expo)
@@ -55,12 +55,21 @@ class MonomialBase {
5555
return (exponent.ult(other.exponent));
5656
}
5757

58+
Derived add(const Derived &other) {
59+
assert(exponent == other.exponent);
60+
CoefficientType newCoeff = coefficient + other.coefficient;
61+
Derived result;
62+
result.setCoefficient(newCoeff);
63+
result.setExponent(exponent);
64+
return result;
65+
}
66+
5867
virtual bool isMonic() const = 0;
5968
virtual void
6069
coefficientToString(llvm::SmallString<16> &coeffString) const = 0;
6170

62-
template <typename T>
63-
friend ::llvm::hash_code hash_value(const MonomialBase<T> &arg);
71+
template <class D, typename T>
72+
friend ::llvm::hash_code hash_value(const MonomialBase<D, T> &arg);
6473

6574
protected:
6675
CoefficientType coefficient;
@@ -69,15 +78,15 @@ class MonomialBase {
6978

7079
/// A class representing a monomial of a single-variable polynomial with integer
7180
/// coefficients.
72-
class IntMonomial : public MonomialBase<APInt> {
81+
class IntMonomial : public MonomialBase<IntMonomial, APInt> {
7382
public:
7483
IntMonomial(int64_t coeff, uint64_t expo)
7584
: MonomialBase(APInt(apintBitWidth, coeff), APInt(apintBitWidth, expo)) {}
7685

7786
IntMonomial()
7887
: MonomialBase(APInt(apintBitWidth, 0), APInt(apintBitWidth, 0)) {}
7988

80-
~IntMonomial() = default;
89+
~IntMonomial() override = default;
8190

8291
bool isMonic() const override { return coefficient == 1; }
8392

@@ -88,14 +97,14 @@ class IntMonomial : public MonomialBase<APInt> {
8897

8998
/// A class representing a monomial of a single-variable polynomial with integer
9099
/// coefficients.
91-
class FloatMonomial : public MonomialBase<APFloat> {
100+
class FloatMonomial : public MonomialBase<FloatMonomial, APFloat> {
92101
public:
93102
FloatMonomial(double coeff, uint64_t expo)
94103
: MonomialBase(APFloat(coeff), APInt(apintBitWidth, expo)) {}
95104

96105
FloatMonomial() : MonomialBase(APFloat((double)0), APInt(apintBitWidth, 0)) {}
97106

98-
~FloatMonomial() = default;
107+
~FloatMonomial() override = default;
99108

100109
bool isMonic() const override { return coefficient == APFloat(1.0); }
101110

@@ -104,7 +113,7 @@ class FloatMonomial : public MonomialBase<APFloat> {
104113
}
105114
};
106115

107-
template <typename Monomial>
116+
template <class Derived, typename Monomial>
108117
class PolynomialBase {
109118
public:
110119
PolynomialBase() = delete;
@@ -149,6 +158,30 @@ class PolynomialBase {
149158
}
150159
}
151160

161+
Derived add(const Derived &other) {
162+
SmallVector<Monomial> newTerms;
163+
auto it1 = terms.begin();
164+
auto it2 = other.terms.begin();
165+
while (it1 != terms.end() || it2 != other.terms.end()) {
166+
if (it1 == terms.end()) {
167+
newTerms.emplace_back(*it2);
168+
it2++;
169+
continue;
170+
}
171+
172+
if (it2 == other.terms.end()) {
173+
newTerms.emplace_back(*it1);
174+
it1++;
175+
continue;
176+
}
177+
178+
newTerms.emplace_back(it1->add(*it2));
179+
it1++;
180+
it2++;
181+
}
182+
return Derived(newTerms);
183+
}
184+
152185
// Prints polynomial to 'os'.
153186
void print(raw_ostream &os) const { print(os, " + ", "**"); }
154187

@@ -168,8 +201,8 @@ class PolynomialBase {
168201

169202
ArrayRef<Monomial> getTerms() const { return terms; }
170203

171-
template <typename T>
172-
friend ::llvm::hash_code hash_value(const PolynomialBase<T> &arg);
204+
template <class D, typename T>
205+
friend ::llvm::hash_code hash_value(const PolynomialBase<D, T> &arg);
173206

174207
private:
175208
// The monomial terms for this polynomial.
@@ -179,7 +212,7 @@ class PolynomialBase {
179212
/// A single-variable polynomial with integer coefficients.
180213
///
181214
/// Eg: x^1024 + x + 1
182-
class IntPolynomial : public PolynomialBase<IntMonomial> {
215+
class IntPolynomial : public PolynomialBase<IntPolynomial, IntMonomial> {
183216
public:
184217
explicit IntPolynomial(ArrayRef<IntMonomial> terms) : PolynomialBase(terms) {}
185218

@@ -196,7 +229,7 @@ class IntPolynomial : public PolynomialBase<IntMonomial> {
196229
/// A single-variable polynomial with double coefficients.
197230
///
198231
/// Eg: 1.0 x^1024 + 3.5 x + 1e-05
199-
class FloatPolynomial : public PolynomialBase<FloatMonomial> {
232+
class FloatPolynomial : public PolynomialBase<FloatPolynomial, FloatMonomial> {
200233
public:
201234
explicit FloatPolynomial(ArrayRef<FloatMonomial> terms)
202235
: PolynomialBase(terms) {}
@@ -212,20 +245,20 @@ class FloatPolynomial : public PolynomialBase<FloatMonomial> {
212245
};
213246

214247
// Make Polynomials hashable.
215-
template <typename T>
216-
inline ::llvm::hash_code hash_value(const PolynomialBase<T> &arg) {
248+
template <class D, typename T>
249+
inline ::llvm::hash_code hash_value(const PolynomialBase<D, T> &arg) {
217250
return ::llvm::hash_combine_range(arg.terms.begin(), arg.terms.end());
218251
}
219252

220-
template <typename T>
221-
inline ::llvm::hash_code hash_value(const MonomialBase<T> &arg) {
253+
template <class D, typename T>
254+
inline ::llvm::hash_code hash_value(const MonomialBase<D, T> &arg) {
222255
return llvm::hash_combine(::llvm::hash_value(arg.coefficient),
223256
::llvm::hash_value(arg.exponent));
224257
}
225258

226-
template <typename T>
259+
template <class D, typename T>
227260
inline raw_ostream &operator<<(raw_ostream &os,
228-
const PolynomialBase<T> &polynomial) {
261+
const PolynomialBase<D, T> &polynomial) {
229262
polynomial.print(os);
230263
return os;
231264
}

mlir/unittests/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_subdirectory(Index)
1111
add_subdirectory(LLVMIR)
1212
add_subdirectory(MemRef)
1313
add_subdirectory(OpenACC)
14+
add_subdirectory(Polynomial)
1415
add_subdirectory(SCF)
1516
add_subdirectory(SparseTensor)
1617
add_subdirectory(SPIRV)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
add_mlir_unittest(MLIRPolynomialTests
2+
PolynomialMathTest.cpp
3+
)
4+
target_link_libraries(MLIRPolynomialTests
5+
PRIVATE
6+
MLIRIR
7+
MLIRPolynomialDialect
8+
)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
//===- PolynomialMathTest.cpp - Polynomial math Tests ---------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
10+
#include "gtest/gtest.h"
11+
12+
using namespace mlir;
13+
using namespace mlir::polynomial;
14+
15+
TEST(AddTest, checkSameDegreeAdditionOfIntPolynomial) {
16+
IntPolynomial x = IntPolynomial::fromCoefficients({1, 2, 3});
17+
IntPolynomial y = IntPolynomial::fromCoefficients({2, 3, 4});
18+
IntPolynomial expected = IntPolynomial::fromCoefficients({3, 5, 7});
19+
EXPECT_EQ(expected, x.add(y));
20+
}
21+
22+
TEST(AddTest, checkDifferentDegreeAdditionOfIntPolynomial) {
23+
IntPolynomial x = IntPolynomial::fromCoefficients({1, 2});
24+
IntPolynomial y = IntPolynomial::fromCoefficients({2, 3, 4});
25+
IntPolynomial expected = IntPolynomial::fromCoefficients({3, 5, 4});
26+
EXPECT_EQ(expected, x.add(y));
27+
EXPECT_EQ(expected, y.add(x));
28+
}
29+
30+
TEST(AddTest, checkSameDegreeAdditionOfFloatPolynomial) {
31+
FloatPolynomial x = FloatPolynomial::fromCoefficients({1.5, 2.5, 3.5});
32+
FloatPolynomial y = FloatPolynomial::fromCoefficients({2.5, 3.5, 4.5});
33+
FloatPolynomial expected = FloatPolynomial::fromCoefficients({4, 6, 8});
34+
EXPECT_EQ(expected, x.add(y));
35+
}
36+
37+
TEST(AddTest, checkDifferentDegreeAdditionOfFloatPolynomial) {
38+
FloatPolynomial x = FloatPolynomial::fromCoefficients({1.5, 2.5});
39+
FloatPolynomial y = FloatPolynomial::fromCoefficients({2.5, 3.5, 4.5});
40+
FloatPolynomial expected = FloatPolynomial::fromCoefficients({4, 6, 4.5});
41+
EXPECT_EQ(expected, x.add(y));
42+
EXPECT_EQ(expected, y.add(x));
43+
}

0 commit comments

Comments
 (0)