Skip to content

Commit f4e9fa6

Browse files
committed
[mlir][complex] Prevent underflow in complex.abs (#79786)
The previous PR was not enough about the way to handle the negative value. It is necessary to take the absolute value of the given real (or imaginary) part to be multiplied with the sqrt part in the case of either is zero. See: #76316
1 parent c8ca98a commit f4e9fa6

File tree

4 files changed

+224
-40
lines changed

4 files changed

+224
-40
lines changed

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,29 +26,59 @@ namespace mlir {
2626
using namespace mlir;
2727

2828
namespace {
29+
// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
2930
struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
3031
using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
3132

3233
LogicalResult
3334
matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
3435
ConversionPatternRewriter &rewriter) const override {
35-
auto loc = op.getLoc();
36-
auto type = op.getType();
36+
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
3737

3838
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
3939

40-
Value real =
41-
rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
42-
Value imag =
43-
rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
44-
Value realSqr =
45-
rewriter.create<arith::MulFOp>(loc, real, real, fmf.getValue());
46-
Value imagSqr =
47-
rewriter.create<arith::MulFOp>(loc, imag, imag, fmf.getValue());
48-
Value sqNorm =
49-
rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr, fmf.getValue());
50-
51-
rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
40+
Type elementType = op.getType();
41+
Value arg = adaptor.getComplex();
42+
43+
Value zero =
44+
b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
45+
Value one = b.create<arith::ConstantOp>(elementType,
46+
b.getFloatAttr(elementType, 1.0));
47+
48+
Value real = b.create<complex::ReOp>(elementType, arg);
49+
Value imag = b.create<complex::ImOp>(elementType, arg);
50+
51+
Value realIsZero =
52+
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
53+
Value imagIsZero =
54+
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
55+
56+
// Real > Imag
57+
Value imagDivReal = b.create<arith::DivFOp>(imag, real, fmf.getValue());
58+
Value imagSq =
59+
b.create<arith::MulFOp>(imagDivReal, imagDivReal, fmf.getValue());
60+
Value imagSqPlusOne = b.create<arith::AddFOp>(imagSq, one, fmf.getValue());
61+
Value imagSqrt = b.create<math::SqrtOp>(imagSqPlusOne, fmf.getValue());
62+
Value realAbs = b.create<math::AbsFOp>(real, fmf.getValue());
63+
Value absImag = b.create<arith::MulFOp>(imagSqrt, realAbs, fmf.getValue());
64+
65+
// Real <= Imag
66+
Value realDivImag = b.create<arith::DivFOp>(real, imag, fmf.getValue());
67+
Value realSq =
68+
b.create<arith::MulFOp>(realDivImag, realDivImag, fmf.getValue());
69+
Value realSqPlusOne = b.create<arith::AddFOp>(realSq, one, fmf.getValue());
70+
Value realSqrt = b.create<math::SqrtOp>(realSqPlusOne, fmf.getValue());
71+
Value imagAbs = b.create<math::AbsFOp>(imag, fmf.getValue());
72+
Value absReal = b.create<arith::MulFOp>(realSqrt, imagAbs, fmf.getValue());
73+
74+
rewriter.replaceOpWithNewOp<arith::SelectOp>(
75+
op, realIsZero, imagAbs,
76+
b.create<arith::SelectOp>(
77+
imagIsZero, realAbs,
78+
b.create<arith::SelectOp>(
79+
b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, real, imag),
80+
absImag, absReal)));
81+
5282
return success();
5383
}
5484
};

mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Lines changed: 103 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,30 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
77
%abs = complex.abs %arg: complex<f32>
88
return %abs : f32
99
}
10+
11+
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
12+
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
1013
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
1114
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
12-
// CHECK-DAG: %[[REAL_SQ:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32
13-
// CHECK-DAG: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32
14-
// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] : f32
15-
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
16-
// CHECK: return %[[NORM]] : f32
15+
// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
16+
// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
17+
// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32
18+
// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
19+
// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
20+
// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
21+
// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] : f32
22+
// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] : f32
23+
// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32
24+
// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
25+
// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
26+
// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
27+
// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] : f32
28+
// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] : f32
29+
// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
30+
// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
31+
// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL_ABS]], %[[ABS1]] : f32
32+
// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG_ABS]], %[[ABS2]] : f32
33+
// CHECK: return %[[ABS3]] : f32
1734

1835
// -----
1936

@@ -241,12 +258,28 @@ func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
241258
%log = complex.log %arg: complex<f32>
242259
return %log : complex<f32>
243260
}
261+
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
262+
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
244263
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
245264
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
246-
// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32
247-
// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32
248-
// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32
249-
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
265+
// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
266+
// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
267+
// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32
268+
// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
269+
// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
270+
// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
271+
// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] : f32
272+
// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] : f32
273+
// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32
274+
// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
275+
// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
276+
// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
277+
// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] : f32
278+
// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] : f32
279+
// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
280+
// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
281+
// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL_ABS]], %[[ABS1]] : f32
282+
// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG_ABS]], %[[ABS2]] : f32
250283
// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] : f32
251284
// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
252285
// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
@@ -469,12 +502,28 @@ func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
469502
// CHECK: %[[REAL_IS_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
470503
// CHECK: %[[IMAG_IS_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
471504
// CHECK: %[[IS_ZERO:.*]] = arith.andi %[[REAL_IS_ZERO]], %[[IMAG_IS_ZERO]] : i1
505+
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
506+
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
472507
// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
473508
// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
474-
// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL2]], %[[REAL2]] : f32
475-
// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG2]], %[[IMAG2]] : f32
476-
// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32
477-
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
509+
// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL2]], %[[ZERO]] : f32
510+
// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG2]], %[[ZERO]] : f32
511+
// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG2]], %[[REAL2]] : f32
512+
// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
513+
// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
514+
// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
515+
// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL2]] : f32
516+
// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] : f32
517+
// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL2]], %[[IMAG2]] : f32
518+
// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
519+
// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
520+
// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
521+
// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG2]] : f32
522+
// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] : f32
523+
// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL2]], %[[IMAG2]] : f32
524+
// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
525+
// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL_ABS]], %[[ABS1]] : f32
526+
// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG_ABS]], %[[ABS2]] : f32
478527
// CHECK: %[[REAL_SIGN:.*]] = arith.divf %[[REAL]], %[[NORM]] : f32
479528
// CHECK: %[[IMAG_SIGN:.*]] = arith.divf %[[IMAG]], %[[NORM]] : f32
480529
// CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex<f32>
@@ -716,13 +765,29 @@ func.func @complex_abs_with_fmf(%arg: complex<f32>) -> f32 {
716765
%abs = complex.abs %arg fastmath<nnan,contract> : complex<f32>
717766
return %abs : f32
718767
}
768+
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
769+
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
719770
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
720771
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
721-
// CHECK-DAG: %[[REAL_SQ:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32
722-
// CHECK-DAG: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32
723-
// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] fastmath<nnan,contract> : f32
724-
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
725-
// CHECK: return %[[NORM]] : f32
772+
// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
773+
// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
774+
// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32
775+
// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath<nnan,contract> : f32
776+
// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
777+
// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
778+
// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32
779+
// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] fastmath<nnan,contract> : f32
780+
// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath<nnan,contract> : f32
781+
// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath<nnan,contract> : f32
782+
// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
783+
// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
784+
// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract> : f32
785+
// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] fastmath<nnan,contract> : f32
786+
// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
787+
// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
788+
// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL_ABS]], %[[ABS1]] : f32
789+
// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG_ABS]], %[[ABS2]] : f32
790+
// CHECK: return %[[ABS3]] : f32
726791

727792
// -----
728793

@@ -807,12 +872,28 @@ func.func @complex_log_with_fmf(%arg: complex<f32>) -> complex<f32> {
807872
%log = complex.log %arg fastmath<nnan,contract> : complex<f32>
808873
return %log : complex<f32>
809874
}
875+
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
876+
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
810877
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
811878
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
812-
// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32
813-
// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32
814-
// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] fastmath<nnan,contract> : f32
815-
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
879+
// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
880+
// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
881+
// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32
882+
// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath<nnan,contract> : f32
883+
// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
884+
// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
885+
// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32
886+
// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] fastmath<nnan,contract> : f32
887+
// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath<nnan,contract> : f32
888+
// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath<nnan,contract> : f32
889+
// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
890+
// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
891+
// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract> : f32
892+
// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] fastmath<nnan,contract> : f32
893+
// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
894+
// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
895+
// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL_ABS]], %[[ABS1]] : f32
896+
// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG_ABS]], %[[ABS2]] : f32
816897
// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] fastmath<nnan,contract> : f32
817898
// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
818899
// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>

mlir/test/Conversion/ComplexToStandard/full-conversion.mlir

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,31 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
66
%abs = complex.abs %arg: complex<f32>
77
return %abs : f32
88
}
9+
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
10+
// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
911
// CHECK: %[[REAL:.*]] = llvm.extractvalue %[[ARG]][0] : ![[C_TY]]
1012
// CHECK: %[[IMAG:.*]] = llvm.extractvalue %[[ARG]][1] : ![[C_TY]]
11-
// CHECK-DAG: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL]], %[[REAL]] : f32
12-
// CHECK-DAG: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG]], %[[IMAG]] : f32
13-
// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[REAL_SQ]], %[[IMAG_SQ]] : f32
14-
// CHECK: %[[NORM:.*]] = llvm.intr.sqrt(%[[SQ_NORM]]) : (f32) -> f32
13+
// CHECK: %[[REAL_IS_ZERO:.*]] = llvm.fcmp "oeq" %[[REAL]], %[[ZERO]] : f32
14+
// CHECK: %[[IMAG_IS_ZERO:.*]] = llvm.fcmp "oeq" %[[IMAG]], %[[ZERO]] : f32
15+
16+
// CHECK: %[[IMAG_DIV_REAL:.*]] = llvm.fdiv %[[IMAG]], %[[REAL]] : f32
17+
// CHECK: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
18+
// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = llvm.fadd %[[IMAG_SQ]], %[[ONE]] : f32
19+
// CHECK: %[[IMAG_SQRT:.*]] = llvm.intr.sqrt(%[[IMAG_SQ_PLUS_ONE]]) : (f32) -> f32
20+
// CHECK: %[[REAL_ABS:.*]] = llvm.intr.fabs(%[[REAL]]) : (f32) -> f32
21+
// CHECK: %[[ABS_IMAG:.*]] = llvm.fmul %[[IMAG_SQRT]], %[[REAL_ABS]] : f32
22+
23+
// CHECK: %[[REAL_DIV_IMAG:.*]] = llvm.fdiv %[[REAL]], %[[IMAG]] : f32
24+
// CHECK: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
25+
// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = llvm.fadd %[[REAL_SQ]], %[[ONE]] : f32
26+
// CHECK: %[[REAL_SQRT:.*]] = llvm.intr.sqrt(%[[REAL_SQ_PLUS_ONE]]) : (f32) -> f32
27+
// CHECK: %[[IMAG_ABS:.*]] = llvm.intr.fabs(%[[IMAG]]) : (f32) -> f32
28+
// CHECK: %[[ABS_REAL:.*]] = llvm.fmul %[[REAL_SQRT]], %[[IMAG_ABS]] : f32
29+
30+
// CHECK: %[[REAL_GT_IMAG:.*]] = llvm.fcmp "ogt" %[[REAL]], %[[IMAG]] : f32
31+
// CHECK: %[[ABS1:.*]] = llvm.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : i1, f32
32+
// CHECK: %[[ABS2:.*]] = llvm.select %[[IMAG_IS_ZERO]], %[[REAL_ABS]], %[[ABS1]] : i1, f32
33+
// CHECK: %[[NORM:.*]] = llvm.select %[[REAL_IS_ZERO]], %[[IMAG_ABS]], %[[ABS2]] : i1, f32
1534
// CHECK: llvm.return %[[NORM]] : f32
1635

1736
// CHECK-LABEL: llvm.func @complex_eq

mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,27 @@ func.func @angle(%arg: complex<f32>) -> f32 {
106106
func.return %angle : f32
107107
}
108108

109+
func.func @test_element_f64(%input: tensor<?xcomplex<f64>>,
110+
%func: (complex<f64>) -> f64) {
111+
%c0 = arith.constant 0 : index
112+
%c1 = arith.constant 1 : index
113+
%size = tensor.dim %input, %c0: tensor<?xcomplex<f64>>
114+
115+
scf.for %i = %c0 to %size step %c1 {
116+
%elem = tensor.extract %input[%i]: tensor<?xcomplex<f64>>
117+
118+
%val = func.call_indirect %func(%elem) : (complex<f64>) -> f64
119+
vector.print %val : f64
120+
scf.yield
121+
}
122+
func.return
123+
}
124+
125+
func.func @abs(%arg: complex<f64>) -> f64 {
126+
%abs = complex.abs %arg : complex<f64>
127+
func.return %abs : f64
128+
}
129+
109130
func.func @entry() {
110131
// complex.sqrt test
111132
%sqrt_test = arith.constant dense<[
@@ -300,5 +321,38 @@ func.func @entry() {
300321
call @test_element(%angle_test_cast, %angle_func)
301322
: (tensor<?xcomplex<f32>>, (complex<f32>) -> f32) -> ()
302323

324+
// complex.abs test
325+
%abs_test = arith.constant dense<[
326+
(1.0, 1.0),
327+
// CHECK: 1.414
328+
(1.0e300, 1.0e300),
329+
// CHECK-NEXT: 1.41421e+300
330+
(1.0e-300, 1.0e-300),
331+
// CHECK-NEXT: 1.41421e-300
332+
(5.0, 0.0),
333+
// CHECK-NEXT: 5
334+
(0.0, 6.0),
335+
// CHECK-NEXT: 6
336+
(7.0, 8.0),
337+
// CHECK-NEXT: 10.6301
338+
(-1.0, -1.0),
339+
// CHECK-NEXT: 1.414
340+
(-1.0e300, -1.0e300),
341+
// CHECK-NEXT: 1.41421e+300
342+
(-1.0, 0.0),
343+
// CHECK-NOT: -1
344+
// CHECK-NEXT: 1
345+
(0.0, -1.0)
346+
// CHECK-NOT: -1
347+
// CHECK-NEXT: 1
348+
]> : tensor<10xcomplex<f64>>
349+
%abs_test_cast = tensor.cast %abs_test
350+
: tensor<10xcomplex<f64>> to tensor<?xcomplex<f64>>
351+
352+
%abs_func = func.constant @abs : (complex<f64>) -> f64
353+
354+
call @test_element_f64(%abs_test_cast, %abs_func)
355+
: (tensor<?xcomplex<f64>>, (complex<f64>) -> f64) -> ()
356+
303357
func.return
304358
}

0 commit comments

Comments
 (0)