Skip to content

Commit b17348c

Browse files
authored
[mlir][complex] Prevent underflow in complex.abs (#79786) (#81092)
1 parent 0df8aed commit b17348c

File tree

4 files changed

+224
-40
lines changed

4 files changed

+224
-40
lines changed

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,29 +26,59 @@ namespace mlir {
2626
using namespace mlir;
2727

2828
namespace {
29+
// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
2930
struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
3031
using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
3132

3233
LogicalResult
3334
matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
3435
ConversionPatternRewriter &rewriter) const override {
35-
auto loc = op.getLoc();
36-
auto type = op.getType();
36+
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
3737

3838
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
3939

40-
Value real =
41-
rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
42-
Value imag =
43-
rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
44-
Value realSqr =
45-
rewriter.create<arith::MulFOp>(loc, real, real, fmf.getValue());
46-
Value imagSqr =
47-
rewriter.create<arith::MulFOp>(loc, imag, imag, fmf.getValue());
48-
Value sqNorm =
49-
rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr, fmf.getValue());
50-
51-
rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
40+
Type elementType = op.getType();
41+
Value arg = adaptor.getComplex();
42+
43+
Value zero =
44+
b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
45+
Value one = b.create<arith::ConstantOp>(elementType,
46+
b.getFloatAttr(elementType, 1.0));
47+
48+
Value real = b.create<complex::ReOp>(elementType, arg);
49+
Value imag = b.create<complex::ImOp>(elementType, arg);
50+
51+
Value realIsZero =
52+
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
53+
Value imagIsZero =
54+
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
55+
56+
// Real > Imag
57+
Value imagDivReal = b.create<arith::DivFOp>(imag, real, fmf.getValue());
58+
Value imagSq =
59+
b.create<arith::MulFOp>(imagDivReal, imagDivReal, fmf.getValue());
60+
Value imagSqPlusOne = b.create<arith::AddFOp>(imagSq, one, fmf.getValue());
61+
Value imagSqrt = b.create<math::SqrtOp>(imagSqPlusOne, fmf.getValue());
62+
Value realAbs = b.create<math::AbsFOp>(real, fmf.getValue());
63+
Value absImag = b.create<arith::MulFOp>(imagSqrt, realAbs, fmf.getValue());
64+
65+
// Real <= Imag
66+
Value realDivImag = b.create<arith::DivFOp>(real, imag, fmf.getValue());
67+
Value realSq =
68+
b.create<arith::MulFOp>(realDivImag, realDivImag, fmf.getValue());
69+
Value realSqPlusOne = b.create<arith::AddFOp>(realSq, one, fmf.getValue());
70+
Value realSqrt = b.create<math::SqrtOp>(realSqPlusOne, fmf.getValue());
71+
Value imagAbs = b.create<math::AbsFOp>(imag, fmf.getValue());
72+
Value absReal = b.create<arith::MulFOp>(realSqrt, imagAbs, fmf.getValue());
73+
74+
rewriter.replaceOpWithNewOp<arith::SelectOp>(
75+
op, realIsZero, imagAbs,
76+
b.create<arith::SelectOp>(
77+
imagIsZero, realAbs,
78+
b.create<arith::SelectOp>(
79+
b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, real, imag),
80+
absImag, absReal)));
81+
5282
return success();
5383
}
5484
};

mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Lines changed: 103 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,30 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
77
%abs = complex.abs %arg: complex<f32>
88
return %abs : f32
99
}
10+
11+
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
12+
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
1013
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
1114
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
12-
// CHECK-DAG: %[[REAL_SQ:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32
13-
// CHECK-DAG: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32
14-
// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] : f32
15-
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
16-
// CHECK: return %[[NORM]] : f32
15+
// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
16+
// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
17+
// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32
18+
// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
19+
// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
20+
// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
21+
// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] : f32
22+
// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] : f32
23+
// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32
24+
// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
25+
// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
26+
// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
27+
// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] : f32
28+
// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] : f32
29+
// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
30+
// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
31+
// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL_ABS]], %[[ABS1]] : f32
32+
// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG_ABS]], %[[ABS2]] : f32
33+
// CHECK: return %[[ABS3]] : f32
1734

1835
// -----
1936

@@ -241,12 +258,28 @@ func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
241258
%log = complex.log %arg: complex<f32>
242259
return %log : complex<f32>
243260
}
261+
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
262+
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
244263
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
245264
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
246-
// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32
247-
// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32
248-
// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32
249-
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
265+
// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
266+
// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
267+
// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32
268+
// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
269+
// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
270+
// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
271+
// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] : f32
272+
// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] : f32
273+
// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32
274+
// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
275+
// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
276+
// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
277+
// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] : f32
278+
// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] : f32
279+
// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
280+
// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
281+
// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL_ABS]], %[[ABS1]] : f32
282+
// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG_ABS]], %[[ABS2]] : f32
250283
// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] : f32
251284
// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
252285
// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
@@ -469,12 +502,28 @@ func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
469502
// CHECK: %[[REAL_IS_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
470503
// CHECK: %[[IMAG_IS_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
471504
// CHECK: %[[IS_ZERO:.*]] = arith.andi %[[REAL_IS_ZERO]], %[[IMAG_IS_ZERO]] : i1
505+
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
506+
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
472507
// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
473508
// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
474-
// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL2]], %[[REAL2]] : f32
475-
// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG2]], %[[IMAG2]] : f32
476-
// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32
477-
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
509+
// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL2]], %[[ZERO]] : f32
510+
// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG2]], %[[ZERO]] : f32
511+
// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG2]], %[[REAL2]] : f32
512+
// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
513+
// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
514+
// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
515+
// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL2]] : f32
516+
// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] : f32
517+
// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL2]], %[[IMAG2]] : f32
518+
// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
519+
// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
520+
// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
521+
// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG2]] : f32
522+
// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] : f32
523+
// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL2]], %[[IMAG2]] : f32
524+
// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
525+
// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL_ABS]], %[[ABS1]] : f32
526+
// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG_ABS]], %[[ABS2]] : f32
478527
// CHECK: %[[REAL_SIGN:.*]] = arith.divf %[[REAL]], %[[NORM]] : f32
479528
// CHECK: %[[IMAG_SIGN:.*]] = arith.divf %[[IMAG]], %[[NORM]] : f32
480529
// CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex<f32>
@@ -716,13 +765,29 @@ func.func @complex_abs_with_fmf(%arg: complex<f32>) -> f32 {
716765
%abs = complex.abs %arg fastmath<nnan,contract> : complex<f32>
717766
return %abs : f32
718767
}
768+
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
769+
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
719770
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
720771
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
721-
// CHECK-DAG: %[[REAL_SQ:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32
722-
// CHECK-DAG: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32
723-
// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] fastmath<nnan,contract> : f32
724-
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
725-
// CHECK: return %[[NORM]] : f32
772+
// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
773+
// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
774+
// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32
775+
// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath<nnan,contract> : f32
776+
// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
777+
// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
778+
// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32
779+
// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] fastmath<nnan,contract> : f32
780+
// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath<nnan,contract> : f32
781+
// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath<nnan,contract> : f32
782+
// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
783+
// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
784+
// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract> : f32
785+
// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] fastmath<nnan,contract> : f32
786+
// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
787+
// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
788+
// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL_ABS]], %[[ABS1]] : f32
789+
// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG_ABS]], %[[ABS2]] : f32
790+
// CHECK: return %[[ABS3]] : f32
726791

727792
// -----
728793

@@ -807,12 +872,28 @@ func.func @complex_log_with_fmf(%arg: complex<f32>) -> complex<f32> {
807872
%log = complex.log %arg fastmath<nnan,contract> : complex<f32>
808873
return %log : complex<f32>
809874
}
875+
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
876+
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
810877
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
811878
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
812-
// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32
813-
// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32
814-
// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] fastmath<nnan,contract> : f32
815-
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
879+
// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
880+
// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
881+
// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32
882+
// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath<nnan,contract> : f32
883+
// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
884+
// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
885+
// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32
886+
// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] fastmath<nnan,contract> : f32
887+
// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath<nnan,contract> : f32
888+
// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath<nnan,contract> : f32
889+
// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
890+
// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
891+
// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract> : f32
892+
// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] fastmath<nnan,contract> : f32
893+
// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
894+
// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
895+
// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL_ABS]], %[[ABS1]] : f32
896+
// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG_ABS]], %[[ABS2]] : f32
816897
// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] fastmath<nnan,contract> : f32
817898
// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
818899
// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>

mlir/test/Conversion/ComplexToStandard/full-conversion.mlir

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,31 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
66
%abs = complex.abs %arg: complex<f32>
77
return %abs : f32
88
}
9+
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
10+
// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
911
// CHECK: %[[REAL:.*]] = llvm.extractvalue %[[ARG]][0] : ![[C_TY]]
1012
// CHECK: %[[IMAG:.*]] = llvm.extractvalue %[[ARG]][1] : ![[C_TY]]
11-
// CHECK-DAG: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL]], %[[REAL]] : f32
12-
// CHECK-DAG: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG]], %[[IMAG]] : f32
13-
// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[REAL_SQ]], %[[IMAG_SQ]] : f32
14-
// CHECK: %[[NORM:.*]] = llvm.intr.sqrt(%[[SQ_NORM]]) : (f32) -> f32
13+
// CHECK: %[[REAL_IS_ZERO:.*]] = llvm.fcmp "oeq" %[[REAL]], %[[ZERO]] : f32
14+
// CHECK: %[[IMAG_IS_ZERO:.*]] = llvm.fcmp "oeq" %[[IMAG]], %[[ZERO]] : f32
15+
16+
// CHECK: %[[IMAG_DIV_REAL:.*]] = llvm.fdiv %[[IMAG]], %[[REAL]] : f32
17+
// CHECK: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
18+
// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = llvm.fadd %[[IMAG_SQ]], %[[ONE]] : f32
19+
// CHECK: %[[IMAG_SQRT:.*]] = llvm.intr.sqrt(%[[IMAG_SQ_PLUS_ONE]]) : (f32) -> f32
20+
// CHECK: %[[REAL_ABS:.*]] = llvm.intr.fabs(%[[REAL]]) : (f32) -> f32
21+
// CHECK: %[[ABS_IMAG:.*]] = llvm.fmul %[[IMAG_SQRT]], %[[REAL_ABS]] : f32
22+
23+
// CHECK: %[[REAL_DIV_IMAG:.*]] = llvm.fdiv %[[REAL]], %[[IMAG]] : f32
24+
// CHECK: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
25+
// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = llvm.fadd %[[REAL_SQ]], %[[ONE]] : f32
26+
// CHECK: %[[REAL_SQRT:.*]] = llvm.intr.sqrt(%[[REAL_SQ_PLUS_ONE]]) : (f32) -> f32
27+
// CHECK: %[[IMAG_ABS:.*]] = llvm.intr.fabs(%[[IMAG]]) : (f32) -> f32
28+
// CHECK: %[[ABS_REAL:.*]] = llvm.fmul %[[REAL_SQRT]], %[[IMAG_ABS]] : f32
29+
30+
// CHECK: %[[REAL_GT_IMAG:.*]] = llvm.fcmp "ogt" %[[REAL]], %[[IMAG]] : f32
31+
// CHECK: %[[ABS1:.*]] = llvm.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : i1, f32
32+
// CHECK: %[[ABS2:.*]] = llvm.select %[[IMAG_IS_ZERO]], %[[REAL_ABS]], %[[ABS1]] : i1, f32
33+
// CHECK: %[[NORM:.*]] = llvm.select %[[REAL_IS_ZERO]], %[[IMAG_ABS]], %[[ABS2]] : i1, f32
1534
// CHECK: llvm.return %[[NORM]] : f32
1635

1736
// CHECK-LABEL: llvm.func @complex_eq

mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,27 @@ func.func @angle(%arg: complex<f32>) -> f32 {
106106
func.return %angle : f32
107107
}
108108

109+
func.func @test_element_f64(%input: tensor<?xcomplex<f64>>,
110+
%func: (complex<f64>) -> f64) {
111+
%c0 = arith.constant 0 : index
112+
%c1 = arith.constant 1 : index
113+
%size = tensor.dim %input, %c0: tensor<?xcomplex<f64>>
114+
115+
scf.for %i = %c0 to %size step %c1 {
116+
%elem = tensor.extract %input[%i]: tensor<?xcomplex<f64>>
117+
118+
%val = func.call_indirect %func(%elem) : (complex<f64>) -> f64
119+
vector.print %val : f64
120+
scf.yield
121+
}
122+
func.return
123+
}
124+
125+
func.func @abs(%arg: complex<f64>) -> f64 {
126+
%abs = complex.abs %arg : complex<f64>
127+
func.return %abs : f64
128+
}
129+
109130
func.func @entry() {
110131
// complex.sqrt test
111132
%sqrt_test = arith.constant dense<[
@@ -300,5 +321,38 @@ func.func @entry() {
300321
call @test_element(%angle_test_cast, %angle_func)
301322
: (tensor<?xcomplex<f32>>, (complex<f32>) -> f32) -> ()
302323

324+
// complex.abs test
325+
%abs_test = arith.constant dense<[
326+
(1.0, 1.0),
327+
// CHECK: 1.414
328+
(1.0e300, 1.0e300),
329+
// CHECK-NEXT: 1.41421e+300
330+
(1.0e-300, 1.0e-300),
331+
// CHECK-NEXT: 1.41421e-300
332+
(5.0, 0.0),
333+
// CHECK-NEXT: 5
334+
(0.0, 6.0),
335+
// CHECK-NEXT: 6
336+
(7.0, 8.0),
337+
// CHECK-NEXT: 10.6301
338+
(-1.0, -1.0),
339+
// CHECK-NEXT: 1.414
340+
(-1.0e300, -1.0e300),
341+
// CHECK-NEXT: 1.41421e+300
342+
(-1.0, 0.0),
343+
// CHECK-NOT: -1
344+
// CHECK-NEXT: 1
345+
(0.0, -1.0)
346+
// CHECK-NOT: -1
347+
// CHECK-NEXT: 1
348+
]> : tensor<10xcomplex<f64>>
349+
%abs_test_cast = tensor.cast %abs_test
350+
: tensor<10xcomplex<f64>> to tensor<?xcomplex<f64>>
351+
352+
%abs_func = func.constant @abs : (complex<f64>) -> f64
353+
354+
call @test_element_f64(%abs_test_cast, %abs_func)
355+
: (tensor<?xcomplex<f64>>, (complex<f64>) -> f64) -> ()
356+
303357
func.return
304358
}

0 commit comments

Comments
 (0)