@@ -73,14 +73,14 @@ static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
73
73
ImplicitLocOpBuilder b (op->getLoc (), rewriter);
74
74
Value operand = op.getOperand ();
75
75
Type opType = operand.getType ();
76
- Value exp = b.create <math::ExpOp>(operand);
77
76
78
- Value one = createFloatConst (op->getLoc (), opType, 1.0 , rewriter);
79
- Value nexp = b.create <arith::DivFOp>(one, exp);
77
+ Value exp = b.create <math::ExpOp>(operand);
78
+ Value neg = b.create <arith::NegFOp>(operand);
79
+ Value nexp = b.create <math::ExpOp>(neg);
80
80
Value sub = b.create <arith::SubFOp>(exp, nexp);
81
- Value two = createFloatConst (op->getLoc (), opType, 2.0 , rewriter);
82
- Value div = b.create <arith::DivFOp >(sub, two );
83
- rewriter.replaceOp (op, div );
81
+ Value half = createFloatConst (op->getLoc (), opType, 0.5 , rewriter);
82
+ Value res = b.create <arith::MulFOp >(sub, half );
83
+ rewriter.replaceOp (op, res );
84
84
return success ();
85
85
}
86
86
@@ -89,14 +89,14 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
89
89
ImplicitLocOpBuilder b (op->getLoc (), rewriter);
90
90
Value operand = op.getOperand ();
91
91
Type opType = operand.getType ();
92
- Value exp = b.create <math::ExpOp>(operand);
93
92
94
- Value one = createFloatConst (op->getLoc (), opType, 1.0 , rewriter);
95
- Value nexp = b.create <arith::DivFOp>(one, exp);
93
+ Value exp = b.create <math::ExpOp>(operand);
94
+ Value neg = b.create <arith::NegFOp>(operand);
95
+ Value nexp = b.create <math::ExpOp>(neg);
96
96
Value add = b.create <arith::AddFOp>(exp, nexp);
97
- Value two = createFloatConst (op->getLoc (), opType, 2.0 , rewriter);
98
- Value div = b.create <arith::DivFOp >(add, two );
99
- rewriter.replaceOp (op, div );
97
+ Value half = createFloatConst (op->getLoc (), opType, 0.5 , rewriter);
98
+ Value res = b.create <arith::MulFOp >(add, half );
99
+ rewriter.replaceOp (op, res );
100
100
return success ();
101
101
}
102
102
@@ -152,6 +152,57 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
152
152
return success ();
153
153
}
154
154
155
+ // asinh(float x) -> log(x + sqrt(x**2 + 1))
156
+ static LogicalResult convertAsinhOp (math::AsinhOp op,
157
+ PatternRewriter &rewriter) {
158
+ ImplicitLocOpBuilder b (op->getLoc (), rewriter);
159
+ Value operand = op.getOperand ();
160
+ Type opType = operand.getType ();
161
+
162
+ Value one = createFloatConst (op->getLoc (), opType, 1.0 , rewriter);
163
+ Value fma = b.create <math::FmaOp>(operand, operand, one);
164
+ Value sqrt = b.create <math::SqrtOp>(fma);
165
+ Value add = b.create <arith::AddFOp>(operand, sqrt);
166
+ Value res = b.create <math::LogOp>(add);
167
+ rewriter.replaceOp (op, res);
168
+ return success ();
169
+ }
170
+
171
+ // acosh(float x) -> log(x + sqrt(x**2 - 1))
172
+ static LogicalResult convertAcoshOp (math::AcoshOp op,
173
+ PatternRewriter &rewriter) {
174
+ ImplicitLocOpBuilder b (op->getLoc (), rewriter);
175
+ Value operand = op.getOperand ();
176
+ Type opType = operand.getType ();
177
+
178
+ Value negOne = createFloatConst (op->getLoc (), opType, -1.0 , rewriter);
179
+ Value fma = b.create <math::FmaOp>(operand, operand, negOne);
180
+ Value sqrt = b.create <math::SqrtOp>(fma);
181
+ Value add = b.create <arith::AddFOp>(operand, sqrt);
182
+ Value res = b.create <math::LogOp>(add);
183
+ rewriter.replaceOp (op, res);
184
+ return success ();
185
+ }
186
+
187
+ // atanh(float x) -> log((1 + x) / (1 - x)) / 2
188
+ static LogicalResult convertAtanhOp (math::AtanhOp op,
189
+ PatternRewriter &rewriter) {
190
+ ImplicitLocOpBuilder b (op->getLoc (), rewriter);
191
+ Value operand = op.getOperand ();
192
+ Type opType = operand.getType ();
193
+
194
+ Value one = createFloatConst (op->getLoc (), opType, 1.0 , rewriter);
195
+ Value add = b.create <arith::AddFOp>(operand, one);
196
+ Value neg = b.create <arith::NegFOp>(operand);
197
+ Value sub = b.create <arith::AddFOp>(neg, one);
198
+ Value div = b.create <arith::DivFOp>(add, sub);
199
+ Value log = b.create <math::LogOp>(div);
200
+ Value half = createFloatConst (op->getLoc (), opType, 0.5 , rewriter);
201
+ Value res = b.create <arith::MulFOp>(log, half);
202
+ rewriter.replaceOp (op, res);
203
+ return success ();
204
+ }
205
+
155
206
static LogicalResult convertFmaFOp (math::FmaOp op, PatternRewriter &rewriter) {
156
207
ImplicitLocOpBuilder b (op->getLoc (), rewriter);
157
208
Value operandA = op.getOperand (0 );
@@ -584,6 +635,18 @@ void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
584
635
patterns.add (convertTanhOp);
585
636
}
586
637
638
+ void mlir::populateExpandAsinhPattern (RewritePatternSet &patterns) {
639
+ patterns.add (convertAsinhOp);
640
+ }
641
+
642
+ void mlir::populateExpandAcoshPattern (RewritePatternSet &patterns) {
643
+ patterns.add (convertAcoshOp);
644
+ }
645
+
646
+ void mlir::populateExpandAtanhPattern (RewritePatternSet &patterns) {
647
+ patterns.add (convertAtanhOp);
648
+ }
649
+
587
650
void mlir::populateExpandFmaFPattern (RewritePatternSet &patterns) {
588
651
patterns.add (convertFmaFOp);
589
652
}
0 commit comments