Skip to content

Commit a5641f1

Browse files
authored
[SPIR-V] Add pass to merge convergence region exit targets (#92531)
The structurizer required regions to be SESE: single entry, single exit. This new pass transforms multiple-exit regions into single-exit regions. ``` +---+ | A | +---+ / \ +---+ +---+ | B | | C | A, B & C belongs to the same convergence region. +---+ +---+ | | +---+ +---+ | D | | E | C & D belongs to the parent convergence region. +---+ +---+ This means B & C are the exit blocks of the region. \ / And D & E the targets of those exits. \ / | +---+ | F | +---+ ``` This pass would assign one value per exit target: B = 0 C = 1 Then, create one variable per exit block (B, C), and assign it to the correct value: in B, the variable will have the value 0, and in C, the value 1. Then, we'd create a new block H, with a PHI node to gather those 2 variables, and a switch, to route to the correct target. Finally, the branches in B and C are updated to exit to this new block. ``` +---+ | A | +---+ / \ +---+ +---+ | B | | C | +---+ +---+ \ / +---+ | H | +---+ / \ +---+ +---+ | D | | E | +---+ +---+ \ / \ / | +---+ | F | +---+ ``` Note: the variable is set depending on the condition used to branch. If B's terminator was conditional, the variable would be set using a SELECT. All internal edges of a region are left intact, only exiting edges are updated. --------- Signed-off-by: Nathan Gauër <[email protected]>
1 parent 5bfc444 commit a5641f1

10 files changed

+641
-1
lines changed

llvm/lib/Target/SPIRV/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ add_llvm_target(SPIRVCodeGen
2525
SPIRVInstrInfo.cpp
2626
SPIRVInstructionSelector.cpp
2727
SPIRVStripConvergentIntrinsics.cpp
28+
SPIRVMergeRegionExitTargets.cpp
2829
SPIRVISelLowering.cpp
2930
SPIRVLegalizerInfo.cpp
3031
SPIRVMCInstLower.cpp

llvm/lib/Target/SPIRV/SPIRV.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class InstructionSelector;
2020
class RegisterBankInfo;
2121

2222
ModulePass *createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM);
23+
FunctionPass *createSPIRVMergeRegionExitTargetsPass();
2324
FunctionPass *createSPIRVStripConvergenceIntrinsicsPass();
2425
FunctionPass *createSPIRVRegularizerPass();
2526
FunctionPass *createSPIRVPreLegalizerPass();

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,16 @@ class SPIRVEmitIntrinsics
151151
ModulePass::getAnalysisUsage(AU);
152152
}
153153
};
154+
155+
bool isConvergenceIntrinsic(const Instruction *I) {
156+
const auto *II = dyn_cast<IntrinsicInst>(I);
157+
if (!II)
158+
return false;
159+
160+
return II->getIntrinsicID() == Intrinsic::experimental_convergence_entry ||
161+
II->getIntrinsicID() == Intrinsic::experimental_convergence_loop ||
162+
II->getIntrinsicID() == Intrinsic::experimental_convergence_anchor;
163+
}
154164
} // namespace
155165

156166
char SPIRVEmitIntrinsics::ID = 0;
@@ -1353,6 +1363,10 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
13531363
Worklist.push_back(&I);
13541364

13551365
for (auto &I : Worklist) {
1366+
// Don't emit intrinsincs for convergence intrinsics.
1367+
if (isConvergenceIntrinsic(I))
1368+
continue;
1369+
13561370
insertAssignPtrTypeIntrs(I, B);
13571371
insertAssignTypeIntrs(I, B);
13581372
insertPtrCastOrAssignTypeInstr(I, B);
@@ -1371,6 +1385,11 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
13711385
I = visit(*I);
13721386
if (!I)
13731387
continue;
1388+
1389+
// Don't emit intrinsics for convergence operations.
1390+
if (isConvergenceIntrinsic(I))
1391+
continue;
1392+
13741393
processInstrAfterVisit(I, B);
13751394
}
13761395

llvm/lib/Target/SPIRV/SPIRVInstrInfo.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ def OpFwidthCoarse: UnOp<"OpFwidthCoarse", 215>;
617617
def OpPhi: Op<245, (outs ID:$res), (ins TYPE:$type, ID:$var0, ID:$block0, variable_ops),
618618
"$res = OpPhi $type $var0 $block0">;
619619
def OpLoopMerge: Op<246, (outs), (ins ID:$merge, ID:$continue, LoopControl:$lc, variable_ops),
620-
"OpLoopMerge $merge $merge $continue $lc">;
620+
"OpLoopMerge $merge $continue $lc">;
621621
def OpSelectionMerge: Op<247, (outs), (ins ID:$merge, SelectionControl:$sc),
622622
"OpSelectionMerge $merge $sc">;
623623
def OpLabel: Op<248, (outs ID:$label), (ins), "$label = OpLabel">;
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
//===-- SPIRVMergeRegionExitTargets.cpp ----------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Merge the multiple exit targets of a convergence region into a single block.
10+
// Each exit target will be assigned a constant value, and a phi node + switch
11+
// will allow the new exit target to re-route to the correct basic block.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "Analysis/SPIRVConvergenceRegionAnalysis.h"
16+
#include "SPIRV.h"
17+
#include "SPIRVSubtarget.h"
18+
#include "SPIRVTargetMachine.h"
19+
#include "SPIRVUtils.h"
20+
#include "llvm/Analysis/LoopInfo.h"
21+
#include "llvm/CodeGen/IntrinsicLowering.h"
22+
#include "llvm/IR/CFG.h"
23+
#include "llvm/IR/Dominators.h"
24+
#include "llvm/IR/IRBuilder.h"
25+
#include "llvm/IR/IntrinsicInst.h"
26+
#include "llvm/IR/Intrinsics.h"
27+
#include "llvm/IR/IntrinsicsSPIRV.h"
28+
#include "llvm/InitializePasses.h"
29+
#include "llvm/Transforms/Utils/Cloning.h"
30+
#include "llvm/Transforms/Utils/LoopSimplify.h"
31+
#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
32+
33+
using namespace llvm;
34+
35+
namespace llvm {
36+
void initializeSPIRVMergeRegionExitTargetsPass(PassRegistry &);
37+
38+
class SPIRVMergeRegionExitTargets : public FunctionPass {
39+
public:
40+
static char ID;
41+
42+
SPIRVMergeRegionExitTargets() : FunctionPass(ID) {
43+
initializeSPIRVMergeRegionExitTargetsPass(*PassRegistry::getPassRegistry());
44+
};
45+
46+
// Gather all the successors of |BB|.
47+
// This function asserts if the terminator neither a branch, switch or return.
48+
std::unordered_set<BasicBlock *> gatherSuccessors(BasicBlock *BB) {
49+
std::unordered_set<BasicBlock *> output;
50+
auto *T = BB->getTerminator();
51+
52+
if (auto *BI = dyn_cast<BranchInst>(T)) {
53+
output.insert(BI->getSuccessor(0));
54+
if (BI->isConditional())
55+
output.insert(BI->getSuccessor(1));
56+
return output;
57+
}
58+
59+
if (auto *SI = dyn_cast<SwitchInst>(T)) {
60+
output.insert(SI->getDefaultDest());
61+
for (auto &Case : SI->cases())
62+
output.insert(Case.getCaseSuccessor());
63+
return output;
64+
}
65+
66+
if (auto *RI = dyn_cast<ReturnInst>(T))
67+
return output;
68+
69+
assert(false && "Unhandled terminator type.");
70+
return output;
71+
}
72+
73+
/// Create a value in BB set to the value associated with the branch the block
74+
/// terminator will take.
75+
llvm::Value *createExitVariable(
76+
BasicBlock *BB,
77+
const std::unordered_map<BasicBlock *, ConstantInt *> &TargetToValue) {
78+
auto *T = BB->getTerminator();
79+
if (auto *RI = dyn_cast<ReturnInst>(T))
80+
return nullptr;
81+
82+
IRBuilder<> Builder(BB);
83+
Builder.SetInsertPoint(T);
84+
85+
if (auto *BI = dyn_cast<BranchInst>(T)) {
86+
87+
BasicBlock *LHSTarget = BI->getSuccessor(0);
88+
BasicBlock *RHSTarget =
89+
BI->isConditional() ? BI->getSuccessor(1) : nullptr;
90+
91+
Value *LHS = TargetToValue.count(LHSTarget) != 0
92+
? TargetToValue.at(LHSTarget)
93+
: nullptr;
94+
Value *RHS = TargetToValue.count(RHSTarget) != 0
95+
? TargetToValue.at(RHSTarget)
96+
: nullptr;
97+
98+
if (LHS == nullptr || RHS == nullptr)
99+
return LHS == nullptr ? RHS : LHS;
100+
return Builder.CreateSelect(BI->getCondition(), LHS, RHS);
101+
}
102+
103+
// TODO: add support for switch cases.
104+
assert(false && "Unhandled terminator type.");
105+
}
106+
107+
/// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|.
108+
void replaceBranchTargets(BasicBlock *BB,
109+
const std::unordered_set<BasicBlock *> ToReplace,
110+
BasicBlock *NewTarget) {
111+
auto *T = BB->getTerminator();
112+
if (auto *RI = dyn_cast<ReturnInst>(T))
113+
return;
114+
115+
if (auto *BI = dyn_cast<BranchInst>(T)) {
116+
for (size_t i = 0; i < BI->getNumSuccessors(); i++) {
117+
if (ToReplace.count(BI->getSuccessor(i)) != 0)
118+
BI->setSuccessor(i, NewTarget);
119+
}
120+
return;
121+
}
122+
123+
if (auto *SI = dyn_cast<SwitchInst>(T)) {
124+
for (size_t i = 0; i < SI->getNumSuccessors(); i++) {
125+
if (ToReplace.count(SI->getSuccessor(i)) != 0)
126+
SI->setSuccessor(i, NewTarget);
127+
}
128+
return;
129+
}
130+
131+
assert(false && "Unhandled terminator type.");
132+
}
133+
134+
// Run the pass on the given convergence region, ignoring the sub-regions.
135+
// Returns true if the CFG changed, false otherwise.
136+
bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,
137+
const SPIRV::ConvergenceRegion *CR) {
138+
// Gather all the exit targets for this region.
139+
std::unordered_set<BasicBlock *> ExitTargets;
140+
for (BasicBlock *Exit : CR->Exits) {
141+
for (BasicBlock *Target : gatherSuccessors(Exit)) {
142+
if (CR->Blocks.count(Target) == 0)
143+
ExitTargets.insert(Target);
144+
}
145+
}
146+
147+
// If we have zero or one exit target, nothing do to.
148+
if (ExitTargets.size() <= 1)
149+
return false;
150+
151+
// Create the new single exit target.
152+
auto F = CR->Entry->getParent();
153+
auto NewExitTarget = BasicBlock::Create(F->getContext(), "new.exit", F);
154+
IRBuilder<> Builder(NewExitTarget);
155+
156+
// CodeGen output needs to be stable. Using the set as-is would order
157+
// the targets differently depending on the allocation pattern.
158+
// Sorting per basic-block ordering in the function.
159+
std::vector<BasicBlock *> SortedExitTargets;
160+
std::vector<BasicBlock *> SortedExits;
161+
for (BasicBlock &BB : *F) {
162+
if (ExitTargets.count(&BB) != 0)
163+
SortedExitTargets.push_back(&BB);
164+
if (CR->Exits.count(&BB) != 0)
165+
SortedExits.push_back(&BB);
166+
}
167+
168+
// Creating one constant per distinct exit target. This will be route to the
169+
// correct target.
170+
std::unordered_map<BasicBlock *, ConstantInt *> TargetToValue;
171+
for (BasicBlock *Target : SortedExitTargets)
172+
TargetToValue.emplace(Target, Builder.getInt32(TargetToValue.size()));
173+
174+
// Creating one variable per exit node, set to the constant matching the
175+
// targeted external block.
176+
std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable;
177+
for (auto Exit : SortedExits) {
178+
llvm::Value *Value = createExitVariable(Exit, TargetToValue);
179+
ExitToVariable.emplace_back(std::make_pair(Exit, Value));
180+
}
181+
182+
// Gather the correct value depending on the exit we came from.
183+
llvm::PHINode *node =
184+
Builder.CreatePHI(Builder.getInt32Ty(), ExitToVariable.size());
185+
for (auto [BB, Value] : ExitToVariable) {
186+
node->addIncoming(Value, BB);
187+
}
188+
189+
// Creating the switch to jump to the correct exit target.
190+
std::vector<std::pair<BasicBlock *, ConstantInt *>> CasesList(
191+
TargetToValue.begin(), TargetToValue.end());
192+
llvm::SwitchInst *Sw =
193+
Builder.CreateSwitch(node, CasesList[0].first, CasesList.size() - 1);
194+
for (size_t i = 1; i < CasesList.size(); i++)
195+
Sw->addCase(CasesList[i].second, CasesList[i].first);
196+
197+
// Fix exit branches to redirect to the new exit.
198+
for (auto Exit : CR->Exits)
199+
replaceBranchTargets(Exit, ExitTargets, NewExitTarget);
200+
201+
return true;
202+
}
203+
204+
/// Run the pass on the given convergence region and sub-regions (DFS).
205+
/// Returns true if a region/sub-region was modified, false otherwise.
206+
/// This returns as soon as one region/sub-region has been modified.
207+
bool runOnConvergenceRegion(LoopInfo &LI,
208+
const SPIRV::ConvergenceRegion *CR) {
209+
for (auto *Child : CR->Children)
210+
if (runOnConvergenceRegion(LI, Child))
211+
return true;
212+
213+
return runOnConvergenceRegionNoRecurse(LI, CR);
214+
}
215+
216+
#if !NDEBUG
217+
/// Validates each edge exiting the region has the same destination basic
218+
/// block.
219+
void validateRegionExits(const SPIRV::ConvergenceRegion *CR) {
220+
for (auto *Child : CR->Children)
221+
validateRegionExits(Child);
222+
223+
std::unordered_set<BasicBlock *> ExitTargets;
224+
for (auto *Exit : CR->Exits) {
225+
auto Set = gatherSuccessors(Exit);
226+
for (auto *BB : Set) {
227+
if (CR->Blocks.count(BB) == 0)
228+
ExitTargets.insert(BB);
229+
}
230+
}
231+
232+
assert(ExitTargets.size() <= 1);
233+
}
234+
#endif
235+
236+
virtual bool runOnFunction(Function &F) override {
237+
LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
238+
const auto *TopLevelRegion =
239+
getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
240+
.getRegionInfo()
241+
.getTopLevelRegion();
242+
243+
// FIXME: very inefficient method: each time a region is modified, we bubble
244+
// back up, and recompute the whole convergence region tree. Once the
245+
// algorithm is completed and test coverage good enough, rewrite this pass
246+
// to be efficient instead of simple.
247+
bool modified = false;
248+
while (runOnConvergenceRegion(LI, TopLevelRegion)) {
249+
TopLevelRegion = getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
250+
.getRegionInfo()
251+
.getTopLevelRegion();
252+
modified = true;
253+
}
254+
255+
#if !defined(NDEBUG) || defined(EXPENSIVE_CHECKS)
256+
validateRegionExits(TopLevelRegion);
257+
#endif
258+
return modified;
259+
}
260+
261+
void getAnalysisUsage(AnalysisUsage &AU) const override {
262+
AU.addRequired<DominatorTreeWrapperPass>();
263+
AU.addRequired<LoopInfoWrapperPass>();
264+
AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
265+
FunctionPass::getAnalysisUsage(AU);
266+
}
267+
};
268+
} // namespace llvm
269+
270+
char SPIRVMergeRegionExitTargets::ID = 0;
271+
272+
INITIALIZE_PASS_BEGIN(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
273+
"SPIRV split region exit blocks", false, false)
274+
INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
275+
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
276+
INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
277+
INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass)
278+
279+
INITIALIZE_PASS_END(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
280+
"SPIRV split region exit blocks", false, false)
281+
282+
FunctionPass *llvm::createSPIRVMergeRegionExitTargetsPass() {
283+
return new SPIRVMergeRegionExitTargets();
284+
}

llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,11 @@ void SPIRVPassConfig::addIRPasses() {
164164
// - all loop exits are dominated by the loop pre-header.
165165
// - loops have a single back-edge.
166166
addPass(createLoopSimplifyPass());
167+
168+
// 2. Merge the convergence region exit nodes into one. After this step,
169+
// regions are single-entry, single-exit. This will help determine the
170+
// correct merge block.
171+
addPass(createSPIRVMergeRegionExitTargetsPass());
167172
}
168173

169174
TargetPassConfig::addIRPasses();

0 commit comments

Comments
 (0)