Skip to content

Commit 45b5e62

Browse files
committed
!fixup isCanonical/getCanonical
1 parent 434fb0d commit 45b5e62

File tree

2 files changed

+114
-16
lines changed

2 files changed

+114
-16
lines changed

llvm/include/llvm/Analysis/ScalarEvolution.h

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,40 @@ extern bool VerifySCEV;
6767

6868
class SCEV;
6969

70-
class SCEVUse : public PointerIntPair<const SCEV *, 2> {
70+
class SCEVUse : public PointerIntPair<const SCEV *, 3> {
71+
bool computeIsCanonical() const;
72+
const SCEV *computeCanonical(ScalarEvolution &SE) const;
73+
7174
public:
7275
SCEVUse() : PointerIntPair(nullptr, 0) {}
7376
SCEVUse(const SCEV *S) : PointerIntPair(S, 0) {}
74-
SCEVUse(const SCEV *S, int Flags) : PointerIntPair(S, Flags) {}
77+
SCEVUse(const SCEV *S, int Flags) : PointerIntPair(S, Flags) {
78+
if (Flags > 0)
79+
setInt(Flags | 1);
80+
}
7581

7682
operator const SCEV *() const { return getPointer(); }
7783
const SCEV *operator->() const { return getPointer(); }
7884
const SCEV *operator->() { return getPointer(); }
7985

80-
void *getRawPointer() { return getOpaqueValue(); }
86+
void *getRawPointer() const { return getOpaqueValue(); }
87+
88+
bool isCanonical() const {
89+
assert(((getFlags() & 1) != 0 || computeIsCanonical()) &&
90+
"Canonical bit set incorrectly");
91+
return (getFlags() & 1) == 0;
92+
}
93+
94+
const SCEV *getCanonical(ScalarEvolution &SE) {
95+
if (isCanonical())
96+
return getPointer();
97+
return computeCanonical(SE);
98+
}
99+
100+
unsigned getFlags() const { return getInt(); }
81101

102+
bool operator==(const SCEVUse &RHS) const;
103+
bool operator==(const SCEV *RHS) const;
82104
/// Print out the internal representation of this scalar to the specified
83105
/// stream. This should really only be used for debugging purposes.
84106
void print(raw_ostream &OS) const;
@@ -127,7 +149,7 @@ template <> struct DenseMapInfo<SCEVUse> {
127149
}
128150

129151
static bool isEqual(const SCEVUse LHS, const SCEVUse RHS) {
130-
return LHS == RHS;
152+
return LHS.getRawPointer() == RHS.getRawPointer();
131153
}
132154
};
133155

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 88 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,59 @@ static cl::opt<bool> UseContextForNoWrapFlagInference(
253253
// SCEV class definitions
254254
//===----------------------------------------------------------------------===//
255255

256+
class SCEVDropFlags : public SCEVRewriteVisitor<SCEVDropFlags> {
257+
using Base = SCEVRewriteVisitor<SCEVDropFlags>;
258+
259+
public:
260+
SCEVDropFlags(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
261+
262+
static SCEVUse rewrite(SCEVUse Scev, ScalarEvolution &SE) {
263+
SCEVDropFlags Rewriter(SE);
264+
return Rewriter.visit(Scev);
265+
}
266+
267+
SCEVUse visitAddExpr(const SCEVAddExpr *Expr) {
268+
SmallVector<const SCEV *, 2> Operands;
269+
bool Changed = false;
270+
for (const auto Op : Expr->operands()) {
271+
Operands.push_back(visit(Op));
272+
Changed |= Op != Operands.back();
273+
}
274+
return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
275+
}
276+
277+
SCEVUse visitMulExpr(const SCEVMulExpr *Expr) {
278+
SmallVector<SCEVUse, 2> Operands;
279+
bool Changed = false;
280+
for (const auto Op : Expr->operands()) {
281+
Operands.push_back(visit(Op));
282+
Changed |= Op != Operands.back();
283+
}
284+
return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
285+
}
286+
};
287+
288+
const SCEV *SCEVUse::computeCanonical(ScalarEvolution &SE) const {
289+
return SCEVDropFlags::rewrite(*this, SE);
290+
}
291+
292+
bool SCEVUse::computeIsCanonical() const {
293+
if (!getRawPointer() ||
294+
DenseMapInfo<SCEVUse>::getEmptyKey().getRawPointer() == getRawPointer() ||
295+
DenseMapInfo<SCEVUse>::getTombstoneKey().getRawPointer() ==
296+
getRawPointer() ||
297+
isa<SCEVCouldNotCompute>(this))
298+
return true;
299+
return !SCEVExprContains(*this, [](SCEVUse U) { return U.getFlags() != 0; });
300+
}
301+
302+
bool SCEVUse::operator==(const SCEVUse &RHS) const {
303+
assert(isCanonical() && RHS.isCanonical());
304+
return getPointer() == RHS.getPointer();
305+
}
306+
307+
bool SCEVUse::operator==(const SCEV *RHS) const { return getPointer() == RHS; }
308+
256309
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
257310
LLVM_DUMP_METHOD void SCEVUse::dump() const {
258311
print(dbgs());
@@ -677,9 +730,10 @@ static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
677730
static std::optional<int>
678731
CompareSCEVComplexity(EquivalenceClasses<SCEVUse> &EqCacheSCEV,
679732
const LoopInfo *const LI, SCEVUse LHS, SCEVUse RHS,
680-
DominatorTree &DT, unsigned Depth = 0) {
733+
DominatorTree &DT, ScalarEvolution &SE,
734+
unsigned Depth = 0) {
681735
// Fast-path: SCEVs are uniqued so we can do a quick equality check.
682-
if (LHS == RHS)
736+
if (LHS.getCanonical(SE) == RHS.getCanonical(SE))
683737
return 0;
684738

685739
// Primarily, sort the SCEVs by their getSCEVType().
@@ -769,7 +823,7 @@ CompareSCEVComplexity(EquivalenceClasses<SCEVUse> &EqCacheSCEV,
769823
return (int)LNumOps - (int)RNumOps;
770824

771825
for (unsigned i = 0; i != LNumOps; ++i) {
772-
auto X = CompareSCEVComplexity(EqCacheSCEV, LI, LOps[i], ROps[i], DT,
826+
auto X = CompareSCEVComplexity(EqCacheSCEV, LI, LOps[i], ROps[i], DT, SE,
773827
Depth + 1);
774828
if (X != 0)
775829
return X;
@@ -794,14 +848,14 @@ CompareSCEVComplexity(EquivalenceClasses<SCEVUse> &EqCacheSCEV,
794848
/// this to depend on where the addresses of various SCEV objects happened to
795849
/// land in memory.
796850
static void GroupByComplexity(SmallVectorImpl<SCEVUse> &Ops, LoopInfo *LI,
797-
DominatorTree &DT) {
851+
DominatorTree &DT, ScalarEvolution &SE) {
798852
if (Ops.size() < 2) return; // Noop
799853

800854
EquivalenceClasses<SCEVUse> EqCacheSCEV;
801855

802856
// Whether LHS has provably less complexity than RHS.
803857
auto IsLessComplex = [&](SCEVUse LHS, SCEVUse RHS) {
804-
auto Complexity = CompareSCEVComplexity(EqCacheSCEV, LI, LHS, RHS, DT);
858+
auto Complexity = CompareSCEVComplexity(EqCacheSCEV, LI, LHS, RHS, DT, SE);
805859
return Complexity && *Complexity < 0;
806860
};
807861
if (Ops.size() == 2) {
@@ -882,7 +936,7 @@ constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT,
882936
if (Folded && IsAbsorber(Folded->getAPInt()))
883937
return Folded;
884938

885-
GroupByComplexity(Ops, &LI, DT);
939+
GroupByComplexity(Ops, &LI, DT, SE);
886940
if (Folded && !IsIdentity(Folded->getAPInt()))
887941
Ops.insert(Ops.begin(), Folded);
888942

@@ -2586,7 +2640,9 @@ SCEVUse ScalarEvolution::getAddExpr(SmallVectorImpl<SCEVUse> &Ops,
25862640
SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
25872641
if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
25882642
Add->setNoWrapFlags(ComputeFlags(Ops));
2589-
return S;
2643+
bool IsCanonical = all_of(Ops, [](SCEVUse U) { return U.getFlags() == 0; });
2644+
int UseFlags = IsCanonical ? 0 : 1;
2645+
return {S, UseFlags};
25902646
}
25912647

25922648
// Okay, check to see if the same value occurs in the operand list more than
@@ -2595,7 +2651,8 @@ SCEVUse ScalarEvolution::getAddExpr(SmallVectorImpl<SCEVUse> &Ops,
25952651
Type *Ty = Ops[0]->getType();
25962652
bool FoundMatch = false;
25972653
for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2598-
if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2654+
if (Ops[i].getCanonical(*this) ==
2655+
Ops[i + 1].getCanonical(*this)) { // X + Y + Y --> X + Y*2
25992656
// Scan ahead to count how many equal operands there are.
26002657
unsigned Count = 2;
26012658
while (i+Count != e && Ops[i+Count] == Ops[i])
@@ -2817,7 +2874,7 @@ SCEVUse ScalarEvolution::getAddExpr(SmallVectorImpl<SCEVUse> &Ops,
28172874
if (isa<SCEVConstant>(MulOpSCEV))
28182875
continue;
28192876
for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2820-
if (MulOpSCEV == Ops[AddOp]) {
2877+
if (MulOpSCEV.getCanonical(*this) == Ops[AddOp].getCanonical(*this)) {
28212878
// Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
28222879
SCEVUse InnerMul = Mul->getOperand(MulOp == 0);
28232880
if (Mul->getNumOperands() != 2) {
@@ -3018,7 +3075,9 @@ SCEVUse ScalarEvolution::getOrCreateAddExpr(ArrayRef<SCEVUse> Ops,
30183075
registerUser(S, Ops);
30193076
}
30203077
S->setNoWrapFlags(Flags);
3021-
return S;
3078+
bool IsCanonical = all_of(Ops, [](SCEVUse U) { return U.getFlags() == 0; });
3079+
int UseFlags = IsCanonical ? 0 : 1;
3080+
return {S, UseFlags};
30223081
}
30233082

30243083
SCEVUse ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<SCEVUse> Ops,
@@ -3063,7 +3122,9 @@ SCEVUse ScalarEvolution::getOrCreateMulExpr(ArrayRef<SCEVUse> Ops,
30633122
registerUser(S, Ops);
30643123
}
30653124
S->setNoWrapFlags(Flags);
3066-
return S;
3125+
bool IsCanonical = all_of(Ops, [](SCEVUse U) { return U.getFlags() == 0; });
3126+
int UseFlags = IsCanonical ? 0 : 1;
3127+
return {S, UseFlags};
30673128
}
30683129

30693130
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
@@ -3165,7 +3226,9 @@ SCEVUse ScalarEvolution::getMulExpr(SmallVectorImpl<SCEVUse> &Ops,
31653226
SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
31663227
if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
31673228
Mul->setNoWrapFlags(ComputeFlags(Ops));
3168-
return S;
3229+
bool IsCanonical = all_of(Ops, [](SCEVUse U) { return U.getFlags() == 0; });
3230+
int UseFlags = IsCanonical ? 0 : 1;
3231+
return {S, UseFlags};
31693232
}
31703233

31713234
if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
@@ -13647,6 +13710,19 @@ ScalarEvolution::~ScalarEvolution() {
1364713710
HasRecMap.clear();
1364813711
BackedgeTakenCounts.clear();
1364913712
PredicatedBackedgeTakenCounts.clear();
13713+
UnsignedRanges.clear();
13714+
SignedRanges.clear();
13715+
13716+
BECountUsers.clear();
13717+
SCEVUsers.clear();
13718+
FoldCache.clear();
13719+
FoldCacheUser.clear();
13720+
ValuesAtScopes.clear();
13721+
ValuesAtScopesUsers.clear();
13722+
LoopDispositions.clear();
13723+
13724+
BlockDispositions.clear();
13725+
ConstantMultipleCache.clear();
1365013726

1365113727
assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
1365213728
assert(PendingPhiRanges.empty() && "getRangeRef garbage");

0 commit comments

Comments
 (0)