Skip to content

[SROA] Optimize reloaded values in allocas that escape into readonly nocapture calls. #116645

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions llvm/include/llvm/Analysis/PtrUseVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class PtrUseVisitorBase {
/// Is the pointer escaped at some point?
bool isEscaped() const { return EscapedInfo != nullptr; }

/// Is the pointer escaped into a read-only nocapture call at some point?
bool isEscapedReadOnly() const { return EscapedReadOnly != nullptr; }

/// Get the instruction causing the visit to abort.
/// \returns a pointer to the instruction causing the abort if one is
/// available; otherwise returns null.
Expand All @@ -74,6 +77,10 @@ class PtrUseVisitorBase {
/// is available; otherwise returns null.
Instruction *getEscapingInst() const { return EscapedInfo; }

/// Get the instruction causing the pointer to escape which is a read-only
/// nocapture call.
Instruction *getEscapedReadOnlyInst() const { return EscapedReadOnly; }

/// Mark the visit as aborted. Intended for use in a void return.
/// \param I The instruction which caused the visit to abort, if available.
void setAborted(Instruction *I) {
Expand All @@ -88,6 +95,12 @@ class PtrUseVisitorBase {
EscapedInfo = I;
}

/// Mark the pointer as escaped into a readonly-nocapture call.
void setEscapedReadOnly(Instruction *I) {
assert(I && "Expected a valid pointer in setEscapedReadOnly");
EscapedReadOnly = I;
}

/// Mark the pointer as escaped, and the visit as aborted. Intended
/// for use in a void return.
/// \param I The instruction which both escapes the pointer and aborts the
Expand All @@ -100,6 +113,7 @@ class PtrUseVisitorBase {
private:
Instruction *AbortedInfo = nullptr;
Instruction *EscapedInfo = nullptr;
Instruction *EscapedReadOnly = nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm concerned that, as implemented, this may affect other users of PtrUseVisitor, which only check isEscaped() and don't know that they need to check isEscapedReadOnly() as well. A read-only escape should probably still report as an escape?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking more closely, I see that the code calling setEscapedReadOnly is part of SROA only rather than the generic implementation, so this is not an issue.

};

protected:
Expand Down
7 changes: 7 additions & 0 deletions llvm/include/llvm/Transforms/Utils/SSAUpdater.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,13 @@ class LoadAndStorePromoter {
/// Return false if a sub-class wants to keep one of the loads/stores
/// after the SSA construction.
virtual bool shouldDelete(Instruction *I) const { return true; }

/// Return the value to use for the point in the code that the alloca is
/// positioned. This will only be used if an Alloca is included in Insts,
/// otherwise the value of a uninitialized load will be assumed to be poison.
virtual Value *getValueToUseForAlloca(Instruction *AI) const {
return nullptr;
}
};

} // end namespace llvm
Expand Down
108 changes: 107 additions & 1 deletion llvm/lib/Transforms/Scalar/SROA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/Loads.h"
#include "llvm/Analysis/PtrUseVisitor.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Config/llvm-config.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constant.h"
Expand Down Expand Up @@ -83,6 +84,7 @@
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/PromoteMemToReg.h"
#include "llvm/Transforms/Utils/SSAUpdater.h"
#include <algorithm>
#include <cassert>
#include <cstddef>
Expand Down Expand Up @@ -246,6 +248,7 @@ class SROA {
bool presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS);
AllocaInst *rewritePartition(AllocaInst &AI, AllocaSlices &AS, Partition &P);
bool splitAlloca(AllocaInst &AI, AllocaSlices &AS);
bool propagateStoredValuesToLoads(AllocaInst &AI, AllocaSlices &AS);
std::pair<bool /*Changed*/, bool /*CFGChanged*/> runOnAlloca(AllocaInst &AI);
void clobberUse(Use &U);
bool deleteDeadInstructions(SmallPtrSetImpl<AllocaInst *> &DeletedAllocas);
Expand Down Expand Up @@ -598,6 +601,7 @@ class AllocaSlices {
/// If this is true, the slices are never fully built and should be
/// ignored.
bool isEscaped() const { return PointerEscapingInstr; }
bool isEscapedReadOnly() const { return PointerEscapingInstrReadOnly; }

/// Support for iterating over the slices.
/// @{
Expand Down Expand Up @@ -680,6 +684,7 @@ class AllocaSlices {
/// store a pointer to that here and abort trying to form slices of the
/// alloca. This will be null if the alloca slices are analyzed successfully.
Instruction *PointerEscapingInstr;
Instruction *PointerEscapingInstrReadOnly;

/// The slices of the alloca.
///
Expand Down Expand Up @@ -1390,14 +1395,26 @@ class AllocaSlices::SliceBuilder : public PtrUseVisitor<SliceBuilder> {

/// Disable SROA entirely if there are unhandled users of the alloca.
void visitInstruction(Instruction &I) { PI.setAborted(&I); }

void visitCallBase(CallBase &CB) {
// If the call operand is NoCapture ReadOnly, then we mark it as
// EscapedReadOnly.
if (CB.doesNotCapture(U->getOperandNo()) &&
CB.onlyReadsMemory(U->getOperandNo())) {
PI.setEscapedReadOnly(&CB);
return;
}

Base::visitCallBase(CB);
}
};

AllocaSlices::AllocaSlices(const DataLayout &DL, AllocaInst &AI)
:
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
AI(AI),
#endif
PointerEscapingInstr(nullptr) {
PointerEscapingInstr(nullptr), PointerEscapingInstrReadOnly(nullptr) {
SliceBuilder PB(DL, AI, *this);
SliceBuilder::PtrInfo PtrI = PB.visitPtr(AI);
if (PtrI.isEscaped() || PtrI.isAborted()) {
Expand All @@ -1408,6 +1425,7 @@ AllocaSlices::AllocaSlices(const DataLayout &DL, AllocaInst &AI)
assert(PointerEscapingInstr && "Did not track a bad instruction");
return;
}
PointerEscapingInstrReadOnly = PtrI.getEscapedReadOnlyInst();

llvm::erase_if(Slices, [](const Slice &S) { return S.isDead(); });

Expand Down Expand Up @@ -1445,6 +1463,9 @@ void AllocaSlices::print(raw_ostream &OS) const {
return;
}

if (PointerEscapingInstrReadOnly)
OS << "Escapes into ReadOnly: " << *PointerEscapingInstrReadOnly << "\n";

OS << "Slices of alloca: " << AI << "\n";
for (const_iterator I = begin(), E = end(); I != E; ++I)
print(OS, I);
Expand Down Expand Up @@ -5454,6 +5475,86 @@ void SROA::clobberUse(Use &U) {
}
}

/// A basic LoadAndStorePromoter that does not remove store nodes.
class BasicLoadAndStorePromoter : public LoadAndStorePromoter {
public:
BasicLoadAndStorePromoter(ArrayRef<const Instruction *> Insts, SSAUpdater &S,
Type *ZeroType)
: LoadAndStorePromoter(Insts, S), ZeroType(ZeroType) {}
bool shouldDelete(Instruction *I) const override {
return !isa<StoreInst>(I) && !isa<AllocaInst>(I);
}

Value *getValueToUseForAlloca(Instruction *I) const override {
return UndefValue::get(ZeroType);
}

private:
Type *ZeroType;
};

bool SROA::propagateStoredValuesToLoads(AllocaInst &AI, AllocaSlices &AS) {
// Look through each "partition", looking for slices with the same start/end
// that do not overlap with any before them. The slices are sorted by
// increasing beginOffset. We don't use AS.partitions(), as it will use a more
// sophisticated algorithm that takes splittable slices into account.
auto PartitionBegin = AS.begin();
auto PartitionEnd = PartitionBegin;
uint64_t BeginOffset = PartitionBegin->beginOffset();
uint64_t EndOffset = PartitionBegin->endOffset();
while (PartitionBegin != AS.end()) {
bool AllSameAndValid = true;
SmallVector<Instruction *> Insts;
Type *PartitionType = nullptr;
while (PartitionEnd != AS.end() &&
(PartitionEnd->beginOffset() < EndOffset ||
PartitionEnd->endOffset() <= EndOffset)) {
if (AllSameAndValid) {
AllSameAndValid &= PartitionEnd->beginOffset() == BeginOffset &&
PartitionEnd->endOffset() == EndOffset;
Instruction *User =
cast<Instruction>(PartitionEnd->getUse()->getUser());
if (auto *LI = dyn_cast<LoadInst>(User)) {
Type *UserTy = LI->getType();
// LoadAndStorePromoter requires all the types to be the same.
if (!LI->isSimple() || (PartitionType && UserTy != PartitionType))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should do the same for non-simple store as well (see https://llvm.godbolt.org/z/8GdhbjnxK for normal SROA skipping it).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering if it was valid to forward out of a volatile store into a load, but early cse seems happy enough to do it: https://llvm.godbolt.org/z/PzKzGPbE8. I have removed it in the latest patch.

AllSameAndValid = false;
PartitionType = UserTy;
Insts.push_back(User);
} else if (auto *SI = dyn_cast<StoreInst>(User)) {
Type *UserTy = SI->getValueOperand()->getType();
if (!SI->isSimple() || PartitionType && UserTy != PartitionType)
AllSameAndValid = false;
PartitionType = UserTy;
Insts.push_back(User);
} else if (!isAssumeLikeIntrinsic(User)) {
AllSameAndValid = false;
}
}
EndOffset = std::max(EndOffset, PartitionEnd->endOffset());
++PartitionEnd;
}

// So long as all the slices start and end offsets matched, update loads to
// the values stored in the partition.
if (AllSameAndValid && !Insts.empty()) {
LLVM_DEBUG(dbgs() << "Propagate values on slice [" << BeginOffset << ", "
<< EndOffset << ")\n");
SmallVector<PHINode *, 4> NewPHIs;
SSAUpdater SSA(&NewPHIs);
Insts.push_back(&AI);
BasicLoadAndStorePromoter Promoter(Insts, SSA, PartitionType);
Promoter.run(Insts);
}

// Step on to the next partition.
PartitionBegin = PartitionEnd;
BeginOffset = PartitionBegin->beginOffset();
EndOffset = PartitionBegin->endOffset();
}
return true;
}

/// Analyze an alloca for SROA.
///
/// This analyzes the alloca to ensure we can reason about it, builds
Expand Down Expand Up @@ -5494,6 +5595,11 @@ SROA::runOnAlloca(AllocaInst &AI) {
if (AS.isEscaped())
return {Changed, CFGChanged};

if (AS.isEscapedReadOnly()) {
Changed |= propagateStoredValuesToLoads(AI, AS);
return {Changed, CFGChanged};
}

// Delete all the dead users of this alloca before splitting and rewriting it.
for (Instruction *DeadUser : AS.getDeadUsers()) {
// Free up everything used by this instruction.
Expand Down
14 changes: 12 additions & 2 deletions llvm/lib/Transforms/Utils/SSAUpdater.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,17 +412,21 @@ void LoadAndStorePromoter::run(const SmallVectorImpl<Instruction *> &Insts) {
if (StoreInst *SI = dyn_cast<StoreInst>(User)) {
updateDebugInfo(SI);
SSA.AddAvailableValue(BB, SI->getOperand(0));
} else
} else if (auto *AI = dyn_cast<AllocaInst>(User)) {
// We treat AllocaInst as a store of an getValueToUseForAlloca value.
SSA.AddAvailableValue(BB, getValueToUseForAlloca(AI));
} else {
// Otherwise it is a load, queue it to rewrite as a live-in load.
LiveInLoads.push_back(cast<LoadInst>(User));
}
BlockUses.clear();
continue;
}

// Otherwise, check to see if this block is all loads.
bool HasStore = false;
for (Instruction *I : BlockUses) {
if (isa<StoreInst>(I)) {
if (isa<StoreInst>(I) || isa<AllocaInst>(I)) {
HasStore = true;
break;
}
Expand Down Expand Up @@ -468,6 +472,12 @@ void LoadAndStorePromoter::run(const SmallVectorImpl<Instruction *> &Insts) {

// Remember that this is the active value in the block.
StoredValue = SI->getOperand(0);
} else if (auto *AI = dyn_cast<AllocaInst>(&I)) {
// Check if this an alloca, in which case we treat it as a store of
// getValueToUseForAlloca.
if (!isInstInList(AI, Insts))
continue;
StoredValue = getValueToUseForAlloca(AI);
}
}

Expand Down
Loading
Loading