-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][math] Add expand patterns for acosh, asinh, atanh #90718
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-math Author: jinchen (jinchen62) Changes
Full diff: https://github.com/llvm/llvm-project/pull/90718.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 428c1c37c4e8b5..36bb8c6245e499 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -615,10 +615,47 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
return success();
}
-#define LN2_VALUE \
- 0.693147180559945309417232121458176568075500134360255254120680009493393621L
-#define LOG2E_VALUE \
- 1.442695040888963407359924681001892137426645954152985934135449406931109219L
+//----------------------------------------------------------------------------//
+// AtanhOp approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+struct AtanhApproximation : public OpRewritePattern<math::AtanhOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(math::AtanhOp op,
+ PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+LogicalResult
+AtanhApproximation::matchAndRewrite(math::AtanhOp op,
+ PatternRewriter &rewriter) const {
+ if (!getElementTypeOrSelf(op.getOperand()).isF32())
+ return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+ auto operand = op.getOperand();
+ VectorShape shape = vectorShape(operand);
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ auto bcast = [&](Value value) -> Value {
+ return broadcast(builder, value, shape);
+ };
+
+ // 1/2 * log((1 + x) / (1 - x))
+ Value cstOne = bcast(f32Cst(builder, 1.0));
+ Value add = builder.create<arith::AddFOp>(operand, cstOne);
+ Value neg = builder.create<arith::NegFOp>(operand);
+ Value sub = builder.create<arith::AddFOp>(neg, cstOne);
+ Value div = builder.create<arith::DivFOp>(add, sub);
+ Value log = builder.create<math::LogOp>(div);
+ Value cstTwo = bcast(f32Cst(builder, 2.0));
+ Value res = builder.create<arith::DivFOp>(log, cstTwo);
+ rewriter.replaceOp(op, res);
+
+ return success();
+}
//----------------------------------------------------------------------------//
// LogOp and Log2Op approximation.
@@ -635,6 +672,11 @@ struct LogApproximationBase : public OpRewritePattern<Op> {
};
} // namespace
+#define LN2_VALUE \
+ 0.693147180559945309417232121458176568075500134360255254120680009493393621L
+#define LOG2E_VALUE \
+ 1.442695040888963407359924681001892137426645954152985934135449406931109219L
+
// This approximation comes from Julien Pommier's SSE math library.
// Link: http://gruntthepeon.free.fr/ssemath
template <typename Op>
@@ -1316,6 +1358,106 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
return success();
}
+//----------------------------------------------------------------------------//
+// SinhOp and CoshOp approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+
+template <bool isSine, typename OpTy>
+struct SinhAndCoshApproximation : public OpRewritePattern<OpTy> {
+public:
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+template <bool isSine, typename OpTy>
+LogicalResult SinhAndCoshApproximation<isSine, OpTy>::matchAndRewrite(
+ OpTy op, PatternRewriter &rewriter) const {
+ static_assert(
+ llvm::is_one_of<OpTy, math::SinhOp, math::CoshOp>::value,
+ "SinAndCosApproximation pattern expects math::SinhOp or math::CoshOp");
+
+ if (!getElementTypeOrSelf(op.getOperand()).isF32())
+ return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+ auto operand = op.getOperand();
+ VectorShape shape = vectorShape(operand);
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ auto bcast = [&](Value value) -> Value {
+ return broadcast(builder, value, shape);
+ };
+
+ // sinh: 1/2 * (exp(x) – exp(-x))
+ // cosh: 1/2 * (exp(x) + exp(-x))
+ Value a = builder.create<math::ExpOp>(operand);
+ Value neg = builder.create<arith::NegFOp>(operand);
+ Value b = builder.create<math::ExpOp>(neg);
+ Value c;
+ if (isSine)
+ c = builder.create<arith::SubFOp>(a, b);
+ else
+ c = builder.create<arith::AddFOp>(a, b);
+ Value cstTwo = bcast(f32Cst(builder, 2.0));
+ Value res = builder.create<arith::DivFOp>(c, cstTwo);
+ rewriter.replaceOp(op, res);
+
+ return success();
+}
+
+//----------------------------------------------------------------------------//
+// AsinhOp and AcoshOp approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+
+template <bool isSine, typename OpTy>
+struct AsinhAndAcoshApproximation : public OpRewritePattern<OpTy> {
+public:
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+template <bool isSine, typename OpTy>
+LogicalResult AsinhAndAcoshApproximation<isSine, OpTy>::matchAndRewrite(
+ OpTy op, PatternRewriter &rewriter) const {
+ static_assert(
+ llvm::is_one_of<OpTy, math::AsinhOp, math::AcoshOp>::value,
+ "SinAndCosApproximation pattern expects math::AsinhOp or math::AcoshOp");
+
+ if (!getElementTypeOrSelf(op.getOperand()).isF32())
+ return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+ auto operand = op.getOperand();
+ VectorShape shape = vectorShape(operand);
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ auto bcast = [&](Value value) -> Value {
+ return broadcast(builder, value, shape);
+ };
+
+ // asinh: log(x + sqrt(x**2 + 1))
+ // acosh: log(x + sqrt(x**2 - 1))
+ Value squared = builder.create<arith::MulFOp>(operand, operand);
+ Value cstOne = bcast(f32Cst(builder, 1.0));
+ Value a;
+ if (isSine)
+ a = builder.create<arith::AddFOp>(squared, cstOne);
+ else
+ a = builder.create<arith::SubFOp>(squared, cstOne);
+ Value sqrt = builder.create<math::SqrtOp>(a);
+ Value b = builder.create<arith::AddFOp>(operand, sqrt);
+ Value res = builder.create<math::LogOp>(b);
+ rewriter.replaceOp(op, res);
+
+ return success();
+}
+
//----------------------------------------------------------------------------//
// Cbrt approximation.
//----------------------------------------------------------------------------//
@@ -1505,11 +1647,16 @@ void mlir::populateMathPolynomialApproximationPatterns(
ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
patterns.getContext());
- patterns.add<AtanApproximation, Atan2Approximation, TanhApproximation,
- LogApproximation, Log2Approximation, Log1pApproximation,
- ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation,
- CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
- SinAndCosApproximation<false, math::CosOp>>(
+ patterns.add<AtanApproximation, Atan2Approximation, AtanhApproximation,
+ TanhApproximation, LogApproximation, Log2Approximation,
+ Log1pApproximation, ErfPolynomialApproximation, ExpApproximation,
+ ExpM1Approximation, CbrtApproximation,
+ SinAndCosApproximation<true, math::SinOp>,
+ SinAndCosApproximation<false, math::CosOp>,
+ SinhAndCoshApproximation<true, math::SinhOp>,
+ SinhAndCoshApproximation<false, math::CoshOp>,
+ AsinhAndAcoshApproximation<true, math::AsinhOp>,
+ AsinhAndAcoshApproximation<false, math::AcoshOp>>(
patterns.getContext());
if (options.enableAvx2) {
patterns.add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
index d3b19be9ecaf8f..9b73cdf57f5a35 100644
--- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
@@ -568,6 +568,203 @@ func.func @atan2() {
}
+// -------------------------------------------------------------------------- //
+// sinh
+// -------------------------------------------------------------------------- //
+
+func.func @sinh_f32(%a : f32) {
+ %r = math.sinh %a : f32
+ vector.print %r : f32
+ return
+}
+
+func.func @sinh_3xf32(%a : vector<3xf32>) {
+ %r = math.sinh %a : vector<3xf32>
+ vector.print %r : vector<3xf32>
+ return
+}
+
+func.func @sinh() {
+ // CHECK: 0
+ %zero = arith.constant 0.0 : f32
+ call @sinh_f32(%zero) : (f32) -> ()
+
+ // CHECK: 0.521095
+ %cst1 = arith.constant 0.5 : f32
+ call @sinh_f32(%cst1) : (f32) -> ()
+
+ // CHECK: -1.1752
+ %cst2 = arith.constant -1.0 : f32
+ call @sinh_f32(%cst2) : (f32) -> ()
+
+ // CHECK: 10.0179
+ %cst3 = arith.constant 3.0 : f32
+ call @sinh_f32(%cst3) : (f32) -> ()
+
+ // CHECK: 0.252612, 0.991007, 3.62686
+ %vec_x = arith.constant dense<[0.25, 0.875, 2.0]> : vector<3xf32>
+ call @sinh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+ return
+}
+
+
+// -------------------------------------------------------------------------- //
+// cosh
+// -------------------------------------------------------------------------- //
+
+func.func @cosh_f32(%a : f32) {
+ %r = math.cosh %a : f32
+ vector.print %r : f32
+ return
+}
+
+func.func @cosh_3xf32(%a : vector<3xf32>) {
+ %r = math.cosh %a : vector<3xf32>
+ vector.print %r : vector<3xf32>
+ return
+}
+
+func.func @cosh() {
+ // CHECK: 1
+ %zero = arith.constant 0.0 : f32
+ call @cosh_f32(%zero) : (f32) -> ()
+
+ // CHECK: 1.54308
+ %cst1 = arith.constant 1.0 : f32
+ call @cosh_f32(%cst1) : (f32) -> ()
+
+ // CHECK: 1.54308
+ %cst2 = arith.constant -1.0 : f32
+ call @cosh_f32(%cst2) : (f32) -> ()
+
+ // CHECK: 10.0677
+ %cst3 = arith.constant 3.0 : f32
+ call @cosh_f32(%cst3) : (f32) -> ()
+
+ // CHECK: 1.03141, 1.40787, 3.7622
+ %vec_x = arith.constant dense<[0.25, 0.875, 2.0]> : vector<3xf32>
+ call @cosh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+ return
+}
+
+
+// -------------------------------------------------------------------------- //
+// asinh
+// -------------------------------------------------------------------------- //
+
+func.func @asinh_f32(%a : f32) {
+ %r = math.asinh %a : f32
+ vector.print %r : f32
+ return
+}
+
+func.func @asinh_3xf32(%a : vector<3xf32>) {
+ %r = math.asinh %a : vector<3xf32>
+ vector.print %r : vector<3xf32>
+ return
+}
+
+func.func @asinh() {
+ // CHECK: 0
+ %zero = arith.constant 0.0 : f32
+ call @asinh_f32(%zero) : (f32) -> ()
+
+ // CHECK: 0.881374
+ %cst1 = arith.constant 1.0 : f32
+ call @asinh_f32(%cst1) : (f32) -> ()
+
+ // CHECK: -0.881374
+ %cst2 = arith.constant -1.0 : f32
+ call @asinh_f32(%cst2) : (f32) -> ()
+
+ // CHECK: 1.81845
+ %cst3 = arith.constant 3.0 : f32
+ call @asinh_f32(%cst3) : (f32) -> ()
+
+ // CHECK: 0.247466, 0.790169, 1.44364
+ %vec_x = arith.constant dense<[0.25, 0.875, 2.0]> : vector<3xf32>
+ call @asinh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+ return
+}
+
+
+// -------------------------------------------------------------------------- //
+// acosh
+// -------------------------------------------------------------------------- //
+
+func.func @acosh_f32(%a : f32) {
+ %r = math.acosh %a : f32
+ vector.print %r : f32
+ return
+}
+
+func.func @acosh_3xf32(%a : vector<3xf32>) {
+ %r = math.acosh %a : vector<3xf32>
+ vector.print %r : vector<3xf32>
+ return
+}
+
+func.func @acosh() {
+ // CHECK: 0
+ %zero = arith.constant 1.0 : f32
+ call @acosh_f32(%zero) : (f32) -> ()
+
+ // CHECK: 1.31696
+ %cst1 = arith.constant 2.0 : f32
+ call @acosh_f32(%cst1) : (f32) -> ()
+
+ // CHECK: 2.99322
+ %cst2 = arith.constant 10.0 : f32
+ call @acosh_f32(%cst2) : (f32) -> ()
+
+ // CHECK: 0.962424, 1.76275, 2.47789
+ %vec_x = arith.constant dense<[1.5, 3.0, 6.0]> : vector<3xf32>
+ call @acosh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+ return
+}
+
+
+// -------------------------------------------------------------------------- //
+// atanh
+// -------------------------------------------------------------------------- //
+
+func.func @atanh_f32(%a : f32) {
+ %r = math.atanh %a : f32
+ vector.print %r : f32
+ return
+}
+
+func.func @atanh_3xf32(%a : vector<3xf32>) {
+ %r = math.atanh %a : vector<3xf32>
+ vector.print %r : vector<3xf32>
+ return
+}
+
+func.func @atanh() {
+ // CHECK: 0
+ %zero = arith.constant 0.0 : f32
+ call @atanh_f32(%zero) : (f32) -> ()
+
+ // CHECK: 0.549306
+ %cst1 = arith.constant 0.5 : f32
+ call @atanh_f32(%cst1) : (f32) -> ()
+
+ // CHECK: -0.549306
+ %cst2 = arith.constant -0.5 : f32
+ call @atanh_f32(%cst2) : (f32) -> ()
+
+ // CHECK: 0.255413, 0.394229, 2.99448
+ %vec_x = arith.constant dense<[0.25, 0.375, 0.995]> : vector<3xf32>
+ call @atanh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+ return
+}
+
+
// -------------------------------------------------------------------------- //
// Cbrt.
// -------------------------------------------------------------------------- //
@@ -696,6 +893,11 @@ func.func @main() {
call @cos(): () -> ()
call @atan() : () -> ()
call @atan2() : () -> ()
+ call @sinh() : () -> ()
+ call @cosh() : () -> ()
+ call @asinh() : () -> ()
+ call @acosh() : () -> ()
+ call @atanh() : () -> ()
call @cbrt() : () -> ()
call @floorf() : () -> ()
call @ceilf() : () -> ()
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work! I have added some comments please address.
Why are these called polynomial approximations when you aren't approximating these functions with polynomials?? |
@zjgarvey You are right! @jinchen62 You should move them to https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp |
314d9a2
to
6d32d4a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
I will test iree tests before merge. |
3805e68
to
1d3a1a8
Compare
@pashu123 Could you review again? I removed the expand pattern of tanh op since it has polynomial approximation. And I added a pass including all math expand patterns so on IREE side we could change iree-org/iree@06febe5, if this pass doesn't make sense I would just add those missing expand patterns on IREE. |
I suggest adding those missing patterns. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comments!
f6c7bfc
to
5670262
Compare
Addressed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for putting this up. LGTM!
@pashu123 Could you merge it? I don’t have the permission. Thanks! |
Done. |
need to update llvm with llvm/llvm-project#90718
need to update llvm with llvm/llvm-project#90718
need to update llvm with llvm/llvm-project#90718 Signed-off-by: Lubo Litchev <[email protected]>
No description provided.