Skip to content

Commit b4e7b56

Browse files
authored
Revert "Fix rsqrt inaccuracies." (#88705)
Reverts #88691
1 parent c7bd284 commit b4e7b56

File tree

2 files changed

+13
-92
lines changed

2 files changed

+13
-92
lines changed

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

Lines changed: 12 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,9 @@ using namespace mlir;
2727

2828
namespace {
2929

30-
enum class AbsFn { abs, sqrt, rsqrt };
31-
32-
// Returns the absolute value, its square root or its reciprocal square root.
30+
// Returns the absolute value or its square root.
3331
Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
34-
ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) {
32+
ImplicitLocOpBuilder &b, bool returnSqrt = false) {
3533
Value one = b.create<arith::ConstantOp>(real.getType(),
3634
b.getFloatAttr(real.getType(), 1.0));
3735

@@ -45,13 +43,7 @@ Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
4543
Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmf);
4644
Value result;
4745

48-
if (fn == AbsFn::rsqrt) {
49-
ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmf);
50-
min = b.create<math::RsqrtOp>(min, fmf);
51-
max = b.create<math::RsqrtOp>(max, fmf);
52-
}
53-
54-
if (fn == AbsFn::sqrt) {
46+
if (returnSqrt) {
5547
Value quarter = b.create<arith::ConstantOp>(
5648
real.getType(), b.getFloatAttr(real.getType(), 0.25));
5749
// sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
@@ -871,7 +863,7 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
871863

872864
Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
873865
Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
874-
Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt);
866+
Value absSqrt = computeAbs(real, imag, fmf, b, /*returnSqrt=*/true);
875867
Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
876868
Value sqrtArg = b.create<arith::MulFOp>(argArg, half, fmf);
877869
Value cos = b.create<math::CosOp>(sqrtArg, fmf);
@@ -1155,74 +1147,18 @@ struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
11551147
LogicalResult
11561148
matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
11571149
ConversionPatternRewriter &rewriter) const override {
1158-
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1150+
mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
11591151
auto type = cast<ComplexType>(adaptor.getComplex().getType());
11601152
auto elementType = cast<FloatType>(type.getElementType());
11611153

1162-
arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
1163-
1164-
auto cst = [&](APFloat v) {
1165-
return b.create<arith::ConstantOp>(elementType,
1166-
b.getFloatAttr(elementType, v));
1167-
};
1168-
const auto &floatSemantics = elementType.getFloatSemantics();
1169-
Value zero = cst(APFloat::getZero(floatSemantics));
1170-
Value inf = cst(APFloat::getInf(floatSemantics));
1171-
Value negHalf = b.create<arith::ConstantOp>(
1172-
elementType, b.getFloatAttr(elementType, -0.5));
1173-
Value nan = cst(APFloat::getNaN(floatSemantics));
1174-
1175-
Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
1176-
Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
1177-
Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt);
1178-
Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
1179-
Value rsqrtArg = b.create<arith::MulFOp>(argArg, negHalf, fmf);
1180-
Value cos = b.create<math::CosOp>(rsqrtArg, fmf);
1181-
Value sin = b.create<math::SinOp>(rsqrtArg, fmf);
1182-
1183-
Value resultReal = b.create<arith::MulFOp>(absRsqrt, cos, fmf);
1184-
Value resultImag = b.create<arith::MulFOp>(absRsqrt, sin, fmf);
1185-
1186-
if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
1187-
arith::FastMathFlags::ninf)) {
1188-
Value negOne = b.create<arith::ConstantOp>(
1189-
elementType, b.getFloatAttr(elementType, -1));
1190-
1191-
Value realSignedZero = b.create<math::CopySignOp>(zero, real, fmf);
1192-
Value imagSignedZero = b.create<math::CopySignOp>(zero, imag, fmf);
1193-
Value negImagSignedZero =
1194-
b.create<arith::MulFOp>(negOne, imagSignedZero, fmf);
1154+
Value c = builder.create<arith::ConstantOp>(
1155+
elementType, builder.getFloatAttr(elementType, -0.5));
1156+
Value d = builder.create<arith::ConstantOp>(
1157+
elementType, builder.getFloatAttr(elementType, 0));
11951158

1196-
Value absReal = b.create<math::AbsFOp>(real, fmf);
1197-
Value absImag = b.create<math::AbsFOp>(imag, fmf);
1198-
1199-
Value absImagIsInf =
1200-
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
1201-
Value realIsNan =
1202-
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real, fmf);
1203-
Value realIsInf =
1204-
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
1205-
Value inIsNanInf = b.create<arith::AndIOp>(absImagIsInf, realIsNan);
1206-
1207-
Value resultIsZero = b.create<arith::OrIOp>(inIsNanInf, realIsInf);
1208-
1209-
resultReal =
1210-
b.create<arith::SelectOp>(resultIsZero, realSignedZero, resultReal);
1211-
resultImag = b.create<arith::SelectOp>(resultIsZero, negImagSignedZero,
1212-
resultImag);
1213-
}
1214-
1215-
Value isRealZero =
1216-
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero, fmf);
1217-
Value isImagZero =
1218-
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
1219-
Value isZero = b.create<arith::AndIOp>(isRealZero, isImagZero);
1220-
1221-
resultReal = b.create<arith::SelectOp>(isZero, inf, resultReal);
1222-
resultImag = b.create<arith::SelectOp>(isZero, nan, resultImag);
1223-
1224-
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
1225-
resultImag);
1159+
rewriter.replaceOp(op,
1160+
{powOpConversionImpl(builder, type, adaptor.getComplex(),
1161+
c, d, op.getFastmath())});
12261162
return success();
12271163
}
12281164
};

mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -837,21 +837,6 @@ func.func @complex_rsqrt(%arg: complex<f32>) -> complex<f32> {
837837
return %rsqrt : complex<f32>
838838
}
839839

840-
// CHECK-COUNT-5: arith.select
841-
// CHECK-NOT: arith.select
842-
843-
// -----
844-
845-
// CHECK-LABEL: func @complex_rsqrt_nnan_ninf
846-
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
847-
func.func @complex_rsqrt_nnan_ninf(%arg: complex<f32>) -> complex<f32> {
848-
%sqrt = complex.rsqrt %arg fastmath<nnan,ninf> : complex<f32>
849-
return %sqrt : complex<f32>
850-
}
851-
852-
// CHECK-COUNT-3: arith.select
853-
// CHECK-NOT: arith.select
854-
855840
// -----
856841

857842
// CHECK-LABEL: func.func @complex_angle
@@ -2118,4 +2103,4 @@ func.func @complex_tanh_with_fmf(%arg: complex<f32>) -> complex<f32> {
21182103
// CHECK: %[[NUM:.*]] = complex.create %[[TANH_A]], %[[TAN_B]] : complex<f32>
21192104
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
21202105
// CHECK: %[[MUL:.*]] = arith.mulf %[[TANH_A]], %[[TAN_B]] fastmath<nnan,contract> : f32
2121-
// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex<f32>
2106+
// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex<f32>

0 commit comments

Comments
 (0)