@@ -253,6 +253,59 @@ static cl::opt<bool> UseContextForNoWrapFlagInference(
253
253
// SCEV class definitions
254
254
//===----------------------------------------------------------------------===//
255
255
256
+ class SCEVDropFlags : public SCEVRewriteVisitor<SCEVDropFlags> {
257
+ using Base = SCEVRewriteVisitor<SCEVDropFlags>;
258
+
259
+ public:
260
+ SCEVDropFlags(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
261
+
262
+ static SCEVUse rewrite(SCEVUse Scev, ScalarEvolution &SE) {
263
+ SCEVDropFlags Rewriter(SE);
264
+ return Rewriter.visit(Scev);
265
+ }
266
+
267
+ SCEVUse visitAddExpr(const SCEVAddExpr *Expr) {
268
+ SmallVector<const SCEV *, 2> Operands;
269
+ bool Changed = false;
270
+ for (const auto Op : Expr->operands()) {
271
+ Operands.push_back(visit(Op));
272
+ Changed |= Op != Operands.back();
273
+ }
274
+ return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
275
+ }
276
+
277
+ SCEVUse visitMulExpr(const SCEVMulExpr *Expr) {
278
+ SmallVector<SCEVUse, 2> Operands;
279
+ bool Changed = false;
280
+ for (const auto Op : Expr->operands()) {
281
+ Operands.push_back(visit(Op));
282
+ Changed |= Op != Operands.back();
283
+ }
284
+ return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
285
+ }
286
+ };
287
+
288
+ const SCEV *SCEVUse::computeCanonical(ScalarEvolution &SE) const {
289
+ return SCEVDropFlags::rewrite(*this, SE);
290
+ }
291
+
292
+ bool SCEVUse::computeIsCanonical() const {
293
+ if (!getRawPointer() ||
294
+ DenseMapInfo<SCEVUse>::getEmptyKey().getRawPointer() == getRawPointer() ||
295
+ DenseMapInfo<SCEVUse>::getTombstoneKey().getRawPointer() ==
296
+ getRawPointer() ||
297
+ isa<SCEVCouldNotCompute>(this))
298
+ return true;
299
+ return !SCEVExprContains(*this, [](SCEVUse U) { return U.getFlags() != 0; });
300
+ }
301
+
302
+ bool SCEVUse::operator==(const SCEVUse &RHS) const {
303
+ assert(isCanonical() && RHS.isCanonical());
304
+ return getPointer() == RHS.getPointer();
305
+ }
306
+
307
+ bool SCEVUse::operator==(const SCEV *RHS) const { return getPointer() == RHS; }
308
+
256
309
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
257
310
LLVM_DUMP_METHOD void SCEVUse::dump() const {
258
311
print(dbgs());
@@ -677,9 +730,10 @@ static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
677
730
static std::optional<int>
678
731
CompareSCEVComplexity(EquivalenceClasses<SCEVUse> &EqCacheSCEV,
679
732
const LoopInfo *const LI, SCEVUse LHS, SCEVUse RHS,
680
- DominatorTree &DT, unsigned Depth = 0) {
733
+ DominatorTree &DT, ScalarEvolution &SE,
734
+ unsigned Depth = 0) {
681
735
// Fast-path: SCEVs are uniqued so we can do a quick equality check.
682
- if (LHS == RHS)
736
+ if (LHS.getCanonical(SE) == RHS.getCanonical(SE) )
683
737
return 0;
684
738
685
739
// Primarily, sort the SCEVs by their getSCEVType().
@@ -769,7 +823,7 @@ CompareSCEVComplexity(EquivalenceClasses<SCEVUse> &EqCacheSCEV,
769
823
return (int)LNumOps - (int)RNumOps;
770
824
771
825
for (unsigned i = 0; i != LNumOps; ++i) {
772
- auto X = CompareSCEVComplexity(EqCacheSCEV, LI, LOps[i], ROps[i], DT,
826
+ auto X = CompareSCEVComplexity(EqCacheSCEV, LI, LOps[i], ROps[i], DT, SE,
773
827
Depth + 1);
774
828
if (X != 0)
775
829
return X;
@@ -794,14 +848,14 @@ CompareSCEVComplexity(EquivalenceClasses<SCEVUse> &EqCacheSCEV,
794
848
/// this to depend on where the addresses of various SCEV objects happened to
795
849
/// land in memory.
796
850
static void GroupByComplexity(SmallVectorImpl<SCEVUse> &Ops, LoopInfo *LI,
797
- DominatorTree &DT) {
851
+ DominatorTree &DT, ScalarEvolution &SE ) {
798
852
if (Ops.size() < 2) return; // Noop
799
853
800
854
EquivalenceClasses<SCEVUse> EqCacheSCEV;
801
855
802
856
// Whether LHS has provably less complexity than RHS.
803
857
auto IsLessComplex = [&](SCEVUse LHS, SCEVUse RHS) {
804
- auto Complexity = CompareSCEVComplexity(EqCacheSCEV, LI, LHS, RHS, DT);
858
+ auto Complexity = CompareSCEVComplexity(EqCacheSCEV, LI, LHS, RHS, DT, SE );
805
859
return Complexity && *Complexity < 0;
806
860
};
807
861
if (Ops.size() == 2) {
@@ -882,7 +936,7 @@ constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT,
882
936
if (Folded && IsAbsorber(Folded->getAPInt()))
883
937
return Folded;
884
938
885
- GroupByComplexity(Ops, &LI, DT);
939
+ GroupByComplexity(Ops, &LI, DT, SE );
886
940
if (Folded && !IsIdentity(Folded->getAPInt()))
887
941
Ops.insert(Ops.begin(), Folded);
888
942
@@ -2586,7 +2640,9 @@ SCEVUse ScalarEvolution::getAddExpr(SmallVectorImpl<SCEVUse> &Ops,
2586
2640
SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2587
2641
if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2588
2642
Add->setNoWrapFlags(ComputeFlags(Ops));
2589
- return S;
2643
+ bool IsCanonical = all_of(Ops, [](SCEVUse U) { return U.getFlags() == 0; });
2644
+ int UseFlags = IsCanonical ? 0 : 1;
2645
+ return {S, UseFlags};
2590
2646
}
2591
2647
2592
2648
// Okay, check to see if the same value occurs in the operand list more than
@@ -2595,7 +2651,8 @@ SCEVUse ScalarEvolution::getAddExpr(SmallVectorImpl<SCEVUse> &Ops,
2595
2651
Type *Ty = Ops[0]->getType();
2596
2652
bool FoundMatch = false;
2597
2653
for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2598
- if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2654
+ if (Ops[i].getCanonical(*this) ==
2655
+ Ops[i + 1].getCanonical(*this)) { // X + Y + Y --> X + Y*2
2599
2656
// Scan ahead to count how many equal operands there are.
2600
2657
unsigned Count = 2;
2601
2658
while (i+Count != e && Ops[i+Count] == Ops[i])
@@ -2817,7 +2874,7 @@ SCEVUse ScalarEvolution::getAddExpr(SmallVectorImpl<SCEVUse> &Ops,
2817
2874
if (isa<SCEVConstant>(MulOpSCEV))
2818
2875
continue;
2819
2876
for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2820
- if (MulOpSCEV == Ops[AddOp]) {
2877
+ if (MulOpSCEV.getCanonical(*this) == Ops[AddOp].getCanonical(*this) ) {
2821
2878
// Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2822
2879
SCEVUse InnerMul = Mul->getOperand(MulOp == 0);
2823
2880
if (Mul->getNumOperands() != 2) {
@@ -3018,7 +3075,9 @@ SCEVUse ScalarEvolution::getOrCreateAddExpr(ArrayRef<SCEVUse> Ops,
3018
3075
registerUser(S, Ops);
3019
3076
}
3020
3077
S->setNoWrapFlags(Flags);
3021
- return S;
3078
+ bool IsCanonical = all_of(Ops, [](SCEVUse U) { return U.getFlags() == 0; });
3079
+ int UseFlags = IsCanonical ? 0 : 1;
3080
+ return {S, UseFlags};
3022
3081
}
3023
3082
3024
3083
SCEVUse ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<SCEVUse> Ops,
@@ -3063,7 +3122,9 @@ SCEVUse ScalarEvolution::getOrCreateMulExpr(ArrayRef<SCEVUse> Ops,
3063
3122
registerUser(S, Ops);
3064
3123
}
3065
3124
S->setNoWrapFlags(Flags);
3066
- return S;
3125
+ bool IsCanonical = all_of(Ops, [](SCEVUse U) { return U.getFlags() == 0; });
3126
+ int UseFlags = IsCanonical ? 0 : 1;
3127
+ return {S, UseFlags};
3067
3128
}
3068
3129
3069
3130
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
@@ -3165,7 +3226,9 @@ SCEVUse ScalarEvolution::getMulExpr(SmallVectorImpl<SCEVUse> &Ops,
3165
3226
SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3166
3227
if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3167
3228
Mul->setNoWrapFlags(ComputeFlags(Ops));
3168
- return S;
3229
+ bool IsCanonical = all_of(Ops, [](SCEVUse U) { return U.getFlags() == 0; });
3230
+ int UseFlags = IsCanonical ? 0 : 1;
3231
+ return {S, UseFlags};
3169
3232
}
3170
3233
3171
3234
if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
@@ -13647,6 +13710,19 @@ ScalarEvolution::~ScalarEvolution() {
13647
13710
HasRecMap.clear();
13648
13711
BackedgeTakenCounts.clear();
13649
13712
PredicatedBackedgeTakenCounts.clear();
13713
+ UnsignedRanges.clear();
13714
+ SignedRanges.clear();
13715
+
13716
+ BECountUsers.clear();
13717
+ SCEVUsers.clear();
13718
+ FoldCache.clear();
13719
+ FoldCacheUser.clear();
13720
+ ValuesAtScopes.clear();
13721
+ ValuesAtScopesUsers.clear();
13722
+ LoopDispositions.clear();
13723
+
13724
+ BlockDispositions.clear();
13725
+ ConstantMultipleCache.clear();
13650
13726
13651
13727
assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
13652
13728
assert(PendingPhiRanges.empty() && "getRangeRef garbage");
0 commit comments