Skip to content

Commit e8dd95e

Browse files
authored
[SandboxVec][DAG] Extend DAG (llvm#111908)
This patch implements growing the DAG towards the top or bottom. This does the necessary dependency checks and adds new mem dependencies.
1 parent 8b17916 commit e8dd95e

File tree

3 files changed

+176
-16
lines changed

3 files changed

+176
-16
lines changed

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,10 @@ class DependencyGraph {
284284
/// \p DstN.
285285
void scanAndAddDeps(MemDGNode &DstN, const Interval<MemDGNode> &SrcScanRange);
286286

287+
/// Create DAG nodes for instrs in \p NewInterval and update the MemNode
288+
/// chain.
289+
void createNewNodes(const Interval<Instruction> &NewInterval);
290+
287291
public:
288292
DependencyGraph(AAResults &AA)
289293
: BatchAA(std::make_unique<BatchAAResults>(AA)) {}
@@ -309,8 +313,10 @@ class DependencyGraph {
309313
return It->second.get();
310314
}
311315
/// Build/extend the dependency graph such that it includes \p Instrs. Returns
312-
/// the interval spanning \p Instrs.
316+
/// the range of instructions added to the DAG.
313317
Interval<Instruction> extend(ArrayRef<Instruction *> Instrs);
318+
/// \Returns the range of instructions included in the DAG.
319+
Interval<Instruction> getInterval() const { return DAGInterval; }
314320
#ifndef NDEBUG
315321
void print(raw_ostream &OS) const;
316322
LLVM_DUMP_METHOD void dump() const;

llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp

Lines changed: 102 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -215,17 +215,11 @@ void DependencyGraph::scanAndAddDeps(MemDGNode &DstN,
215215
}
216216
}
217217

218-
Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
219-
if (Instrs.empty())
220-
return {};
221-
222-
Interval<Instruction> InstrInterval(Instrs);
223-
224-
DGNode *LastN = getOrCreateNode(InstrInterval.top());
225-
// Create DGNodes for all instrs in Interval to avoid future Instruction to
226-
// DGNode lookups.
218+
void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) {
219+
// Create Nodes only for the new sections of the DAG.
220+
DGNode *LastN = getOrCreateNode(NewInterval.top());
227221
MemDGNode *LastMemN = dyn_cast<MemDGNode>(LastN);
228-
for (Instruction &I : drop_begin(InstrInterval)) {
222+
for (Instruction &I : drop_begin(NewInterval)) {
229223
auto *N = getOrCreateNode(&I);
230224
// Build the Mem node chain.
231225
if (auto *MemN = dyn_cast<MemDGNode>(N)) {
@@ -235,16 +229,109 @@ Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
235229
LastMemN = MemN;
236230
}
237231
}
232+
// Link new MemDGNode chain with the old one, if any.
233+
if (!DAGInterval.empty()) {
234+
// TODO: Implement Interval::comesBefore() to replace this check.
235+
bool NewIsAbove = NewInterval.bottom()->comesBefore(DAGInterval.top());
236+
assert(
237+
(NewIsAbove || DAGInterval.bottom()->comesBefore(NewInterval.top())) &&
238+
"Expected NewInterval below DAGInterval.");
239+
const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval;
240+
const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval;
241+
MemDGNode *LinkTopN =
242+
MemDGNodeIntervalBuilder::getBotMemDGNode(TopInterval, *this);
243+
MemDGNode *LinkBotN =
244+
MemDGNodeIntervalBuilder::getTopMemDGNode(BotInterval, *this);
245+
assert(LinkTopN->comesBefore(LinkBotN) && "Wrong order!");
246+
if (LinkTopN != nullptr && LinkBotN != nullptr) {
247+
LinkTopN->setNextNode(LinkBotN);
248+
LinkBotN->setPrevNode(LinkTopN);
249+
}
250+
#ifndef NDEBUG
251+
// TODO: Remove this once we've done enough testing.
252+
// Check that the chain is well formed.
253+
auto UnionIntvl = DAGInterval.getUnionInterval(NewInterval);
254+
MemDGNode *ChainTopN =
255+
MemDGNodeIntervalBuilder::getTopMemDGNode(UnionIntvl, *this);
256+
MemDGNode *ChainBotN =
257+
MemDGNodeIntervalBuilder::getBotMemDGNode(UnionIntvl, *this);
258+
if (ChainTopN != nullptr && ChainBotN != nullptr) {
259+
for (auto *N = ChainTopN->getNextNode(), *LastN = ChainTopN; N != nullptr;
260+
LastN = N, N = N->getNextNode()) {
261+
assert(N == LastN->getNextNode() && "Bad chain!");
262+
assert(N->getPrevNode() == LastN && "Bad chain!");
263+
}
264+
}
265+
#endif // NDEBUG
266+
}
267+
}
268+
269+
Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
270+
if (Instrs.empty())
271+
return {};
272+
273+
Interval<Instruction> InstrsInterval(Instrs);
274+
Interval<Instruction> Union = DAGInterval.getUnionInterval(InstrsInterval);
275+
auto NewInterval = Union.getSingleDiff(DAGInterval);
276+
if (NewInterval.empty())
277+
return {};
278+
279+
createNewNodes(NewInterval);
280+
238281
// Create the dependencies.
239-
auto DstRange = MemDGNodeIntervalBuilder::make(InstrInterval, *this);
240-
if (!DstRange.empty()) {
241-
for (MemDGNode &DstN : drop_begin(DstRange)) {
242-
auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode());
282+
//
283+
// 1. DAGInterval empty 2. New is below Old 3. New is above old
284+
// ------------------------ ------------------- -------------------
285+
// Scan: DstN: Scan:
286+
// +---+ -ScanTopN +---+DstTopN -ScanTopN
287+
// | | | |New| |
288+
// |Old| | +---+ -ScanBotN
289+
// | | | +---+
290+
// DstN: Scan: +---+DstN: | | |
291+
// +---+DstTopN -ScanTopN +---+DstTopN | |Old|
292+
// |New| | |New| | | |
293+
// +---+DstBotN -ScanBotN +---+DstBotN -ScanBotN +---+DstBotN
294+
295+
// 1. This is a new DAG.
296+
if (DAGInterval.empty()) {
297+
assert(NewInterval == InstrsInterval && "Expected empty DAGInterval!");
298+
auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this);
299+
if (!DstRange.empty()) {
300+
for (MemDGNode &DstN : drop_begin(DstRange)) {
301+
auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode());
302+
scanAndAddDeps(DstN, SrcRange);
303+
}
304+
}
305+
}
306+
// 2. The new section is below the old section.
307+
else if (DAGInterval.bottom()->comesBefore(NewInterval.top())) {
308+
auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this);
309+
auto SrcRangeFull = MemDGNodeIntervalBuilder::make(
310+
DAGInterval.getUnionInterval(NewInterval), *this);
311+
for (MemDGNode &DstN : DstRange) {
312+
auto SrcRange =
313+
Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode());
243314
scanAndAddDeps(DstN, SrcRange);
244315
}
245316
}
317+
// 3. The new section is above the old section.
318+
else if (NewInterval.bottom()->comesBefore(DAGInterval.top())) {
319+
auto DstRange = MemDGNodeIntervalBuilder::make(
320+
NewInterval.getUnionInterval(DAGInterval), *this);
321+
auto SrcRangeFull = MemDGNodeIntervalBuilder::make(NewInterval, *this);
322+
if (!DstRange.empty()) {
323+
for (MemDGNode &DstN : drop_begin(DstRange)) {
324+
auto SrcRange =
325+
Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode());
326+
scanAndAddDeps(DstN, SrcRange);
327+
}
328+
}
329+
} else {
330+
llvm_unreachable("We don't expect extending in both directions!");
331+
}
246332

247-
return InstrInterval;
333+
DAGInterval = Union;
334+
return NewInterval;
248335
}
249336

250337
#ifndef NDEBUG

llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,3 +681,70 @@ define void @foo() {
681681
EXPECT_FALSE(memDependency(StackSaveN, AllocaN));
682682
EXPECT_FALSE(memDependency(AllocaN, StackRestoreN));
683683
}
684+
685+
TEST_F(DependencyGraphTest, Extend) {
686+
parseIR(C, R"IR(
687+
define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
688+
store i8 %v1, ptr %ptr
689+
store i8 %v2, ptr %ptr
690+
store i8 %v3, ptr %ptr
691+
store i8 %v4, ptr %ptr
692+
store i8 %v5, ptr %ptr
693+
ret void
694+
}
695+
)IR");
696+
llvm::Function *LLVMF = &*M->getFunction("foo");
697+
sandboxir::Context Ctx(C);
698+
auto *F = Ctx.createFunction(LLVMF);
699+
auto *BB = &*F->begin();
700+
auto It = BB->begin();
701+
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
702+
auto *S2 = cast<sandboxir::StoreInst>(&*It++);
703+
auto *S3 = cast<sandboxir::StoreInst>(&*It++);
704+
auto *S4 = cast<sandboxir::StoreInst>(&*It++);
705+
auto *S5 = cast<sandboxir::StoreInst>(&*It++);
706+
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
707+
{
708+
// Scenario 1: Build new DAG
709+
auto NewIntvl = DAG.extend({S3, S3});
710+
EXPECT_EQ(NewIntvl, sandboxir::Interval<sandboxir::Instruction>(S3, S3));
711+
EXPECT_EQ(DAG.getInterval().top(), S3);
712+
EXPECT_EQ(DAG.getInterval().bottom(), S3);
713+
[[maybe_unused]] auto *S3N = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
714+
}
715+
{
716+
// Scenario 2: Extend below
717+
auto NewIntvl = DAG.extend({S5, S5});
718+
EXPECT_EQ(NewIntvl, sandboxir::Interval<sandboxir::Instruction>(S4, S5));
719+
auto *S3N = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
720+
auto *S4N = cast<sandboxir::MemDGNode>(DAG.getNode(S4));
721+
auto *S5N = cast<sandboxir::MemDGNode>(DAG.getNode(S5));
722+
EXPECT_TRUE(S4N->hasMemPred(S3N));
723+
EXPECT_TRUE(S5N->hasMemPred(S4N));
724+
EXPECT_TRUE(S5N->hasMemPred(S3N));
725+
}
726+
{
727+
// Scenario 3: Extend above
728+
auto NewIntvl = DAG.extend({S1, S2});
729+
EXPECT_EQ(NewIntvl, sandboxir::Interval<sandboxir::Instruction>(S1, S2));
730+
auto *S1N = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
731+
auto *S2N = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
732+
auto *S3N = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
733+
auto *S4N = cast<sandboxir::MemDGNode>(DAG.getNode(S4));
734+
auto *S5N = cast<sandboxir::MemDGNode>(DAG.getNode(S5));
735+
736+
EXPECT_TRUE(S2N->hasMemPred(S1N));
737+
738+
EXPECT_TRUE(S3N->hasMemPred(S2N));
739+
EXPECT_TRUE(S3N->hasMemPred(S1N));
740+
741+
EXPECT_TRUE(S4N->hasMemPred(S3N));
742+
EXPECT_TRUE(S4N->hasMemPred(S2N));
743+
EXPECT_TRUE(S4N->hasMemPred(S1N));
744+
745+
EXPECT_TRUE(S5N->hasMemPred(S4N));
746+
EXPECT_TRUE(S5N->hasMemPred(S3N));
747+
EXPECT_TRUE(S5N->hasMemPred(S2N));
748+
EXPECT_TRUE(S5N->hasMemPred(S1N));
749+
}
750+
}

0 commit comments

Comments
 (0)