Skip to content

Commit dafeb04

Browse files
committed
[mlir][complex] Support fast math flag in converting complex.atan2 op to
standard
1 parent 118a2a5 commit dafeb04

File tree

2 files changed

+586
-12
lines changed

2 files changed

+586
-12
lines changed

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -94,34 +94,37 @@ struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
9494

9595
auto type = cast<ComplexType>(op.getType());
9696
Type elementType = type.getElementType();
97+
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
9798

9899
Value lhs = adaptor.getLhs();
99100
Value rhs = adaptor.getRhs();
100101

101-
Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs);
102-
Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs);
102+
Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs, fmf.getValue());
103+
Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs, fmf.getValue());
103104
Value rhsSquaredPlusLhsSquared =
104-
b.create<complex::AddOp>(type, rhsSquared, lhsSquared);
105-
Value sqrtOfRhsSquaredPlusLhsSquared =
106-
b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared);
105+
b.create<complex::AddOp>(type, rhsSquared, lhsSquared, fmf.getValue());
106+
Value sqrtOfRhsSquaredPlusLhsSquared = b.create<complex::SqrtOp>(
107+
type, rhsSquaredPlusLhsSquared, fmf.getValue());
107108

108109
Value zero =
109110
b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
110111
Value one = b.create<arith::ConstantOp>(elementType,
111112
b.getFloatAttr(elementType, 1));
112113
Value i = b.create<complex::CreateOp>(type, zero, one);
113-
Value iTimesLhs = b.create<complex::MulOp>(i, lhs);
114-
Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs);
114+
Value iTimesLhs = b.create<complex::MulOp>(i, lhs, fmf.getValue());
115+
Value rhsPlusILhs =
116+
b.create<complex::AddOp>(rhs, iTimesLhs, fmf.getValue());
115117

116-
Value divResult =
117-
b.create<complex::DivOp>(rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared);
118-
Value logResult = b.create<complex::LogOp>(divResult);
118+
Value divResult = b.create<complex::DivOp>(
119+
rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf.getValue());
120+
Value logResult = b.create<complex::LogOp>(divResult, fmf.getValue());
119121

120122
Value negativeOne = b.create<arith::ConstantOp>(
121123
elementType, b.getFloatAttr(elementType, -1));
122124
Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);
123125

124-
rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult);
126+
rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult,
127+
fmf.getValue());
125128
return success();
126129
}
127130
};

0 commit comments

Comments
 (0)