-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][complex] Prevent underflow in complex.abs #76316
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-complex Author: Kai Sasaki (Lewuathe) Changes
See: #62011 Full diff: https://github.com/llvm/llvm-project/pull/76316.diff 4 Files Affected:
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index bf753c7062f366..7c1db57b55f996 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -26,29 +26,57 @@ namespace mlir {
using namespace mlir;
namespace {
+// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op.getLoc();
- auto type = op.getType();
+ mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
- Value real =
- rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
- Value imag =
- rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
- Value realSqr =
- rewriter.create<arith::MulFOp>(loc, real, real, fmf.getValue());
- Value imagSqr =
- rewriter.create<arith::MulFOp>(loc, imag, imag, fmf.getValue());
- Value sqNorm =
- rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr, fmf.getValue());
-
- rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
+ Type elementType = op.getType();
+ Value arg = adaptor.getComplex();
+
+ Value zero =
+ b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
+ Value one = b.create<arith::ConstantOp>(elementType,
+ b.getFloatAttr(elementType, 1.0));
+
+ Value real = b.create<complex::ReOp>(elementType, arg);
+ Value imag = b.create<complex::ImOp>(elementType, arg);
+
+ Value realIsZero =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
+ Value imagIsZero =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
+
+ // Real > Imag
+ Value imagDivReal = b.create<arith::DivFOp>(imag, real, fmf.getValue());
+ Value imagSq =
+ b.create<arith::MulFOp>(imagDivReal, imagDivReal, fmf.getValue());
+ Value imagSqPlusOne = b.create<arith::AddFOp>(imagSq, one, fmf.getValue());
+ Value imagSqrt = b.create<math::SqrtOp>(imagSqPlusOne, fmf.getValue());
+ Value absImag = b.create<arith::MulFOp>(imagSqrt, real, fmf.getValue());
+
+ // Real <= Imag
+ Value realDivImag = b.create<arith::DivFOp>(real, imag, fmf.getValue());
+ Value realSq =
+ b.create<arith::MulFOp>(realDivImag, realDivImag, fmf.getValue());
+ Value realSqPlusOne = b.create<arith::AddFOp>(realSq, one, fmf.getValue());
+ Value realSqrt = b.create<math::SqrtOp>(realSqPlusOne, fmf.getValue());
+ Value absReal = b.create<arith::MulFOp>(realSqrt, imag, fmf.getValue());
+
+ rewriter.replaceOpWithNewOp<arith::SelectOp>(
+ op, realIsZero, imag,
+ b.create<arith::SelectOp>(
+ imagIsZero, real,
+ b.create<arith::SelectOp>(
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, real, imag),
+ absImag, absReal)));
+
return success();
}
};
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 3af28150fd5c3f..1028c9aae92c05 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -7,13 +7,28 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
%abs = complex.abs %arg: complex<f32>
return %abs : f32
}
+
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK-DAG: %[[REAL_SQ:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32
-// CHECK-DAG: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32
-// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] : f32
-// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
-// CHECK: return %[[NORM]] : f32
+// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
+// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32
+// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
+// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
+// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL]] : f32
+// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
+// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
+// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG]] : f32
+// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
+// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32
+// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32
+// CHECK: return %[[ABS3]] : f32
// -----
@@ -241,12 +256,26 @@ func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
%log = complex.log %arg: complex<f32>
return %log : complex<f32>
}
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32
-// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32
-// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32
-// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
+// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
+// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32
+// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
+// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
+// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL]] : f32
+// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
+// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
+// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG]] : f32
+// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
+// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32
+// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32
// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] : f32
// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
@@ -469,12 +498,26 @@ func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[REAL_IS_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
// CHECK: %[[IMAG_IS_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
// CHECK: %[[IS_ZERO:.*]] = arith.andi %[[REAL_IS_ZERO]], %[[IMAG_IS_ZERO]] : i1
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL2]], %[[REAL2]] : f32
-// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG2]], %[[IMAG2]] : f32
-// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32
-// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
+// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL2]], %[[ZERO]] : f32
+// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG2]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG2]], %[[REAL2]] : f32
+// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
+// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
+// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL2]] : f32
+// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL2]], %[[IMAG2]] : f32
+// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
+// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
+// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG2]] : f32
+// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL2]], %[[IMAG2]] : f32
+// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
+// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL2]], %[[ABS1]] : f32
+// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG2]], %[[ABS2]] : f32
// CHECK: %[[REAL_SIGN:.*]] = arith.divf %[[REAL]], %[[NORM]] : f32
// CHECK: %[[IMAG_SIGN:.*]] = arith.divf %[[IMAG]], %[[NORM]] : f32
// CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex<f32>
@@ -716,13 +759,27 @@ func.func @complex_abs_with_fmf(%arg: complex<f32>) -> f32 {
%abs = complex.abs %arg fastmath<nnan,contract> : complex<f32>
return %abs : f32
}
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK-DAG: %[[REAL_SQ:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK-DAG: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] fastmath<nnan,contract> : f32
-// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
-// CHECK: return %[[NORM]] : f32
+// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
+// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
+// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32
+// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32
+// CHECK: return %[[ABS3]] : f32
// -----
@@ -807,12 +864,26 @@ func.func @complex_log_with_fmf(%arg: complex<f32>) -> complex<f32> {
%log = complex.log %arg fastmath<nnan,contract> : complex<f32>
return %log : complex<f32>
}
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
+// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
+// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
+// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32
+// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32
// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] fastmath<nnan,contract> : f32
// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
diff --git a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
index 9983dd46f09433..d710dc8e1adeb7 100644
--- a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
@@ -6,12 +6,29 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
%abs = complex.abs %arg: complex<f32>
return %abs : f32
}
+// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
+// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
// CHECK: %[[REAL:.*]] = llvm.extractvalue %[[ARG]][0] : ![[C_TY]]
// CHECK: %[[IMAG:.*]] = llvm.extractvalue %[[ARG]][1] : ![[C_TY]]
-// CHECK-DAG: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL]], %[[REAL]] : f32
-// CHECK-DAG: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG]], %[[IMAG]] : f32
-// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[REAL_SQ]], %[[IMAG_SQ]] : f32
-// CHECK: %[[NORM:.*]] = llvm.intr.sqrt(%[[SQ_NORM]]) : (f32) -> f32
+// CHECK: %[[REAL_IS_ZERO:.*]] = llvm.fcmp "oeq" %[[REAL]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_IS_ZERO:.*]] = llvm.fcmp "oeq" %[[IMAG]], %[[ZERO]] : f32
+
+// CHECK: %[[IMAG_DIV_REAL:.*]] = llvm.fdiv %[[IMAG]], %[[REAL]] : f32
+// CHECK: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = llvm.fadd %[[IMAG_SQ]], %[[ONE]] : f32
+// CHECK: %[[IMAG_SQRT:.*]] = llvm.intr.sqrt(%[[IMAG_SQ_PLUS_ONE]]) : (f32) -> f32
+// CHECK: %[[ABS_IMAG:.*]] = llvm.fmul %[[IMAG_SQRT]], %[[REAL]] : f32
+
+// CHECK: %[[REAL_DIV_IMAG:.*]] = llvm.fdiv %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = llvm.fadd %[[REAL_SQ]], %[[ONE]] : f32
+// CHECK: %[[REAL_SQRT:.*]] = llvm.intr.sqrt(%[[REAL_SQ_PLUS_ONE]]) : (f32) -> f32
+// CHECK: %[[ABS_REAL:.*]] = llvm.fmul %[[REAL_SQRT]], %[[IMAG]] : f32
+
+// CHECK: %[[REAL_GT_IMAG:.*]] = llvm.fcmp "ogt" %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[ABS1:.*]] = llvm.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : i1, f32
+// CHECK: %[[ABS2:.*]] = llvm.select %[[IMAG_IS_ZERO]], %[[REAL]], %[[ABS1]] : i1, f32
+// CHECK: %[[NORM:.*]] = llvm.select %[[REAL_IS_ZERO]], %[[IMAG]], %[[ABS2]] : i1, f32
// CHECK: llvm.return %[[NORM]] : f32
// CHECK-LABEL: llvm.func @complex_eq
diff --git a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
index 349b92a7aefa2e..b7849945b3cf49 100644
--- a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
+++ b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
@@ -106,6 +106,27 @@ func.func @angle(%arg: complex<f32>) -> f32 {
func.return %angle : f32
}
+func.func @test_element_f64(%input: tensor<?xcomplex<f64>>,
+ %func: (complex<f64>) -> f64) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %size = tensor.dim %input, %c0: tensor<?xcomplex<f64>>
+
+ scf.for %i = %c0 to %size step %c1 {
+ %elem = tensor.extract %input[%i]: tensor<?xcomplex<f64>>
+
+ %val = func.call_indirect %func(%elem) : (complex<f64>) -> f64
+ vector.print %val : f64
+ scf.yield
+ }
+ func.return
+}
+
+func.func @abs(%arg: complex<f64>) -> f64 {
+ %abs = complex.abs %arg : complex<f64>
+ func.return %abs : f64
+}
+
func.func @entry() {
// complex.sqrt test
%sqrt_test = arith.constant dense<[
@@ -300,5 +321,22 @@ func.func @entry() {
call @test_element(%angle_test_cast, %angle_func)
: (tensor<?xcomplex<f32>>, (complex<f32>) -> f32) -> ()
+ // complex.abs test
+ %abs_test = arith.constant dense<[
+ (1.0, 1.0),
+ // CHECK: 1.414
+ (1.0e300, 1.0e300),
+ // CHECK-NEXT: 1.41421e+300
+ (1.0e-300, 1.0e-300)
+ // CHECK-NEXT: 1.41421e-300
+ ]> : tensor<3xcomplex<f64>>
+ %abs_test_cast = tensor.cast %abs_test
+ : tensor<3xcomplex<f64>> to tensor<?xcomplex<f64>>
+
+ %abs_func = func.constant @abs : (complex<f64>) -> f64
+
+ call @test_element_f64(%abs_test_cast, %abs_func)
+ : (tensor<?xcomplex<f64>>, (complex<f64>) -> f64) -> ()
+
func.return
}
|
@llvm/pr-subscribers-mlir Author: Kai Sasaki (Lewuathe) Changes
See: #62011 Full diff: https://github.com/llvm/llvm-project/pull/76316.diff 4 Files Affected:
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index bf753c7062f366..7c1db57b55f996 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -26,29 +26,57 @@ namespace mlir {
using namespace mlir;
namespace {
+// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op.getLoc();
- auto type = op.getType();
+ mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
- Value real =
- rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
- Value imag =
- rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
- Value realSqr =
- rewriter.create<arith::MulFOp>(loc, real, real, fmf.getValue());
- Value imagSqr =
- rewriter.create<arith::MulFOp>(loc, imag, imag, fmf.getValue());
- Value sqNorm =
- rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr, fmf.getValue());
-
- rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
+ Type elementType = op.getType();
+ Value arg = adaptor.getComplex();
+
+ Value zero =
+ b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
+ Value one = b.create<arith::ConstantOp>(elementType,
+ b.getFloatAttr(elementType, 1.0));
+
+ Value real = b.create<complex::ReOp>(elementType, arg);
+ Value imag = b.create<complex::ImOp>(elementType, arg);
+
+ Value realIsZero =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
+ Value imagIsZero =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
+
+ // Real > Imag
+ Value imagDivReal = b.create<arith::DivFOp>(imag, real, fmf.getValue());
+ Value imagSq =
+ b.create<arith::MulFOp>(imagDivReal, imagDivReal, fmf.getValue());
+ Value imagSqPlusOne = b.create<arith::AddFOp>(imagSq, one, fmf.getValue());
+ Value imagSqrt = b.create<math::SqrtOp>(imagSqPlusOne, fmf.getValue());
+ Value absImag = b.create<arith::MulFOp>(imagSqrt, real, fmf.getValue());
+
+ // Real <= Imag
+ Value realDivImag = b.create<arith::DivFOp>(real, imag, fmf.getValue());
+ Value realSq =
+ b.create<arith::MulFOp>(realDivImag, realDivImag, fmf.getValue());
+ Value realSqPlusOne = b.create<arith::AddFOp>(realSq, one, fmf.getValue());
+ Value realSqrt = b.create<math::SqrtOp>(realSqPlusOne, fmf.getValue());
+ Value absReal = b.create<arith::MulFOp>(realSqrt, imag, fmf.getValue());
+
+ rewriter.replaceOpWithNewOp<arith::SelectOp>(
+ op, realIsZero, imag,
+ b.create<arith::SelectOp>(
+ imagIsZero, real,
+ b.create<arith::SelectOp>(
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, real, imag),
+ absImag, absReal)));
+
return success();
}
};
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 3af28150fd5c3f..1028c9aae92c05 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -7,13 +7,28 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
%abs = complex.abs %arg: complex<f32>
return %abs : f32
}
+
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK-DAG: %[[REAL_SQ:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32
-// CHECK-DAG: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32
-// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] : f32
-// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
-// CHECK: return %[[NORM]] : f32
+// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
+// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32
+// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
+// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
+// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL]] : f32
+// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
+// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
+// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG]] : f32
+// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
+// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32
+// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32
+// CHECK: return %[[ABS3]] : f32
// -----
@@ -241,12 +256,26 @@ func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
%log = complex.log %arg: complex<f32>
return %log : complex<f32>
}
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32
-// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32
-// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32
-// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
+// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
+// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32
+// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
+// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
+// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL]] : f32
+// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
+// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
+// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG]] : f32
+// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
+// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32
+// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32
// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] : f32
// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
@@ -469,12 +498,26 @@ func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[REAL_IS_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
// CHECK: %[[IMAG_IS_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
// CHECK: %[[IS_ZERO:.*]] = arith.andi %[[REAL_IS_ZERO]], %[[IMAG_IS_ZERO]] : i1
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL2]], %[[REAL2]] : f32
-// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG2]], %[[IMAG2]] : f32
-// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32
-// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
+// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL2]], %[[ZERO]] : f32
+// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG2]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG2]], %[[REAL2]] : f32
+// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
+// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
+// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL2]] : f32
+// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL2]], %[[IMAG2]] : f32
+// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
+// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
+// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG2]] : f32
+// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL2]], %[[IMAG2]] : f32
+// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
+// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL2]], %[[ABS1]] : f32
+// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG2]], %[[ABS2]] : f32
// CHECK: %[[REAL_SIGN:.*]] = arith.divf %[[REAL]], %[[NORM]] : f32
// CHECK: %[[IMAG_SIGN:.*]] = arith.divf %[[IMAG]], %[[NORM]] : f32
// CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex<f32>
@@ -716,13 +759,27 @@ func.func @complex_abs_with_fmf(%arg: complex<f32>) -> f32 {
%abs = complex.abs %arg fastmath<nnan,contract> : complex<f32>
return %abs : f32
}
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK-DAG: %[[REAL_SQ:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK-DAG: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] fastmath<nnan,contract> : f32
-// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
-// CHECK: return %[[NORM]] : f32
+// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
+// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
+// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32
+// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32
+// CHECK: return %[[ABS3]] : f32
// -----
@@ -807,12 +864,26 @@ func.func @complex_log_with_fmf(%arg: complex<f32>) -> complex<f32> {
%log = complex.log %arg fastmath<nnan,contract> : complex<f32>
return %log : complex<f32>
}
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
+// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
+// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
+// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32
+// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32
// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] fastmath<nnan,contract> : f32
// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
diff --git a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
index 9983dd46f09433..d710dc8e1adeb7 100644
--- a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
@@ -6,12 +6,29 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
%abs = complex.abs %arg: complex<f32>
return %abs : f32
}
+// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
+// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
// CHECK: %[[REAL:.*]] = llvm.extractvalue %[[ARG]][0] : ![[C_TY]]
// CHECK: %[[IMAG:.*]] = llvm.extractvalue %[[ARG]][1] : ![[C_TY]]
-// CHECK-DAG: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL]], %[[REAL]] : f32
-// CHECK-DAG: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG]], %[[IMAG]] : f32
-// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[REAL_SQ]], %[[IMAG_SQ]] : f32
-// CHECK: %[[NORM:.*]] = llvm.intr.sqrt(%[[SQ_NORM]]) : (f32) -> f32
+// CHECK: %[[REAL_IS_ZERO:.*]] = llvm.fcmp "oeq" %[[REAL]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_IS_ZERO:.*]] = llvm.fcmp "oeq" %[[IMAG]], %[[ZERO]] : f32
+
+// CHECK: %[[IMAG_DIV_REAL:.*]] = llvm.fdiv %[[IMAG]], %[[REAL]] : f32
+// CHECK: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = llvm.fadd %[[IMAG_SQ]], %[[ONE]] : f32
+// CHECK: %[[IMAG_SQRT:.*]] = llvm.intr.sqrt(%[[IMAG_SQ_PLUS_ONE]]) : (f32) -> f32
+// CHECK: %[[ABS_IMAG:.*]] = llvm.fmul %[[IMAG_SQRT]], %[[REAL]] : f32
+
+// CHECK: %[[REAL_DIV_IMAG:.*]] = llvm.fdiv %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = llvm.fadd %[[REAL_SQ]], %[[ONE]] : f32
+// CHECK: %[[REAL_SQRT:.*]] = llvm.intr.sqrt(%[[REAL_SQ_PLUS_ONE]]) : (f32) -> f32
+// CHECK: %[[ABS_REAL:.*]] = llvm.fmul %[[REAL_SQRT]], %[[IMAG]] : f32
+
+// CHECK: %[[REAL_GT_IMAG:.*]] = llvm.fcmp "ogt" %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[ABS1:.*]] = llvm.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : i1, f32
+// CHECK: %[[ABS2:.*]] = llvm.select %[[IMAG_IS_ZERO]], %[[REAL]], %[[ABS1]] : i1, f32
+// CHECK: %[[NORM:.*]] = llvm.select %[[REAL_IS_ZERO]], %[[IMAG]], %[[ABS2]] : i1, f32
// CHECK: llvm.return %[[NORM]] : f32
// CHECK-LABEL: llvm.func @complex_eq
diff --git a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
index 349b92a7aefa2e..b7849945b3cf49 100644
--- a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
+++ b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
@@ -106,6 +106,27 @@ func.func @angle(%arg: complex<f32>) -> f32 {
func.return %angle : f32
}
+func.func @test_element_f64(%input: tensor<?xcomplex<f64>>,
+ %func: (complex<f64>) -> f64) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %size = tensor.dim %input, %c0: tensor<?xcomplex<f64>>
+
+ scf.for %i = %c0 to %size step %c1 {
+ %elem = tensor.extract %input[%i]: tensor<?xcomplex<f64>>
+
+ %val = func.call_indirect %func(%elem) : (complex<f64>) -> f64
+ vector.print %val : f64
+ scf.yield
+ }
+ func.return
+}
+
+func.func @abs(%arg: complex<f64>) -> f64 {
+ %abs = complex.abs %arg : complex<f64>
+ func.return %abs : f64
+}
+
func.func @entry() {
// complex.sqrt test
%sqrt_test = arith.constant dense<[
@@ -300,5 +321,22 @@ func.func @entry() {
call @test_element(%angle_test_cast, %angle_func)
: (tensor<?xcomplex<f32>>, (complex<f32>) -> f32) -> ()
+ // complex.abs test
+ %abs_test = arith.constant dense<[
+ (1.0, 1.0),
+ // CHECK: 1.414
+ (1.0e300, 1.0e300),
+ // CHECK-NEXT: 1.41421e+300
+ (1.0e-300, 1.0e-300)
+ // CHECK-NEXT: 1.41421e-300
+ ]> : tensor<3xcomplex<f64>>
+ %abs_test_cast = tensor.cast %abs_test
+ : tensor<3xcomplex<f64>> to tensor<?xcomplex<f64>>
+
+ %abs_func = func.constant @abs : (complex<f64>) -> f64
+
+ call @test_element_f64(%abs_test_cast, %abs_func)
+ : (tensor<?xcomplex<f64>>, (complex<f64>) -> f64) -> ()
+
func.return
}
|
(1.0e300, 1.0e300), | ||
// CHECK-NEXT: 1.41421e+300 | ||
(1.0e-300, 1.0e-300) | ||
// CHECK-NEXT: 1.41421e-300 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The returned values are inf
and 0 without this change.
I looked into the flang overflow issue too a while back and opened a topic about it: https://discourse.llvm.org/t/overflows-during-folding-of-basic-arith-ops/73952. So as far as I understand, the fix would be for Flang to generate code which does not overflow. MLIR does not promise to handle it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rikhuijzer Thanks for giving me the pointer to the past discussion. The discussion originated from the issue in flang but the problem this is trying to solve could be more general.
- complex.abs cannot treat the full range of f64 due to overflow. The conversion to standard should be responsible for.
- complex dialect can be regarded as the user of lower level dialect. We can let it control how the lower level ops are used.
- complex.sqrt is already taking care of this case. Treating the f64 precision similarly in a dialect is consistent.
- the double in C++ seems to treat the 1.414+300e value properly in the original issue. This may be also expected to complex dialect.
Sorry I don't know enough about this to be able to review. I hope someone else can pick this up. |
@kiranchandramohan @Leporacanthicus @joker-eph Sorry for taking your time but can I ask you guys to review this change when you get a chance? I requested a review to the people listed in the original issue so please correct me if I am wrong. |
Could someone take a look at this change? I think |
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero); | ||
|
||
// Real > Imag | ||
Value imagDivReal = b.create<arith::DivFOp>(imag, real, fmf.getValue()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering what's going to happen here if real
is zero. Should the computation be wrapped in scf.if
instead of using an arith.select
below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be no problem as long as the realIsZero
is being true. The decision operator does not seem to be evaluated. The case with real
is zero is returning expected value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the value is not used afterwards but I'm wondering if executing a division by zero could cause a crash (maybe only on some platforms). Maybe ask on Discourse before merging. If you don't get an answer, I'd say just merge as is and if it causes issues in the future it will be a simple fix (wrapping the computation in scf.if
instead of using arith.select
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. That makes sense. I posted the question here.
https://discourse.llvm.org/t/devision-by-zero-semantics-in-arith-divf/76514
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@matthias-springer Thank you for reviewing! I added some test cases accordingly.
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero); | ||
|
||
// Real > Imag | ||
Value imagDivReal = b.create<arith::DivFOp>(imag, real, fmf.getValue()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the value is not used afterwards but I'm wondering if executing a division by zero could cause a crash (maybe only on some platforms). Maybe ask on Discourse before merging. If you don't get an answer, I'd say just merge as is and if it causes issues in the future it will be a simple fix (wrapping the computation in scf.if
instead of using arith.select
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the MLIR viewpoint, it can be treated as undefined as it is. The language frontend or the hardware seem to be responsible how the division-by-zero case is treated.
https://discourse.llvm.org/t/devision-by-zero-semantics-in-arith-divf/76514/3
Therefore, I think we can merge this as it is for now. Thanks for reviewing this change anyway!
Reverts #76316 Buildbot test is broken.
I had to revert because this broke a bot: https://lab.llvm.org/buildbot/#/builders/264/builds/6131
|
@joker-eph Oh, sorry for bothering you. I'll check what's going on. But I have a question about reproducing the problem in the topic branch. I think it was not running in the PR request. Is there any way to check the test passes before merging into the |
The previous PR was not correct on the way to handle the negative value. It is necessary to take the absolute value of the given real (or imaginary) part to be multiplied with the sqrt part. See: #76316
The previous PR was not enough about the way to handle the negative value. It is necessary to take the absolute value of the given real (or imaginary) part to be multiplied with the sqrt part in the case of either is zero. See: llvm#76316
The previous PR was not enough about the way to handle the negative value. It is necessary to take the absolute value of the given real (or imaginary) part to be multiplied with the sqrt part in the case of either is zero. See: llvm#76316
complex.abs
op may causes the underflow issue with the large input. We can prevent the problem by using the algorithm listed in "ALGORITHM 312 ABSOLUTE VALUE AND SQUARE ROOT OF A COMPLEX NUMBER" as well as thecomplex.sqrt
.See: #62011