-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][complex] Prevent underflow in complex.abs #79786
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][complex] Prevent underflow in complex.abs #79786
Conversation
@llvm/pr-subscribers-mlir-complex @llvm/pr-subscribers-mlir Author: Kai Sasaki (Lewuathe) ChangesThe previous PR was not correct on the way to handle the negative value. It is necessary to take the absolute value of the given real (or imaginary) part to be multiplied with the sqrt part. See: #76316 Full diff: https://github.com/llvm/llvm-project/pull/79786.diff 4 Files Affected:
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 4c9dad9e2c17312..6e0eddab0e3aae0 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -26,29 +26,59 @@ namespace mlir {
using namespace mlir;
namespace {
+// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op.getLoc();
- auto type = op.getType();
+ mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
- Value real =
- rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
- Value imag =
- rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
- Value realSqr =
- rewriter.create<arith::MulFOp>(loc, real, real, fmf.getValue());
- Value imagSqr =
- rewriter.create<arith::MulFOp>(loc, imag, imag, fmf.getValue());
- Value sqNorm =
- rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr, fmf.getValue());
-
- rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
+ Type elementType = op.getType();
+ Value arg = adaptor.getComplex();
+
+ Value zero =
+ b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
+ Value one = b.create<arith::ConstantOp>(elementType,
+ b.getFloatAttr(elementType, 1.0));
+
+ Value real = b.create<complex::ReOp>(elementType, arg);
+ Value imag = b.create<complex::ImOp>(elementType, arg);
+
+ Value realIsZero =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
+ Value imagIsZero =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
+
+ // Real > Imag
+ Value imagDivReal = b.create<arith::DivFOp>(imag, real, fmf.getValue());
+ Value imagSq =
+ b.create<arith::MulFOp>(imagDivReal, imagDivReal, fmf.getValue());
+ Value imagSqPlusOne = b.create<arith::AddFOp>(imagSq, one, fmf.getValue());
+ Value imagSqrt = b.create<math::SqrtOp>(imagSqPlusOne, fmf.getValue());
+ Value realAbs = b.create<math::AbsFOp>(real, fmf.getValue());
+ Value absImag = b.create<arith::MulFOp>(imagSqrt, realAbs, fmf.getValue());
+
+ // Real <= Imag
+ Value realDivImag = b.create<arith::DivFOp>(real, imag, fmf.getValue());
+ Value realSq =
+ b.create<arith::MulFOp>(realDivImag, realDivImag, fmf.getValue());
+ Value realSqPlusOne = b.create<arith::AddFOp>(realSq, one, fmf.getValue());
+ Value realSqrt = b.create<math::SqrtOp>(realSqPlusOne, fmf.getValue());
+ Value imagAbs = b.create<math::AbsFOp>(imag, fmf.getValue());
+ Value absReal = b.create<arith::MulFOp>(realSqrt, imagAbs, fmf.getValue());
+
+ rewriter.replaceOpWithNewOp<arith::SelectOp>(
+ op, realIsZero, imag,
+ b.create<arith::SelectOp>(
+ imagIsZero, real,
+ b.create<arith::SelectOp>(
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, real, imag),
+ absImag, absReal)));
+
return success();
}
};
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 8fa29ea43854a4b..de519f6a706ab03 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -7,13 +7,30 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
%abs = complex.abs %arg: complex<f32>
return %abs : f32
}
+
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK-DAG: %[[REAL_SQ:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32
-// CHECK-DAG: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32
-// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] : f32
-// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
-// CHECK: return %[[NORM]] : f32
+// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
+// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32
+// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
+// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
+// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] : f32
+// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] : f32
+// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
+// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
+// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] : f32
+// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] : f32
+// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
+// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32
+// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32
+// CHECK: return %[[ABS3]] : f32
// -----
@@ -241,12 +258,28 @@ func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
%log = complex.log %arg: complex<f32>
return %log : complex<f32>
}
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32
-// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32
-// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32
-// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
+// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
+// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32
+// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
+// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
+// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] : f32
+// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] : f32
+// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
+// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
+// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] : f32
+// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] : f32
+// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
+// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32
+// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32
// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] : f32
// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
@@ -469,12 +502,28 @@ func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[REAL_IS_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
// CHECK: %[[IMAG_IS_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
// CHECK: %[[IS_ZERO:.*]] = arith.andi %[[REAL_IS_ZERO]], %[[IMAG_IS_ZERO]] : i1
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL2]], %[[REAL2]] : f32
-// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG2]], %[[IMAG2]] : f32
-// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32
-// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
+// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL2]], %[[ZERO]] : f32
+// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG2]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG2]], %[[REAL2]] : f32
+// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
+// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
+// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL2]] : f32
+// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] : f32
+// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL2]], %[[IMAG2]] : f32
+// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
+// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
+// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG2]] : f32
+// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] : f32
+// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL2]], %[[IMAG2]] : f32
+// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
+// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL2]], %[[ABS1]] : f32
+// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG2]], %[[ABS2]] : f32
// CHECK: %[[REAL_SIGN:.*]] = arith.divf %[[REAL]], %[[NORM]] : f32
// CHECK: %[[IMAG_SIGN:.*]] = arith.divf %[[IMAG]], %[[NORM]] : f32
// CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex<f32>
@@ -716,13 +765,29 @@ func.func @complex_abs_with_fmf(%arg: complex<f32>) -> f32 {
%abs = complex.abs %arg fastmath<nnan,contract> : complex<f32>
return %abs : f32
}
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK-DAG: %[[REAL_SQ:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK-DAG: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] fastmath<nnan,contract> : f32
-// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
-// CHECK: return %[[NORM]] : f32
+// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
+// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
+// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32
+// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32
+// CHECK: return %[[ABS3]] : f32
// -----
@@ -807,12 +872,28 @@ func.func @complex_log_with_fmf(%arg: complex<f32>) -> complex<f32> {
%log = complex.log %arg fastmath<nnan,contract> : complex<f32>
return %log : complex<f32>
}
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
+// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
+// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
+// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32
+// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32
// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] fastmath<nnan,contract> : f32
// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
diff --git a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
index 9983dd46f094334..c01f3698c73868c 100644
--- a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
@@ -6,12 +6,31 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
%abs = complex.abs %arg: complex<f32>
return %abs : f32
}
+// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
+// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
// CHECK: %[[REAL:.*]] = llvm.extractvalue %[[ARG]][0] : ![[C_TY]]
// CHECK: %[[IMAG:.*]] = llvm.extractvalue %[[ARG]][1] : ![[C_TY]]
-// CHECK-DAG: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL]], %[[REAL]] : f32
-// CHECK-DAG: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG]], %[[IMAG]] : f32
-// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[REAL_SQ]], %[[IMAG_SQ]] : f32
-// CHECK: %[[NORM:.*]] = llvm.intr.sqrt(%[[SQ_NORM]]) : (f32) -> f32
+// CHECK: %[[REAL_IS_ZERO:.*]] = llvm.fcmp "oeq" %[[REAL]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_IS_ZERO:.*]] = llvm.fcmp "oeq" %[[IMAG]], %[[ZERO]] : f32
+
+// CHECK: %[[IMAG_DIV_REAL:.*]] = llvm.fdiv %[[IMAG]], %[[REAL]] : f32
+// CHECK: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = llvm.fadd %[[IMAG_SQ]], %[[ONE]] : f32
+// CHECK: %[[IMAG_SQRT:.*]] = llvm.intr.sqrt(%[[IMAG_SQ_PLUS_ONE]]) : (f32) -> f32
+// CHECK: %[[REAL_ABS:.*]] = llvm.intr.fabs(%[[REAL]]) : (f32) -> f32
+// CHECK: %[[ABS_IMAG:.*]] = llvm.fmul %[[IMAG_SQRT]], %[[REAL_ABS]] : f32
+
+// CHECK: %[[REAL_DIV_IMAG:.*]] = llvm.fdiv %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = llvm.fadd %[[REAL_SQ]], %[[ONE]] : f32
+// CHECK: %[[REAL_SQRT:.*]] = llvm.intr.sqrt(%[[REAL_SQ_PLUS_ONE]]) : (f32) -> f32
+// CHECK: %[[IMAG_ABS:.*]] = llvm.intr.fabs(%[[IMAG]]) : (f32) -> f32
+// CHECK: %[[ABS_REAL:.*]] = llvm.fmul %[[REAL_SQRT]], %[[IMAG_ABS]] : f32
+
+// CHECK: %[[REAL_GT_IMAG:.*]] = llvm.fcmp "ogt" %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[ABS1:.*]] = llvm.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : i1, f32
+// CHECK: %[[ABS2:.*]] = llvm.select %[[IMAG_IS_ZERO]], %[[REAL]], %[[ABS1]] : i1, f32
+// CHECK: %[[NORM:.*]] = llvm.select %[[REAL_IS_ZERO]], %[[IMAG]], %[[ABS2]] : i1, f32
// CHECK: llvm.return %[[NORM]] : f32
// CHECK-LABEL: llvm.func @complex_eq
diff --git a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
index 349b92a7aefa2e3..efdf057f08df271 100644
--- a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
+++ b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
@@ -106,6 +106,27 @@ func.func @angle(%arg: complex<f32>) -> f32 {
func.return %angle : f32
}
+func.func @test_element_f64(%input: tensor<?xcomplex<f64>>,
+ %func: (complex<f64>) -> f64) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %size = tensor.dim %input, %c0: tensor<?xcomplex<f64>>
+
+ scf.for %i = %c0 to %size step %c1 {
+ %elem = tensor.extract %input[%i]: tensor<?xcomplex<f64>>
+
+ %val = func.call_indirect %func(%elem) : (complex<f64>) -> f64
+ vector.print %val : f64
+ scf.yield
+ }
+ func.return
+}
+
+func.func @abs(%arg: complex<f64>) -> f64 {
+ %abs = complex.abs %arg : complex<f64>
+ func.return %abs : f64
+}
+
func.func @entry() {
// complex.sqrt test
%sqrt_test = arith.constant dense<[
@@ -300,5 +321,32 @@ func.func @entry() {
call @test_element(%angle_test_cast, %angle_func)
: (tensor<?xcomplex<f32>>, (complex<f32>) -> f32) -> ()
+ // complex.abs test
+ %abs_test = arith.constant dense<[
+ (1.0, 1.0),
+ // CHECK: 1.414
+ (1.0e300, 1.0e300),
+ // CHECK-NEXT: 1.41421e+300
+ (1.0e-300, 1.0e-300),
+ // CHECK-NEXT: 1.41421e-300
+ (5.0, 0.0),
+ // CHECK-NEXT: 5
+ (0.0, 6.0),
+ // CHECK-NEXT: 6
+ (7.0, 8.0),
+ // CHECK-NEXT: 10.6301
+ (-1.0, -1.0),
+ // CHECK-NEXT: 1.414
+ (-1.0e300, -1.0e300)
+ // CHECK-NEXT: 1.41421e+300
+ ]> : tensor<8xcomplex<f64>>
+ %abs_test_cast = tensor.cast %abs_test
+ : tensor<8xcomplex<f64>> to tensor<?xcomplex<f64>>
+
+ %abs_func = func.constant @abs : (complex<f64>) -> f64
+
+ call @test_element_f64(%abs_test_cast, %abs_func)
+ : (tensor<?xcomplex<f64>>, (complex<f64>) -> f64) -> ()
+
func.return
}
|
b.create<arith::MulFOp>(imagDivReal, imagDivReal, fmf.getValue()); | ||
Value imagSqPlusOne = b.create<arith::AddFOp>(imagSq, one, fmf.getValue()); | ||
Value imagSqrt = b.create<math::SqrtOp>(imagSqPlusOne, fmf.getValue()); | ||
Value realAbs = b.create<math::AbsFOp>(real, fmf.getValue()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to get the absolute value in case of the negative value.
// CHECK-NEXT: 6 | ||
(7.0, 8.0), | ||
// CHECK-NEXT: 10.6301 | ||
(-1.0, -1.0), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add test case with the negative values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@joker-eph @matthias-springer I've fixed a bug in the previous PR that caused the integration test failure in the previous change. Could you review this change when you get a chance?
The build which previously failed is now finished successfully. |
This reverts commit 4effff2. It makes `complex.abs(-1)` return `-1`.
The previous PR was not enough about the way to handle the negative value. It is necessary to take the absolute value of the given real (or imaginary) part to be multiplied with the sqrt part in the case of either is zero. See: llvm#76316
The previous PR was not enough about the way to handle the negative value. It is necessary to take the absolute value of the given real (or imaginary) part to be multiplied with the sqrt part in the case of either is zero. See: llvm#76316
The previous PR was not correct on the way to handle the negative value. It is necessary to take the absolute value of the given real (or imaginary) part to be multiplied with the sqrt part.
See: #76316