@@ -6899,21 +6899,19 @@ SDValue SITargetLowering::lowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const {
6899
6899
if (Op.getOpcode() != ISD::FP_ROUND)
6900
6900
return Op;
6901
6901
6902
- if (Subtarget->has16BitInsts()) {
6903
- if (getTargetMachine().Options.UnsafeFPMath) {
6904
- SDValue Flags = Op.getOperand(1);
6905
- SDValue Src32 = DAG.getNode(ISD::FP_ROUND, DL, MVT::f32, Src, Flags);
6906
- return DAG.getNode(ISD::FP_ROUND, DL, MVT::f16, Src32, Flags);
6907
- } else {
6908
- SDValue FpToFp16 = LowerF64ToF16(Src, DL, DAG);
6909
- SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FpToFp16);
6910
- return DAG.getNode(ISD::BITCAST, DL, MVT::f16, Trunc);
6911
- }
6912
- } else {
6902
+ if (!Subtarget->has16BitInsts()) {
6913
6903
SDValue FpToFp16 = DAG.getNode(ISD::FP_TO_FP16, DL, MVT::i32, Src);
6914
6904
SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FpToFp16);
6915
6905
return DAG.getNode(ISD::BITCAST, DL, MVT::f16, Trunc);
6916
6906
}
6907
+ if (getTargetMachine().Options.UnsafeFPMath) {
6908
+ SDValue Flags = Op.getOperand(1);
6909
+ SDValue Src32 = DAG.getNode(ISD::FP_ROUND, DL, MVT::f32, Src, Flags);
6910
+ return DAG.getNode(ISD::FP_ROUND, DL, MVT::f16, Src32, Flags);
6911
+ }
6912
+ SDValue FpToFp16 = LowerF64ToF16(Src, DL, DAG);
6913
+ SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FpToFp16);
6914
+ return DAG.getNode(ISD::BITCAST, DL, MVT::f16, Trunc);
6917
6915
}
6918
6916
6919
6917
assert(DstVT.getScalarType() == MVT::bf16 &&
0 commit comments