Skip to content

[UniformAnalysis] Use Immediate postDom as last join #140013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/include/llvm/ADT/GenericSSAContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ template <typename _FunctionT> class GenericSSAContext {
// a given funciton.
using DominatorTreeT = DominatorTreeBase<BlockT, false>;

// A post-dominator tree provides the post-dominance relation between
// basic blocks in a given funciton.
using PostDominatorTreeT = DominatorTreeBase<BlockT, true>;

GenericSSAContext() = default;
GenericSSAContext(const FunctionT *F) : F(F) {}

Expand Down
91 changes: 44 additions & 47 deletions llvm/include/llvm/ADT/GenericUniformityImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ template <typename ContextT> class GenericSyncDependenceAnalysis {
public:
using BlockT = typename ContextT::BlockT;
using DominatorTreeT = typename ContextT::DominatorTreeT;
using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
using FunctionT = typename ContextT::FunctionT;
using ValueRefT = typename ContextT::ValueRefT;
using InstructionT = typename ContextT::InstructionT;
Expand Down Expand Up @@ -296,7 +297,9 @@ template <typename ContextT> class GenericSyncDependenceAnalysis {
using DivergencePropagatorT = DivergencePropagator<ContextT>;

GenericSyncDependenceAnalysis(const ContextT &Context,
const DominatorTreeT &DT, const CycleInfoT &CI);
const DominatorTreeT &DT,
const PostDominatorTreeT &PDT,
const CycleInfoT &CI);

/// \brief Computes divergent join points and cycle exits caused by branch
/// divergence in \p Term.
Expand All @@ -315,6 +318,7 @@ template <typename ContextT> class GenericSyncDependenceAnalysis {
ModifiedPO CyclePO;

const DominatorTreeT &DT;
const PostDominatorTreeT &PDT;
const CycleInfoT &CI;

DenseMap<const BlockT *, std::unique_ptr<DivergenceDescriptor>>
Expand All @@ -336,6 +340,7 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
using UseT = typename ContextT::UseT;
using InstructionT = typename ContextT::InstructionT;
using DominatorTreeT = typename ContextT::DominatorTreeT;
using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;

using CycleInfoT = GenericCycleInfo<ContextT>;
using CycleT = typename CycleInfoT::CycleT;
Expand All @@ -348,10 +353,12 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
using TemporalDivergenceTuple =
std::tuple<ConstValueRefT, InstructionT *, const CycleT *>;

GenericUniformityAnalysisImpl(const DominatorTreeT &DT, const CycleInfoT &CI,
GenericUniformityAnalysisImpl(const DominatorTreeT &DT,
const PostDominatorTreeT &PDT,
const CycleInfoT &CI,
const TargetTransformInfo *TTI)
: Context(CI.getSSAContext()), F(*Context.getFunction()), CI(CI),
TTI(TTI), DT(DT), SDA(Context, DT, CI) {}
TTI(TTI), DT(DT), PDT(PDT), SDA(Context, DT, PDT, CI) {}

void initialize();

Expand Down Expand Up @@ -435,6 +442,7 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {

private:
const DominatorTreeT &DT;
const PostDominatorTreeT &PDT;

// Recognized cycles with divergent exits.
SmallPtrSet<const CycleT *, 16> DivergentExitCycles;
Expand Down Expand Up @@ -493,6 +501,7 @@ template <typename ContextT> class DivergencePropagator {
public:
using BlockT = typename ContextT::BlockT;
using DominatorTreeT = typename ContextT::DominatorTreeT;
using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
using FunctionT = typename ContextT::FunctionT;
using ValueRefT = typename ContextT::ValueRefT;

Expand All @@ -507,6 +516,7 @@ template <typename ContextT> class DivergencePropagator {

const ModifiedPO &CyclePOT;
const DominatorTreeT &DT;
const PostDominatorTreeT &PDT;
const CycleInfoT &CI;
const BlockT &DivTermBlock;
const ContextT &Context;
Expand All @@ -522,10 +532,11 @@ template <typename ContextT> class DivergencePropagator {
BlockLabelMapT &BlockLabels;

DivergencePropagator(const ModifiedPO &CyclePOT, const DominatorTreeT &DT,
const CycleInfoT &CI, const BlockT &DivTermBlock)
: CyclePOT(CyclePOT), DT(DT), CI(CI), DivTermBlock(DivTermBlock),
Context(CI.getSSAContext()), DivDesc(new DivergenceDescriptorT),
BlockLabels(DivDesc->BlockLabels) {}
const PostDominatorTreeT &PDT, const CycleInfoT &CI,
const BlockT &DivTermBlock)
: CyclePOT(CyclePOT), DT(DT), PDT(PDT), CI(CI),
DivTermBlock(DivTermBlock), Context(CI.getSSAContext()),
DivDesc(new DivergenceDescriptorT), BlockLabels(DivDesc->BlockLabels) {}

void printDefs(raw_ostream &Out) {
Out << "Propagator::BlockLabels {\n";
Expand All @@ -542,6 +553,12 @@ template <typename ContextT> class DivergencePropagator {
Out << "}\n";
}

const BlockT *getIPDom(const BlockT *B) {
const auto *Node = PDT.getNode(B);
const auto *IPDomNode = Node->getIDom();
return IPDomNode->getBlock();
}

// Push a definition (\p PushedLabel) to \p SuccBlock and return whether this
// causes a divergent join.
bool computeJoin(const BlockT &SuccBlock, const BlockT &PushedLabel) {
Expand Down Expand Up @@ -610,10 +627,11 @@ template <typename ContextT> class DivergencePropagator {
LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: "
<< Context.print(&DivTermBlock) << "\n");

// Early stopping criterion
int FloorIdx = CyclePOT.size() - 1;
const BlockT *FloorLabel = nullptr;
int DivTermIdx = CyclePOT.getIndex(&DivTermBlock);
// Immediate Post-dominator of DivTermBlock is the last join
// to visit.
const auto *ImmPDom = getIPDom(&DivTermBlock);

LLVM_DEBUG(dbgs() << "Last join: " << Context.print(ImmPDom) << "\n");

// Bootstrap with branch targets
auto const *DivTermCycle = CI.getCycle(&DivTermBlock);
Expand All @@ -626,34 +644,29 @@ template <typename ContextT> class DivergencePropagator {
LLVM_DEBUG(dbgs() << "\tImmediate divergent cycle exit: "
<< Context.print(SuccBlock) << "\n");
}
auto SuccIdx = CyclePOT.getIndex(SuccBlock);
visitEdge(*SuccBlock, *SuccBlock);
FloorIdx = std::min<int>(FloorIdx, SuccIdx);
}

while (true) {
auto BlockIdx = FreshLabels.find_last();
if (BlockIdx == -1 || BlockIdx < FloorIdx)
if (BlockIdx == -1)
break;

LLVM_DEBUG(dbgs() << "Current labels:\n"; printDefs(dbgs()));

FreshLabels.reset(BlockIdx);
if (BlockIdx == DivTermIdx) {
LLVM_DEBUG(dbgs() << "Skipping DivTermBlock\n");
const auto *Block = CyclePOT[BlockIdx];
if (Block == ImmPDom) {
LLVM_DEBUG(dbgs() << "Skipping ImmPDom\n");
continue;
}

const auto *Block = CyclePOT[BlockIdx];
LLVM_DEBUG(dbgs() << "visiting " << Context.print(Block) << " at index "
<< BlockIdx << "\n");

const auto *Label = BlockLabels[Block];
assert(Label);

bool CausedJoin = false;
int LoweredFloorIdx = FloorIdx;

// If the current block is the header of a reducible cycle that
// contains the divergent branch, then the label should be
// propagated to the cycle exits. Such a header is the "last
Expand Down Expand Up @@ -681,28 +694,11 @@ template <typename ContextT> class DivergencePropagator {
if (const auto *BlockCycle = getReducibleParent(Block)) {
SmallVector<BlockT *, 4> BlockCycleExits;
BlockCycle->getExitBlocks(BlockCycleExits);
for (auto *BlockCycleExit : BlockCycleExits) {
CausedJoin |= visitCycleExitEdge(*BlockCycleExit, *Label);
LoweredFloorIdx =
std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(BlockCycleExit));
}
for (auto *BlockCycleExit : BlockCycleExits)
visitCycleExitEdge(*BlockCycleExit, *Label);
} else {
for (const auto *SuccBlock : successors(Block)) {
CausedJoin |= visitEdge(*SuccBlock, *Label);
LoweredFloorIdx =
std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(SuccBlock));
}
}

// Floor update
if (CausedJoin) {
// 1. Different labels pushed to successors
FloorIdx = LoweredFloorIdx;
} else if (FloorLabel != Label) {
// 2. No join caused BUT we pushed a label that is different than the
// last pushed label
FloorIdx = LoweredFloorIdx;
FloorLabel = Label;
for (const auto *SuccBlock : successors(Block))
visitEdge(*SuccBlock, *Label);
}
}

Expand Down Expand Up @@ -742,8 +738,9 @@ typename llvm::GenericSyncDependenceAnalysis<ContextT>::DivergenceDescriptor

template <typename ContextT>
llvm::GenericSyncDependenceAnalysis<ContextT>::GenericSyncDependenceAnalysis(
const ContextT &Context, const DominatorTreeT &DT, const CycleInfoT &CI)
: CyclePO(Context), DT(DT), CI(CI) {
const ContextT &Context, const DominatorTreeT &DT,
const PostDominatorTreeT &PDT, const CycleInfoT &CI)
: CyclePO(Context), DT(DT), PDT(PDT), CI(CI) {
CyclePO.compute(CI);
}

Expand All @@ -761,7 +758,7 @@ auto llvm::GenericSyncDependenceAnalysis<ContextT>::getJoinBlocks(
return *ItCached->second;

// compute all join points
DivergencePropagatorT Propagator(CyclePO, DT, CI, *DivTermBlock);
DivergencePropagatorT Propagator(CyclePO, DT, PDT, CI, *DivTermBlock);
auto DivDesc = Propagator.computeJoinPoints();

auto printBlockSet = [&](ConstBlockSet &Blocks) {
Expand Down Expand Up @@ -1155,9 +1152,9 @@ bool GenericUniformityAnalysisImpl<ContextT>::isAlwaysUniform(

template <typename ContextT>
GenericUniformityInfo<ContextT>::GenericUniformityInfo(
const DominatorTreeT &DT, const CycleInfoT &CI,
const TargetTransformInfo *TTI) {
DA.reset(new ImplT{DT, CI, TTI});
const DominatorTreeT &DT, const PostDominatorTreeT &PDT,
const CycleInfoT &CI, const TargetTransformInfo *TTI) {
DA.reset(new ImplT{DT, PDT, CI, TTI});
}

template <typename ContextT>
Expand Down
4 changes: 3 additions & 1 deletion llvm/include/llvm/ADT/GenericUniformityInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ template <typename ContextT> class GenericUniformityInfo {
using UseT = typename ContextT::UseT;
using InstructionT = typename ContextT::InstructionT;
using DominatorTreeT = typename ContextT::DominatorTreeT;
using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
using ThisT = GenericUniformityInfo<ContextT>;

using CycleInfoT = GenericCycleInfo<ContextT>;
Expand All @@ -43,7 +44,8 @@ template <typename ContextT> class GenericUniformityInfo {
using TemporalDivergenceTuple =
std::tuple<ConstValueRefT, InstructionT *, const CycleT *>;

GenericUniformityInfo(const DominatorTreeT &DT, const CycleInfoT &CI,
GenericUniformityInfo(const DominatorTreeT &DT, const PostDominatorTreeT &PDT,
const CycleInfoT &CI,
const TargetTransformInfo *TTI = nullptr);
GenericUniformityInfo() = default;
GenericUniformityInfo(GenericUniformityInfo &&) = default;
Expand Down
4 changes: 3 additions & 1 deletion llvm/include/llvm/CodeGen/MachineUniformityAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "llvm/CodeGen/MachineCycleAnalysis.h"
#include "llvm/CodeGen/MachineDominators.h"
#include "llvm/CodeGen/MachinePassManager.h"
#include "llvm/CodeGen/MachinePostDominators.h"
#include "llvm/CodeGen/MachineSSAContext.h"

namespace llvm {
Expand All @@ -31,7 +32,8 @@ using MachineUniformityInfo = GenericUniformityInfo<MachineSSAContext>;
/// everything is uniform.
MachineUniformityInfo computeMachineUniformityInfo(
MachineFunction &F, const MachineCycleInfo &cycleInfo,
const MachineDominatorTree &domTree, bool HasBranchDivergence);
const MachineDominatorTree &domTree,
const MachinePostDominatorTree &pdomTree, bool HasBranchDivergence);

/// Legacy analysis pass which computes a \ref MachineUniformityInfo.
class MachineUniformityAnalysisPass : public MachineFunctionPass {
Expand Down
10 changes: 8 additions & 2 deletions llvm/lib/Analysis/UniformityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "llvm/Analysis/UniformityAnalysis.h"
#include "llvm/ADT/GenericUniformityImpl.h"
#include "llvm/Analysis/CycleAnalysis.h"
#include "llvm/Analysis/PostDominators.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/InstIterator.h"
Expand Down Expand Up @@ -114,9 +115,10 @@ template struct llvm::GenericUniformityAnalysisImplDeleter<
llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F,
FunctionAnalysisManager &FAM) {
auto &DT = FAM.getResult<DominatorTreeAnalysis>(F);
auto &PDT = FAM.getResult<PostDominatorTreeAnalysis>(F);
auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
auto &CI = FAM.getResult<CycleAnalysis>(F);
UniformityInfo UI{DT, CI, &TTI};
UniformityInfo UI{DT, PDT, CI, &TTI};
// Skip computation if we can assume everything is uniform.
if (TTI.hasBranchDivergence(&F))
UI.compute();
Expand Down Expand Up @@ -148,6 +150,7 @@ UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {}
INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniformity",
"Uniformity Analysis", true, true)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity",
Expand All @@ -156,18 +159,21 @@ INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity",
void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesAll();
AU.addRequired<DominatorTreeWrapperPass>();
AU.addRequired<PostDominatorTreeWrapperPass>();
AU.addRequiredTransitive<CycleInfoWrapperPass>();
AU.addRequired<TargetTransformInfoWrapperPass>();
}

bool UniformityInfoWrapperPass::runOnFunction(Function &F) {
auto &cycleInfo = getAnalysis<CycleInfoWrapperPass>().getResult();
auto &domTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
auto &pdomTree = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
auto &targetTransformInfo =
getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);

m_function = &F;
m_uniformityInfo = UniformityInfo{domTree, cycleInfo, &targetTransformInfo};
m_uniformityInfo =
UniformityInfo{domTree, pdomTree, cycleInfo, &targetTransformInfo};

// Skip computation if we can assume everything is uniform.
if (targetTransformInfo.hasBranchDivergence(m_function))
Expand Down
15 changes: 11 additions & 4 deletions llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/CodeGen/MachineCycleAnalysis.h"
#include "llvm/CodeGen/MachineDominators.h"
#include "llvm/CodeGen/MachinePostDominators.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/MachineSSAContext.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
Expand Down Expand Up @@ -156,9 +157,10 @@ template struct llvm::GenericUniformityAnalysisImplDeleter<

MachineUniformityInfo llvm::computeMachineUniformityInfo(
MachineFunction &F, const MachineCycleInfo &cycleInfo,
const MachineDominatorTree &domTree, bool HasBranchDivergence) {
const MachineDominatorTree &domTree,
const MachinePostDominatorTree &pdomTree, bool HasBranchDivergence) {
assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!");
MachineUniformityInfo UI(domTree, cycleInfo);
MachineUniformityInfo UI(domTree, pdomTree, cycleInfo);
if (HasBranchDivergence)
UI.compute();
return UI;
Expand All @@ -184,12 +186,13 @@ MachineUniformityAnalysis::Result
MachineUniformityAnalysis::run(MachineFunction &MF,
MachineFunctionAnalysisManager &MFAM) {
auto &DomTree = MFAM.getResult<MachineDominatorTreeAnalysis>(MF);
auto &PDomTree = MFAM.getResult<MachinePostDominatorTreeAnalysis>(MF);
auto &CI = MFAM.getResult<MachineCycleAnalysis>(MF);
auto &FAM = MFAM.getResult<FunctionAnalysisManagerMachineFunctionProxy>(MF)
.getManager();
auto &F = MF.getFunction();
auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
return computeMachineUniformityInfo(MF, CI, DomTree,
return computeMachineUniformityInfo(MF, CI, DomTree, PDomTree,
TTI.hasBranchDivergence(&F));
}

Expand All @@ -215,22 +218,26 @@ INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity",
"Machine Uniformity Info Analysis", false, true)
INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTreeWrapperPass)
INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity",
"Machine Uniformity Info Analysis", false, true)

void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesAll();
AU.addRequiredTransitive<MachineCycleInfoWrapperPass>();
AU.addRequired<MachineDominatorTreeWrapperPass>();
AU.addRequired<MachinePostDominatorTreeWrapperPass>();
MachineFunctionPass::getAnalysisUsage(AU);
}

bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) {
auto &DomTree = getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
auto &PDomTree =
getAnalysis<MachinePostDominatorTreeWrapperPass>().getPostDomTree();
auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo();
// FIXME: Query TTI::hasBranchDivergence. -run-pass seems to end up with a
// default NoTTI
UI = computeMachineUniformityInfo(MF, CI, DomTree, true);
UI = computeMachineUniformityInfo(MF, CI, DomTree, PDomTree, true);
return false;
}

Expand Down
Loading
Loading