Skip to content

[GVN] Restrict equality propagation for pointers #82458

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions llvm/include/llvm/Analysis/Loads.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,17 @@ Value *findAvailablePtrLoadStore(const MemoryLocation &Loc, Type *AccessTy,
unsigned MaxInstsToScan, BatchAAResults *AA,
bool *IsLoadCSE, unsigned *NumScanedInst);

/// Returns true if a pointer value \p A can be replace with another pointer
/// value \B if they are deemed equal through some means (e.g. information from
/// Returns true if a pointer value \p From can be replaced with another pointer
/// value \To if they are deemed equal through some means (e.g. information from
/// conditions).
/// NOTE: the current implementations is incomplete and unsound. It does not
/// reject all invalid cases yet, but will be made stricter in the future. In
/// particular this means returning true means unknown if replacement is safe.
bool canReplacePointersIfEqual(Value *A, Value *B, const DataLayout &DL,
Instruction *CtxI);
/// NOTE: The current implementation allows replacement in Icmp and PtrToInt
/// instructions, as well as when we are replacing with a null pointer.
/// Additionally it also allows replacement of pointers when both pointers have
/// the same underlying object.
bool canReplacePointersIfEqual(const Value *From, const Value *To,
const DataLayout &DL);
bool canReplacePointersInUseIfEqual(const Use &U, const Value *To,
const DataLayout &DL);
}

#endif
23 changes: 20 additions & 3 deletions llvm/include/llvm/Transforms/Utils/Local.h
Original file line number Diff line number Diff line change
Expand Up @@ -433,12 +433,29 @@ unsigned replaceNonLocalUsesWith(Instruction *From, Value *To);

/// Replace each use of 'From' with 'To' if that use is dominated by
/// the given edge. Returns the number of replacements made.
unsigned replaceDominatedUsesWith(Value *From, Value *To, DominatorTree &DT,
unsigned replaceDominatedUsesWith(Value *From, Value *To, const DataLayout &DL,
DominatorTree &DT,
const BasicBlockEdge &Edge);
/// Replace each use of 'From' with 'To' if that use is dominated by
/// the end of the given BasicBlock. Returns the number of replacements made.
unsigned replaceDominatedUsesWith(Value *From, Value *To, DominatorTree &DT,
const BasicBlock *BB);
unsigned replaceDominatedUsesWith(Value *From, Value *To, const DataLayout &DL,
DominatorTree &DT, const BasicBlock *BB);
/// Replace each use of 'From' with 'To' if that use is dominated by
/// the given edge and the callback ShouldReplace returns true. Returns the
/// number of replacements made.
unsigned replaceDominatedUsesWithIf(
Value *From, Value *To, const DataLayout &DL, DominatorTree &DT,
const BasicBlockEdge &Edge,
function_ref<bool(const Use &U, const Value *To, const DataLayout &DL)>
ShouldReplace);
/// Replace each use of 'From' with 'To' if that use is dominated by
/// the end of the given BasicBlock and the callback ShouldReplace returns true.
/// Returns the number of replacements made.
unsigned replaceDominatedUsesWithIf(
Value *From, Value *To, const DataLayout &DL, DominatorTree &DT,
const BasicBlock *BB,
function_ref<bool(const Use &U, const Value *To, const DataLayout &DL)>
ShouldReplace);

/// Return true if this call calls a gc leaf function.
///
Expand Down
70 changes: 53 additions & 17 deletions llvm/lib/Analysis/Loads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -710,22 +710,58 @@ Value *llvm::FindAvailableLoadedValue(LoadInst *Load, BatchAAResults &AA,
return Available;
}

bool llvm::canReplacePointersIfEqual(Value *A, Value *B, const DataLayout &DL,
Instruction *CtxI) {
Type *Ty = A->getType();
assert(Ty == B->getType() && Ty->isPointerTy() &&
"values must have matching pointer types");

// NOTE: The checks in the function are incomplete and currently miss illegal
// cases! The current implementation is a starting point and the
// implementation should be made stricter over time.
if (auto *C = dyn_cast<Constant>(B)) {
// Do not allow replacing a pointer with a constant pointer, unless it is
// either null or at least one byte is dereferenceable.
APInt OneByte(DL.getPointerTypeSizeInBits(Ty), 1);
return C->isNullValue() ||
isDereferenceableAndAlignedPointer(B, Align(1), OneByte, DL, CtxI);
}
// Returns true if a use is either in an ICmp/PtrToInt or a Phi/Select that only
// feeds into them.
static bool isPointerUseReplacable(const Use &U, int MaxLookup = 6) {
if (MaxLookup == 0)
return false;

const User *User = U.getUser();
if (isa<ICmpInst>(User))
return true;
if (isa<PtrToIntInst>(User))
return true;
if (isa<PHINode, SelectInst>(User) &&
all_of(User->uses(), [&](const Use &Use) {
return isPointerUseReplacable(Use, MaxLookup - 1);
}))
return true;

return false;
}

// Returns true if `To` is a null pointer, constant dereferenceable pointer or
// both pointers have the same underlying objects.
static bool isPointerAlwaysReplaceable(const Value *From, const Value *To,
const DataLayout &DL) {
if (isa<ConstantPointerNull>(To))
return true;
if (isa<Constant>(To) &&
isDereferenceablePointer(To, Type::getInt8Ty(To->getContext()), DL))
return true;
if (getUnderlyingObject(From) == getUnderlyingObject(To))
return true;
return false;
}

bool llvm::canReplacePointersInUseIfEqual(const Use &U, const Value *To,
const DataLayout &DL) {
assert(U->getType() == To->getType() && "values must have matching types");
// Not a pointer, just return true.
if (!To->getType()->isPointerTy())
return true;

if (isPointerAlwaysReplaceable(&*U, To, DL))
return true;
return isPointerUseReplacable(U);
}

bool llvm::canReplacePointersIfEqual(const Value *From, const Value *To,
const DataLayout &DL) {
assert(From->getType() == To->getType() && "values must have matching types");
// Not a pointer, just return true.
if (!From->getType()->isPointerTy())
return true;

return true;
return isPointerAlwaysReplaceable(From, To, DL);
}
5 changes: 3 additions & 2 deletions llvm/lib/Transforms/Scalar/EarlyCSE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1228,8 +1228,9 @@ bool EarlyCSE::handleBranchCondition(Instruction *CondInst,
LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n");
} else {
// Replace all dominated uses with the known value.
if (unsigned Count = replaceDominatedUsesWith(Curr, TorF, DT,
BasicBlockEdge(Pred, BB))) {
if (unsigned Count = replaceDominatedUsesWith(
Curr, TorF, Curr->getModule()->getDataLayout(), DT,
BasicBlockEdge(Pred, BB))) {
NumCSECVP += Count;
MadeChanges = true;
}
Expand Down
32 changes: 21 additions & 11 deletions llvm/lib/Transforms/Scalar/GVN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/InstructionPrecedenceTracking.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/Loads.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/MemoryBuiltins.h"
#include "llvm/Analysis/MemoryDependenceAnalysis.h"
Expand Down Expand Up @@ -2419,6 +2420,10 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS,
if (isa<Constant>(LHS) || (isa<Argument>(LHS) && !isa<Constant>(RHS)))
std::swap(LHS, RHS);
assert((isa<Argument>(LHS) || isa<Instruction>(LHS)) && "Unexpected value!");
const DataLayout &DL =
isa<Argument>(LHS)
? cast<Argument>(LHS)->getParent()->getParent()->getDataLayout()
: cast<Instruction>(LHS)->getModule()->getDataLayout();

// If there is no obvious reason to prefer the left-hand side over the
// right-hand side, ensure the longest lived term is on the right-hand side,
Expand All @@ -2445,7 +2450,8 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS,
// using the leader table is about compiling faster, not optimizing better).
// The leader table only tracks basic blocks, not edges. Only add to if we
// have the simple case where the edge dominates the end.
if (RootDominatesEnd && !isa<Instruction>(RHS))
if (RootDominatesEnd && !isa<Instruction>(RHS) &&
canReplacePointersIfEqual(LHS, RHS, DL))
addToLeaderTable(LVN, RHS, Root.getEnd());

// Replace all occurrences of 'LHS' with 'RHS' everywhere in the scope. As
Expand All @@ -2454,14 +2460,18 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS,
if (!LHS->hasOneUse()) {
unsigned NumReplacements =
DominatesByEdge
? replaceDominatedUsesWith(LHS, RHS, *DT, Root)
: replaceDominatedUsesWith(LHS, RHS, *DT, Root.getStart());

Changed |= NumReplacements > 0;
NumGVNEqProp += NumReplacements;
// Cached information for anything that uses LHS will be invalid.
if (MD)
MD->invalidateCachedPointerInfo(LHS);
? replaceDominatedUsesWithIf(LHS, RHS, DL, *DT, Root,
canReplacePointersInUseIfEqual)
: replaceDominatedUsesWithIf(LHS, RHS, DL, *DT, Root.getStart(),
canReplacePointersInUseIfEqual);

if (NumReplacements > 0) {
Changed = true;
NumGVNEqProp += NumReplacements;
// Cached information for anything that uses LHS will be invalid.
if (MD)
MD->invalidateCachedPointerInfo(LHS);
}
}

// Now try to deduce additional equalities from this one. For example, if
Expand Down Expand Up @@ -2517,8 +2527,8 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS,
if (NotCmp && isa<Instruction>(NotCmp)) {
unsigned NumReplacements =
DominatesByEdge
? replaceDominatedUsesWith(NotCmp, NotVal, *DT, Root)
: replaceDominatedUsesWith(NotCmp, NotVal, *DT,
? replaceDominatedUsesWith(NotCmp, NotVal, DL, *DT, Root)
: replaceDominatedUsesWith(NotCmp, NotVal, DL, *DT,
Root.getStart());
Changed |= NumReplacements > 0;
NumGVNEqProp += NumReplacements;
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Transforms/Scalar/LoopSink.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ static bool sinkInstruction(
return UIToReplace->getParent() == N && !isa<PHINode>(UIToReplace);
});
// Replaces uses of I with IC in blocks dominated by N
replaceDominatedUsesWith(&I, IC, DT, N);
replaceDominatedUsesWith(&I, IC, I.getModule()->getDataLayout(), DT, N);
LLVM_DEBUG(dbgs() << "Sinking a clone of " << I << " To: " << N->getName()
<< '\n');
NumLoopSunkCloned++;
Expand Down
40 changes: 33 additions & 7 deletions llvm/lib/Transforms/Utils/Local.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3429,15 +3429,16 @@ void llvm::patchReplacementInstruction(Instruction *I, Value *Repl) {
combineMetadataForCSE(ReplInst, I, false);
}

template <typename RootType, typename DominatesFn>
template <typename RootType, typename ShouldReplaceFn>
static unsigned replaceDominatedUsesWith(Value *From, Value *To,
const DataLayout &DL,
const RootType &Root,
const DominatesFn &Dominates) {
const ShouldReplaceFn &ShouldReplace) {
assert(From->getType() == To->getType());

unsigned Count = 0;
for (Use &U : llvm::make_early_inc_range(From->uses())) {
if (!Dominates(Root, U))
if (!ShouldReplace(Root, U))
continue;
LLVM_DEBUG(dbgs() << "Replace dominated use of '";
From->printAsOperand(dbgs());
Expand All @@ -3464,23 +3465,48 @@ unsigned llvm::replaceNonLocalUsesWith(Instruction *From, Value *To) {
}

unsigned llvm::replaceDominatedUsesWith(Value *From, Value *To,
DominatorTree &DT,
const DataLayout &DL, DominatorTree &DT,
const BasicBlockEdge &Root) {
auto Dominates = [&DT](const BasicBlockEdge &Root, const Use &U) {
return DT.dominates(Root, U);
};
return ::replaceDominatedUsesWith(From, To, Root, Dominates);
return ::replaceDominatedUsesWith(From, To, DL, Root, Dominates);
}

unsigned llvm::replaceDominatedUsesWith(Value *From, Value *To,
DominatorTree &DT,
const DataLayout &DL, DominatorTree &DT,
const BasicBlock *BB) {
auto Dominates = [&DT](const BasicBlock *BB, const Use &U) {
return DT.dominates(BB, U);
};
return ::replaceDominatedUsesWith(From, To, BB, Dominates);
return ::replaceDominatedUsesWith(From, To, DL, BB, Dominates);
}

unsigned llvm::replaceDominatedUsesWithIf(
Value *From, Value *To, const DataLayout &DL, DominatorTree &DT,
const BasicBlockEdge &Root,
function_ref<bool(const Use &U, const Value *To, const DataLayout &DL)>
ShouldReplace) {
auto DominatesAndShouldReplace =
[ShouldReplace, To, &DT, &DL](const BasicBlockEdge &Root, const Use &U) {
return DT.dominates(Root, U) && ShouldReplace(U, To, DL);
};
return ::replaceDominatedUsesWith(From, To, DL, Root,
DominatesAndShouldReplace);
}

unsigned llvm::replaceDominatedUsesWithIf(
Value *From, Value *To, const DataLayout &DL, DominatorTree &DT,
const BasicBlock *BB,
function_ref<bool(const Use &U, const Value *To, const DataLayout &DL)>
ShouldReplace) {
auto DominatesAndShouldReplace = [ShouldReplace, To, &DT,
&DL](const BasicBlock *BB, const Use &U) {
return DT.dominates(BB, U) && ShouldReplace(U, To, DL);
};
return ::replaceDominatedUsesWith(From, To, DL, BB,
DominatesAndShouldReplace);
}
bool llvm::callsGCLeafFunction(const CallBase *Call,
const TargetLibraryInfo &TLI) {
// Check if the function is specifically marked as a gc leaf function.
Expand Down
Loading
Loading