Skip to content

Commit eeb55d3

Browse files
authored
[SandboxVec][DAG] Update MemDGNode chain upon instr creation (#116896)
The DAG maintains a chain of MemDGNodes that links together all the nodes that may touch memroy. Whenever a new instruction gets created we need to make sure that this chain gets updated. If the new instruction touches memory then its corresponding MemDGNode should be inserted into the chain.
1 parent 8b844de commit eeb55d3

File tree

3 files changed

+90
-14
lines changed

3 files changed

+90
-14
lines changed

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -329,13 +329,19 @@ class DependencyGraph {
329329
/// chain.
330330
void createNewNodes(const Interval<Instruction> &NewInterval);
331331

332+
/// Helper for `notify*Instr()`. \Returns the first MemDGNode that comes
333+
/// before \p N, including or excluding \p N based on \p IncludingN, or
334+
/// nullptr if not found.
335+
MemDGNode *getMemDGNodeBefore(DGNode *N, bool IncludingN) const;
336+
/// Helper for `notifyMoveInstr()`. \Returns the first MemDGNode that comes
337+
/// after \p N, including or excluding \p N based on \p IncludingN, or nullptr
338+
/// if not found.
339+
MemDGNode *getMemDGNodeAfter(DGNode *N, bool IncludingN) const;
340+
332341
/// Called by the callbacks when a new instruction \p I has been created.
333-
void notifyCreateInstr(Instruction *I) {
334-
getOrCreateNode(I);
335-
// TODO: Update the dependencies for the new node.
336-
// TODO: Update the MemDGNode chain to include the new node if needed.
337-
}
338-
/// Called by the callbacks when instruction \p I is about to get deleted.
342+
void notifyCreateInstr(Instruction *I);
343+
/// Called by the callbacks when instruction \p I is about to get
344+
/// deleted.
339345
void notifyEraseInstr(Instruction *I) {
340346
InstrToNodeMap.erase(I);
341347
// TODO: Update the dependencies.

llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,51 @@ void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) {
325325
setDefUseUnscheduledSuccs(NewInterval);
326326
}
327327

328+
MemDGNode *DependencyGraph::getMemDGNodeBefore(DGNode *N,
329+
bool IncludingN) const {
330+
auto *I = N->getInstruction();
331+
for (auto *PrevI = IncludingN ? I : I->getPrevNode(); PrevI != nullptr;
332+
PrevI = PrevI->getPrevNode()) {
333+
auto *PrevN = getNodeOrNull(PrevI);
334+
if (PrevN == nullptr)
335+
return nullptr;
336+
if (auto *PrevMemN = dyn_cast<MemDGNode>(PrevN))
337+
return PrevMemN;
338+
}
339+
return nullptr;
340+
}
341+
342+
MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N,
343+
bool IncludingN) const {
344+
auto *I = N->getInstruction();
345+
for (auto *NextI = IncludingN ? I : I->getNextNode(); NextI != nullptr;
346+
NextI = NextI->getNextNode()) {
347+
auto *NextN = getNodeOrNull(NextI);
348+
if (NextN == nullptr)
349+
return nullptr;
350+
if (auto *NextMemN = dyn_cast<MemDGNode>(NextN))
351+
return NextMemN;
352+
}
353+
return nullptr;
354+
}
355+
356+
void DependencyGraph::notifyCreateInstr(Instruction *I) {
357+
auto *MemN = dyn_cast<MemDGNode>(getOrCreateNode(I));
358+
// TODO: Update the dependencies for the new node.
359+
360+
// Update the MemDGNode chain if this is a memory node.
361+
if (MemN != nullptr) {
362+
if (auto *PrevMemN = getMemDGNodeBefore(MemN, /*IncludingN=*/false)) {
363+
PrevMemN->NextMemN = MemN;
364+
MemN->PrevMemN = PrevMemN;
365+
}
366+
if (auto *NextMemN = getMemDGNodeAfter(MemN, /*IncludingN=*/false)) {
367+
NextMemN->PrevMemN = MemN;
368+
MemN->NextMemN = NextMemN;
369+
}
370+
}
371+
}
372+
328373
Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
329374
if (Instrs.empty())
330375
return {};

llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -814,21 +814,46 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
814814
auto *BB = &*F->begin();
815815
auto It = BB->begin();
816816
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
817-
[[maybe_unused]] auto *S2 = cast<sandboxir::StoreInst>(&*It++);
817+
auto *S2 = cast<sandboxir::StoreInst>(&*It++);
818818
auto *S3 = cast<sandboxir::StoreInst>(&*It++);
819+
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
819820

820821
// Check new instruction callback.
821822
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
822-
DAG.extend({S1, S3});
823+
DAG.extend({S1, Ret});
823824
auto *Arg = F->getArg(3);
824825
auto *Ptr = S1->getPointerOperand();
825-
sandboxir::StoreInst *NewS =
826-
sandboxir::StoreInst::create(Arg, Ptr, Align(8), S3->getIterator(),
827-
/*IsVolatile=*/true, Ctx);
828-
auto *NewSN = DAG.getNode(NewS);
829-
EXPECT_TRUE(NewSN != nullptr);
826+
{
827+
sandboxir::StoreInst *NewS =
828+
sandboxir::StoreInst::create(Arg, Ptr, Align(8), S3->getIterator(),
829+
/*IsVolatile=*/true, Ctx);
830+
auto *NewSN = DAG.getNode(NewS);
831+
EXPECT_TRUE(NewSN != nullptr);
832+
833+
// Check the MemDGNode chain.
834+
auto *S2MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
835+
auto *NewMemSN = cast<sandboxir::MemDGNode>(NewSN);
836+
auto *S3MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
837+
EXPECT_EQ(S2MemN->getNextNode(), NewMemSN);
838+
EXPECT_EQ(NewMemSN->getPrevNode(), S2MemN);
839+
EXPECT_EQ(NewMemSN->getNextNode(), S3MemN);
840+
EXPECT_EQ(S3MemN->getPrevNode(), NewMemSN);
841+
}
842+
843+
{
844+
// Also check if new node is at the end of the BB, after Ret.
845+
sandboxir::StoreInst *NewS =
846+
sandboxir::StoreInst::create(Arg, Ptr, Align(8), BB->end(),
847+
/*IsVolatile=*/true, Ctx);
848+
// Check the MemDGNode chain.
849+
auto *S3MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
850+
auto *NewMemSN = cast<sandboxir::MemDGNode>(DAG.getNode(NewS));
851+
EXPECT_EQ(S3MemN->getNextNode(), NewMemSN);
852+
EXPECT_EQ(NewMemSN->getPrevNode(), S3MemN);
853+
EXPECT_EQ(NewMemSN->getNextNode(), nullptr);
854+
}
855+
830856
// TODO: Check the dependencies to/from NewSN after they land.
831-
// TODO: Check the MemDGNode chain.
832857
}
833858

834859
TEST_F(DependencyGraphTest, EraseInstrCallback) {

0 commit comments

Comments
 (0)