Skip to content

Commit b930b14

Browse files
authored
[mlir][complex] Support fast math flag in converting complex.atan2 op (#82101)
When converting complex.atan2 op to standard, we need to keep the fast math flag given to the op. See: https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981
1 parent d1aec79 commit b930b14

File tree

2 files changed

+582
-10
lines changed

2 files changed

+582
-10
lines changed

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,34 +94,35 @@ struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
9494

9595
auto type = cast<ComplexType>(op.getType());
9696
Type elementType = type.getElementType();
97+
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
9798

9899
Value lhs = adaptor.getLhs();
99100
Value rhs = adaptor.getRhs();
100101

101-
Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs);
102-
Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs);
102+
Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs, fmf);
103+
Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs, fmf);
103104
Value rhsSquaredPlusLhsSquared =
104-
b.create<complex::AddOp>(type, rhsSquared, lhsSquared);
105+
b.create<complex::AddOp>(type, rhsSquared, lhsSquared, fmf);
105106
Value sqrtOfRhsSquaredPlusLhsSquared =
106-
b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared);
107+
b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared, fmf);
107108

108109
Value zero =
109110
b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
110111
Value one = b.create<arith::ConstantOp>(elementType,
111112
b.getFloatAttr(elementType, 1));
112113
Value i = b.create<complex::CreateOp>(type, zero, one);
113-
Value iTimesLhs = b.create<complex::MulOp>(i, lhs);
114-
Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs);
114+
Value iTimesLhs = b.create<complex::MulOp>(i, lhs, fmf);
115+
Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs, fmf);
115116

116-
Value divResult =
117-
b.create<complex::DivOp>(rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared);
118-
Value logResult = b.create<complex::LogOp>(divResult);
117+
Value divResult = b.create<complex::DivOp>(
118+
rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf);
119+
Value logResult = b.create<complex::LogOp>(divResult, fmf);
119120

120121
Value negativeOne = b.create<arith::ConstantOp>(
121122
elementType, b.getFloatAttr(elementType, -1));
122123
Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);
123124

124-
rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult);
125+
rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult, fmf);
125126
return success();
126127
}
127128
};

0 commit comments

Comments
 (0)