Skip to content

Commit 733b8b2

Browse files
committed
[LAA] Simplify identification of speculatable strides [nfc]
Mostly just avoiding the need to keep both Value and SCEVs flowing through with consistent handling. We can do everything in terms of SCEV - aside from the profitability heuristics which are now isolated in one spot.
1 parent 8165792 commit 733b8b2

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

llvm/lib/Analysis/LoopAccessAnalysis.cpp

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2610,7 +2610,7 @@ static Value *getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty) {
26102610

26112611
/// Get the stride of a pointer access in a loop. Looks for symbolic
26122612
/// strides "a[i*stride]". Returns the symbolic stride, or null otherwise.
2613-
static Value *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp) {
2613+
static const SCEV *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp) {
26142614
auto *PtrTy = dyn_cast<PointerType>(Ptr->getType());
26152615
if (!PtrTy || PtrTy->isAggregateType())
26162616
return nullptr;
@@ -2664,28 +2664,27 @@ static Value *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp) {
26642664
}
26652665
}
26662666

2667-
// Strip off casts.
2668-
Type *StripedOffRecurrenceCast = nullptr;
2669-
if (const SCEVIntegralCastExpr *C = dyn_cast<SCEVIntegralCastExpr>(V)) {
2670-
StripedOffRecurrenceCast = C->getType();
2671-
V = C->getOperand();
2672-
}
2667+
// Note that the restriction after this loop invariant check are only
2668+
// profitability restrictions.
2669+
if (!SE->isLoopInvariant(V, Lp))
2670+
return nullptr;
26732671

26742672
// Look for the loop invariant symbolic value.
26752673
const SCEVUnknown *U = dyn_cast<SCEVUnknown>(V);
2676-
if (!U)
2677-
return nullptr;
2674+
if (!U) {
2675+
const auto *C = dyn_cast<SCEVIntegralCastExpr>(V);
2676+
if (!C)
2677+
return nullptr;
2678+
U = dyn_cast<SCEVUnknown>(C->getOperand());
2679+
if (!U)
2680+
return nullptr;
26782681

2679-
Value *Stride = U->getValue();
2680-
if (!Lp->isLoopInvariant(Stride))
2681-
return nullptr;
2682-
2683-
// If we have stripped off the recurrence cast we have to make sure that we
2684-
// return the value that is used in this loop so that we can replace it later.
2685-
if (StripedOffRecurrenceCast)
2686-
Stride = getUniqueCastUse(Stride, Lp, StripedOffRecurrenceCast);
2682+
// Match legacy behavior - this is not needed for correctness
2683+
if (!getUniqueCastUse(U->getValue(), Lp, V->getType()))
2684+
return nullptr;
2685+
}
26872686

2688-
return Stride;
2687+
return V;
26892688
}
26902689

26912690
void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
@@ -2699,13 +2698,13 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
26992698
// computation of an interesting IV - but we chose not to as we
27002699
// don't have a cost model here, and broadening the scope exposes
27012700
// far too many unprofitable cases.
2702-
Value *Stride = getStrideFromPointer(Ptr, PSE->getSE(), TheLoop);
2703-
if (!Stride)
2701+
const SCEV *StrideExpr = getStrideFromPointer(Ptr, PSE->getSE(), TheLoop);
2702+
if (!StrideExpr)
27042703
return;
27052704

27062705
LLVM_DEBUG(dbgs() << "LAA: Found a strided access that is a candidate for "
27072706
"versioning:");
2708-
LLVM_DEBUG(dbgs() << " Ptr: " << *Ptr << " Stride: " << *Stride << "\n");
2707+
LLVM_DEBUG(dbgs() << " Ptr: " << *Ptr << " Stride: " << *StrideExpr << "\n");
27092708

27102709
if (!SpeculateUnitStride) {
27112710
LLVM_DEBUG(dbgs() << " Chose not to due to -laa-speculate-unit-stride\n");
@@ -2725,7 +2724,6 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
27252724
// of various possible stride specializations, considering the alternatives
27262725
// of using gather/scatters (if available).
27272726

2728-
const SCEV *StrideExpr = PSE->getSCEV(Stride);
27292727
const SCEV *BETakenCount = PSE->getBackedgeTakenCount();
27302728

27312729
// Match the types so we can compare the stride and the BETakenCount.
@@ -2756,8 +2754,10 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
27562754

27572755
// Strip back off the integer cast, and check that our result is a
27582756
// SCEVUnknown as we expect.
2759-
Value *StrideVal = stripIntegerCast(Stride);
2760-
SymbolicStrides[Ptr] = cast<SCEVUnknown>(PSE->getSCEV(StrideVal));
2757+
const SCEV *StrideBase = StrideExpr;
2758+
if (const auto *C = dyn_cast<SCEVIntegralCastExpr>(StrideBase))
2759+
StrideBase = C->getOperand();
2760+
SymbolicStrides[Ptr] = cast<SCEVUnknown>(StrideBase);
27612761
}
27622762

27632763
LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE,

0 commit comments

Comments
 (0)