@@ -615,10 +615,47 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
615
615
return success ();
616
616
}
617
617
618
- #define LN2_VALUE \
619
- 0 .693147180559945309417232121458176568075500134360255254120680009493393621L
620
- #define LOG2E_VALUE \
621
- 1 .442695040888963407359924681001892137426645954152985934135449406931109219L
618
+ // ----------------------------------------------------------------------------//
619
+ // AtanhOp approximation.
620
+ // ----------------------------------------------------------------------------//
621
+
622
+ namespace {
623
+ struct AtanhApproximation : public OpRewritePattern <math::AtanhOp> {
624
+ public:
625
+ using OpRewritePattern::OpRewritePattern;
626
+
627
+ LogicalResult matchAndRewrite (math::AtanhOp op,
628
+ PatternRewriter &rewriter) const final ;
629
+ };
630
+ } // namespace
631
+
632
+ LogicalResult
633
+ AtanhApproximation::matchAndRewrite (math::AtanhOp op,
634
+ PatternRewriter &rewriter) const {
635
+ if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
636
+ return rewriter.notifyMatchFailure (op, " unsupported operand type" );
637
+
638
+ auto operand = op.getOperand ();
639
+ VectorShape shape = vectorShape (operand);
640
+
641
+ ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
642
+ auto bcast = [&](Value value) -> Value {
643
+ return broadcast (builder, value, shape);
644
+ };
645
+
646
+ // 1/2 * log((1 + x) / (1 - x))
647
+ Value cstOne = bcast (f32Cst (builder, 1.0 ));
648
+ Value add = builder.create <arith::AddFOp>(operand, cstOne);
649
+ Value neg = builder.create <arith::NegFOp>(operand);
650
+ Value sub = builder.create <arith::AddFOp>(neg, cstOne);
651
+ Value div = builder.create <arith::DivFOp>(add, sub);
652
+ Value log = builder.create <math::LogOp>(div);
653
+ Value cstHalf = bcast (f32Cst (builder, 0.5 ));
654
+ Value res = builder.create <arith::MulFOp>(log, cstHalf);
655
+ rewriter.replaceOp (op, res);
656
+
657
+ return success ();
658
+ }
622
659
623
660
// ----------------------------------------------------------------------------//
624
661
// LogOp and Log2Op approximation.
@@ -635,6 +672,11 @@ struct LogApproximationBase : public OpRewritePattern<Op> {
635
672
};
636
673
} // namespace
637
674
675
+ #define LN2_VALUE \
676
+ 0 .693147180559945309417232121458176568075500134360255254120680009493393621L
677
+ #define LOG2E_VALUE \
678
+ 1 .442695040888963407359924681001892137426645954152985934135449406931109219L
679
+
638
680
// This approximation comes from Julien Pommier's SSE math library.
639
681
// Link: http://gruntthepeon.free.fr/ssemath
640
682
template <typename Op>
@@ -1316,6 +1358,105 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
1316
1358
return success ();
1317
1359
}
1318
1360
1361
+ // ----------------------------------------------------------------------------//
1362
+ // SinhOp and CoshOp approximation.
1363
+ // ----------------------------------------------------------------------------//
1364
+
1365
+ namespace {
1366
+
1367
+ template <bool isSine, typename OpTy>
1368
+ struct SinhAndCoshApproximation : public OpRewritePattern <OpTy> {
1369
+ public:
1370
+ using OpRewritePattern<OpTy>::OpRewritePattern;
1371
+
1372
+ LogicalResult matchAndRewrite (OpTy op, PatternRewriter &rewriter) const final ;
1373
+ };
1374
+ } // namespace
1375
+
1376
+ template <bool isSine, typename OpTy>
1377
+ LogicalResult SinhAndCoshApproximation<isSine, OpTy>::matchAndRewrite(
1378
+ OpTy op, PatternRewriter &rewriter) const {
1379
+ static_assert (
1380
+ llvm::is_one_of<OpTy, math::SinhOp, math::CoshOp>::value,
1381
+ " SinhAndCoshApproximation pattern expects math::SinhOp or math::CoshOp" );
1382
+
1383
+ if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
1384
+ return rewriter.notifyMatchFailure (op, " unsupported operand type" );
1385
+
1386
+ auto operand = op.getOperand ();
1387
+ VectorShape shape = vectorShape (operand);
1388
+
1389
+ ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
1390
+ auto bcast = [&](Value value) -> Value {
1391
+ return broadcast (builder, value, shape);
1392
+ };
1393
+
1394
+ // sinh: 1/2 * (exp(x) – exp(-x))
1395
+ // cosh: 1/2 * (exp(x) + exp(-x))
1396
+ Value exp = builder.create <math::ExpOp>(operand);
1397
+ Value neg = builder.create <arith::NegFOp>(operand);
1398
+ Value negExp = builder.create <math::ExpOp>(neg);
1399
+ Value addOrSub;
1400
+ if (isSine)
1401
+ addOrSub = builder.create <arith::SubFOp>(exp, negExp);
1402
+ else
1403
+ addOrSub = builder.create <arith::AddFOp>(exp, negExp);
1404
+ Value cstHalf = bcast (f32Cst (builder, 0.5 ));
1405
+ Value res = builder.create <arith::MulFOp>(addOrSub, cstHalf);
1406
+ rewriter.replaceOp (op, res);
1407
+
1408
+ return success ();
1409
+ }
1410
+
1411
+ // ----------------------------------------------------------------------------//
1412
+ // AsinhOp and AcoshOp approximation.
1413
+ // ----------------------------------------------------------------------------//
1414
+
1415
+ namespace {
1416
+
1417
+ template <bool isSine, typename OpTy>
1418
+ struct AsinhAndAcoshApproximation : public OpRewritePattern <OpTy> {
1419
+ public:
1420
+ using OpRewritePattern<OpTy>::OpRewritePattern;
1421
+
1422
+ LogicalResult matchAndRewrite (OpTy op, PatternRewriter &rewriter) const final ;
1423
+ };
1424
+ } // namespace
1425
+
1426
+ template <bool isSine, typename OpTy>
1427
+ LogicalResult AsinhAndAcoshApproximation<isSine, OpTy>::matchAndRewrite(
1428
+ OpTy op, PatternRewriter &rewriter) const {
1429
+ static_assert (
1430
+ llvm::is_one_of<OpTy, math::AsinhOp, math::AcoshOp>::value,
1431
+ " SinAndCosApproximation pattern expects math::AsinhOp or math::AcoshOp" );
1432
+
1433
+ if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
1434
+ return rewriter.notifyMatchFailure (op, " unsupported operand type" );
1435
+
1436
+ auto operand = op.getOperand ();
1437
+ VectorShape shape = vectorShape (operand);
1438
+
1439
+ ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
1440
+ auto bcast = [&](Value value) -> Value {
1441
+ return broadcast (builder, value, shape);
1442
+ };
1443
+
1444
+ // asinh: log(x + sqrt(x**2 + 1))
1445
+ // acosh: log(x + sqrt(x**2 - 1))
1446
+ Value cst;
1447
+ if (isSine)
1448
+ cst = bcast (f32Cst (builder, 1.0 ));
1449
+ else
1450
+ cst = bcast (f32Cst (builder, -1.0 ));
1451
+ Value fma = builder.create <math::FmaOp>(operand, operand, cst);
1452
+ Value sqrt = builder.create <math::SqrtOp>(fma);
1453
+ Value b = builder.create <arith::AddFOp>(operand, sqrt);
1454
+ Value res = builder.create <math::LogOp>(b);
1455
+ rewriter.replaceOp (op, res);
1456
+
1457
+ return success ();
1458
+ }
1459
+
1319
1460
// ----------------------------------------------------------------------------//
1320
1461
// Cbrt approximation.
1321
1462
// ----------------------------------------------------------------------------//
@@ -1505,11 +1646,16 @@ void mlir::populateMathPolynomialApproximationPatterns(
1505
1646
ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
1506
1647
patterns.getContext ());
1507
1648
1508
- patterns.add <AtanApproximation, Atan2Approximation, TanhApproximation,
1509
- LogApproximation, Log2Approximation, Log1pApproximation,
1510
- ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation,
1511
- CbrtApproximation, SinAndCosApproximation<true , math::SinOp>,
1512
- SinAndCosApproximation<false , math::CosOp>>(
1649
+ patterns.add <AtanApproximation, Atan2Approximation, AtanhApproximation,
1650
+ TanhApproximation, LogApproximation, Log2Approximation,
1651
+ Log1pApproximation, ErfPolynomialApproximation, ExpApproximation,
1652
+ ExpM1Approximation, CbrtApproximation,
1653
+ SinAndCosApproximation<true , math::SinOp>,
1654
+ SinAndCosApproximation<false , math::CosOp>,
1655
+ SinhAndCoshApproximation<true , math::SinhOp>,
1656
+ SinhAndCoshApproximation<false , math::CoshOp>,
1657
+ AsinhAndAcoshApproximation<true , math::AsinhOp>,
1658
+ AsinhAndAcoshApproximation<false , math::AcoshOp>>(
1513
1659
patterns.getContext ());
1514
1660
if (options.enableAvx2 ) {
1515
1661
patterns.add <RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
0 commit comments