@@ -2610,7 +2610,7 @@ static Value *getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty) {
2610
2610
2611
2611
// / Get the stride of a pointer access in a loop. Looks for symbolic
2612
2612
// / strides "a[i*stride]". Returns the symbolic stride, or null otherwise.
2613
- static Value *getStrideFromPointer (Value *Ptr , ScalarEvolution *SE, Loop *Lp) {
2613
+ static const SCEV *getStrideFromPointer (Value *Ptr , ScalarEvolution *SE, Loop *Lp) {
2614
2614
auto *PtrTy = dyn_cast<PointerType>(Ptr ->getType ());
2615
2615
if (!PtrTy || PtrTy->isAggregateType ())
2616
2616
return nullptr ;
@@ -2664,28 +2664,27 @@ static Value *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp) {
2664
2664
}
2665
2665
}
2666
2666
2667
- // Strip off casts.
2668
- Type *StripedOffRecurrenceCast = nullptr ;
2669
- if (const SCEVIntegralCastExpr *C = dyn_cast<SCEVIntegralCastExpr>(V)) {
2670
- StripedOffRecurrenceCast = C->getType ();
2671
- V = C->getOperand ();
2672
- }
2667
+ // Note that the restriction after this loop invariant check are only
2668
+ // profitability restrictions.
2669
+ if (!SE->isLoopInvariant (V, Lp))
2670
+ return nullptr ;
2673
2671
2674
2672
// Look for the loop invariant symbolic value.
2675
2673
const SCEVUnknown *U = dyn_cast<SCEVUnknown>(V);
2676
- if (!U)
2677
- return nullptr ;
2674
+ if (!U) {
2675
+ const auto *C = dyn_cast<SCEVIntegralCastExpr>(V);
2676
+ if (!C)
2677
+ return nullptr ;
2678
+ U = dyn_cast<SCEVUnknown>(C->getOperand ());
2679
+ if (!U)
2680
+ return nullptr ;
2678
2681
2679
- Value *Stride = U->getValue ();
2680
- if (!Lp->isLoopInvariant (Stride))
2681
- return nullptr ;
2682
-
2683
- // If we have stripped off the recurrence cast we have to make sure that we
2684
- // return the value that is used in this loop so that we can replace it later.
2685
- if (StripedOffRecurrenceCast)
2686
- Stride = getUniqueCastUse (Stride, Lp, StripedOffRecurrenceCast);
2682
+ // Match legacy behavior - this is not needed for correctness
2683
+ if (!getUniqueCastUse (U->getValue (), Lp, V->getType ()))
2684
+ return nullptr ;
2685
+ }
2687
2686
2688
- return Stride ;
2687
+ return V ;
2689
2688
}
2690
2689
2691
2690
void LoopAccessInfo::collectStridedAccess (Value *MemAccess) {
@@ -2699,13 +2698,13 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
2699
2698
// computation of an interesting IV - but we chose not to as we
2700
2699
// don't have a cost model here, and broadening the scope exposes
2701
2700
// far too many unprofitable cases.
2702
- Value *Stride = getStrideFromPointer (Ptr , PSE->getSE (), TheLoop);
2703
- if (!Stride )
2701
+ const SCEV *StrideExpr = getStrideFromPointer (Ptr , PSE->getSE (), TheLoop);
2702
+ if (!StrideExpr )
2704
2703
return ;
2705
2704
2706
2705
LLVM_DEBUG (dbgs () << " LAA: Found a strided access that is a candidate for "
2707
2706
" versioning:" );
2708
- LLVM_DEBUG (dbgs () << " Ptr: " << *Ptr << " Stride: " << *Stride << " \n " );
2707
+ LLVM_DEBUG (dbgs () << " Ptr: " << *Ptr << " Stride: " << *StrideExpr << " \n " );
2709
2708
2710
2709
if (!SpeculateUnitStride) {
2711
2710
LLVM_DEBUG (dbgs () << " Chose not to due to -laa-speculate-unit-stride\n " );
@@ -2725,7 +2724,6 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
2725
2724
// of various possible stride specializations, considering the alternatives
2726
2725
// of using gather/scatters (if available).
2727
2726
2728
- const SCEV *StrideExpr = PSE->getSCEV (Stride);
2729
2727
const SCEV *BETakenCount = PSE->getBackedgeTakenCount ();
2730
2728
2731
2729
// Match the types so we can compare the stride and the BETakenCount.
@@ -2756,8 +2754,10 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
2756
2754
2757
2755
// Strip back off the integer cast, and check that our result is a
2758
2756
// SCEVUnknown as we expect.
2759
- Value *StrideVal = stripIntegerCast (Stride);
2760
- SymbolicStrides[Ptr ] = cast<SCEVUnknown>(PSE->getSCEV (StrideVal));
2757
+ const SCEV *StrideBase = StrideExpr;
2758
+ if (const auto *C = dyn_cast<SCEVIntegralCastExpr>(StrideBase))
2759
+ StrideBase = C->getOperand ();
2760
+ SymbolicStrides[Ptr ] = cast<SCEVUnknown>(StrideBase);
2761
2761
}
2762
2762
2763
2763
LoopAccessInfo::LoopAccessInfo (Loop *L, ScalarEvolution *SE,
0 commit comments