Skip to content

Commit 3c912c4

Browse files
committed
[DAG][X86] Convert isNegatibleForFree/GetNegatedExpression to a target hook (PR42863)
This patch converts the DAGCombine isNegatibleForFree/GetNegatedExpression into overridable TLI hooks. The intention is to let us extend existing FNEG combines to work more generally with negatible float ops, allowing it work with target specific combines and opcodes (e.g. X86's FMA variants). Unlike the SimplifyDemandedBits, we can't just handle target nodes through a Target callback, we need to do this as an override to allow targets to handle generic opcodes as well. This does mean that the target implementations has to duplicate some checks (recursion depth etc.). Partial reversion of rL372756 - I've identified the infinite loop issue inside the X86 override but haven't fixed it yet so I've only (re)committed the common TargetLowering refactoring part of the patch. Differential Revision: https://reviews.llvm.org/D67557 llvm-svn: 373343
1 parent 796cd31 commit 3c912c4

File tree

3 files changed

+284
-276
lines changed

3 files changed

+284
-276
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3386,6 +3386,18 @@ class TargetLowering : public TargetLoweringBase {
33863386
llvm_unreachable("Not Implemented");
33873387
}
33883388

3389+
/// Return 1 if we can compute the negated form of the specified expression
3390+
/// for the same cost as the expression itself, or 2 if we can compute the
3391+
/// negated form more cheaply than the expression itself. Else return 0.
3392+
virtual char isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
3393+
bool LegalOperations, bool ForCodeSize,
3394+
unsigned Depth = 0) const;
3395+
3396+
/// If isNegatibleForFree returns true, return the newly negated expression.
3397+
virtual SDValue getNegatedExpression(SDValue Op, SelectionDAG &DAG,
3398+
bool LegalOperations, bool ForCodeSize,
3399+
unsigned Depth = 0) const;
3400+
33893401
//===--------------------------------------------------------------------===//
33903402
// Lowering methods - These methods must be implemented by targets so that
33913403
// the SelectionDAGBuilder code knows how to lower these.

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 32 additions & 276 deletions
Original file line numberDiff line numberDiff line change
@@ -785,252 +785,6 @@ void DAGCombiner::deleteAndRecombine(SDNode *N) {
785785
DAG.DeleteNode(N);
786786
}
787787

788-
/// Return 1 if we can compute the negated form of the specified expression for
789-
/// the same cost as the expression itself, or 2 if we can compute the negated
790-
/// form more cheaply than the expression itself.
791-
static char isNegatibleForFree(SDValue Op, bool LegalOperations,
792-
const TargetLowering &TLI,
793-
const TargetOptions *Options,
794-
bool ForCodeSize,
795-
unsigned Depth = 0) {
796-
// fneg is removable even if it has multiple uses.
797-
if (Op.getOpcode() == ISD::FNEG)
798-
return 2;
799-
800-
// Don't allow anything with multiple uses unless we know it is free.
801-
EVT VT = Op.getValueType();
802-
const SDNodeFlags Flags = Op->getFlags();
803-
if (!Op.hasOneUse() &&
804-
!(Op.getOpcode() == ISD::FP_EXTEND &&
805-
TLI.isFPExtFree(VT, Op.getOperand(0).getValueType())))
806-
return 0;
807-
808-
// Don't recurse exponentially.
809-
if (Depth > SelectionDAG::MaxRecursionDepth)
810-
return 0;
811-
812-
switch (Op.getOpcode()) {
813-
default: return false;
814-
case ISD::ConstantFP: {
815-
if (!LegalOperations)
816-
return 1;
817-
818-
// Don't invert constant FP values after legalization unless the target says
819-
// the negated constant is legal.
820-
return TLI.isOperationLegal(ISD::ConstantFP, VT) ||
821-
TLI.isFPImmLegal(neg(cast<ConstantFPSDNode>(Op)->getValueAPF()), VT,
822-
ForCodeSize);
823-
}
824-
case ISD::BUILD_VECTOR: {
825-
// Only permit BUILD_VECTOR of constants.
826-
if (llvm::any_of(Op->op_values(), [&](SDValue N) {
827-
return !N.isUndef() && !isa<ConstantFPSDNode>(N);
828-
}))
829-
return 0;
830-
if (!LegalOperations)
831-
return 1;
832-
if (TLI.isOperationLegal(ISD::ConstantFP, VT) &&
833-
TLI.isOperationLegal(ISD::BUILD_VECTOR, VT))
834-
return 1;
835-
return llvm::all_of(Op->op_values(), [&](SDValue N) {
836-
return N.isUndef() ||
837-
TLI.isFPImmLegal(neg(cast<ConstantFPSDNode>(N)->getValueAPF()), VT,
838-
ForCodeSize);
839-
});
840-
}
841-
case ISD::FADD:
842-
if (!Options->NoSignedZerosFPMath && !Flags.hasNoSignedZeros())
843-
return 0;
844-
845-
// After operation legalization, it might not be legal to create new FSUBs.
846-
if (LegalOperations && !TLI.isOperationLegalOrCustom(ISD::FSUB, VT))
847-
return 0;
848-
849-
// fold (fneg (fadd A, B)) -> (fsub (fneg A), B)
850-
if (char V = isNegatibleForFree(Op.getOperand(0), LegalOperations, TLI,
851-
Options, ForCodeSize, Depth + 1))
852-
return V;
853-
// fold (fneg (fadd A, B)) -> (fsub (fneg B), A)
854-
return isNegatibleForFree(Op.getOperand(1), LegalOperations, TLI, Options,
855-
ForCodeSize, Depth + 1);
856-
case ISD::FSUB:
857-
// We can't turn -(A-B) into B-A when we honor signed zeros.
858-
if (!Options->NoSignedZerosFPMath && !Flags.hasNoSignedZeros())
859-
return 0;
860-
861-
// fold (fneg (fsub A, B)) -> (fsub B, A)
862-
return 1;
863-
864-
case ISD::FMUL:
865-
case ISD::FDIV:
866-
// fold (fneg (fmul X, Y)) -> (fmul (fneg X), Y) or (fmul X, (fneg Y))
867-
if (char V = isNegatibleForFree(Op.getOperand(0), LegalOperations, TLI,
868-
Options, ForCodeSize, Depth + 1))
869-
return V;
870-
871-
// Ignore X * 2.0 because that is expected to be canonicalized to X + X.
872-
if (auto *C = isConstOrConstSplatFP(Op.getOperand(1)))
873-
if (C->isExactlyValue(2.0) && Op.getOpcode() == ISD::FMUL)
874-
return 0;
875-
876-
return isNegatibleForFree(Op.getOperand(1), LegalOperations, TLI, Options,
877-
ForCodeSize, Depth + 1);
878-
879-
case ISD::FMA:
880-
case ISD::FMAD: {
881-
if (!Options->NoSignedZerosFPMath && !Flags.hasNoSignedZeros())
882-
return 0;
883-
884-
// fold (fneg (fma X, Y, Z)) -> (fma (fneg X), Y, (fneg Z))
885-
// fold (fneg (fma X, Y, Z)) -> (fma X, (fneg Y), (fneg Z))
886-
char V2 = isNegatibleForFree(Op.getOperand(2), LegalOperations, TLI,
887-
Options, ForCodeSize, Depth + 1);
888-
if (!V2)
889-
return 0;
890-
891-
// One of Op0/Op1 must be cheaply negatible, then select the cheapest.
892-
char V0 = isNegatibleForFree(Op.getOperand(0), LegalOperations, TLI,
893-
Options, ForCodeSize, Depth + 1);
894-
char V1 = isNegatibleForFree(Op.getOperand(1), LegalOperations, TLI,
895-
Options, ForCodeSize, Depth + 1);
896-
char V01 = std::max(V0, V1);
897-
return V01 ? std::max(V01, V2) : 0;
898-
}
899-
900-
case ISD::FP_EXTEND:
901-
case ISD::FP_ROUND:
902-
case ISD::FSIN:
903-
return isNegatibleForFree(Op.getOperand(0), LegalOperations, TLI, Options,
904-
ForCodeSize, Depth + 1);
905-
}
906-
}
907-
908-
/// If isNegatibleForFree returns true, return the newly negated expression.
909-
static SDValue GetNegatedExpression(SDValue Op, SelectionDAG &DAG,
910-
bool LegalOperations, bool ForCodeSize,
911-
unsigned Depth = 0) {
912-
// fneg is removable even if it has multiple uses.
913-
if (Op.getOpcode() == ISD::FNEG)
914-
return Op.getOperand(0);
915-
916-
assert(Depth <= SelectionDAG::MaxRecursionDepth &&
917-
"GetNegatedExpression doesn't match isNegatibleForFree");
918-
const TargetOptions &Options = DAG.getTarget().Options;
919-
const SDNodeFlags Flags = Op->getFlags();
920-
921-
switch (Op.getOpcode()) {
922-
default: llvm_unreachable("Unknown code");
923-
case ISD::ConstantFP: {
924-
APFloat V = cast<ConstantFPSDNode>(Op)->getValueAPF();
925-
V.changeSign();
926-
return DAG.getConstantFP(V, SDLoc(Op), Op.getValueType());
927-
}
928-
case ISD::BUILD_VECTOR: {
929-
SmallVector<SDValue, 4> Ops;
930-
for (SDValue C : Op->op_values()) {
931-
if (C.isUndef()) {
932-
Ops.push_back(C);
933-
continue;
934-
}
935-
APFloat V = cast<ConstantFPSDNode>(C)->getValueAPF();
936-
V.changeSign();
937-
Ops.push_back(DAG.getConstantFP(V, SDLoc(Op), C.getValueType()));
938-
}
939-
return DAG.getBuildVector(Op.getValueType(), SDLoc(Op), Ops);
940-
}
941-
case ISD::FADD:
942-
assert((Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros()) &&
943-
"Expected NSZ fp-flag");
944-
945-
// fold (fneg (fadd A, B)) -> (fsub (fneg A), B)
946-
if (isNegatibleForFree(Op.getOperand(0), LegalOperations,
947-
DAG.getTargetLoweringInfo(), &Options, ForCodeSize,
948-
Depth + 1))
949-
return DAG.getNode(ISD::FSUB, SDLoc(Op), Op.getValueType(),
950-
GetNegatedExpression(Op.getOperand(0), DAG,
951-
LegalOperations, ForCodeSize,
952-
Depth + 1),
953-
Op.getOperand(1), Flags);
954-
// fold (fneg (fadd A, B)) -> (fsub (fneg B), A)
955-
return DAG.getNode(ISD::FSUB, SDLoc(Op), Op.getValueType(),
956-
GetNegatedExpression(Op.getOperand(1), DAG,
957-
LegalOperations, ForCodeSize,
958-
Depth + 1),
959-
Op.getOperand(0), Flags);
960-
case ISD::FSUB:
961-
// fold (fneg (fsub 0, B)) -> B
962-
if (ConstantFPSDNode *N0CFP =
963-
isConstOrConstSplatFP(Op.getOperand(0), /*AllowUndefs*/ true))
964-
if (N0CFP->isZero())
965-
return Op.getOperand(1);
966-
967-
// fold (fneg (fsub A, B)) -> (fsub B, A)
968-
return DAG.getNode(ISD::FSUB, SDLoc(Op), Op.getValueType(),
969-
Op.getOperand(1), Op.getOperand(0), Flags);
970-
971-
case ISD::FMUL:
972-
case ISD::FDIV:
973-
// fold (fneg (fmul X, Y)) -> (fmul (fneg X), Y)
974-
if (isNegatibleForFree(Op.getOperand(0), LegalOperations,
975-
DAG.getTargetLoweringInfo(), &Options, ForCodeSize,
976-
Depth + 1))
977-
return DAG.getNode(Op.getOpcode(), SDLoc(Op), Op.getValueType(),
978-
GetNegatedExpression(Op.getOperand(0), DAG,
979-
LegalOperations, ForCodeSize,
980-
Depth + 1),
981-
Op.getOperand(1), Flags);
982-
983-
// fold (fneg (fmul X, Y)) -> (fmul X, (fneg Y))
984-
return DAG.getNode(Op.getOpcode(), SDLoc(Op), Op.getValueType(),
985-
Op.getOperand(0),
986-
GetNegatedExpression(Op.getOperand(1), DAG,
987-
LegalOperations, ForCodeSize,
988-
Depth + 1), Flags);
989-
990-
case ISD::FMA:
991-
case ISD::FMAD: {
992-
assert((Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros()) &&
993-
"Expected NSZ fp-flag");
994-
995-
SDValue Neg2 = GetNegatedExpression(Op.getOperand(2), DAG, LegalOperations,
996-
ForCodeSize, Depth + 1);
997-
998-
char V0 = isNegatibleForFree(Op.getOperand(0), LegalOperations,
999-
DAG.getTargetLoweringInfo(), &Options,
1000-
ForCodeSize, Depth + 1);
1001-
char V1 = isNegatibleForFree(Op.getOperand(1), LegalOperations,
1002-
DAG.getTargetLoweringInfo(), &Options,
1003-
ForCodeSize, Depth + 1);
1004-
if (V0 >= V1) {
1005-
// fold (fneg (fma X, Y, Z)) -> (fma (fneg X), Y, (fneg Z))
1006-
SDValue Neg0 = GetNegatedExpression(
1007-
Op.getOperand(0), DAG, LegalOperations, ForCodeSize, Depth + 1);
1008-
return DAG.getNode(Op.getOpcode(), SDLoc(Op), Op.getValueType(), Neg0,
1009-
Op.getOperand(1), Neg2, Flags);
1010-
}
1011-
1012-
// fold (fneg (fma X, Y, Z)) -> (fma X, (fneg Y), (fneg Z))
1013-
SDValue Neg1 = GetNegatedExpression(Op.getOperand(1), DAG, LegalOperations,
1014-
ForCodeSize, Depth + 1);
1015-
return DAG.getNode(Op.getOpcode(), SDLoc(Op), Op.getValueType(),
1016-
Op.getOperand(0), Neg1, Neg2, Flags);
1017-
}
1018-
1019-
case ISD::FP_EXTEND:
1020-
case ISD::FSIN:
1021-
return DAG.getNode(Op.getOpcode(), SDLoc(Op), Op.getValueType(),
1022-
GetNegatedExpression(Op.getOperand(0), DAG,
1023-
LegalOperations, ForCodeSize,
1024-
Depth + 1));
1025-
case ISD::FP_ROUND:
1026-
return DAG.getNode(ISD::FP_ROUND, SDLoc(Op), Op.getValueType(),
1027-
GetNegatedExpression(Op.getOperand(0), DAG,
1028-
LegalOperations, ForCodeSize,
1029-
Depth + 1),
1030-
Op.getOperand(1));
1031-
}
1032-
}
1033-
1034788
// APInts must be the same size for most operations, this helper
1035789
// function zero extends the shorter of the pair so that they match.
1036790
// We provide an Offset so that we can create bitwidths that won't overflow.
@@ -12053,17 +11807,17 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
1205311807

1205411808
// fold (fadd A, (fneg B)) -> (fsub A, B)
1205511809
if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT)) &&
12056-
isNegatibleForFree(N1, LegalOperations, TLI, &Options, ForCodeSize) == 2)
12057-
return DAG.getNode(ISD::FSUB, DL, VT, N0,
12058-
GetNegatedExpression(N1, DAG, LegalOperations,
12059-
ForCodeSize), Flags);
11810+
TLI.isNegatibleForFree(N1, DAG, LegalOperations, ForCodeSize) == 2)
11811+
return DAG.getNode(
11812+
ISD::FSUB, DL, VT, N0,
11813+
TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize), Flags);
1206011814

1206111815
// fold (fadd (fneg A), B) -> (fsub B, A)
1206211816
if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT)) &&
12063-
isNegatibleForFree(N0, LegalOperations, TLI, &Options, ForCodeSize) == 2)
12064-
return DAG.getNode(ISD::FSUB, DL, VT, N1,
12065-
GetNegatedExpression(N0, DAG, LegalOperations,
12066-
ForCodeSize), Flags);
11817+
TLI.isNegatibleForFree(N0, DAG, LegalOperations, ForCodeSize) == 2)
11818+
return DAG.getNode(
11819+
ISD::FSUB, DL, VT, N1,
11820+
TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize), Flags);
1206711821

1206811822
auto isFMulNegTwo = [](SDValue FMul) {
1206911823
if (!FMul.hasOneUse() || FMul.getOpcode() != ISD::FMUL)
@@ -12242,16 +11996,16 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) {
1224211996
if (N0CFP && N0CFP->isZero()) {
1224311997
if (N0CFP->isNegative() ||
1224411998
(Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())) {
12245-
if (isNegatibleForFree(N1, LegalOperations, TLI, &Options, ForCodeSize))
12246-
return GetNegatedExpression(N1, DAG, LegalOperations, ForCodeSize);
11999+
if (TLI.isNegatibleForFree(N1, DAG, LegalOperations, ForCodeSize))
12000+
return TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize);
1224712001
if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))
1224812002
return DAG.getNode(ISD::FNEG, DL, VT, N1, Flags);
1224912003
}
1225012004
}
1225112005

1225212006
if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
12253-
(Flags.hasAllowReassociation() && Flags.hasNoSignedZeros()))
12254-
&& N1.getOpcode() == ISD::FADD) {
12007+
(Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
12008+
N1.getOpcode() == ISD::FADD) {
1225512009
// X - (X + Y) -> -Y
1225612010
if (N0 == N1->getOperand(0))
1225712011
return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(1), Flags);
@@ -12261,10 +12015,10 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) {
1226112015
}
1226212016

1226312017
// fold (fsub A, (fneg B)) -> (fadd A, B)
12264-
if (isNegatibleForFree(N1, LegalOperations, TLI, &Options, ForCodeSize))
12265-
return DAG.getNode(ISD::FADD, DL, VT, N0,
12266-
GetNegatedExpression(N1, DAG, LegalOperations,
12267-
ForCodeSize), Flags);
12018+
if (TLI.isNegatibleForFree(N1, DAG, LegalOperations, ForCodeSize))
12019+
return DAG.getNode(
12020+
ISD::FADD, DL, VT, N0,
12021+
TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize), Flags);
1226812022

1226912023
// FSUB -> FMA combines:
1227012024
if (SDValue Fused = visitFSUBForFMACombine(N)) {
@@ -12278,11 +12032,10 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) {
1227812032
/// Return true if both inputs are at least as cheap in negated form and at
1227912033
/// least one input is strictly cheaper in negated form.
1228012034
bool DAGCombiner::isCheaperToUseNegatedFPOps(SDValue X, SDValue Y) {
12281-
const TargetOptions &Options = DAG.getTarget().Options;
12282-
if (char LHSNeg = isNegatibleForFree(X, LegalOperations, TLI, &Options,
12283-
ForCodeSize))
12284-
if (char RHSNeg = isNegatibleForFree(Y, LegalOperations, TLI, &Options,
12285-
ForCodeSize))
12035+
if (char LHSNeg =
12036+
TLI.isNegatibleForFree(X, DAG, LegalOperations, ForCodeSize))
12037+
if (char RHSNeg =
12038+
TLI.isNegatibleForFree(Y, DAG, LegalOperations, ForCodeSize))
1228612039
// Both negated operands are at least as cheap as their counterparts.
1228712040
// Check to see if at least one is cheaper negated.
1228812041
if (LHSNeg == 2 || RHSNeg == 2)
@@ -12363,8 +12116,10 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) {
1236312116

1236412117
// -N0 * -N1 --> N0 * N1
1236512118
if (isCheaperToUseNegatedFPOps(N0, N1)) {
12366-
SDValue NegN0 = GetNegatedExpression(N0, DAG, LegalOperations, ForCodeSize);
12367-
SDValue NegN1 = GetNegatedExpression(N1, DAG, LegalOperations, ForCodeSize);
12119+
SDValue NegN0 =
12120+
TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize);
12121+
SDValue NegN1 =
12122+
TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize);
1236812123
return DAG.getNode(ISD::FMUL, DL, VT, NegN0, NegN1, Flags);
1236912124
}
1237012125

@@ -12446,8 +12201,10 @@ SDValue DAGCombiner::visitFMA(SDNode *N) {
1244612201

1244712202
// (-N0 * -N1) + N2 --> (N0 * N1) + N2
1244812203
if (isCheaperToUseNegatedFPOps(N0, N1)) {
12449-
SDValue NegN0 = GetNegatedExpression(N0, DAG, LegalOperations, ForCodeSize);
12450-
SDValue NegN1 = GetNegatedExpression(N1, DAG, LegalOperations, ForCodeSize);
12204+
SDValue NegN0 =
12205+
TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize);
12206+
SDValue NegN1 =
12207+
TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize);
1245112208
return DAG.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2, Flags);
1245212209
}
1245312210

@@ -12708,8 +12465,8 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
1270812465
if (isCheaperToUseNegatedFPOps(N0, N1))
1270912466
return DAG.getNode(
1271012467
ISD::FDIV, SDLoc(N), VT,
12711-
GetNegatedExpression(N0, DAG, LegalOperations, ForCodeSize),
12712-
GetNegatedExpression(N1, DAG, LegalOperations, ForCodeSize), Flags);
12468+
TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize),
12469+
TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize), Flags);
1271312470

1271412471
return SDValue();
1271512472
}
@@ -13263,9 +13020,8 @@ SDValue DAGCombiner::visitFNEG(SDNode *N) {
1326313020
if (isConstantFPBuildVectorOrConstantFP(N0))
1326413021
return DAG.getNode(ISD::FNEG, SDLoc(N), VT, N0);
1326513022

13266-
if (isNegatibleForFree(N0, LegalOperations, DAG.getTargetLoweringInfo(),
13267-
&DAG.getTarget().Options, ForCodeSize))
13268-
return GetNegatedExpression(N0, DAG, LegalOperations, ForCodeSize);
13023+
if (TLI.isNegatibleForFree(N0, DAG, LegalOperations, ForCodeSize))
13024+
return TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize);
1326913025

1327013026
// Transform fneg(bitconvert(x)) -> bitconvert(x ^ sign) to avoid loading
1327113027
// constant pool values.

0 commit comments

Comments
 (0)