Skip to content

Commit 919f9e8

Browse files
author
Jun Bum Lim
committed
[InlineCost] Improve the cost heuristic for Switch
Summary: The motivation example is like below which has 13 cases but only 2 distinct targets ``` lor.lhs.false2: ; preds = %if.then switch i32 %Status, label %if.then27 [ i32 -7012, label %if.end35 i32 -10008, label %if.end35 i32 -10016, label %if.end35 i32 15000, label %if.end35 i32 14013, label %if.end35 i32 10114, label %if.end35 i32 10107, label %if.end35 i32 10105, label %if.end35 i32 10013, label %if.end35 i32 10011, label %if.end35 i32 7008, label %if.end35 i32 7007, label %if.end35 i32 5002, label %if.end35 ] ``` which is compiled into a balanced binary tree like this on AArch64 (similar on X86) ``` .LBB853_9: // %lor.lhs.false2 mov w8, #10012 cmp w19, w8 b.gt .LBB853_14 // BB#10: // %lor.lhs.false2 mov w8, #5001 cmp w19, w8 b.gt .LBB853_18 // BB#11: // %lor.lhs.false2 mov w8, #-10016 cmp w19, w8 b.eq .LBB853_23 // BB#12: // %lor.lhs.false2 mov w8, #-10008 cmp w19, w8 b.eq .LBB853_23 // BB#13: // %lor.lhs.false2 mov w8, #-7012 cmp w19, w8 b.eq .LBB853_23 b .LBB853_3 .LBB853_14: // %lor.lhs.false2 mov w8, #14012 cmp w19, w8 b.gt .LBB853_21 // BB#15: // %lor.lhs.false2 mov w8, #-10105 add w8, w19, w8 cmp w8, #9 // =9 b.hi .LBB853_17 // BB#16: // %lor.lhs.false2 orr w9, wzr, #0x1 lsl w8, w9, w8 mov w9, #517 and w8, w8, w9 cbnz w8, .LBB853_23 .LBB853_17: // %lor.lhs.false2 mov w8, #10013 cmp w19, w8 b.eq .LBB853_23 b .LBB853_3 .LBB853_18: // %lor.lhs.false2 mov w8, #-7007 add w8, w19, w8 cmp w8, #2 // =2 b.lo .LBB853_23 // BB#19: // %lor.lhs.false2 mov w8, #5002 cmp w19, w8 b.eq .LBB853_23 // BB#20: // %lor.lhs.false2 mov w8, #10011 cmp w19, w8 b.eq .LBB853_23 b .LBB853_3 .LBB853_21: // %lor.lhs.false2 mov w8, #14013 cmp w19, w8 b.eq .LBB853_23 // BB#22: // %lor.lhs.false2 mov w8, #15000 cmp w19, w8 b.ne .LBB853_3 ``` However, the inline cost model estimates the cost to be linear with the number of distinct targets and the cost of the above switch is just 2 InstrCosts. The function containing this switch is then inlined about 900 times. This change use the general way of switch lowering for the inline heuristic. It etimate the number of case clusters with the suitability check for a jump table or bit test. Considering the binary search tree built for the clusters, this change modifies the model to be linear with the size of the balanced binary tree. The model is off by default for now : -inline-generic-switch-cost=false This change was originally proposed by Haicheng in D29870. Reviewers: hans, bmakam, chandlerc, eraman, haicheng, mcrosier Reviewed By: hans Subscribers: joerg, aemerson, llvm-commits, rengolin Differential Revision: https://reviews.llvm.org/D31085 llvm-svn: 301649
1 parent 485ad42 commit 919f9e8

File tree

10 files changed

+405
-115
lines changed

10 files changed

+405
-115
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,12 @@ class TargetTransformInfo {
197197
int getIntrinsicCost(Intrinsic::ID IID, Type *RetTy,
198198
ArrayRef<const Value *> Arguments) const;
199199

200+
/// \return The estimated number of case clusters when lowering \p 'SI'.
201+
/// \p JTSize Set a jump table size only when \p SI is suitable for a jump
202+
/// table.
203+
unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI,
204+
unsigned &JTSize) const;
205+
200206
/// \brief Estimate the cost of a given IR user when lowered.
201207
///
202208
/// This can estimate the cost of either a ConstantExpr or Instruction when
@@ -764,6 +770,8 @@ class TargetTransformInfo::Concept {
764770
ArrayRef<Type *> ParamTys) = 0;
765771
virtual int getIntrinsicCost(Intrinsic::ID IID, Type *RetTy,
766772
ArrayRef<const Value *> Arguments) = 0;
773+
virtual unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI,
774+
unsigned &JTSize) = 0;
767775
virtual int getUserCost(const User *U) = 0;
768776
virtual bool hasBranchDivergence() = 0;
769777
virtual bool isSourceOfDivergence(const Value *V) = 0;
@@ -1067,6 +1075,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
10671075
unsigned getMaxInterleaveFactor(unsigned VF) override {
10681076
return Impl.getMaxInterleaveFactor(VF);
10691077
}
1078+
unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI,
1079+
unsigned &JTSize) override {
1080+
return Impl.getEstimatedNumberOfCaseClusters(SI, JTSize);
1081+
}
10701082
unsigned
10711083
getArithmeticInstrCost(unsigned Opcode, Type *Ty, OperandValueKind Opd1Info,
10721084
OperandValueKind Opd2Info,

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,12 @@ class TargetTransformInfoImplBase {
114114
return TTI::TCC_Free;
115115
}
116116

117+
unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI,
118+
unsigned &JTSize) {
119+
JTSize = 0;
120+
return SI.getNumCases();
121+
}
122+
117123
unsigned getCallCost(FunctionType *FTy, int NumArgs) {
118124
assert(FTy && "FunctionType must be provided to this routine.");
119125

llvm/include/llvm/CodeGen/BasicTTIImpl.h

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,62 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
171171
return BaseT::getIntrinsicCost(IID, RetTy, ParamTys);
172172
}
173173

174+
unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI,
175+
unsigned &JumpTableSize) {
176+
/// Try to find the estimated number of clusters. Note that the number of
177+
/// clusters identified in this function could be different from the actural
178+
/// numbers found in lowering. This function ignore switches that are
179+
/// lowered with a mix of jump table / bit test / BTree. This function was
180+
/// initially intended to be used when estimating the cost of switch in
181+
/// inline cost heuristic, but it's a generic cost model to be used in other
182+
/// places (e.g., in loop unrolling).
183+
unsigned N = SI.getNumCases();
184+
const TargetLoweringBase *TLI = getTLI();
185+
const DataLayout &DL = this->getDataLayout();
186+
187+
JumpTableSize = 0;
188+
bool IsJTAllowed = TLI->areJTsAllowed(SI.getParent()->getParent());
189+
190+
// Early exit if both a jump table and bit test are not allowed.
191+
if (N < 1 || (!IsJTAllowed && DL.getPointerSizeInBits() < N))
192+
return N;
193+
194+
APInt MaxCaseVal = SI.case_begin()->getCaseValue()->getValue();
195+
APInt MinCaseVal = MaxCaseVal;
196+
for (auto CI : SI.cases()) {
197+
const APInt &CaseVal = CI.getCaseValue()->getValue();
198+
if (CaseVal.sgt(MaxCaseVal))
199+
MaxCaseVal = CaseVal;
200+
if (CaseVal.slt(MinCaseVal))
201+
MinCaseVal = CaseVal;
202+
}
203+
204+
// Check if suitable for a bit test
205+
if (N <= DL.getPointerSizeInBits()) {
206+
SmallPtrSet<const BasicBlock *, 4> Dests;
207+
for (auto I : SI.cases())
208+
Dests.insert(I.getCaseSuccessor());
209+
210+
if (TLI->isSuitableForBitTests(Dests.size(), N, MinCaseVal, MaxCaseVal,
211+
DL))
212+
return 1;
213+
}
214+
215+
// Check if suitable for a jump table.
216+
if (IsJTAllowed) {
217+
if (N < 2 || N < TLI->getMinimumJumpTableEntries())
218+
return N;
219+
uint64_t Range =
220+
(MaxCaseVal - MinCaseVal).getLimitedValue(UINT64_MAX - 1) + 1;
221+
// Check whether a range of clusters is dense enough for a jump table
222+
if (TLI->isSuitableForJumpTable(&SI, N, Range)) {
223+
JumpTableSize = Range;
224+
return 1;
225+
}
226+
}
227+
return N;
228+
}
229+
174230
unsigned getJumpBufAlignment() { return getTLI()->getJumpBufAlignment(); }
175231

176232
unsigned getJumpBufSize() { return getTLI()->getJumpBufSize(); }

llvm/include/llvm/Target/TargetLowering.h

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,74 @@ class TargetLoweringBase {
775775
return (!isTypeLegal(VT) && getOperationAction(Op, VT) == Custom);
776776
}
777777

778+
/// Return true if lowering to a jump table is allowed.
779+
bool areJTsAllowed(const Function *Fn) const {
780+
if (Fn->getFnAttribute("no-jump-tables").getValueAsString() == "true")
781+
return false;
782+
783+
return isOperationLegalOrCustom(ISD::BR_JT, MVT::Other) ||
784+
isOperationLegalOrCustom(ISD::BRIND, MVT::Other);
785+
}
786+
787+
/// Check whether the range [Low,High] fits in a machine word.
788+
bool rangeFitsInWord(const APInt &Low, const APInt &High,
789+
const DataLayout &DL) const {
790+
// FIXME: Using the pointer type doesn't seem ideal.
791+
uint64_t BW = DL.getPointerSizeInBits();
792+
uint64_t Range = (High - Low).getLimitedValue(UINT64_MAX - 1) + 1;
793+
return Range <= BW;
794+
}
795+
796+
/// Return true if lowering to a jump table is suitable for a set of case
797+
/// clusters which may contain \p NumCases cases, \p Range range of values.
798+
/// FIXME: This function check the maximum table size and density, but the
799+
/// minimum size is not checked. It would be nice if the the minimum size is
800+
/// also combined within this function. Currently, the minimum size check is
801+
/// performed in findJumpTable() in SelectionDAGBuiler and
802+
/// getEstimatedNumberOfCaseClusters() in BasicTTIImpl.
803+
bool isSuitableForJumpTable(const SwitchInst *SI, uint64_t NumCases,
804+
uint64_t Range) const {
805+
const bool OptForSize = SI->getParent()->getParent()->optForSize();
806+
const unsigned MinDensity = getMinimumJumpTableDensity(OptForSize);
807+
const unsigned MaxJumpTableSize =
808+
OptForSize || getMaximumJumpTableSize() == 0
809+
? UINT_MAX
810+
: getMaximumJumpTableSize();
811+
// Check whether a range of clusters is dense enough for a jump table.
812+
if (Range <= MaxJumpTableSize &&
813+
(NumCases * 100 >= Range * MinDensity)) {
814+
return true;
815+
}
816+
return false;
817+
}
818+
819+
/// Return true if lowering to a bit test is suitable for a set of case
820+
/// clusters which contains \p NumDests unique destinations, \p Low and
821+
/// \p High as its lowest and highest case values, and expects \p NumCmps
822+
/// case value comparisons. Check if the number of destinations, comparison
823+
/// metric, and range are all suitable.
824+
bool isSuitableForBitTests(unsigned NumDests, unsigned NumCmps,
825+
const APInt &Low, const APInt &High,
826+
const DataLayout &DL) const {
827+
// FIXME: I don't think NumCmps is the correct metric: a single case and a
828+
// range of cases both require only one branch to lower. Just looking at the
829+
// number of clusters and destinations should be enough to decide whether to
830+
// build bit tests.
831+
832+
// To lower a range with bit tests, the range must fit the bitwidth of a
833+
// machine word.
834+
if (!rangeFitsInWord(Low, High, DL))
835+
return false;
836+
837+
// Decide whether it's profitable to lower this range with bit tests. Each
838+
// destination requires a bit test and branch, and there is an overall range
839+
// check branch. For a small number of clusters, separate comparisons might
840+
// be cheaper, and for many destinations, splitting the range might be
841+
// better.
842+
return (NumDests == 1 && NumCmps >= 3) || (NumDests == 2 && NumCmps >= 5) ||
843+
(NumDests == 3 && NumCmps >= 6);
844+
}
845+
778846
/// Return true if the specified operation is illegal on this target or
779847
/// unlikely to be made legal with custom lowering. This is used to help guide
780848
/// high-level lowering decisions.
@@ -1149,6 +1217,9 @@ class TargetLoweringBase {
11491217
/// Return lower limit for number of blocks in a jump table.
11501218
unsigned getMinimumJumpTableEntries() const;
11511219

1220+
/// Return lower limit of the density in a jump table.
1221+
unsigned getMinimumJumpTableDensity(bool OptForSize) const;
1222+
11521223
/// Return upper limit for number of entries in a jump table.
11531224
/// Zero if no limit.
11541225
unsigned getMaximumJumpTableSize() const;

llvm/lib/Analysis/InlineCost.cpp

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ static cl::opt<int>
5454
cl::init(45),
5555
cl::desc("Threshold for inlining cold callsites"));
5656

57+
static cl::opt<bool>
58+
EnableGenericSwitchCost("inline-generic-switch-cost", cl::Hidden,
59+
cl::init(false),
60+
cl::desc("Enable generic switch cost model"));
61+
5762
// We introduce this threshold to help performance of instrumentation based
5863
// PGO before we actually hook up inliner with analysis passes such as BPI and
5964
// BFI.
@@ -998,11 +1003,72 @@ bool CallAnalyzer::visitSwitchInst(SwitchInst &SI) {
9981003
if (isa<ConstantInt>(V))
9991004
return true;
10001005

1001-
// Otherwise, we need to accumulate a cost proportional to the number of
1002-
// distinct successor blocks. This fan-out in the CFG cannot be represented
1003-
// for free even if we can represent the core switch as a jumptable that
1004-
// takes a single instruction.
1005-
//
1006+
if (EnableGenericSwitchCost) {
1007+
// Assume the most general case where the swith is lowered into
1008+
// either a jump table, bit test, or a balanced binary tree consisting of
1009+
// case clusters without merging adjacent clusters with the same
1010+
// destination. We do not consider the switches that are lowered with a mix
1011+
// of jump table/bit test/binary search tree. The cost of the switch is
1012+
// proportional to the size of the tree or the size of jump table range.
1013+
1014+
// Exit early for a large switch, assuming one case needs at least one
1015+
// instruction.
1016+
// FIXME: This is not true for a bit test, but ignore such case for now to
1017+
// save compile-time.
1018+
int64_t CostLowerBound =
1019+
std::min((int64_t)INT_MAX,
1020+
(int64_t)SI.getNumCases() * InlineConstants::InstrCost + Cost);
1021+
1022+
if (CostLowerBound > Threshold) {
1023+
Cost = CostLowerBound;
1024+
return false;
1025+
}
1026+
1027+
unsigned JumpTableSize = 0;
1028+
unsigned NumCaseCluster =
1029+
TTI.getEstimatedNumberOfCaseClusters(SI, JumpTableSize);
1030+
1031+
// If suitable for a jump table, consider the cost for the table size and
1032+
// branch to destination.
1033+
if (JumpTableSize) {
1034+
int64_t JTCost = (int64_t)JumpTableSize * InlineConstants::InstrCost +
1035+
4 * InlineConstants::InstrCost;
1036+
Cost = std::min((int64_t)INT_MAX, JTCost + Cost);
1037+
return false;
1038+
}
1039+
1040+
// Considering forming a binary search, we should find the number of nodes
1041+
// which is same as the number of comparisons when lowered. For a given
1042+
// number of clusters, n, we can define a recursive function, f(n), to find
1043+
// the number of nodes in the tree. The recursion is :
1044+
// f(n) = 1 + f(n/2) + f (n - n/2), when n > 3,
1045+
// and f(n) = n, when n <= 3.
1046+
// This will lead a binary tree where the leaf should be either f(2) or f(3)
1047+
// when n > 3. So, the number of comparisons from leaves should be n, while
1048+
// the number of non-leaf should be :
1049+
// 2^(log2(n) - 1) - 1
1050+
// = 2^log2(n) * 2^-1 - 1
1051+
// = n / 2 - 1.
1052+
// Considering comparisons from leaf and non-leaf nodes, we can estimate the
1053+
// number of comparisons in a simple closed form :
1054+
// n + n / 2 - 1 = n * 3 / 2 - 1
1055+
if (NumCaseCluster <= 3) {
1056+
// Suppose a comparison includes one compare and one conditional branch.
1057+
Cost += NumCaseCluster * 2 * InlineConstants::InstrCost;
1058+
return false;
1059+
}
1060+
int64_t ExpectedNumberOfCompare = 3 * (uint64_t)NumCaseCluster / 2 - 1;
1061+
uint64_t SwitchCost =
1062+
ExpectedNumberOfCompare * 2 * InlineConstants::InstrCost;
1063+
Cost = std::min((uint64_t)INT_MAX, SwitchCost + Cost);
1064+
return false;
1065+
}
1066+
1067+
// Use a simple switch cost model where we accumulate a cost proportional to
1068+
// the number of distinct successor blocks. This fan-out in the CFG cannot
1069+
// be represented for free even if we can represent the core switch as a
1070+
// jumptable that takes a single instruction.
1071+
///
10061072
// NB: We convert large switches which are just used to initialize large phi
10071073
// nodes to lookup tables instead in simplify-cfg, so this shouldn't prevent
10081074
// inlining those. It will prevent inlining in cases where the optimization

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ int TargetTransformInfo::getIntrinsicCost(
8383
return Cost;
8484
}
8585

86+
unsigned
87+
TargetTransformInfo::getEstimatedNumberOfCaseClusters(const SwitchInst &SI,
88+
unsigned &JTSize) const {
89+
return TTIImpl->getEstimatedNumberOfCaseClusters(SI, JTSize);
90+
}
91+
8692
int TargetTransformInfo::getUserCost(const User *U) const {
8793
int Cost = TTIImpl->getUserCost(U);
8894
assert(Cost >= 0 && "TTI should not produce negative costs!");

0 commit comments

Comments
 (0)