Skip to content

Commit c2749e5

Browse files
committed
[SandboxVec][DAG] Extend DAG
This patch implements growing the DAG towards the top or bottom. This does the necessary dependency checks and adds new mem dependencies.
1 parent 69c0067 commit c2749e5

File tree

3 files changed

+172
-16
lines changed

3 files changed

+172
-16
lines changed

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,10 @@ class DependencyGraph {
280280
/// \p DstN.
281281
void scanAndAddDeps(MemDGNode &DstN, const Interval<MemDGNode> &SrcScanRange);
282282

283+
/// Create DAG nodes for instrs in \p NewInterval and update the MemNode
284+
/// chain.
285+
void createNewNodes(const Interval<Instruction> &NewInterval);
286+
283287
public:
284288
DependencyGraph(AAResults &AA)
285289
: BatchAA(std::make_unique<BatchAAResults>(AA)) {}
@@ -305,8 +309,10 @@ class DependencyGraph {
305309
return It->second.get();
306310
}
307311
/// Build/extend the dependency graph such that it includes \p Instrs. Returns
308-
/// the interval spanning \p Instrs.
312+
/// the range of instructions added to the DAG.
309313
Interval<Instruction> extend(ArrayRef<Instruction *> Instrs);
314+
/// \Returns the range of instructions included in the DAG.
315+
Interval<Instruction> getInterval() const { return DAGInterval; }
310316
#ifndef NDEBUG
311317
void print(raw_ostream &OS) const;
312318
LLVM_DUMP_METHOD void dump() const;

llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp

Lines changed: 98 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -215,17 +215,11 @@ void DependencyGraph::scanAndAddDeps(MemDGNode &DstN,
215215
}
216216
}
217217

218-
Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
219-
if (Instrs.empty())
220-
return {};
221-
222-
Interval<Instruction> InstrInterval(Instrs);
223-
224-
DGNode *LastN = getOrCreateNode(InstrInterval.top());
225-
// Create DGNodes for all instrs in Interval to avoid future Instruction to
226-
// DGNode lookups.
218+
void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) {
219+
// Create Nodes only for the new sections of the DAG.
220+
DGNode *LastN = getOrCreateNode(NewInterval.top());
227221
MemDGNode *LastMemN = dyn_cast<MemDGNode>(LastN);
228-
for (Instruction &I : drop_begin(InstrInterval)) {
222+
for (Instruction &I : drop_begin(NewInterval)) {
229223
auto *N = getOrCreateNode(&I);
230224
// Build the Mem node chain.
231225
if (auto *MemN = dyn_cast<MemDGNode>(N)) {
@@ -235,16 +229,105 @@ Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
235229
LastMemN = MemN;
236230
}
237231
}
232+
// Link new MemDGNode chain with the old one, if any.
233+
if (!DAGInterval.empty()) {
234+
bool NewIsAbove = NewInterval.bottom()->comesBefore(DAGInterval.bottom());
235+
const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval;
236+
const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval;
237+
MemDGNode *LinkTopN =
238+
MemDGNodeIntervalBuilder::getBotMemDGNode(TopInterval, *this);
239+
MemDGNode *LinkBotN =
240+
MemDGNodeIntervalBuilder::getTopMemDGNode(BotInterval, *this);
241+
assert(LinkTopN->comesBefore(LinkBotN) && "Wrong order!");
242+
if (LinkTopN != nullptr && LinkBotN != nullptr) {
243+
LinkTopN->setNextNode(LinkBotN);
244+
LinkBotN->setPrevNode(LinkTopN);
245+
}
246+
#ifndef NDEBUG
247+
// TODO: Remove this once we've done enough testing.
248+
// Check that the chain is well formed.
249+
auto UnionIntvl = DAGInterval.getUnionInterval(NewInterval);
250+
MemDGNode *ChainTopN =
251+
MemDGNodeIntervalBuilder::getTopMemDGNode(UnionIntvl, *this);
252+
MemDGNode *ChainBotN =
253+
MemDGNodeIntervalBuilder::getBotMemDGNode(UnionIntvl, *this);
254+
if (ChainTopN != nullptr && ChainBotN != nullptr) {
255+
for (auto *N = ChainTopN->getNextNode(), *LastN = ChainTopN; N != nullptr;
256+
LastN = N, N = N->getNextNode()) {
257+
assert(N == LastN->getNextNode() && "Bad chain!");
258+
assert(N->getPrevNode() == LastN && "Bad chain!");
259+
}
260+
}
261+
#endif // NDEBUG
262+
}
263+
}
264+
265+
Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
266+
if (Instrs.empty())
267+
return {};
268+
269+
Interval<Instruction> InstrsInterval(Instrs);
270+
Interval<Instruction> Union = DAGInterval.getUnionInterval(InstrsInterval);
271+
auto NewInterval = Union.getSingleDiff(DAGInterval);
272+
if (NewInterval.empty())
273+
return {};
274+
275+
createNewNodes(NewInterval);
276+
238277
// Create the dependencies.
239-
auto DstRange = MemDGNodeIntervalBuilder::make(InstrInterval, *this);
240-
if (!DstRange.empty()) {
241-
for (MemDGNode &DstN : drop_begin(DstRange)) {
242-
auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode());
278+
//
279+
// 1. DAGInterval empty 2. New is below Old 3. New is above old
280+
// ------------------------ ------------------- -------------------
281+
// Scan: DstN: Scan:
282+
// +---+ -ScanTopN +---+DstTopN -ScanTopN
283+
// | | | |New| |
284+
// |Old| | +---+ -ScanBotN
285+
// | | | +---+
286+
// DstN: Scan: +---+DstN: | | |
287+
// +---+DstTopN -ScanTopN +---+DstTopN | |Old|
288+
// |New| | |New| | | |
289+
// +---+DstBotN -ScanBotN +---+DstBotN -ScanBotN +---+DstBotN
290+
291+
// 1. This is a new DAG.
292+
if (DAGInterval.empty()) {
293+
assert(NewInterval == InstrsInterval && "Expected empty DAGInterval!");
294+
auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this);
295+
if (!DstRange.empty()) {
296+
for (MemDGNode &DstN : drop_begin(DstRange)) {
297+
auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode());
298+
scanAndAddDeps(DstN, SrcRange);
299+
}
300+
}
301+
}
302+
// 2. The new section is below the old section.
303+
else if (DAGInterval.bottom()->comesBefore(NewInterval.top())) {
304+
auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this);
305+
auto SrcRangeFull = MemDGNodeIntervalBuilder::make(
306+
DAGInterval.getUnionInterval(NewInterval), *this);
307+
for (MemDGNode &DstN : DstRange) {
308+
auto SrcRange =
309+
Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode());
243310
scanAndAddDeps(DstN, SrcRange);
244311
}
245312
}
313+
// 3. The new section is above the old section.
314+
else if (NewInterval.bottom()->comesBefore(DAGInterval.top())) {
315+
auto DstRange = MemDGNodeIntervalBuilder::make(
316+
NewInterval.getUnionInterval(DAGInterval), *this);
317+
auto SrcRangeFull = MemDGNodeIntervalBuilder::make(NewInterval, *this);
318+
if (!DstRange.empty()) {
319+
for (MemDGNode &DstN : drop_begin(DstRange)) {
320+
auto SrcRange =
321+
Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode());
322+
scanAndAddDeps(DstN, SrcRange);
323+
}
324+
}
325+
} else {
326+
llvm_unreachable("We don't expect extending in both directions!");
327+
}
246328

247-
return InstrInterval;
329+
DAGInterval = Union;
330+
return NewInterval;
248331
}
249332

250333
#ifndef NDEBUG

llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,3 +681,70 @@ define void @foo() {
681681
EXPECT_FALSE(memDependency(StackSaveN, AllocaN));
682682
EXPECT_FALSE(memDependency(AllocaN, StackRestoreN));
683683
}
684+
685+
TEST_F(DependencyGraphTest, Extend) {
686+
parseIR(C, R"IR(
687+
define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
688+
store i8 %v1, ptr %ptr
689+
store i8 %v2, ptr %ptr
690+
store i8 %v3, ptr %ptr
691+
store i8 %v4, ptr %ptr
692+
store i8 %v5, ptr %ptr
693+
ret void
694+
}
695+
)IR");
696+
llvm::Function *LLVMF = &*M->getFunction("foo");
697+
sandboxir::Context Ctx(C);
698+
auto *F = Ctx.createFunction(LLVMF);
699+
auto *BB = &*F->begin();
700+
auto It = BB->begin();
701+
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
702+
auto *S2 = cast<sandboxir::StoreInst>(&*It++);
703+
auto *S3 = cast<sandboxir::StoreInst>(&*It++);
704+
auto *S4 = cast<sandboxir::StoreInst>(&*It++);
705+
auto *S5 = cast<sandboxir::StoreInst>(&*It++);
706+
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
707+
{
708+
// Scenario 1: Build new DAG
709+
auto NewIntvl = DAG.extend({S3, S3});
710+
EXPECT_EQ(NewIntvl, sandboxir::Interval<sandboxir::Instruction>(S3, S3));
711+
EXPECT_EQ(DAG.getInterval().top(), S3);
712+
EXPECT_EQ(DAG.getInterval().bottom(), S3);
713+
[[maybe_unused]] auto *S3N = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
714+
}
715+
{
716+
// Scenario 2: Extend below
717+
auto NewIntvl = DAG.extend({S5, S5});
718+
EXPECT_EQ(NewIntvl, sandboxir::Interval<sandboxir::Instruction>(S4, S5));
719+
auto *S3N = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
720+
auto *S4N = cast<sandboxir::MemDGNode>(DAG.getNode(S4));
721+
auto *S5N = cast<sandboxir::MemDGNode>(DAG.getNode(S5));
722+
EXPECT_TRUE(S4N->hasMemPred(S3N));
723+
EXPECT_TRUE(S5N->hasMemPred(S4N));
724+
EXPECT_TRUE(S5N->hasMemPred(S3N));
725+
}
726+
{
727+
// Scenario 3: Extend above
728+
auto NewIntvl = DAG.extend({S1, S2});
729+
EXPECT_EQ(NewIntvl, sandboxir::Interval<sandboxir::Instruction>(S1, S2));
730+
auto *S1N = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
731+
auto *S2N = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
732+
auto *S3N = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
733+
auto *S4N = cast<sandboxir::MemDGNode>(DAG.getNode(S4));
734+
auto *S5N = cast<sandboxir::MemDGNode>(DAG.getNode(S5));
735+
736+
EXPECT_TRUE(S2N->hasMemPred(S1N));
737+
738+
EXPECT_TRUE(S3N->hasMemPred(S2N));
739+
EXPECT_TRUE(S3N->hasMemPred(S1N));
740+
741+
EXPECT_TRUE(S4N->hasMemPred(S3N));
742+
EXPECT_TRUE(S4N->hasMemPred(S2N));
743+
EXPECT_TRUE(S4N->hasMemPred(S1N));
744+
745+
EXPECT_TRUE(S5N->hasMemPred(S4N));
746+
EXPECT_TRUE(S5N->hasMemPred(S3N));
747+
EXPECT_TRUE(S5N->hasMemPred(S2N));
748+
EXPECT_TRUE(S5N->hasMemPred(S1N));
749+
}
750+
}

0 commit comments

Comments
 (0)