Skip to content

Commit c41bb11

Browse files
authored
Revert "[mlir][complex] Prevent underflow in complex.abs" (#79722)
Reverts #76316 Buildbot test is broken.
1 parent 5a5ce01 commit c41bb11

File tree

4 files changed

+40
-200
lines changed

4 files changed

+40
-200
lines changed

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

Lines changed: 14 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -26,57 +26,29 @@ namespace mlir {
2626
using namespace mlir;
2727

2828
namespace {
29-
// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
3029
struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
3130
using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
3231

3332
LogicalResult
3433
matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
3534
ConversionPatternRewriter &rewriter) const override {
36-
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
35+
auto loc = op.getLoc();
36+
auto type = op.getType();
3737

3838
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
3939

40-
Type elementType = op.getType();
41-
Value arg = adaptor.getComplex();
42-
43-
Value zero =
44-
b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
45-
Value one = b.create<arith::ConstantOp>(elementType,
46-
b.getFloatAttr(elementType, 1.0));
47-
48-
Value real = b.create<complex::ReOp>(elementType, arg);
49-
Value imag = b.create<complex::ImOp>(elementType, arg);
50-
51-
Value realIsZero =
52-
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
53-
Value imagIsZero =
54-
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
55-
56-
// Real > Imag
57-
Value imagDivReal = b.create<arith::DivFOp>(imag, real, fmf.getValue());
58-
Value imagSq =
59-
b.create<arith::MulFOp>(imagDivReal, imagDivReal, fmf.getValue());
60-
Value imagSqPlusOne = b.create<arith::AddFOp>(imagSq, one, fmf.getValue());
61-
Value imagSqrt = b.create<math::SqrtOp>(imagSqPlusOne, fmf.getValue());
62-
Value absImag = b.create<arith::MulFOp>(imagSqrt, real, fmf.getValue());
63-
64-
// Real <= Imag
65-
Value realDivImag = b.create<arith::DivFOp>(real, imag, fmf.getValue());
66-
Value realSq =
67-
b.create<arith::MulFOp>(realDivImag, realDivImag, fmf.getValue());
68-
Value realSqPlusOne = b.create<arith::AddFOp>(realSq, one, fmf.getValue());
69-
Value realSqrt = b.create<math::SqrtOp>(realSqPlusOne, fmf.getValue());
70-
Value absReal = b.create<arith::MulFOp>(realSqrt, imag, fmf.getValue());
71-
72-
rewriter.replaceOpWithNewOp<arith::SelectOp>(
73-
op, realIsZero, imag,
74-
b.create<arith::SelectOp>(
75-
imagIsZero, real,
76-
b.create<arith::SelectOp>(
77-
b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, real, imag),
78-
absImag, absReal)));
79-
40+
Value real =
41+
rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
42+
Value imag =
43+
rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
44+
Value realSqr =
45+
rewriter.create<arith::MulFOp>(loc, real, real, fmf.getValue());
46+
Value imagSqr =
47+
rewriter.create<arith::MulFOp>(loc, imag, imag, fmf.getValue());
48+
Value sqNorm =
49+
rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr, fmf.getValue());
50+
51+
rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
8052
return success();
8153
}
8254
};

mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Lines changed: 22 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,13 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
77
%abs = complex.abs %arg: complex<f32>
88
return %abs : f32
99
}
10-
11-
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
12-
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
1310
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
1411
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
15-
// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
16-
// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
17-
// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32
18-
// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
19-
// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
20-
// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
21-
// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL]] : f32
22-
// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32
23-
// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
24-
// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
25-
// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
26-
// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG]] : f32
27-
// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
28-
// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
29-
// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32
30-
// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32
31-
// CHECK: return %[[ABS3]] : f32
12+
// CHECK-DAG: %[[REAL_SQ:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32
13+
// CHECK-DAG: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32
14+
// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] : f32
15+
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
16+
// CHECK: return %[[NORM]] : f32
3217

3318
// -----
3419

@@ -256,26 +241,12 @@ func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
256241
%log = complex.log %arg: complex<f32>
257242
return %log : complex<f32>
258243
}
259-
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
260-
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
261244
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
262245
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
263-
// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
264-
// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
265-
// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32
266-
// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
267-
// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
268-
// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
269-
// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL]] : f32
270-
// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32
271-
// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
272-
// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
273-
// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
274-
// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG]] : f32
275-
// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
276-
// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
277-
// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32
278-
// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32
246+
// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32
247+
// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32
248+
// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32
249+
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
279250
// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] : f32
280251
// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
281252
// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
@@ -498,26 +469,12 @@ func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
498469
// CHECK: %[[REAL_IS_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
499470
// CHECK: %[[IMAG_IS_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
500471
// CHECK: %[[IS_ZERO:.*]] = arith.andi %[[REAL_IS_ZERO]], %[[IMAG_IS_ZERO]] : i1
501-
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
502-
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
503472
// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
504473
// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
505-
// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL2]], %[[ZERO]] : f32
506-
// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG2]], %[[ZERO]] : f32
507-
// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG2]], %[[REAL2]] : f32
508-
// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
509-
// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
510-
// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
511-
// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL2]] : f32
512-
// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL2]], %[[IMAG2]] : f32
513-
// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
514-
// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
515-
// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
516-
// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG2]] : f32
517-
// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL2]], %[[IMAG2]] : f32
518-
// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
519-
// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL2]], %[[ABS1]] : f32
520-
// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG2]], %[[ABS2]] : f32
474+
// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL2]], %[[REAL2]] : f32
475+
// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG2]], %[[IMAG2]] : f32
476+
// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32
477+
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
521478
// CHECK: %[[REAL_SIGN:.*]] = arith.divf %[[REAL]], %[[NORM]] : f32
522479
// CHECK: %[[IMAG_SIGN:.*]] = arith.divf %[[IMAG]], %[[NORM]] : f32
523480
// CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex<f32>
@@ -759,27 +716,13 @@ func.func @complex_abs_with_fmf(%arg: complex<f32>) -> f32 {
759716
%abs = complex.abs %arg fastmath<nnan,contract> : complex<f32>
760717
return %abs : f32
761718
}
762-
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
763-
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
764719
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
765720
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
766-
// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
767-
// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
768-
// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32
769-
// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath<nnan,contract> : f32
770-
// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
771-
// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
772-
// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL]] fastmath<nnan,contract> : f32
773-
// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath<nnan,contract> : f32
774-
// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath<nnan,contract> : f32
775-
// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
776-
// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
777-
// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG]] fastmath<nnan,contract> : f32
778-
// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
779-
// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
780-
// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32
781-
// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32
782-
// CHECK: return %[[ABS3]] : f32
721+
// CHECK-DAG: %[[REAL_SQ:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32
722+
// CHECK-DAG: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32
723+
// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] fastmath<nnan,contract> : f32
724+
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
725+
// CHECK: return %[[NORM]] : f32
783726

784727
// -----
785728

@@ -864,26 +807,12 @@ func.func @complex_log_with_fmf(%arg: complex<f32>) -> complex<f32> {
864807
%log = complex.log %arg fastmath<nnan,contract> : complex<f32>
865808
return %log : complex<f32>
866809
}
867-
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
868-
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
869810
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
870811
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
871-
// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
872-
// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
873-
// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32
874-
// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath<nnan,contract> : f32
875-
// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
876-
// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
877-
// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL]] fastmath<nnan,contract> : f32
878-
// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath<nnan,contract> : f32
879-
// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath<nnan,contract> : f32
880-
// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
881-
// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
882-
// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG]] fastmath<nnan,contract> : f32
883-
// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
884-
// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
885-
// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32
886-
// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32
812+
// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32
813+
// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32
814+
// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] fastmath<nnan,contract> : f32
815+
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
887816
// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] fastmath<nnan,contract> : f32
888817
// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
889818
// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>

mlir/test/Conversion/ComplexToStandard/full-conversion.mlir

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,12 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
66
%abs = complex.abs %arg: complex<f32>
77
return %abs : f32
88
}
9-
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
10-
// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
119
// CHECK: %[[REAL:.*]] = llvm.extractvalue %[[ARG]][0] : ![[C_TY]]
1210
// CHECK: %[[IMAG:.*]] = llvm.extractvalue %[[ARG]][1] : ![[C_TY]]
13-
// CHECK: %[[REAL_IS_ZERO:.*]] = llvm.fcmp "oeq" %[[REAL]], %[[ZERO]] : f32
14-
// CHECK: %[[IMAG_IS_ZERO:.*]] = llvm.fcmp "oeq" %[[IMAG]], %[[ZERO]] : f32
15-
16-
// CHECK: %[[IMAG_DIV_REAL:.*]] = llvm.fdiv %[[IMAG]], %[[REAL]] : f32
17-
// CHECK: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
18-
// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = llvm.fadd %[[IMAG_SQ]], %[[ONE]] : f32
19-
// CHECK: %[[IMAG_SQRT:.*]] = llvm.intr.sqrt(%[[IMAG_SQ_PLUS_ONE]]) : (f32) -> f32
20-
// CHECK: %[[ABS_IMAG:.*]] = llvm.fmul %[[IMAG_SQRT]], %[[REAL]] : f32
21-
22-
// CHECK: %[[REAL_DIV_IMAG:.*]] = llvm.fdiv %[[REAL]], %[[IMAG]] : f32
23-
// CHECK: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
24-
// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = llvm.fadd %[[REAL_SQ]], %[[ONE]] : f32
25-
// CHECK: %[[REAL_SQRT:.*]] = llvm.intr.sqrt(%[[REAL_SQ_PLUS_ONE]]) : (f32) -> f32
26-
// CHECK: %[[ABS_REAL:.*]] = llvm.fmul %[[REAL_SQRT]], %[[IMAG]] : f32
27-
28-
// CHECK: %[[REAL_GT_IMAG:.*]] = llvm.fcmp "ogt" %[[REAL]], %[[IMAG]] : f32
29-
// CHECK: %[[ABS1:.*]] = llvm.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : i1, f32
30-
// CHECK: %[[ABS2:.*]] = llvm.select %[[IMAG_IS_ZERO]], %[[REAL]], %[[ABS1]] : i1, f32
31-
// CHECK: %[[NORM:.*]] = llvm.select %[[REAL_IS_ZERO]], %[[IMAG]], %[[ABS2]] : i1, f32
11+
// CHECK-DAG: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL]], %[[REAL]] : f32
12+
// CHECK-DAG: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG]], %[[IMAG]] : f32
13+
// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[REAL_SQ]], %[[IMAG_SQ]] : f32
14+
// CHECK: %[[NORM:.*]] = llvm.intr.sqrt(%[[SQ_NORM]]) : (f32) -> f32
3215
// CHECK: llvm.return %[[NORM]] : f32
3316

3417
// CHECK-LABEL: llvm.func @complex_eq

mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -106,27 +106,6 @@ func.func @angle(%arg: complex<f32>) -> f32 {
106106
func.return %angle : f32
107107
}
108108

109-
func.func @test_element_f64(%input: tensor<?xcomplex<f64>>,
110-
%func: (complex<f64>) -> f64) {
111-
%c0 = arith.constant 0 : index
112-
%c1 = arith.constant 1 : index
113-
%size = tensor.dim %input, %c0: tensor<?xcomplex<f64>>
114-
115-
scf.for %i = %c0 to %size step %c1 {
116-
%elem = tensor.extract %input[%i]: tensor<?xcomplex<f64>>
117-
118-
%val = func.call_indirect %func(%elem) : (complex<f64>) -> f64
119-
vector.print %val : f64
120-
scf.yield
121-
}
122-
func.return
123-
}
124-
125-
func.func @abs(%arg: complex<f64>) -> f64 {
126-
%abs = complex.abs %arg : complex<f64>
127-
func.return %abs : f64
128-
}
129-
130109
func.func @entry() {
131110
// complex.sqrt test
132111
%sqrt_test = arith.constant dense<[
@@ -321,28 +300,5 @@ func.func @entry() {
321300
call @test_element(%angle_test_cast, %angle_func)
322301
: (tensor<?xcomplex<f32>>, (complex<f32>) -> f32) -> ()
323302

324-
// complex.abs test
325-
%abs_test = arith.constant dense<[
326-
(1.0, 1.0),
327-
// CHECK: 1.414
328-
(1.0e300, 1.0e300),
329-
// CHECK-NEXT: 1.41421e+300
330-
(1.0e-300, 1.0e-300),
331-
// CHECK-NEXT: 1.41421e-300
332-
(5.0, 0.0),
333-
// CHECK-NEXT: 5
334-
(0.0, 6.0),
335-
// CHECK-NEXT: 6
336-
(7.0, 8.0)
337-
// CHECK-NEXT: 10.6301
338-
]> : tensor<6xcomplex<f64>>
339-
%abs_test_cast = tensor.cast %abs_test
340-
: tensor<6xcomplex<f64>> to tensor<?xcomplex<f64>>
341-
342-
%abs_func = func.constant @abs : (complex<f64>) -> f64
343-
344-
call @test_element_f64(%abs_test_cast, %abs_func)
345-
: (tensor<?xcomplex<f64>>, (complex<f64>) -> f64) -> ()
346-
347303
func.return
348304
}

0 commit comments

Comments
 (0)