Skip to content

Commit d36eb79

Browse files
authored
[RISCV] Support Strict FP arithmetic Op when only have Zvfhmin (#68867)
Include: STRICT_FADD, STRICT_FSUB, STRICT_FMUL, STRICT_FDIV, STRICT_FSQRT and STRICT_FMA.
1 parent ab03141 commit d36eb79

11 files changed

+2910
-556
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ class VectorLegalizer {
179179
/// type.
180180
void PromoteSETCC(SDNode *Node, SmallVectorImpl<SDValue> &Results);
181181

182+
void PromoteSTRICT(SDNode *Node, SmallVectorImpl<SDValue> &Results);
183+
182184
public:
183185
VectorLegalizer(SelectionDAG& dag) :
184186
DAG(dag), TLI(dag.getTargetLoweringInfo()) {}
@@ -636,6 +638,47 @@ void VectorLegalizer::PromoteSETCC(SDNode *Node,
636638
Results.push_back(Res);
637639
}
638640

641+
void VectorLegalizer::PromoteSTRICT(SDNode *Node,
642+
SmallVectorImpl<SDValue> &Results) {
643+
MVT VecVT = Node->getOperand(1).getSimpleValueType();
644+
MVT NewVecVT = TLI.getTypeToPromoteTo(Node->getOpcode(), VecVT);
645+
646+
assert(VecVT.isFloatingPoint());
647+
648+
SDLoc DL(Node);
649+
SmallVector<SDValue, 5> Operands(Node->getNumOperands());
650+
SmallVector<SDValue, 2> Chains;
651+
652+
for (unsigned j = 1; j != Node->getNumOperands(); ++j)
653+
if (Node->getOperand(j).getValueType().isVector() &&
654+
!(ISD::isVPOpcode(Node->getOpcode()) &&
655+
ISD::getVPMaskIdx(Node->getOpcode()) == j)) // Skip mask operand.
656+
{
657+
// promote the vector operand.
658+
SDValue Ext =
659+
DAG.getNode(ISD::STRICT_FP_EXTEND, DL, {NewVecVT, MVT::Other},
660+
{Node->getOperand(0), Node->getOperand(j)});
661+
Operands[j] = Ext.getValue(0);
662+
Chains.push_back(Ext.getValue(1));
663+
} else
664+
Operands[j] = Node->getOperand(j); // Skip no vector operand.
665+
666+
SDVTList VTs = DAG.getVTList(NewVecVT, Node->getValueType(1));
667+
668+
Operands[0] = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains);
669+
670+
SDValue Res =
671+
DAG.getNode(Node->getOpcode(), DL, VTs, Operands, Node->getFlags());
672+
673+
SDValue Round =
674+
DAG.getNode(ISD::STRICT_FP_ROUND, DL, {VecVT, MVT::Other},
675+
{Res.getValue(1), Res.getValue(0),
676+
DAG.getIntPtrConstant(0, DL, /*isTarget=*/true)});
677+
678+
Results.push_back(Round.getValue(0));
679+
Results.push_back(Round.getValue(1));
680+
}
681+
639682
void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
640683
// For a few operations there is a specific concept for promotion based on
641684
// the operand's type.
@@ -676,6 +719,14 @@ void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
676719
// Promote the operation by extending the operand.
677720
PromoteSETCC(Node, Results);
678721
return;
722+
case ISD::STRICT_FADD:
723+
case ISD::STRICT_FSUB:
724+
case ISD::STRICT_FMUL:
725+
case ISD::STRICT_FDIV:
726+
case ISD::STRICT_FSQRT:
727+
case ISD::STRICT_FMA:
728+
PromoteSTRICT(Node, Results);
729+
return;
679730
case ISD::FP_ROUND:
680731
case ISD::FP_EXTEND:
681732
// These operations are used to do promotion so they can't be promoted

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -896,12 +896,13 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
896896

897897
// TODO: support more ops.
898898
static const unsigned ZvfhminPromoteOps[] = {
899-
ISD::FMINNUM, ISD::FMAXNUM, ISD::FADD, ISD::FSUB,
900-
ISD::FMUL, ISD::FMA, ISD::FDIV, ISD::FSQRT,
901-
ISD::FABS, ISD::FNEG, ISD::FCOPYSIGN, ISD::FCEIL,
902-
ISD::FFLOOR, ISD::FROUND, ISD::FROUNDEVEN, ISD::FRINT,
903-
ISD::FNEARBYINT, ISD::IS_FPCLASS, ISD::SETCC, ISD::FMAXIMUM,
904-
ISD::FMINIMUM};
899+
ISD::FMINNUM, ISD::FMAXNUM, ISD::FADD, ISD::FSUB,
900+
ISD::FMUL, ISD::FMA, ISD::FDIV, ISD::FSQRT,
901+
ISD::FABS, ISD::FNEG, ISD::FCOPYSIGN, ISD::FCEIL,
902+
ISD::FFLOOR, ISD::FROUND, ISD::FROUNDEVEN, ISD::FRINT,
903+
ISD::FNEARBYINT, ISD::IS_FPCLASS, ISD::SETCC, ISD::FMAXIMUM,
904+
ISD::FMINIMUM, ISD::STRICT_FADD, ISD::STRICT_FSUB, ISD::STRICT_FMUL,
905+
ISD::STRICT_FDIV, ISD::STRICT_FSQRT, ISD::STRICT_FMA};
905906

906907
// TODO: support more vp ops.
907908
static const unsigned ZvfhminPromoteVPOps[] = {
@@ -5597,6 +5598,41 @@ static SDValue SplitVectorReductionOp(SDValue Op, SelectionDAG &DAG) {
55975598
{ResLo, Hi, MaskHi, EVLHi}, Op->getFlags());
55985599
}
55995600

5601+
static SDValue SplitStrictFPVectorOp(SDValue Op, SelectionDAG &DAG) {
5602+
5603+
assert(Op->isStrictFPOpcode());
5604+
5605+
auto [LoVT, HiVT] = DAG.GetSplitDestVTs(Op->getValueType(0));
5606+
5607+
SDVTList LoVTs = DAG.getVTList(LoVT, Op->getValueType(1));
5608+
SDVTList HiVTs = DAG.getVTList(HiVT, Op->getValueType(1));
5609+
5610+
SDLoc DL(Op);
5611+
5612+
SmallVector<SDValue, 4> LoOperands(Op.getNumOperands());
5613+
SmallVector<SDValue, 4> HiOperands(Op.getNumOperands());
5614+
5615+
for (unsigned j = 0; j != Op.getNumOperands(); ++j) {
5616+
if (!Op.getOperand(j).getValueType().isVector()) {
5617+
LoOperands[j] = Op.getOperand(j);
5618+
HiOperands[j] = Op.getOperand(j);
5619+
continue;
5620+
}
5621+
std::tie(LoOperands[j], HiOperands[j]) =
5622+
DAG.SplitVector(Op.getOperand(j), DL);
5623+
}
5624+
5625+
SDValue LoRes =
5626+
DAG.getNode(Op.getOpcode(), DL, LoVTs, LoOperands, Op->getFlags());
5627+
HiOperands[0] = LoRes.getValue(1);
5628+
SDValue HiRes =
5629+
DAG.getNode(Op.getOpcode(), DL, HiVTs, HiOperands, Op->getFlags());
5630+
5631+
SDValue V = DAG.getNode(ISD::CONCAT_VECTORS, DL, Op->getValueType(0),
5632+
LoRes.getValue(0), HiRes.getValue(0));
5633+
return DAG.getMergeValues({V, HiRes.getValue(1)}, DL);
5634+
}
5635+
56005636
SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
56015637
SelectionDAG &DAG) const {
56025638
switch (Op.getOpcode()) {
@@ -6374,6 +6410,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
63746410
case ISD::STRICT_FDIV:
63756411
case ISD::STRICT_FSQRT:
63766412
case ISD::STRICT_FMA:
6413+
if (Op.getValueType() == MVT::nxv32f16 &&
6414+
(Subtarget.hasVInstructionsF16Minimal() &&
6415+
!Subtarget.hasVInstructionsF16()))
6416+
return SplitStrictFPVectorOp(Op, DAG);
63776417
return lowerToScalableOp(Op, DAG);
63786418
case ISD::STRICT_FSETCC:
63796419
case ISD::STRICT_FSETCCS:

0 commit comments

Comments
 (0)