Skip to content

Commit 4080cbd

Browse files
committed
[AMDGPU] Adopt new lowering sequence for fdiv16 (llvm#109295)
The current lowering of `fdiv16` can generate incorrectly rounded result in some cases. The new sequence was provided by the HW team, as shown below written in C++. ``` half fdiv(half a, half b) { float a32 = float(a); float b32 = float(b); float r32 = 1.0f / b32; float q32 = a32 * r32; float e32 = -b32 * q32 + a32; q32 = e32 * r32 + q32; e32 = -b32 * q32 + a32; float tmp = e32 * r32; uin32_t tmp32 = std::bit_cast<uint32_t>(tmp); tmp32 = tmp32 & 0xff800000; tmp = std::bit_cast<float>(tmp32); q32 = tmp + q32; half q16 = half(q32); q16 = div_fixup_f16(q16); return q16; } ``` Fixes SWDEV-477608. (cherry picked from commit 88a239d) Change-Id: I4d1dffc2037c66d2ebe0cf7b2e144eae78008d04
1 parent 1a70de0 commit 4080cbd

File tree

8 files changed

+3507
-1032
lines changed

8 files changed

+3507
-1032
lines changed

llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4887,16 +4887,40 @@ bool AMDGPULegalizerInfo::legalizeFDIV16(MachineInstr &MI,
48874887
LLT S16 = LLT::scalar(16);
48884888
LLT S32 = LLT::scalar(32);
48894889

4890+
// a32.u = opx(V_CVT_F32_F16, a.u); // CVT to F32
4891+
// b32.u = opx(V_CVT_F32_F16, b.u); // CVT to F32
4892+
// r32.u = opx(V_RCP_F32, b32.u); // rcp = 1 / d
4893+
// q32.u = opx(V_MUL_F32, a32.u, r32.u); // q = n * rcp
4894+
// e32.u = opx(V_MAD_F32, (b32.u^_neg32), q32.u, a32.u); // err = -d * q + n
4895+
// q32.u = opx(V_MAD_F32, e32.u, r32.u, q32.u); // q = n * rcp
4896+
// e32.u = opx(V_MAD_F32, (b32.u^_neg32), q32.u, a32.u); // err = -d * q + n
4897+
// tmp.u = opx(V_MUL_F32, e32.u, r32.u);
4898+
// tmp.u = opx(V_AND_B32, tmp.u, 0xff800000)
4899+
// q32.u = opx(V_ADD_F32, tmp.u, q32.u);
4900+
// q16.u = opx(V_CVT_F16_F32, q32.u);
4901+
// q16.u = opx(V_DIV_FIXUP_F16, q16.u, b.u, a.u); // q = touchup(q, d, n)
4902+
48904903
auto LHSExt = B.buildFPExt(S32, LHS, Flags);
48914904
auto RHSExt = B.buildFPExt(S32, RHS, Flags);
4892-
4893-
auto RCP = B.buildIntrinsic(Intrinsic::amdgcn_rcp, {S32})
4905+
auto NegRHSExt = B.buildFNeg(S32, RHSExt);
4906+
auto Rcp = B.buildIntrinsic(Intrinsic::amdgcn_rcp, {S32})
48944907
.addUse(RHSExt.getReg(0))
48954908
.setMIFlags(Flags);
4896-
4897-
auto QUOT = B.buildFMul(S32, LHSExt, RCP, Flags);
4898-
auto RDst = B.buildFPTrunc(S16, QUOT, Flags);
4899-
4909+
auto Quot = B.buildFMul(S32, LHSExt, Rcp, Flags);
4910+
MachineInstrBuilder Err;
4911+
if (ST.hasMadMacF32Insts()) {
4912+
Err = B.buildFMAD(S32, NegRHSExt, Quot, LHSExt, Flags);
4913+
Quot = B.buildFMAD(S32, Err, Rcp, Quot, Flags);
4914+
Err = B.buildFMAD(S32, NegRHSExt, Quot, LHSExt, Flags);
4915+
} else {
4916+
Err = B.buildFMA(S32, NegRHSExt, Quot, LHSExt, Flags);
4917+
Quot = B.buildFMA(S32, Err, Rcp, Quot, Flags);
4918+
Err = B.buildFMA(S32, NegRHSExt, Quot, LHSExt, Flags);
4919+
}
4920+
auto Tmp = B.buildFMul(S32, Err, Rcp, Flags);
4921+
Tmp = B.buildAnd(S32, Tmp, B.buildConstant(S32, 0xff800000));
4922+
Quot = B.buildFAdd(S32, Tmp, Quot, Flags);
4923+
auto RDst = B.buildFPTrunc(S16, Quot, Flags);
49004924
B.buildIntrinsic(Intrinsic::amdgcn_div_fixup, Res)
49014925
.addUse(RDst.getReg(0))
49024926
.addUse(RHS)

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10818,19 +10818,48 @@ SDValue SITargetLowering::LowerFDIV16(SDValue Op, SelectionDAG &DAG) const {
1081810818
return FastLowered;
1081910819

1082010820
SDLoc SL(Op);
10821-
SDValue Src0 = Op.getOperand(0);
10822-
SDValue Src1 = Op.getOperand(1);
10823-
10824-
SDValue CvtSrc0 = DAG.getNode(ISD::FP_EXTEND, SL, MVT::f32, Src0);
10825-
SDValue CvtSrc1 = DAG.getNode(ISD::FP_EXTEND, SL, MVT::f32, Src1);
10826-
10827-
SDValue RcpSrc1 = DAG.getNode(AMDGPUISD::RCP, SL, MVT::f32, CvtSrc1);
10828-
SDValue Quot = DAG.getNode(ISD::FMUL, SL, MVT::f32, CvtSrc0, RcpSrc1);
10829-
10830-
SDValue FPRoundFlag = DAG.getTargetConstant(0, SL, MVT::i32);
10831-
SDValue BestQuot = DAG.getNode(ISD::FP_ROUND, SL, MVT::f16, Quot, FPRoundFlag);
10821+
SDValue LHS = Op.getOperand(0);
10822+
SDValue RHS = Op.getOperand(1);
1083210823

10833-
return DAG.getNode(AMDGPUISD::DIV_FIXUP, SL, MVT::f16, BestQuot, Src1, Src0);
10824+
// a32.u = opx(V_CVT_F32_F16, a.u); // CVT to F32
10825+
// b32.u = opx(V_CVT_F32_F16, b.u); // CVT to F32
10826+
// r32.u = opx(V_RCP_F32, b32.u); // rcp = 1 / d
10827+
// q32.u = opx(V_MUL_F32, a32.u, r32.u); // q = n * rcp
10828+
// e32.u = opx(V_MAD_F32, (b32.u^_neg32), q32.u, a32.u); // err = -d * q + n
10829+
// q32.u = opx(V_MAD_F32, e32.u, r32.u, q32.u); // q = n * rcp
10830+
// e32.u = opx(V_MAD_F32, (b32.u^_neg32), q32.u, a32.u); // err = -d * q + n
10831+
// tmp.u = opx(V_MUL_F32, e32.u, r32.u);
10832+
// tmp.u = opx(V_AND_B32, tmp.u, 0xff800000)
10833+
// q32.u = opx(V_ADD_F32, tmp.u, q32.u);
10834+
// q16.u = opx(V_CVT_F16_F32, q32.u);
10835+
// q16.u = opx(V_DIV_FIXUP_F16, q16.u, b.u, a.u); // q = touchup(q, d, n)
10836+
10837+
// We will use ISD::FMA on targets that don't support ISD::FMAD.
10838+
unsigned FMADOpCode =
10839+
isOperationLegal(ISD::FMAD, MVT::f32) ? ISD::FMAD : ISD::FMA;
10840+
10841+
SDValue LHSExt = DAG.getNode(ISD::FP_EXTEND, SL, MVT::f32, LHS);
10842+
SDValue RHSExt = DAG.getNode(ISD::FP_EXTEND, SL, MVT::f32, RHS);
10843+
SDValue NegRHSExt = DAG.getNode(ISD::FNEG, SL, MVT::f32, RHSExt);
10844+
SDValue Rcp =
10845+
DAG.getNode(AMDGPUISD::RCP, SL, MVT::f32, RHSExt, Op->getFlags());
10846+
SDValue Quot =
10847+
DAG.getNode(ISD::FMUL, SL, MVT::f32, LHSExt, Rcp, Op->getFlags());
10848+
SDValue Err = DAG.getNode(FMADOpCode, SL, MVT::f32, NegRHSExt, Quot, LHSExt,
10849+
Op->getFlags());
10850+
Quot = DAG.getNode(FMADOpCode, SL, MVT::f32, Err, Rcp, Quot, Op->getFlags());
10851+
Err = DAG.getNode(FMADOpCode, SL, MVT::f32, NegRHSExt, Quot, LHSExt,
10852+
Op->getFlags());
10853+
SDValue Tmp = DAG.getNode(ISD::FMUL, SL, MVT::f32, Err, Rcp, Op->getFlags());
10854+
SDValue TmpCast = DAG.getNode(ISD::BITCAST, SL, MVT::i32, Tmp);
10855+
TmpCast = DAG.getNode(ISD::AND, SL, MVT::i32, TmpCast,
10856+
DAG.getConstant(0xff800000, SL, MVT::i32));
10857+
Tmp = DAG.getNode(ISD::BITCAST, SL, MVT::f32, TmpCast);
10858+
Quot = DAG.getNode(ISD::FADD, SL, MVT::f32, Tmp, Quot, Op->getFlags());
10859+
SDValue RDst = DAG.getNode(ISD::FP_ROUND, SL, MVT::f16, Quot,
10860+
DAG.getConstant(0, SL, MVT::i32));
10861+
return DAG.getNode(AMDGPUISD::DIV_FIXUP, SL, MVT::f16, RDst, RHS, LHS,
10862+
Op->getFlags());
1083410863
}
1083510864

1083610865
// Faster 2.5 ULP division that does not support denormals.

0 commit comments

Comments
 (0)