Skip to content

[mlir][math] Add expand patterns for acosh, asinh, atanh #90718

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 7, 2024

Conversation

jinchen62
Copy link
Contributor

@jinchen62 jinchen62 commented May 1, 2024

No description provided.

@llvmbot
Copy link
Member

llvmbot commented May 1, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-math

Author: jinchen (jinchen62)

Changes
  • acosh
  • asinh
  • atanh
  • cosh
  • sinh

Full diff: https://github.com/llvm/llvm-project/pull/90718.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp (+156-9)
  • (modified) mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir (+202)
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 428c1c37c4e8b5..36bb8c6245e499 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -615,10 +615,47 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
   return success();
 }
 
-#define LN2_VALUE                                                              \
-  0.693147180559945309417232121458176568075500134360255254120680009493393621L
-#define LOG2E_VALUE                                                            \
-  1.442695040888963407359924681001892137426645954152985934135449406931109219L
+//----------------------------------------------------------------------------//
+// AtanhOp approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+struct AtanhApproximation : public OpRewritePattern<math::AtanhOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(math::AtanhOp op,
+                                PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+LogicalResult
+AtanhApproximation::matchAndRewrite(math::AtanhOp op,
+                                   PatternRewriter &rewriter) const {
+  if (!getElementTypeOrSelf(op.getOperand()).isF32())
+    return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+  auto operand = op.getOperand();
+  VectorShape shape = vectorShape(operand);
+
+  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+  auto bcast = [&](Value value) -> Value {
+    return broadcast(builder, value, shape);
+  };
+
+  // 1/2 * log((1 + x) / (1 - x))
+  Value cstOne = bcast(f32Cst(builder, 1.0));
+  Value add = builder.create<arith::AddFOp>(operand, cstOne);
+  Value neg = builder.create<arith::NegFOp>(operand);
+  Value sub = builder.create<arith::AddFOp>(neg, cstOne);
+  Value div = builder.create<arith::DivFOp>(add, sub);
+  Value log = builder.create<math::LogOp>(div);
+  Value cstTwo = bcast(f32Cst(builder, 2.0));
+  Value res = builder.create<arith::DivFOp>(log, cstTwo);
+  rewriter.replaceOp(op, res);
+
+  return success();
+}
 
 //----------------------------------------------------------------------------//
 // LogOp and Log2Op approximation.
@@ -635,6 +672,11 @@ struct LogApproximationBase : public OpRewritePattern<Op> {
 };
 } // namespace
 
+#define LN2_VALUE                                                              \
+  0.693147180559945309417232121458176568075500134360255254120680009493393621L
+#define LOG2E_VALUE                                                            \
+  1.442695040888963407359924681001892137426645954152985934135449406931109219L
+
 // This approximation comes from Julien Pommier's SSE math library.
 // Link: http://gruntthepeon.free.fr/ssemath
 template <typename Op>
@@ -1316,6 +1358,106 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
   return success();
 }
 
+//----------------------------------------------------------------------------//
+// SinhOp and CoshOp approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+
+template <bool isSine, typename OpTy>
+struct SinhAndCoshApproximation : public OpRewritePattern<OpTy> {
+public:
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+template <bool isSine, typename OpTy>
+LogicalResult SinhAndCoshApproximation<isSine, OpTy>::matchAndRewrite(
+    OpTy op, PatternRewriter &rewriter) const {
+  static_assert(
+      llvm::is_one_of<OpTy, math::SinhOp, math::CoshOp>::value,
+      "SinAndCosApproximation pattern expects math::SinhOp or math::CoshOp");
+
+  if (!getElementTypeOrSelf(op.getOperand()).isF32())
+    return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+  auto operand = op.getOperand();
+  VectorShape shape = vectorShape(operand);
+
+  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+  auto bcast = [&](Value value) -> Value {
+    return broadcast(builder, value, shape);
+  };
+
+  // sinh: 1/2 * (exp(x) – exp(-x))
+  // cosh: 1/2 * (exp(x) + exp(-x))
+  Value a = builder.create<math::ExpOp>(operand);
+  Value neg = builder.create<arith::NegFOp>(operand);
+  Value b = builder.create<math::ExpOp>(neg);
+  Value c;
+  if (isSine)
+    c = builder.create<arith::SubFOp>(a, b);
+  else
+    c = builder.create<arith::AddFOp>(a, b);
+  Value cstTwo = bcast(f32Cst(builder, 2.0));
+  Value res = builder.create<arith::DivFOp>(c, cstTwo);
+  rewriter.replaceOp(op, res);
+
+  return success();
+}
+
+//----------------------------------------------------------------------------//
+// AsinhOp and AcoshOp approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+
+template <bool isSine, typename OpTy>
+struct AsinhAndAcoshApproximation : public OpRewritePattern<OpTy> {
+public:
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+template <bool isSine, typename OpTy>
+LogicalResult AsinhAndAcoshApproximation<isSine, OpTy>::matchAndRewrite(
+    OpTy op, PatternRewriter &rewriter) const {
+  static_assert(
+      llvm::is_one_of<OpTy, math::AsinhOp, math::AcoshOp>::value,
+      "SinAndCosApproximation pattern expects math::AsinhOp or math::AcoshOp");
+
+  if (!getElementTypeOrSelf(op.getOperand()).isF32())
+    return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+  auto operand = op.getOperand();
+  VectorShape shape = vectorShape(operand);
+
+  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+  auto bcast = [&](Value value) -> Value {
+    return broadcast(builder, value, shape);
+  };
+
+  // asinh: log(x + sqrt(x**2 + 1))
+  // acosh: log(x + sqrt(x**2 - 1))
+  Value squared = builder.create<arith::MulFOp>(operand, operand);
+  Value cstOne = bcast(f32Cst(builder, 1.0));
+  Value a;
+  if (isSine)
+    a = builder.create<arith::AddFOp>(squared, cstOne);
+  else
+    a = builder.create<arith::SubFOp>(squared, cstOne);
+  Value sqrt = builder.create<math::SqrtOp>(a);
+  Value b = builder.create<arith::AddFOp>(operand, sqrt);
+  Value res = builder.create<math::LogOp>(b);
+  rewriter.replaceOp(op, res);
+
+  return success();
+}
+
 //----------------------------------------------------------------------------//
 // Cbrt approximation.
 //----------------------------------------------------------------------------//
@@ -1505,11 +1647,16 @@ void mlir::populateMathPolynomialApproximationPatterns(
            ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
           patterns.getContext());
 
-  patterns.add<AtanApproximation, Atan2Approximation, TanhApproximation,
-               LogApproximation, Log2Approximation, Log1pApproximation,
-               ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation,
-               CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
-               SinAndCosApproximation<false, math::CosOp>>(
+  patterns.add<AtanApproximation, Atan2Approximation, AtanhApproximation,
+               TanhApproximation, LogApproximation, Log2Approximation,
+               Log1pApproximation, ErfPolynomialApproximation, ExpApproximation,
+               ExpM1Approximation, CbrtApproximation,
+               SinAndCosApproximation<true, math::SinOp>,
+               SinAndCosApproximation<false, math::CosOp>,
+               SinhAndCoshApproximation<true, math::SinhOp>,
+               SinhAndCoshApproximation<false, math::CoshOp>,
+               AsinhAndAcoshApproximation<true, math::AsinhOp>,
+               AsinhAndAcoshApproximation<false, math::AcoshOp>>(
       patterns.getContext());
   if (options.enableAvx2) {
     patterns.add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
index d3b19be9ecaf8f..9b73cdf57f5a35 100644
--- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
@@ -568,6 +568,203 @@ func.func @atan2() {
 }
 
 
+// -------------------------------------------------------------------------- //
+// sinh
+// -------------------------------------------------------------------------- //
+
+func.func @sinh_f32(%a : f32) {
+  %r = math.sinh %a : f32
+  vector.print %r : f32
+  return
+}
+
+func.func @sinh_3xf32(%a : vector<3xf32>) {
+  %r = math.sinh %a : vector<3xf32>
+  vector.print %r : vector<3xf32>
+  return
+}
+
+func.func @sinh() {
+  // CHECK: 0
+  %zero = arith.constant 0.0 : f32
+  call @sinh_f32(%zero) : (f32) -> ()
+
+  // CHECK: 0.521095
+  %cst1 = arith.constant 0.5 : f32
+  call @sinh_f32(%cst1) : (f32) -> ()
+
+  // CHECK: -1.1752
+  %cst2 = arith.constant -1.0 : f32
+  call @sinh_f32(%cst2) : (f32) -> ()
+
+  // CHECK: 10.0179
+  %cst3 = arith.constant 3.0 : f32
+  call @sinh_f32(%cst3) : (f32) -> ()
+
+  // CHECK: 0.252612, 0.991007, 3.62686
+  %vec_x = arith.constant dense<[0.25, 0.875, 2.0]> : vector<3xf32>
+  call @sinh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+  return
+}
+
+
+// -------------------------------------------------------------------------- //
+// cosh
+// -------------------------------------------------------------------------- //
+
+func.func @cosh_f32(%a : f32) {
+  %r = math.cosh %a : f32
+  vector.print %r : f32
+  return
+}
+
+func.func @cosh_3xf32(%a : vector<3xf32>) {
+  %r = math.cosh %a : vector<3xf32>
+  vector.print %r : vector<3xf32>
+  return
+}
+
+func.func @cosh() {
+  // CHECK: 1
+  %zero = arith.constant 0.0 : f32
+  call @cosh_f32(%zero) : (f32) -> ()
+
+  // CHECK: 1.54308
+  %cst1 = arith.constant 1.0 : f32
+  call @cosh_f32(%cst1) : (f32) -> ()
+
+  // CHECK: 1.54308
+  %cst2 = arith.constant -1.0 : f32
+  call @cosh_f32(%cst2) : (f32) -> ()
+
+  // CHECK: 10.0677
+  %cst3 = arith.constant 3.0 : f32
+  call @cosh_f32(%cst3) : (f32) -> ()
+
+  // CHECK: 1.03141, 1.40787, 3.7622
+  %vec_x = arith.constant dense<[0.25, 0.875, 2.0]> : vector<3xf32>
+  call @cosh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+  return
+}
+
+
+// -------------------------------------------------------------------------- //
+// asinh
+// -------------------------------------------------------------------------- //
+
+func.func @asinh_f32(%a : f32) {
+  %r = math.asinh %a : f32
+  vector.print %r : f32
+  return
+}
+
+func.func @asinh_3xf32(%a : vector<3xf32>) {
+  %r = math.asinh %a : vector<3xf32>
+  vector.print %r : vector<3xf32>
+  return
+}
+
+func.func @asinh() {
+  // CHECK: 0
+  %zero = arith.constant 0.0 : f32
+  call @asinh_f32(%zero) : (f32) -> ()
+
+  // CHECK: 0.881374
+  %cst1 = arith.constant 1.0 : f32
+  call @asinh_f32(%cst1) : (f32) -> ()
+
+  // CHECK: -0.881374
+  %cst2 = arith.constant -1.0 : f32
+  call @asinh_f32(%cst2) : (f32) -> ()
+
+  // CHECK: 1.81845
+  %cst3 = arith.constant 3.0 : f32
+  call @asinh_f32(%cst3) : (f32) -> ()
+
+  // CHECK: 0.247466, 0.790169, 1.44364
+  %vec_x = arith.constant dense<[0.25, 0.875, 2.0]> : vector<3xf32>
+  call @asinh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+  return
+}
+
+
+// -------------------------------------------------------------------------- //
+// acosh
+// -------------------------------------------------------------------------- //
+
+func.func @acosh_f32(%a : f32) {
+  %r = math.acosh %a : f32
+  vector.print %r : f32
+  return
+}
+
+func.func @acosh_3xf32(%a : vector<3xf32>) {
+  %r = math.acosh %a : vector<3xf32>
+  vector.print %r : vector<3xf32>
+  return
+}
+
+func.func @acosh() {
+  // CHECK: 0
+  %zero = arith.constant 1.0 : f32
+  call @acosh_f32(%zero) : (f32) -> ()
+
+  // CHECK: 1.31696
+  %cst1 = arith.constant 2.0 : f32
+  call @acosh_f32(%cst1) : (f32) -> ()
+
+  // CHECK: 2.99322
+  %cst2 = arith.constant 10.0 : f32
+  call @acosh_f32(%cst2) : (f32) -> ()
+
+  // CHECK: 0.962424, 1.76275, 2.47789
+  %vec_x = arith.constant dense<[1.5, 3.0, 6.0]> : vector<3xf32>
+  call @acosh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+  return
+}
+
+
+// -------------------------------------------------------------------------- //
+// atanh
+// -------------------------------------------------------------------------- //
+
+func.func @atanh_f32(%a : f32) {
+  %r = math.atanh %a : f32
+  vector.print %r : f32
+  return
+}
+
+func.func @atanh_3xf32(%a : vector<3xf32>) {
+  %r = math.atanh %a : vector<3xf32>
+  vector.print %r : vector<3xf32>
+  return
+}
+
+func.func @atanh() {
+  // CHECK: 0
+  %zero = arith.constant 0.0 : f32
+  call @atanh_f32(%zero) : (f32) -> ()
+
+  // CHECK: 0.549306
+  %cst1 = arith.constant 0.5 : f32
+  call @atanh_f32(%cst1) : (f32) -> ()
+
+  // CHECK: -0.549306
+  %cst2 = arith.constant -0.5 : f32
+  call @atanh_f32(%cst2) : (f32) -> ()
+
+  // CHECK: 0.255413, 0.394229, 2.99448
+  %vec_x = arith.constant dense<[0.25, 0.375, 0.995]> : vector<3xf32>
+  call @atanh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+  return
+}
+
+
 // -------------------------------------------------------------------------- //
 // Cbrt.
 // -------------------------------------------------------------------------- //
@@ -696,6 +893,11 @@ func.func @main() {
   call @cos(): () -> ()
   call @atan() : () -> ()
   call @atan2() : () -> ()
+  call @sinh() : () -> ()
+  call @cosh() : () -> ()
+  call @asinh() : () -> ()
+  call @acosh() : () -> ()
+  call @atanh() : () -> ()
   call @cbrt() : () -> ()
   call @floorf() : () -> ()
   call @ceilf() : () -> ()

Copy link

github-actions bot commented May 1, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Member

@pashu123 pashu123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! I have added some comments please address.

@zjgarvey
Copy link
Contributor

zjgarvey commented May 2, 2024

Why are these called polynomial approximations when you aren't approximating these functions with polynomials??

@pashu123
Copy link
Member

pashu123 commented May 2, 2024

Why are these called polynomial approximations when you aren't approximating these functions with polynomials??

@zjgarvey You are right! @jinchen62 You should move them to https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp

Copy link
Member

@pashu123 pashu123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@jinchen62 jinchen62 changed the title [mlir][math] Add Polynomial Approximation for few ops [mlir][math] Add conversions for acosh, asinh, atanh May 3, 2024
@jinchen62
Copy link
Contributor Author

I will test iree tests before merge.

@jinchen62 jinchen62 force-pushed the poly_approx branch 2 times, most recently from 3805e68 to 1d3a1a8 Compare May 4, 2024 03:20
@jinchen62
Copy link
Contributor Author

@pashu123 Could you review again? I removed the expand pattern of tanh op since it has polynomial approximation. And I added a pass including all math expand patterns so on IREE side we could change iree-org/iree@06febe5, if this pass doesn't make sense I would just add those missing expand patterns on IREE.

@pashu123
Copy link
Member

pashu123 commented May 4, 2024

@pashu123 Could you review again? I removed the expand pattern of tanh op since it has polynomial approximation. And I added a pass including all math expand patterns so on IREE side we could change iree-org/iree@06febe5, if this pass doesn't make sense I would just add those missing expand patterns on IREE.

I suggest adding those missing patterns.

Copy link
Member

@pashu123 pashu123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments!

@jinchen62 jinchen62 force-pushed the poly_approx branch 3 times, most recently from f6c7bfc to 5670262 Compare May 6, 2024 15:12
@jinchen62 jinchen62 requested a review from pashu123 May 6, 2024 23:13
@jinchen62
Copy link
Contributor Author

Minor comments!

Addressed.

Copy link
Member

@pashu123 pashu123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for putting this up. LGTM!

@jinchen62
Copy link
Contributor Author

@pashu123 Could you merge it? I don’t have the permission. Thanks!

@jinchen62 jinchen62 changed the title [mlir][math] Add conversions for acosh, asinh, atanh [mlir][math] Add expand patterns for acosh, asinh, atanh May 7, 2024
@pashu123 pashu123 merged commit a62a702 into llvm:main May 7, 2024
4 checks passed
@pashu123
Copy link
Member

pashu123 commented May 7, 2024

@pashu123 Could you merge it? I don’t have the permission. Thanks!

Done.

@jinchen62 jinchen62 deleted the poly_approx branch May 8, 2024 06:55
jinchen62 added a commit to iree-org/iree that referenced this pull request May 13, 2024
bangtianliu pushed a commit to bangtianliu/iree that referenced this pull request Jun 5, 2024
LLITCHEV pushed a commit to LLITCHEV/iree that referenced this pull request Jul 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants