@@ -1957,6 +1957,8 @@ class GeneratedRTChecks {
1957
1957
bool CostTooHigh = false ;
1958
1958
const bool AddBranchWeights;
1959
1959
1960
+ Loop *OuterLoop = nullptr ;
1961
+
1960
1962
public:
1961
1963
GeneratedRTChecks (ScalarEvolution &SE, DominatorTree *DT, LoopInfo *LI,
1962
1964
TargetTransformInfo *TTI, const DataLayout &DL,
@@ -2053,6 +2055,9 @@ class GeneratedRTChecks {
2053
2055
DT->eraseNode (SCEVCheckBlock);
2054
2056
LI->removeBlock (SCEVCheckBlock);
2055
2057
}
2058
+
2059
+ // Outer loop is used as part of the later cost calculations.
2060
+ OuterLoop = L->getParentLoop ();
2056
2061
}
2057
2062
2058
2063
InstructionCost getCost () {
@@ -2076,16 +2081,62 @@ class GeneratedRTChecks {
2076
2081
LLVM_DEBUG (dbgs () << " " << C << " for " << I << " \n " );
2077
2082
RTCheckCost += C;
2078
2083
}
2079
- if (MemCheckBlock)
2084
+ if (MemCheckBlock) {
2085
+ InstructionCost MemCheckCost = 0 ;
2080
2086
for (Instruction &I : *MemCheckBlock) {
2081
2087
if (MemCheckBlock->getTerminator () == &I)
2082
2088
continue ;
2083
2089
InstructionCost C =
2084
2090
TTI->getInstructionCost (&I, TTI::TCK_RecipThroughput);
2085
2091
LLVM_DEBUG (dbgs () << " " << C << " for " << I << " \n " );
2086
- RTCheckCost += C;
2092
+ MemCheckCost += C;
2087
2093
}
2088
2094
2095
+ // If the runtime memory checks are being created inside an outer loop
2096
+ // we should find out if these checks are outer loop invariant. If so,
2097
+ // the checks will likely be hoisted out and so the effective cost will
2098
+ // reduce according to the outer loop trip count.
2099
+ if (OuterLoop) {
2100
+ ScalarEvolution *SE = MemCheckExp.getSE ();
2101
+ // TODO: If profitable, we could refine this further by analysing every
2102
+ // individual memory check, since there could be a mixture of loop
2103
+ // variant and invariant checks that mean the final condition is
2104
+ // variant.
2105
+ const SCEV *Cond = SE->getSCEV (MemRuntimeCheckCond);
2106
+ if (SE->isLoopInvariant (Cond, OuterLoop)) {
2107
+ // It seems reasonable to assume that we can reduce the effective
2108
+ // cost of the checks even when we know nothing about the trip
2109
+ // count. Here I've assumed that the outer loop executes at least
2110
+ // twice.
2111
+ unsigned BestTripCount = 2 ;
2112
+
2113
+ // If exact trip count is known use that.
2114
+ if (unsigned SmallTC = SE->getSmallConstantTripCount (OuterLoop))
2115
+ BestTripCount = SmallTC;
2116
+ else if (LoopVectorizeWithBlockFrequency) {
2117
+ // Else use profile data if available.
2118
+ if (auto EstimatedTC = getLoopEstimatedTripCount (OuterLoop))
2119
+ BestTripCount = *EstimatedTC;
2120
+ }
2121
+
2122
+ InstructionCost NewMemCheckCost = MemCheckCost / BestTripCount;
2123
+
2124
+ // Let's ensure the cost is always at least 1.
2125
+ NewMemCheckCost = std::max (*NewMemCheckCost.getValue (),
2126
+ (InstructionCost::CostType)1 );
2127
+
2128
+ LLVM_DEBUG (dbgs ()
2129
+ << " We expect runtime memory checks to be hoisted "
2130
+ << " out of the outer loop. Cost reduced from "
2131
+ << MemCheckCost << " to " << NewMemCheckCost << ' \n ' );
2132
+
2133
+ MemCheckCost = NewMemCheckCost;
2134
+ }
2135
+ }
2136
+
2137
+ RTCheckCost += MemCheckCost;
2138
+ }
2139
+
2089
2140
if (SCEVCheckBlock || MemCheckBlock)
2090
2141
LLVM_DEBUG (dbgs () << " Total cost of runtime checks: " << RTCheckCost
2091
2142
<< " \n " );
@@ -2144,8 +2195,8 @@ class GeneratedRTChecks {
2144
2195
2145
2196
BranchInst::Create (LoopVectorPreHeader, SCEVCheckBlock);
2146
2197
// Create new preheader for vector loop.
2147
- if (auto *PL = LI-> getLoopFor (LoopVectorPreHeader) )
2148
- PL ->addBasicBlockToLoop (SCEVCheckBlock, *LI);
2198
+ if (OuterLoop )
2199
+ OuterLoop ->addBasicBlockToLoop (SCEVCheckBlock, *LI);
2149
2200
2150
2201
SCEVCheckBlock->getTerminator ()->eraseFromParent ();
2151
2202
SCEVCheckBlock->moveBefore (LoopVectorPreHeader);
@@ -2179,8 +2230,8 @@ class GeneratedRTChecks {
2179
2230
DT->changeImmediateDominator (LoopVectorPreHeader, MemCheckBlock);
2180
2231
MemCheckBlock->moveBefore (LoopVectorPreHeader);
2181
2232
2182
- if (auto *PL = LI-> getLoopFor (LoopVectorPreHeader) )
2183
- PL ->addBasicBlockToLoop (MemCheckBlock, *LI);
2233
+ if (OuterLoop )
2234
+ OuterLoop ->addBasicBlockToLoop (MemCheckBlock, *LI);
2184
2235
2185
2236
BranchInst &BI =
2186
2237
*BranchInst::Create (Bypass, LoopVectorPreHeader, MemRuntimeCheckCond);
0 commit comments