Skip to content

Commit 5ae8f2d

Browse files
committed
[SPIR-V] Add pass to merge convergence region exit targets
The structurizer required regions to be SESE: single entry, single exit. This new pass transforms multiple-exit regions into single-exit regions. ``` +---+ | A | +---+ / \ +---+ +---+ | B | | C | A, B & C belongs to the same convergence region. +---+ +---+ | | +---+ +---+ | D | | E | C & D belongs to the parent convergence region. +---+ +---+ This means B & C are the exit blocks of the region. \ / And D & E the targets of those exits. \ / | +---+ | F | +---+ ``` This pass would assign one value per exit target: B = 0 C = 1 Then, create one variable per exit block (B, C), and assign it to the correct value: in B, the variable will have the value 0, and in C, the value 1. Then, we'd create a new block H, with a PHI node to gather those 2 variables, and a switch, to route to the correct target. Finally, the branches in B and C are updated to exit to this new block. ``` +---+ | A | +---+ / \ +---+ +---+ | B | | C | +---+ +---+ \ / +---+ | H | +---+ / \ +---+ +---+ | D | | E | +---+ +---+ \ / \ / | +---+ | F | +---+ ``` Note: the variable is set depending on the condition used to branch. If B's terminator was conditional, the variable would be set using a SELECT. All internal edges of a region are left intact, only exiting edges are updated. Signed-off-by: Nathan Gauër <[email protected]>
1 parent 371eccd commit 5ae8f2d

10 files changed

+647
-1
lines changed

llvm/lib/Target/SPIRV/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ add_llvm_target(SPIRVCodeGen
2424
SPIRVInstrInfo.cpp
2525
SPIRVInstructionSelector.cpp
2626
SPIRVStripConvergentIntrinsics.cpp
27+
SPIRVMergeRegionExitTargets.cpp
2728
SPIRVISelLowering.cpp
2829
SPIRVLegalizerInfo.cpp
2930
SPIRVMCInstLower.cpp

llvm/lib/Target/SPIRV/SPIRV.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class InstructionSelector;
2020
class RegisterBankInfo;
2121

2222
ModulePass *createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM);
23+
FunctionPass *createSPIRVMergeRegionExitTargetsPass();
2324
FunctionPass *createSPIRVStripConvergenceIntrinsicsPass();
2425
FunctionPass *createSPIRVRegularizerPass();
2526
FunctionPass *createSPIRVPreLegalizerPass();

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,16 @@ class SPIRVEmitIntrinsics
150150
ModulePass::getAnalysisUsage(AU);
151151
}
152152
};
153+
154+
bool isConvergenceIntrinsic(const Instruction *I) {
155+
const auto *II = dyn_cast<IntrinsicInst>(I);
156+
if (!II)
157+
return false;
158+
159+
return II->getIntrinsicID() == Intrinsic::experimental_convergence_entry ||
160+
II->getIntrinsicID() == Intrinsic::experimental_convergence_loop ||
161+
II->getIntrinsicID() == Intrinsic::experimental_convergence_anchor;
162+
}
153163
} // namespace
154164

155165
char SPIRVEmitIntrinsics::ID = 0;
@@ -1074,6 +1084,10 @@ void SPIRVEmitIntrinsics::processGlobalValue(GlobalVariable &GV,
10741084

10751085
void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
10761086
IRBuilder<> &B) {
1087+
// Don't assign types to LLVM tokens.
1088+
if (isConvergenceIntrinsic(I))
1089+
return;
1090+
10771091
reportFatalOnTokenType(I);
10781092
if (!isPointerTy(I->getType()) || !requireAssignType(I) ||
10791093
isa<BitCastInst>(I))
@@ -1092,6 +1106,10 @@ void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
10921106

10931107
void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
10941108
IRBuilder<> &B) {
1109+
// Don't assign types to LLVM tokens.
1110+
if (isConvergenceIntrinsic(I))
1111+
return;
1112+
10951113
reportFatalOnTokenType(I);
10961114
Type *Ty = I->getType();
10971115
if (!Ty->isVoidTy() && !isPointerTy(Ty) && requireAssignType(I)) {
@@ -1319,6 +1337,11 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
13191337
I = visit(*I);
13201338
if (!I)
13211339
continue;
1340+
1341+
// Don't emit intrinsics for convergence operations.
1342+
if (isConvergenceIntrinsic(I))
1343+
continue;
1344+
13221345
processInstrAfterVisit(I, B);
13231346
}
13241347

llvm/lib/Target/SPIRV/SPIRVInstrInfo.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ def OpFwidthCoarse: UnOp<"OpFwidthCoarse", 215>;
615615
def OpPhi: Op<245, (outs ID:$res), (ins TYPE:$type, ID:$var0, ID:$block0, variable_ops),
616616
"$res = OpPhi $type $var0 $block0">;
617617
def OpLoopMerge: Op<246, (outs), (ins ID:$merge, ID:$continue, LoopControl:$lc, variable_ops),
618-
"OpLoopMerge $merge $merge $continue $lc">;
618+
"OpLoopMerge $merge $continue $lc">;
619619
def OpSelectionMerge: Op<247, (outs), (ins ID:$merge, SelectionControl:$sc),
620620
"OpSelectionMerge $merge $sc">;
621621
def OpLabel: Op<248, (outs ID:$label), (ins), "$label = OpLabel">;
Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
//===-- SPIRVMergeRegionExitTargets.cpp ----------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Merge the multiple exit targets of a convergence region into a single block.
10+
// Each exit target will be assigned a constant value, and a phi node + switch
11+
// will allow the new exit target to re-route to the correct basic block.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "Analysis/SPIRVConvergenceRegionAnalysis.h"
16+
#include "SPIRV.h"
17+
#include "SPIRVSubtarget.h"
18+
#include "SPIRVTargetMachine.h"
19+
#include "SPIRVUtils.h"
20+
#include "llvm/Analysis/LoopInfo.h"
21+
#include "llvm/CodeGen/IntrinsicLowering.h"
22+
#include "llvm/IR/CFG.h"
23+
#include "llvm/IR/Dominators.h"
24+
#include "llvm/IR/IRBuilder.h"
25+
#include "llvm/IR/IntrinsicInst.h"
26+
#include "llvm/IR/Intrinsics.h"
27+
#include "llvm/IR/IntrinsicsSPIRV.h"
28+
#include "llvm/InitializePasses.h"
29+
#include "llvm/Transforms/Utils/Cloning.h"
30+
#include "llvm/Transforms/Utils/LoopSimplify.h"
31+
#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
32+
33+
using namespace llvm;
34+
35+
namespace llvm {
36+
void initializeSPIRVMergeRegionExitTargetsPass(PassRegistry &);
37+
} // namespace llvm
38+
39+
namespace llvm {
40+
41+
class SPIRVMergeRegionExitTargets : public FunctionPass {
42+
public:
43+
static char ID;
44+
45+
SPIRVMergeRegionExitTargets() : FunctionPass(ID) {
46+
initializeSPIRVMergeRegionExitTargetsPass(*PassRegistry::getPassRegistry());
47+
};
48+
49+
// Gather all the successors of |BB|.
50+
// This function asserts if the terminator neither a branch, switch or return.
51+
std::unordered_set<BasicBlock *> gatherSuccessors(BasicBlock *BB) {
52+
std::unordered_set<BasicBlock *> output;
53+
auto *T = BB->getTerminator();
54+
55+
if (auto *BI = dyn_cast<BranchInst>(T)) {
56+
output.insert(BI->getSuccessor(0));
57+
if (BI->isConditional())
58+
output.insert(BI->getSuccessor(1));
59+
return output;
60+
}
61+
62+
if (auto *SI = dyn_cast<SwitchInst>(T)) {
63+
output.insert(SI->getDefaultDest());
64+
for (auto &Case : SI->cases()) {
65+
output.insert(Case.getCaseSuccessor());
66+
}
67+
return output;
68+
}
69+
70+
if (auto *RI = dyn_cast<ReturnInst>(T))
71+
return output;
72+
73+
assert(false && "Unhandled terminator type.");
74+
return output;
75+
}
76+
77+
/// Create a value in BB set to the value associated with the branch the block
78+
/// terminator will take.
79+
llvm::Value *createExitVariable(
80+
BasicBlock *BB,
81+
const std::unordered_map<BasicBlock *, ConstantInt *> &TargetToValue) {
82+
auto *T = BB->getTerminator();
83+
if (auto *RI = dyn_cast<ReturnInst>(T)) {
84+
return nullptr;
85+
}
86+
87+
IRBuilder<> Builder(BB);
88+
Builder.SetInsertPoint(T);
89+
90+
if (auto *BI = dyn_cast<BranchInst>(T)) {
91+
92+
BasicBlock *LHSTarget = BI->getSuccessor(0);
93+
BasicBlock *RHSTarget =
94+
BI->isConditional() ? BI->getSuccessor(1) : nullptr;
95+
96+
Value *LHS = TargetToValue.count(LHSTarget) != 0
97+
? TargetToValue.at(LHSTarget)
98+
: nullptr;
99+
Value *RHS = TargetToValue.count(RHSTarget) != 0
100+
? TargetToValue.at(RHSTarget)
101+
: nullptr;
102+
103+
if (LHS == nullptr || RHS == nullptr)
104+
return LHS == nullptr ? RHS : LHS;
105+
return Builder.CreateSelect(BI->getCondition(), LHS, RHS);
106+
}
107+
108+
// TODO: add support for switch cases.
109+
assert(false && "Unhandled terminator type.");
110+
}
111+
112+
/// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|.
113+
void replaceBranchTargets(BasicBlock *BB,
114+
const std::unordered_set<BasicBlock *> ToReplace,
115+
BasicBlock *NewTarget) {
116+
auto *T = BB->getTerminator();
117+
if (auto *RI = dyn_cast<ReturnInst>(T))
118+
return;
119+
120+
if (auto *BI = dyn_cast<BranchInst>(T)) {
121+
for (size_t i = 0; i < BI->getNumSuccessors(); i++) {
122+
if (ToReplace.count(BI->getSuccessor(i)) != 0)
123+
BI->setSuccessor(i, NewTarget);
124+
}
125+
return;
126+
}
127+
128+
if (auto *SI = dyn_cast<SwitchInst>(T)) {
129+
for (size_t i = 0; i < SI->getNumSuccessors(); i++) {
130+
if (ToReplace.count(SI->getSuccessor(i)) != 0)
131+
SI->setSuccessor(i, NewTarget);
132+
}
133+
return;
134+
}
135+
136+
assert(false && "Unhandled terminator type.");
137+
}
138+
139+
// Run the pass on the given convergence region, ignoring the sub-regions.
140+
// Returns true if the CFG changed, false otherwise.
141+
bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,
142+
const SPIRV::ConvergenceRegion *CR) {
143+
// Gather all the exit targets for this region.
144+
std::unordered_set<BasicBlock *> ExitTargets;
145+
for (BasicBlock *Exit : CR->Exits) {
146+
for (BasicBlock *Target : gatherSuccessors(Exit)) {
147+
if (CR->Blocks.count(Target) == 0)
148+
ExitTargets.insert(Target);
149+
}
150+
}
151+
152+
// If we have zero or one exit target, nothing do to.
153+
if (ExitTargets.size() <= 1)
154+
return false;
155+
156+
// Create the new single exit target.
157+
auto F = CR->Entry->getParent();
158+
auto NewExitTarget = BasicBlock::Create(F->getContext(), "new.exit", F);
159+
IRBuilder<> Builder(NewExitTarget);
160+
161+
// CodeGen output needs to be stable. Using the set as-is would order
162+
// the targets differently depending on the allocation pattern.
163+
// Sorting per basic-block ordering in the function.
164+
std::vector<BasicBlock *> SortedExitTargets;
165+
std::vector<BasicBlock *> SortedExits;
166+
for (BasicBlock &BB : *F) {
167+
if (ExitTargets.count(&BB) != 0)
168+
SortedExitTargets.push_back(&BB);
169+
if (CR->Exits.count(&BB) != 0)
170+
SortedExits.push_back(&BB);
171+
}
172+
173+
// Creating one constant per distinct exit target. This will be route to the
174+
// correct target.
175+
std::unordered_map<BasicBlock *, ConstantInt *> TargetToValue;
176+
for (BasicBlock *Target : SortedExitTargets)
177+
TargetToValue.emplace(Target, Builder.getInt32(TargetToValue.size()));
178+
179+
// Creating one variable per exit node, set to the constant matching the
180+
// targeted external block.
181+
std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable;
182+
for (auto Exit : SortedExits) {
183+
llvm::Value *Value = createExitVariable(Exit, TargetToValue);
184+
ExitToVariable.emplace_back(std::make_pair(Exit, Value));
185+
}
186+
187+
// Gather the correct value depending on the exit we came from.
188+
llvm::PHINode *node =
189+
Builder.CreatePHI(Builder.getInt32Ty(), ExitToVariable.size());
190+
for (auto [BB, Value] : ExitToVariable) {
191+
node->addIncoming(Value, BB);
192+
}
193+
194+
// Creating the switch to jump to the correct exit target.
195+
std::vector<std::pair<BasicBlock *, ConstantInt *>> CasesList(
196+
TargetToValue.begin(), TargetToValue.end());
197+
llvm::SwitchInst *Sw =
198+
Builder.CreateSwitch(node, CasesList[0].first, CasesList.size() - 1);
199+
for (size_t i = 1; i < CasesList.size(); i++)
200+
Sw->addCase(CasesList[i].second, CasesList[i].first);
201+
202+
// Fix exit branches to redirect to the new exit.
203+
for (auto Exit : CR->Exits)
204+
replaceBranchTargets(Exit, ExitTargets, NewExitTarget);
205+
206+
return true;
207+
}
208+
209+
/// Run the pass on the given convergence region and sub-regions (DFS).
210+
/// Returns true if a region/sub-region was modified, false otherwise.
211+
/// This returns as soon as one region/sub-region has been modified.
212+
bool runOnConvergenceRegion(LoopInfo &LI,
213+
const SPIRV::ConvergenceRegion *CR) {
214+
for (auto *Child : CR->Children)
215+
if (runOnConvergenceRegion(LI, Child))
216+
return true;
217+
218+
return runOnConvergenceRegionNoRecurse(LI, CR);
219+
}
220+
221+
#if !NDEBUG
222+
/// Validates each edge exiting the region has the same destination basic
223+
/// block.
224+
void validateRegionExits(const SPIRV::ConvergenceRegion *CR) {
225+
for (auto *Child : CR->Children)
226+
validateRegionExits(Child);
227+
228+
std::unordered_set<BasicBlock *> ExitTargets;
229+
for (auto *Exit : CR->Exits) {
230+
auto Set = gatherSuccessors(Exit);
231+
for (auto *BB : Set) {
232+
if (CR->Blocks.count(BB) == 0)
233+
ExitTargets.insert(BB);
234+
}
235+
}
236+
237+
assert(ExitTargets.size() <= 1);
238+
}
239+
#endif
240+
241+
virtual bool runOnFunction(Function &F) override {
242+
LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
243+
const auto *TopLevelRegion =
244+
getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
245+
.getRegionInfo()
246+
.getTopLevelRegion();
247+
248+
// FIXME: very inefficient method: each time a region is modified, we bubble
249+
// back up, and recompute the whole convergence region tree. Once the
250+
// algorithm is completed and test coverage good enough, rewrite this pass
251+
// to be efficient instead of simple.
252+
bool modified = false;
253+
while (runOnConvergenceRegion(LI, TopLevelRegion)) {
254+
TopLevelRegion = getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
255+
.getRegionInfo()
256+
.getTopLevelRegion();
257+
modified = true;
258+
}
259+
260+
F.dump();
261+
#if !NDEBUG
262+
validateRegionExits(TopLevelRegion);
263+
#endif
264+
return modified;
265+
}
266+
267+
void getAnalysisUsage(AnalysisUsage &AU) const override {
268+
AU.addRequired<DominatorTreeWrapperPass>();
269+
AU.addRequired<LoopInfoWrapperPass>();
270+
AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
271+
FunctionPass::getAnalysisUsage(AU);
272+
}
273+
};
274+
} // namespace llvm
275+
276+
char SPIRVMergeRegionExitTargets::ID = 0;
277+
278+
INITIALIZE_PASS_BEGIN(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
279+
"SPIRV split region exit blocks", false, false)
280+
INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
281+
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
282+
INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
283+
INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass)
284+
285+
INITIALIZE_PASS_END(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
286+
"SPIRV split region exit blocks", false, false)
287+
288+
FunctionPass *llvm::createSPIRVMergeRegionExitTargetsPass() {
289+
return new SPIRVMergeRegionExitTargets();
290+
}

llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ void SPIRVPassConfig::addIRPasses() {
164164
// - all loop exits are dominated by the loop pre-header.
165165
// - loops have a single back-edge.
166166
addPass(createLoopSimplifyPass());
167+
addPass(createSPIRVMergeRegionExitTargetsPass());
167168
}
168169

169170
TargetPassConfig::addIRPasses();

0 commit comments

Comments
 (0)