Skip to content

[LexicalDestroyHoisting] Adopt iterative dataflow. #58939

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 117 additions & 83 deletions lib/SILOptimizer/Utils/LexicalDestroyHoisting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "swift/SIL/SILInstruction.h"
#include "swift/SIL/SILValue.h"
#include "swift/SILOptimizer/Analysis/Reachability.h"
#include "swift/SILOptimizer/Analysis/VisitBarrierAccessScopes.h"
#include "swift/SILOptimizer/Utils/CanonicalizeBorrowScope.h"
#include "swift/SILOptimizer/Utils/InstOptUtils.h"
#include "swift/SILOptimizer/Utils/InstructionDeleter.h"
Expand All @@ -43,14 +44,17 @@ struct Context final {
/// value->getDefiningInstruction()
SILInstruction *const definition;

SILBasicBlock *defBlock;

SILFunction &function;

InstructionDeleter &deleter;

Context(SILValue const &value, SILFunction &function,
InstructionDeleter &deleter)
: value(value), definition(value->getDefiningInstruction()),
function(function), deleter(deleter) {
defBlock(value->getParentBlock()), function(function),
deleter(deleter) {
assert(value->isLexical());
assert(value->getOwnershipKind() == OwnershipKind::Owned);
}
Expand All @@ -63,7 +67,7 @@ struct Usage final {
/// Instructions which are users of the simple (i.e. not reborrowed) value.
SmallPtrSet<SILInstruction *, 16> users;
// The instructions from which the hoisting starts, the destroy_values.
llvm::SmallSetVector<SILInstruction *, 4> ends;
llvm::SmallVector<SILInstruction *, 4> ends;

Usage(){};
Usage(Usage const &) = delete;
Expand All @@ -85,7 +89,7 @@ bool findUsage(Context const &context, Usage &usage) {
// flow and determine whether any were reused. They aren't uses over which
// we can't hoist though.
if (isa<DestroyValueInst>(use->getUser())) {
usage.ends.insert(use->getUser());
usage.ends.push_back(use->getUser());
} else {
usage.users.insert(use->getUser());
}
Expand All @@ -95,95 +99,108 @@ bool findUsage(Context const &context, Usage &usage) {

/// How destroy_value hoisting is obstructed.
struct DeinitBarriers final {
/// Blocks up to "before the beginning" of which hoisting was able to proceed.
BasicBlockSetVector hoistingReachesBeginBlocks;

/// Blocks to "after the end" of which hoisting was able to proceed.
BasicBlockSet hoistingReachesEndBlocks;

/// Instructions above which destroy_values cannot be hoisted.
SmallVector<SILInstruction *, 4> barriers;
SmallVector<SILInstruction *, 4> instructions;

/// Blocks one of whose phis is a barrier and consequently out of which
/// destroy_values cannot be hoisted.
SmallVector<SILBasicBlock *, 4> phiBarriers;
SmallVector<SILBasicBlock *, 4> phis;

DeinitBarriers(Context &context)
: hoistingReachesBeginBlocks(&context.function),
hoistingReachesEndBlocks(&context.function) {}
SmallVector<SILBasicBlock *, 4> blocks;

DeinitBarriers(Context &context) {}
DeinitBarriers(DeinitBarriers const &) = delete;
DeinitBarriers &operator=(DeinitBarriers const &) = delete;
};

class BarrierAccessScopeFinder;

/// Works backwards from the current location of destroy_values to the earliest
/// place they can be hoisted to.
///
/// Implements BackwardReachability::BlockReachability.
class DataFlow final {
/// Implements IterativeBackwardReachability::Effects
/// Implements IterativeBackwardReachability::bindBarriers::Visitor
/// Implements VisitBarrierAccessScopes::Effects
class Dataflow final {
using Reachability = IterativeBackwardReachability<Dataflow>;
using Effect = Reachability::Effect;
Context const &context;
Usage const &uses;
DeinitBarriers &result;
DeinitBarriers &barriers;
Reachability::Result result;
Reachability reachability;
SmallPtrSet<BeginAccessInst *, 8> barrierAccessScopes;

enum class Classification { Barrier, Other };

BackwardReachability<DataFlow> reachability;

public:
DataFlow(Context const &context, Usage const &uses, DeinitBarriers &result)
: context(context), uses(uses), result(result),
reachability(&context.function, *this) {
// Seed reachability with the scope ending uses from which the backwards
// data flow will begin.
for (auto *end : uses.ends) {
reachability.initLastUse(end);
}
}
DataFlow(DataFlow const &) = delete;
DataFlow &operator=(DataFlow const &) = delete;
Dataflow(Context const &context, Usage const &uses, DeinitBarriers &barriers)
: context(context), uses(uses), barriers(barriers),
result(&context.function),
reachability(&context.function, context.defBlock, *this, result) {}
Dataflow(Dataflow const &) = delete;
Dataflow &operator=(Dataflow const &) = delete;

void run() { reachability.solveBackward(); }
void run();

private:
friend class BackwardReachability<DataFlow>;
friend Reachability;
friend class BarrierAccessScopeFinder;
friend class VisitBarrierAccessScopes<Dataflow, BarrierAccessScopeFinder>;

bool hasReachableBegin(SILBasicBlock *block) {
return result.hoistingReachesBeginBlocks.contains(block);
}
Classification classifyInstruction(SILInstruction *);

void markReachableBegin(SILBasicBlock *block) {
result.hoistingReachesBeginBlocks.insert(block);
}
bool classificationIsBarrier(Classification);

void markReachableEnd(SILBasicBlock *block) {
result.hoistingReachesEndBlocks.insert(block);
}
/// IterativeBackwardReachability::Effects
/// VisitBarrierAccessScopes::Effects

Classification classifyInstruction(SILInstruction *);
ArrayRef<SILInstruction *> gens() { return uses.ends; }

bool classificationIsBarrier(Classification);
Effect effectForInstruction(SILInstruction *);
Effect effectForPhi(SILBasicBlock *);

/// VisitBarrierAccessScopes::Effects

auto localGens() { return result.localGens; }

void visitedInstruction(SILInstruction *, Classification);
bool isLocalGen(SILInstruction *instruction) {
return result.localGens.contains(instruction);
}

/// IterativeBackwardReachability::bindBarriers::Visitor

void visitBarrierInstruction(SILInstruction *instruction) {
barriers.instructions.push_back(instruction);
}

bool checkReachableBarrier(SILInstruction *);
void visitBarrierPhi(SILBasicBlock *block) { barriers.phis.push_back(block); }

bool checkReachablePhiBarrier(SILBasicBlock *);
void visitBarrierBlock(SILBasicBlock *block) {
barriers.blocks.push_back(block);
}
};

DataFlow::Classification
DataFlow::classifyInstruction(SILInstruction *instruction) {
Dataflow::Classification
Dataflow::classifyInstruction(SILInstruction *instruction) {
if (instruction == context.definition) {
return Classification::Barrier;
}
if (uses.users.contains(instruction)) {
return Classification::Barrier;
}
if (auto *eai = dyn_cast<EndAccessInst>(instruction)) {
return barrierAccessScopes.contains(eai->getBeginAccess())
? Classification::Barrier
: Classification::Other;
}
if (isDeinitBarrier(instruction)) {
return Classification::Barrier;
}
return Classification::Other;
}

bool DataFlow::classificationIsBarrier(Classification classification) {
bool Dataflow::classificationIsBarrier(Classification classification) {
switch (classification) {
case Classification::Barrier:
return true;
Expand All @@ -193,26 +210,15 @@ bool DataFlow::classificationIsBarrier(Classification classification) {
llvm_unreachable("exhaustive switch not exhaustive?!");
}

void DataFlow::visitedInstruction(SILInstruction *instruction,
Classification classification) {
assert(classifyInstruction(instruction) == classification);
switch (classification) {
case Classification::Barrier:
result.barriers.push_back(instruction);
return;
case Classification::Other:
return;
}
llvm_unreachable("exhaustive switch not exhaustive?!");
}

bool DataFlow::checkReachableBarrier(SILInstruction *instruction) {
Dataflow::Effect Dataflow::effectForInstruction(SILInstruction *instruction) {
if (llvm::find(uses.ends, instruction) != uses.ends.end())
return Effect::Gen();
auto classification = classifyInstruction(instruction);
visitedInstruction(instruction, classification);
return classificationIsBarrier(classification);
return classificationIsBarrier(classification) ? Effect::Kill()
: Effect::NoEffect();
}

bool DataFlow::checkReachablePhiBarrier(SILBasicBlock *block) {
Dataflow::Effect Dataflow::effectForPhi(SILBasicBlock *block) {
assert(llvm::all_of(block->getArguments(),
[&](auto argument) { return PhiValue(argument); }));

Expand All @@ -221,10 +227,46 @@ bool DataFlow::checkReachablePhiBarrier(SILBasicBlock *block) {
return classificationIsBarrier(
classifyInstruction(predecessor->getTerminator()));
});
if (isBarrier) {
result.phiBarriers.push_back(block);
return isBarrier ? Effect::Kill() : Effect::NoEffect();
}

/// Finds end_access instructions which are barriers to hoisting because the
/// access scopes they contain barriers to hoisting. Hoisting destroy_values
/// into such access scopes could introduce exclusivity violations.
///
/// Implements BarrierAccessScopeFinder::Visitor
class BarrierAccessScopeFinder final {
using Impl = VisitBarrierAccessScopes<Dataflow, BarrierAccessScopeFinder>;
Impl impl;
Dataflow &dataflow;

public:
BarrierAccessScopeFinder(Context const &context, Dataflow &dataflow)
: impl(&context.function, dataflow, *this), dataflow(dataflow) {}

void find() { impl.visit(); }

private:
friend Impl;

bool isInRegion(SILBasicBlock *block) {
return dataflow.result.discoveredBlocks.contains(block);
}

void visitBarrierAccessScope(BeginAccessInst *bai) {
dataflow.barrierAccessScopes.insert(bai);
for (auto *eai : bai->getEndAccesses()) {
dataflow.reachability.addKill(eai);
}
}
return isBarrier;
};

void Dataflow::run() {
reachability.initialize();
BarrierAccessScopeFinder finder(context, *this);
finder.find();
reachability.solve();
reachability.findBarriers(*this);
}

/// Hoist the destroy_values of %value.
Expand Down Expand Up @@ -256,7 +298,7 @@ bool Rewriter::run() {
//
// A block is a phi barrier iff any of its predecessors' terminators get
// classified as barriers.
for (auto *block : barriers.phiBarriers) {
for (auto *block : barriers.phis) {
madeChange |= createDestroyValue(&block->front());
}

Expand All @@ -271,13 +313,9 @@ bool Rewriter::run() {
// have returned true for P, so none of its instructions would ever have been
// classified (except for via checkReachablePhiBarrier, which doesn't record
// terminator barriers).
for (auto instruction : barriers.barriers) {
for (auto instruction : barriers.instructions) {
if (auto *terminator = dyn_cast<TermInst>(instruction)) {
auto successors = terminator->getParentBlock()->getSuccessorBlocks();
// In order for the instruction to have been classified as a barrier,
// reachability would have had to reach the block containing it.
assert(barriers.hoistingReachesEndBlocks.contains(
terminator->getParentBlock()));
for (auto *successor : successors) {
madeChange |= createDestroyValue(&successor->front());
}
Expand All @@ -301,12 +339,8 @@ bool Rewriter::run() {
// P not having a reachable end--see BackwardReachability::meetOverSuccessors.
//
// control-flow-boundary(B) := beginning-reachable(B) && !end-reachable(P)
for (auto *block : barriers.hoistingReachesBeginBlocks) {
if (auto *predecessor = block->getSinglePredecessorBlock()) {
if (!barriers.hoistingReachesEndBlocks.contains(predecessor)) {
madeChange |= createDestroyValue(&block->front());
}
}
for (auto *block : barriers.blocks) {
madeChange |= createDestroyValue(&block->front());
}

if (madeChange) {
Expand All @@ -324,7 +358,7 @@ bool Rewriter::run() {

bool Rewriter::createDestroyValue(SILInstruction *insertionPoint) {
if (auto *ebi = dyn_cast<DestroyValueInst>(insertionPoint)) {
if (uses.ends.contains(insertionPoint)) {
if (llvm::find(uses.ends, insertionPoint) != uses.ends.end()) {
reusedDestroyValueInsts.insert(insertionPoint);
return false;
}
Expand All @@ -342,7 +376,7 @@ bool run(Context &context) {
return false;

DeinitBarriers barriers(context);
DataFlow flow(context, usage, barriers);
Dataflow flow(context, usage, barriers);
flow.run();

Rewriter rewriter(context, usage, barriers);
Expand Down
Loading