Skip to content

Commit c72fc49

Browse files
committed
[mlir][math] Add Polynomial Approximation for acosh, asinh, atanh ops
1 parent 02ce822 commit c72fc49

File tree

4 files changed

+200
-12
lines changed

4 files changed

+200
-12
lines changed

mlir/include/mlir/Dialect/Math/Transforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ void populateExpandTanPattern(RewritePatternSet &patterns);
3131
void populateExpandSinhPattern(RewritePatternSet &patterns);
3232
void populateExpandCoshPattern(RewritePatternSet &patterns);
3333
void populateExpandTanhPattern(RewritePatternSet &patterns);
34+
void populateExpandAsinhPattern(RewritePatternSet &patterns);
35+
void populateExpandAcoshPattern(RewritePatternSet &patterns);
36+
void populateExpandAtanhPattern(RewritePatternSet &patterns);
3437
void populateExpandFmaFPattern(RewritePatternSet &patterns);
3538
void populateExpandFloorFPattern(RewritePatternSet &patterns);
3639
void populateExpandCeilFPattern(RewritePatternSet &patterns);

mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp

Lines changed: 75 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,14 @@ static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
7373
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
7474
Value operand = op.getOperand();
7575
Type opType = operand.getType();
76-
Value exp = b.create<math::ExpOp>(operand);
7776

78-
Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
79-
Value nexp = b.create<arith::DivFOp>(one, exp);
77+
Value exp = b.create<math::ExpOp>(operand);
78+
Value neg = b.create<arith::NegFOp>(operand);
79+
Value nexp = b.create<math::ExpOp>(neg);
8080
Value sub = b.create<arith::SubFOp>(exp, nexp);
81-
Value two = createFloatConst(op->getLoc(), opType, 2.0, rewriter);
82-
Value div = b.create<arith::DivFOp>(sub, two);
83-
rewriter.replaceOp(op, div);
81+
Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
82+
Value res = b.create<arith::MulFOp>(sub, half);
83+
rewriter.replaceOp(op, res);
8484
return success();
8585
}
8686

@@ -89,14 +89,14 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
8989
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
9090
Value operand = op.getOperand();
9191
Type opType = operand.getType();
92-
Value exp = b.create<math::ExpOp>(operand);
9392

94-
Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
95-
Value nexp = b.create<arith::DivFOp>(one, exp);
93+
Value exp = b.create<math::ExpOp>(operand);
94+
Value neg = b.create<arith::NegFOp>(operand);
95+
Value nexp = b.create<math::ExpOp>(neg);
9696
Value add = b.create<arith::AddFOp>(exp, nexp);
97-
Value two = createFloatConst(op->getLoc(), opType, 2.0, rewriter);
98-
Value div = b.create<arith::DivFOp>(add, two);
99-
rewriter.replaceOp(op, div);
97+
Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
98+
Value res = b.create<arith::MulFOp>(add, half);
99+
rewriter.replaceOp(op, res);
100100
return success();
101101
}
102102

@@ -152,6 +152,57 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
152152
return success();
153153
}
154154

155+
// asinh(float x) -> log(x + sqrt(x**2 + 1))
156+
static LogicalResult convertAsinhOp(math::AsinhOp op,
157+
PatternRewriter &rewriter) {
158+
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
159+
Value operand = op.getOperand();
160+
Type opType = operand.getType();
161+
162+
Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
163+
Value fma = b.create<math::FmaOp>(operand, operand, one);
164+
Value sqrt = b.create<math::SqrtOp>(fma);
165+
Value add = b.create<arith::AddFOp>(operand, sqrt);
166+
Value res = b.create<math::LogOp>(add);
167+
rewriter.replaceOp(op, res);
168+
return success();
169+
}
170+
171+
// acosh(float x) -> log(x + sqrt(x**2 - 1))
172+
static LogicalResult convertAcoshOp(math::AcoshOp op,
173+
PatternRewriter &rewriter) {
174+
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
175+
Value operand = op.getOperand();
176+
Type opType = operand.getType();
177+
178+
Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter);
179+
Value fma = b.create<math::FmaOp>(operand, operand, negOne);
180+
Value sqrt = b.create<math::SqrtOp>(fma);
181+
Value add = b.create<arith::AddFOp>(operand, sqrt);
182+
Value res = b.create<math::LogOp>(add);
183+
rewriter.replaceOp(op, res);
184+
return success();
185+
}
186+
187+
// atanh(float x) -> log((1 + x) / (1 - x)) / 2
188+
static LogicalResult convertAtanhOp(math::AtanhOp op,
189+
PatternRewriter &rewriter) {
190+
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
191+
Value operand = op.getOperand();
192+
Type opType = operand.getType();
193+
194+
Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
195+
Value add = b.create<arith::AddFOp>(operand, one);
196+
Value neg = b.create<arith::NegFOp>(operand);
197+
Value sub = b.create<arith::AddFOp>(neg, one);
198+
Value div = b.create<arith::DivFOp>(add, sub);
199+
Value log = b.create<math::LogOp>(div);
200+
Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
201+
Value res = b.create<arith::MulFOp>(log, half);
202+
rewriter.replaceOp(op, res);
203+
return success();
204+
}
205+
155206
static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
156207
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
157208
Value operandA = op.getOperand(0);
@@ -584,6 +635,18 @@ void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
584635
patterns.add(convertTanhOp);
585636
}
586637

638+
void mlir::populateExpandAsinhPattern(RewritePatternSet &patterns) {
639+
patterns.add(convertAsinhOp);
640+
}
641+
642+
void mlir::populateExpandAcoshPattern(RewritePatternSet &patterns) {
643+
patterns.add(convertAcoshOp);
644+
}
645+
646+
void mlir::populateExpandAtanhPattern(RewritePatternSet &patterns) {
647+
patterns.add(convertAtanhOp);
648+
}
649+
587650
void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) {
588651
patterns.add(convertFmaFOp);
589652
}

mlir/test/lib/Dialect/Math/TestExpandMath.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ void TestExpandMathPass::runOnOperation() {
4242
populateExpandSinhPattern(patterns);
4343
populateExpandCoshPattern(patterns);
4444
populateExpandTanhPattern(patterns);
45+
populateExpandAsinhPattern(patterns);
46+
populateExpandAcoshPattern(patterns);
47+
populateExpandAtanhPattern(patterns);
4548
populateExpandFmaFPattern(patterns);
4649
populateExpandFloorFPattern(patterns);
4750
populateExpandCeilFPattern(patterns);

mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,122 @@ func.func @tanh() {
717717
return
718718
}
719719

720+
// -------------------------------------------------------------------------- //
721+
// Asinh.
722+
// -------------------------------------------------------------------------- //
723+
724+
func.func @asinh_f32(%a : f32) {
725+
%r = math.asinh %a : f32
726+
vector.print %r : f32
727+
return
728+
}
729+
730+
func.func @asinh_3xf32(%a : vector<3xf32>) {
731+
%r = math.asinh %a : vector<3xf32>
732+
vector.print %r : vector<3xf32>
733+
return
734+
}
735+
736+
func.func @asinh() {
737+
// CHECK: 0
738+
%zero = arith.constant 0.0 : f32
739+
call @asinh_f32(%zero) : (f32) -> ()
740+
741+
// CHECK: 0.881374
742+
%cst1 = arith.constant 1.0 : f32
743+
call @asinh_f32(%cst1) : (f32) -> ()
744+
745+
// CHECK: -0.881374
746+
%cst2 = arith.constant -1.0 : f32
747+
call @asinh_f32(%cst2) : (f32) -> ()
748+
749+
// CHECK: 1.81845
750+
%cst3 = arith.constant 3.0 : f32
751+
call @asinh_f32(%cst3) : (f32) -> ()
752+
753+
// CHECK: 0.247466, 0.790169, 1.44364
754+
%vec_x = arith.constant dense<[0.25, 0.875, 2.0]> : vector<3xf32>
755+
call @asinh_3xf32(%vec_x) : (vector<3xf32>) -> ()
756+
757+
return
758+
}
759+
760+
// -------------------------------------------------------------------------- //
761+
// Acosh.
762+
// -------------------------------------------------------------------------- //
763+
764+
func.func @acosh_f32(%a : f32) {
765+
%r = math.acosh %a : f32
766+
vector.print %r : f32
767+
return
768+
}
769+
770+
func.func @acosh_3xf32(%a : vector<3xf32>) {
771+
%r = math.acosh %a : vector<3xf32>
772+
vector.print %r : vector<3xf32>
773+
return
774+
}
775+
776+
func.func @acosh() {
777+
// CHECK: 0
778+
%zero = arith.constant 1.0 : f32
779+
call @acosh_f32(%zero) : (f32) -> ()
780+
781+
// CHECK: 1.31696
782+
%cst1 = arith.constant 2.0 : f32
783+
call @acosh_f32(%cst1) : (f32) -> ()
784+
785+
// CHECK: 2.99322
786+
%cst2 = arith.constant 10.0 : f32
787+
call @acosh_f32(%cst2) : (f32) -> ()
788+
789+
// CHECK: 0.962424, 1.76275, 2.47789
790+
%vec_x = arith.constant dense<[1.5, 3.0, 6.0]> : vector<3xf32>
791+
call @acosh_3xf32(%vec_x) : (vector<3xf32>) -> ()
792+
793+
return
794+
}
795+
796+
// -------------------------------------------------------------------------- //
797+
// Atanh.
798+
// -------------------------------------------------------------------------- //
799+
800+
func.func @atanh_f32(%a : f32) {
801+
%r = math.atanh %a : f32
802+
vector.print %r : f32
803+
return
804+
}
805+
806+
func.func @atanh_3xf32(%a : vector<3xf32>) {
807+
%r = math.atanh %a : vector<3xf32>
808+
vector.print %r : vector<3xf32>
809+
return
810+
}
811+
812+
func.func @atanh() {
813+
// CHECK: 0
814+
%zero = arith.constant 0.0 : f32
815+
call @atanh_f32(%zero) : (f32) -> ()
816+
817+
// CHECK: 0.549306
818+
%cst1 = arith.constant 0.5 : f32
819+
call @atanh_f32(%cst1) : (f32) -> ()
820+
821+
// CHECK: -0.549306
822+
%cst2 = arith.constant -0.5 : f32
823+
call @atanh_f32(%cst2) : (f32) -> ()
824+
825+
// CHECK: inf
826+
%cst3 = arith.constant 1.0 : f32
827+
call @atanh_f32(%cst3) : (f32) -> ()
828+
829+
// CHECK: 0.255413, 0.394229, 2.99448
830+
%vec_x = arith.constant dense<[0.25, 0.375, 0.995]> : vector<3xf32>
831+
call @atanh_3xf32(%vec_x) : (vector<3xf32>) -> ()
832+
833+
return
834+
}
835+
720836
func.func @main() {
721837
call @exp2f() : () -> ()
722838
call @roundf() : () -> ()
@@ -725,5 +841,8 @@ func.func @main() {
725841
call @sinh() : () -> ()
726842
call @cosh() : () -> ()
727843
call @tanh() : () -> ()
844+
call @asinh() : () -> ()
845+
call @acosh() : () -> ()
846+
call @atanh() : () -> ()
728847
return
729848
}

0 commit comments

Comments
 (0)