Skip to content

Commit dc88dba

Browse files
committed
[mlir][math] Add Polynomial Approximation for acosh, asinh, atanh, cosh, sinh ops
1 parent 465807e commit dc88dba

File tree

2 files changed

+361
-9
lines changed

2 files changed

+361
-9
lines changed

mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp

Lines changed: 155 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -615,10 +615,47 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
615615
return success();
616616
}
617617

618-
#define LN2_VALUE \
619-
0.693147180559945309417232121458176568075500134360255254120680009493393621L
620-
#define LOG2E_VALUE \
621-
1.442695040888963407359924681001892137426645954152985934135449406931109219L
618+
//----------------------------------------------------------------------------//
619+
// AtanhOp approximation.
620+
//----------------------------------------------------------------------------//
621+
622+
namespace {
623+
struct AtanhApproximation : public OpRewritePattern<math::AtanhOp> {
624+
public:
625+
using OpRewritePattern::OpRewritePattern;
626+
627+
LogicalResult matchAndRewrite(math::AtanhOp op,
628+
PatternRewriter &rewriter) const final;
629+
};
630+
} // namespace
631+
632+
LogicalResult
633+
AtanhApproximation::matchAndRewrite(math::AtanhOp op,
634+
PatternRewriter &rewriter) const {
635+
if (!getElementTypeOrSelf(op.getOperand()).isF32())
636+
return rewriter.notifyMatchFailure(op, "unsupported operand type");
637+
638+
auto operand = op.getOperand();
639+
VectorShape shape = vectorShape(operand);
640+
641+
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
642+
auto bcast = [&](Value value) -> Value {
643+
return broadcast(builder, value, shape);
644+
};
645+
646+
// 1/2 * log((1 + x) / (1 - x))
647+
Value cstOne = bcast(f32Cst(builder, 1.0));
648+
Value add = builder.create<arith::AddFOp>(operand, cstOne);
649+
Value neg = builder.create<arith::NegFOp>(operand);
650+
Value sub = builder.create<arith::AddFOp>(neg, cstOne);
651+
Value div = builder.create<arith::DivFOp>(add, sub);
652+
Value log = builder.create<math::LogOp>(div);
653+
Value cstHalf = bcast(f32Cst(builder, 0.5));
654+
Value res = builder.create<arith::MulFOp>(log, cstHalf);
655+
rewriter.replaceOp(op, res);
656+
657+
return success();
658+
}
622659

623660
//----------------------------------------------------------------------------//
624661
// LogOp and Log2Op approximation.
@@ -635,6 +672,11 @@ struct LogApproximationBase : public OpRewritePattern<Op> {
635672
};
636673
} // namespace
637674

675+
#define LN2_VALUE \
676+
0.693147180559945309417232121458176568075500134360255254120680009493393621L
677+
#define LOG2E_VALUE \
678+
1.442695040888963407359924681001892137426645954152985934135449406931109219L
679+
638680
// This approximation comes from Julien Pommier's SSE math library.
639681
// Link: http://gruntthepeon.free.fr/ssemath
640682
template <typename Op>
@@ -1316,6 +1358,105 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
13161358
return success();
13171359
}
13181360

1361+
//----------------------------------------------------------------------------//
1362+
// SinhOp and CoshOp approximation.
1363+
//----------------------------------------------------------------------------//
1364+
1365+
namespace {
1366+
1367+
template <bool isSine, typename OpTy>
1368+
struct SinhAndCoshApproximation : public OpRewritePattern<OpTy> {
1369+
public:
1370+
using OpRewritePattern<OpTy>::OpRewritePattern;
1371+
1372+
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final;
1373+
};
1374+
} // namespace
1375+
1376+
template <bool isSine, typename OpTy>
1377+
LogicalResult SinhAndCoshApproximation<isSine, OpTy>::matchAndRewrite(
1378+
OpTy op, PatternRewriter &rewriter) const {
1379+
static_assert(
1380+
llvm::is_one_of<OpTy, math::SinhOp, math::CoshOp>::value,
1381+
"SinhAndCoshApproximation pattern expects math::SinhOp or math::CoshOp");
1382+
1383+
if (!getElementTypeOrSelf(op.getOperand()).isF32())
1384+
return rewriter.notifyMatchFailure(op, "unsupported operand type");
1385+
1386+
auto operand = op.getOperand();
1387+
VectorShape shape = vectorShape(operand);
1388+
1389+
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1390+
auto bcast = [&](Value value) -> Value {
1391+
return broadcast(builder, value, shape);
1392+
};
1393+
1394+
// sinh: 1/2 * (exp(x) – exp(-x))
1395+
// cosh: 1/2 * (exp(x) + exp(-x))
1396+
Value exp = builder.create<math::ExpOp>(operand);
1397+
Value neg = builder.create<arith::NegFOp>(operand);
1398+
Value negExp = builder.create<math::ExpOp>(neg);
1399+
Value addOrSub;
1400+
if (isSine)
1401+
addOrSub = builder.create<arith::SubFOp>(exp, negExp);
1402+
else
1403+
addOrSub = builder.create<arith::AddFOp>(exp, negExp);
1404+
Value cstHalf = bcast(f32Cst(builder, 0.5));
1405+
Value res = builder.create<arith::MulFOp>(addOrSub, cstHalf);
1406+
rewriter.replaceOp(op, res);
1407+
1408+
return success();
1409+
}
1410+
1411+
//----------------------------------------------------------------------------//
1412+
// AsinhOp and AcoshOp approximation.
1413+
//----------------------------------------------------------------------------//
1414+
1415+
namespace {
1416+
1417+
template <bool isSine, typename OpTy>
1418+
struct AsinhAndAcoshApproximation : public OpRewritePattern<OpTy> {
1419+
public:
1420+
using OpRewritePattern<OpTy>::OpRewritePattern;
1421+
1422+
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final;
1423+
};
1424+
} // namespace
1425+
1426+
template <bool isSine, typename OpTy>
1427+
LogicalResult AsinhAndAcoshApproximation<isSine, OpTy>::matchAndRewrite(
1428+
OpTy op, PatternRewriter &rewriter) const {
1429+
static_assert(
1430+
llvm::is_one_of<OpTy, math::AsinhOp, math::AcoshOp>::value,
1431+
"SinAndCosApproximation pattern expects math::AsinhOp or math::AcoshOp");
1432+
1433+
if (!getElementTypeOrSelf(op.getOperand()).isF32())
1434+
return rewriter.notifyMatchFailure(op, "unsupported operand type");
1435+
1436+
auto operand = op.getOperand();
1437+
VectorShape shape = vectorShape(operand);
1438+
1439+
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1440+
auto bcast = [&](Value value) -> Value {
1441+
return broadcast(builder, value, shape);
1442+
};
1443+
1444+
// asinh: log(x + sqrt(x**2 + 1))
1445+
// acosh: log(x + sqrt(x**2 - 1))
1446+
Value cst;
1447+
if (isSine)
1448+
cst = bcast(f32Cst(builder, 1.0));
1449+
else
1450+
cst = bcast(f32Cst(builder, -1.0));
1451+
Value fma = builder.create<math::FmaOp>(operand, operand, cst);
1452+
Value sqrt = builder.create<math::SqrtOp>(fma);
1453+
Value b = builder.create<arith::AddFOp>(operand, sqrt);
1454+
Value res = builder.create<math::LogOp>(b);
1455+
rewriter.replaceOp(op, res);
1456+
1457+
return success();
1458+
}
1459+
13191460
//----------------------------------------------------------------------------//
13201461
// Cbrt approximation.
13211462
//----------------------------------------------------------------------------//
@@ -1505,11 +1646,16 @@ void mlir::populateMathPolynomialApproximationPatterns(
15051646
ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
15061647
patterns.getContext());
15071648

1508-
patterns.add<AtanApproximation, Atan2Approximation, TanhApproximation,
1509-
LogApproximation, Log2Approximation, Log1pApproximation,
1510-
ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation,
1511-
CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
1512-
SinAndCosApproximation<false, math::CosOp>>(
1649+
patterns.add<AtanApproximation, Atan2Approximation, AtanhApproximation,
1650+
TanhApproximation, LogApproximation, Log2Approximation,
1651+
Log1pApproximation, ErfPolynomialApproximation, ExpApproximation,
1652+
ExpM1Approximation, CbrtApproximation,
1653+
SinAndCosApproximation<true, math::SinOp>,
1654+
SinAndCosApproximation<false, math::CosOp>,
1655+
SinhAndCoshApproximation<true, math::SinhOp>,
1656+
SinhAndCoshApproximation<false, math::CoshOp>,
1657+
AsinhAndAcoshApproximation<true, math::AsinhOp>,
1658+
AsinhAndAcoshApproximation<false, math::AcoshOp>>(
15131659
patterns.getContext());
15141660
if (options.enableAvx2) {
15151661
patterns.add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(

0 commit comments

Comments
 (0)