@@ -27,11 +27,9 @@ using namespace mlir;
27
27
28
28
namespace {
29
29
30
- enum class AbsFn { abs, sqrt, rsqrt };
31
-
32
- // Returns the absolute value, its square root or its reciprocal square root.
30
+ // Returns the absolute value or its square root.
33
31
Value computeAbs (Value real, Value imag, arith::FastMathFlags fmf,
34
- ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs ) {
32
+ ImplicitLocOpBuilder &b, bool returnSqrt = false ) {
35
33
Value one = b.create <arith::ConstantOp>(real.getType (),
36
34
b.getFloatAttr (real.getType (), 1.0 ));
37
35
@@ -45,13 +43,7 @@ Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
45
43
Value ratioSqPlusOne = b.create <arith::AddFOp>(ratioSq, one, fmf);
46
44
Value result;
47
45
48
- if (fn == AbsFn::rsqrt) {
49
- ratioSqPlusOne = b.create <math::RsqrtOp>(ratioSqPlusOne, fmf);
50
- min = b.create <math::RsqrtOp>(min, fmf);
51
- max = b.create <math::RsqrtOp>(max, fmf);
52
- }
53
-
54
- if (fn == AbsFn::sqrt ) {
46
+ if (returnSqrt) {
55
47
Value quarter = b.create <arith::ConstantOp>(
56
48
real.getType (), b.getFloatAttr (real.getType (), 0.25 ));
57
49
// sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
@@ -871,7 +863,7 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
871
863
872
864
Value real = b.create <complex::ReOp>(elementType, adaptor.getComplex ());
873
865
Value imag = b.create <complex::ImOp>(elementType, adaptor.getComplex ());
874
- Value absSqrt = computeAbs (real, imag, fmf, b, AbsFn:: sqrt );
866
+ Value absSqrt = computeAbs (real, imag, fmf, b, /* returnSqrt= */ true );
875
867
Value argArg = b.create <math::Atan2Op>(imag, real, fmf);
876
868
Value sqrtArg = b.create <arith::MulFOp>(argArg, half, fmf);
877
869
Value cos = b.create <math::CosOp>(sqrtArg, fmf);
@@ -1155,74 +1147,18 @@ struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
1155
1147
LogicalResult
1156
1148
matchAndRewrite (complex::RsqrtOp op, OpAdaptor adaptor,
1157
1149
ConversionPatternRewriter &rewriter) const override {
1158
- mlir::ImplicitLocOpBuilder b (op.getLoc (), rewriter);
1150
+ mlir::ImplicitLocOpBuilder builder (op.getLoc (), rewriter);
1159
1151
auto type = cast<ComplexType>(adaptor.getComplex ().getType ());
1160
1152
auto elementType = cast<FloatType>(type.getElementType ());
1161
1153
1162
- arith::FastMathFlags fmf = op.getFastMathFlagsAttr ().getValue ();
1163
-
1164
- auto cst = [&](APFloat v) {
1165
- return b.create <arith::ConstantOp>(elementType,
1166
- b.getFloatAttr (elementType, v));
1167
- };
1168
- const auto &floatSemantics = elementType.getFloatSemantics ();
1169
- Value zero = cst (APFloat::getZero (floatSemantics));
1170
- Value inf = cst (APFloat::getInf (floatSemantics));
1171
- Value negHalf = b.create <arith::ConstantOp>(
1172
- elementType, b.getFloatAttr (elementType, -0.5 ));
1173
- Value nan = cst (APFloat::getNaN (floatSemantics));
1174
-
1175
- Value real = b.create <complex::ReOp>(elementType, adaptor.getComplex ());
1176
- Value imag = b.create <complex::ImOp>(elementType, adaptor.getComplex ());
1177
- Value absRsqrt = computeAbs (real, imag, fmf, b, AbsFn::rsqrt);
1178
- Value argArg = b.create <math::Atan2Op>(imag, real, fmf);
1179
- Value rsqrtArg = b.create <arith::MulFOp>(argArg, negHalf, fmf);
1180
- Value cos = b.create <math::CosOp>(rsqrtArg, fmf);
1181
- Value sin = b.create <math::SinOp>(rsqrtArg, fmf);
1182
-
1183
- Value resultReal = b.create <arith::MulFOp>(absRsqrt, cos , fmf);
1184
- Value resultImag = b.create <arith::MulFOp>(absRsqrt, sin , fmf);
1185
-
1186
- if (!arith::bitEnumContainsAll (fmf, arith::FastMathFlags::nnan |
1187
- arith::FastMathFlags::ninf)) {
1188
- Value negOne = b.create <arith::ConstantOp>(
1189
- elementType, b.getFloatAttr (elementType, -1 ));
1190
-
1191
- Value realSignedZero = b.create <math::CopySignOp>(zero, real, fmf);
1192
- Value imagSignedZero = b.create <math::CopySignOp>(zero, imag, fmf);
1193
- Value negImagSignedZero =
1194
- b.create <arith::MulFOp>(negOne, imagSignedZero, fmf);
1154
+ Value c = builder.create <arith::ConstantOp>(
1155
+ elementType, builder.getFloatAttr (elementType, -0.5 ));
1156
+ Value d = builder.create <arith::ConstantOp>(
1157
+ elementType, builder.getFloatAttr (elementType, 0 ));
1195
1158
1196
- Value absReal = b.create <math::AbsFOp>(real, fmf);
1197
- Value absImag = b.create <math::AbsFOp>(imag, fmf);
1198
-
1199
- Value absImagIsInf =
1200
- b.create <arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
1201
- Value realIsNan =
1202
- b.create <arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real, fmf);
1203
- Value realIsInf =
1204
- b.create <arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
1205
- Value inIsNanInf = b.create <arith::AndIOp>(absImagIsInf, realIsNan);
1206
-
1207
- Value resultIsZero = b.create <arith::OrIOp>(inIsNanInf, realIsInf);
1208
-
1209
- resultReal =
1210
- b.create <arith::SelectOp>(resultIsZero, realSignedZero, resultReal);
1211
- resultImag = b.create <arith::SelectOp>(resultIsZero, negImagSignedZero,
1212
- resultImag);
1213
- }
1214
-
1215
- Value isRealZero =
1216
- b.create <arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero, fmf);
1217
- Value isImagZero =
1218
- b.create <arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
1219
- Value isZero = b.create <arith::AndIOp>(isRealZero, isImagZero);
1220
-
1221
- resultReal = b.create <arith::SelectOp>(isZero, inf, resultReal);
1222
- resultImag = b.create <arith::SelectOp>(isZero, nan , resultImag);
1223
-
1224
- rewriter.replaceOpWithNewOp <complex::CreateOp>(op, type, resultReal,
1225
- resultImag);
1159
+ rewriter.replaceOp (op,
1160
+ {powOpConversionImpl (builder, type, adaptor.getComplex (),
1161
+ c, d, op.getFastmath ())});
1226
1162
return success ();
1227
1163
}
1228
1164
};
0 commit comments