Skip to content

[InlineCost] Consider the default branch when calculating cost #77856

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 11, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ class TargetTransformInfoImplBase {
(void)PSI;
(void)BFI;
JTSize = 0;
return SI.getNumCases();
bool HasDefault = !SI.defaultDestUndefined();
return SI.getNumCases() + HasDefault;
}

unsigned getInliningThresholdMultiplier() const { return 1; }
Expand Down
11 changes: 6 additions & 5 deletions llvm/include/llvm/CodeGen/BasicTTIImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
/// inline cost heuristic, but it's a generic cost model to be used in other
/// places (e.g., in loop unrolling).
unsigned N = SI.getNumCases();
bool HasDefault = !SI.defaultDestUndefined();
const TargetLoweringBase *TLI = getTLI();
const DataLayout &DL = this->getDataLayout();

Expand All @@ -454,7 +455,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {

// Early exit if both a jump table and bit test are not allowed.
if (N < 1 || (!IsJTAllowed && DL.getIndexSizeInBits(0u) < N))
return N;
return N + HasDefault;

APInt MaxCaseVal = SI.case_begin()->getCaseValue()->getValue();
APInt MinCaseVal = MaxCaseVal;
Expand All @@ -474,23 +475,23 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {

if (TLI->isSuitableForBitTests(Dests.size(), N, MinCaseVal, MaxCaseVal,
DL))
return 1;
return 1 + HasDefault;
}

// Check if suitable for a jump table.
if (IsJTAllowed) {
if (N < 2 || N < TLI->getMinimumJumpTableEntries())
return N;
return N + HasDefault;
uint64_t Range =
(MaxCaseVal - MinCaseVal)
.getLimitedValue(std::numeric_limits<uint64_t>::max() - 1) + 1;
// Check whether a range of clusters is dense enough for a jump table
if (TLI->isSuitableForJumpTable(&SI, N, Range, PSI, BFI)) {
JumpTableSize = Range;
return 1;
return 1 + HasDefault;
}
}
return N;
return N + HasDefault;
}

bool shouldBuildLookupTables() {
Expand Down
7 changes: 7 additions & 0 deletions llvm/include/llvm/IR/Instructions.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class DataLayout;
class StringRef;
class Type;
class Value;
class UnreachableInst;

//===----------------------------------------------------------------------===//
// AllocaInst Class
Expand Down Expand Up @@ -3505,6 +3506,12 @@ class SwitchInst : public Instruction {
return cast<BasicBlock>(getOperand(1));
}

/// Returns true if the default branch must result in immediate undefined
/// behavior, false otherwise.
bool defaultDestUndefined() const {
return isa<UnreachableInst>(getDefaultDest()->getFirstNonPHIOrDbg());
}

void setDefaultDest(BasicBlock *DefaultCase) {
setOperand(1, reinterpret_cast<Value*>(DefaultCase));
}
Expand Down
6 changes: 5 additions & 1 deletion llvm/lib/Analysis/InlineCost.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -707,8 +707,9 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
if (JumpTableSize) {
int64_t JTCost =
static_cast<int64_t>(JumpTableSize) * InstrCost + 4 * InstrCost;

addCost(JTCost);
if (NumCaseCluster > 1)
addCost((NumCaseCluster - 1) * 2 * InstrCost);
return;
}

Expand Down Expand Up @@ -1238,6 +1239,9 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
int64_t JTCost = static_cast<int64_t>(JumpTableSize) * InstrCost +
JTCostMultiplier * InstrCost;
increment(InlineCostFeatureIndex::jump_table_penalty, JTCost);
if (NumCaseCluster > 1)
increment(InlineCostFeatureIndex::case_cluster_penalty,
(NumCaseCluster - 1) * CaseClusterCostMultiplier * InstrCost);
return;
}

Expand Down
Loading