Skip to content

[SandboxVec][DAG] Extend DAG #111908

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 11, 2024
Merged

[SandboxVec][DAG] Extend DAG #111908

merged 1 commit into from
Oct 11, 2024

Conversation

vporpo
Copy link
Contributor

@vporpo vporpo commented Oct 10, 2024

This patch implements growing the DAG towards the top or bottom. This does the necessary dependency checks and adds new mem dependencies.

@llvmbot
Copy link
Member

llvmbot commented Oct 10, 2024

@llvm/pr-subscribers-vectorizers

Author: vporpo (vporpo)

Changes

This patch implements growing the DAG towards the top or bottom. This does the necessary dependency checks and adds new mem dependencies.


Full diff: https://github.com/llvm/llvm-project/pull/111908.diff

3 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h (+7-1)
  • (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp (+98-15)
  • (modified) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp (+67)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index 7d300ea2b60d2d..dabeec8279472c 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -280,6 +280,10 @@ class DependencyGraph {
   /// \p DstN.
   void scanAndAddDeps(MemDGNode &DstN, const Interval<MemDGNode> &SrcScanRange);
 
+  /// Create DAG nodes for instrs in \p NewInterval and update the MemNode
+  /// chain.
+  void createNewNodes(const Interval<Instruction> &NewInterval);
+
 public:
   DependencyGraph(AAResults &AA)
       : BatchAA(std::make_unique<BatchAAResults>(AA)) {}
@@ -305,8 +309,10 @@ class DependencyGraph {
     return It->second.get();
   }
   /// Build/extend the dependency graph such that it includes \p Instrs. Returns
-  /// the interval spanning \p Instrs.
+  /// the range of instructions added to the DAG.
   Interval<Instruction> extend(ArrayRef<Instruction *> Instrs);
+  /// \Returns the range of instructions included in the DAG.
+  Interval<Instruction> getInterval() const { return DAGInterval; }
 #ifndef NDEBUG
   void print(raw_ostream &OS) const;
   LLVM_DUMP_METHOD void dump() const;
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index 0cd2240e7ff1b3..bac6d19877bfb6 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -215,17 +215,11 @@ void DependencyGraph::scanAndAddDeps(MemDGNode &DstN,
   }
 }
 
-Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
-  if (Instrs.empty())
-    return {};
-
-  Interval<Instruction> InstrInterval(Instrs);
-
-  DGNode *LastN = getOrCreateNode(InstrInterval.top());
-  // Create DGNodes for all instrs in Interval to avoid future Instruction to
-  // DGNode lookups.
+void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) {
+  // Create Nodes only for the new sections of the DAG.
+  DGNode *LastN = getOrCreateNode(NewInterval.top());
   MemDGNode *LastMemN = dyn_cast<MemDGNode>(LastN);
-  for (Instruction &I : drop_begin(InstrInterval)) {
+  for (Instruction &I : drop_begin(NewInterval)) {
     auto *N = getOrCreateNode(&I);
     // Build the Mem node chain.
     if (auto *MemN = dyn_cast<MemDGNode>(N)) {
@@ -235,16 +229,105 @@ Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
       LastMemN = MemN;
     }
   }
+  // Link new MemDGNode chain with the old one, if any.
+  if (!DAGInterval.empty()) {
+    bool NewIsAbove = NewInterval.bottom()->comesBefore(DAGInterval.bottom());
+    const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval;
+    const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval;
+    MemDGNode *LinkTopN =
+        MemDGNodeIntervalBuilder::getBotMemDGNode(TopInterval, *this);
+    MemDGNode *LinkBotN =
+        MemDGNodeIntervalBuilder::getTopMemDGNode(BotInterval, *this);
+    assert(LinkTopN->comesBefore(LinkBotN) && "Wrong order!");
+    if (LinkTopN != nullptr && LinkBotN != nullptr) {
+      LinkTopN->setNextNode(LinkBotN);
+      LinkBotN->setPrevNode(LinkTopN);
+    }
+#ifndef NDEBUG
+    // TODO: Remove this once we've done enough testing.
+    // Check that the chain is well formed.
+    auto UnionIntvl = DAGInterval.getUnionInterval(NewInterval);
+    MemDGNode *ChainTopN =
+        MemDGNodeIntervalBuilder::getTopMemDGNode(UnionIntvl, *this);
+    MemDGNode *ChainBotN =
+        MemDGNodeIntervalBuilder::getBotMemDGNode(UnionIntvl, *this);
+    if (ChainTopN != nullptr && ChainBotN != nullptr) {
+      for (auto *N = ChainTopN->getNextNode(), *LastN = ChainTopN; N != nullptr;
+           LastN = N, N = N->getNextNode()) {
+        assert(N == LastN->getNextNode() && "Bad chain!");
+        assert(N->getPrevNode() == LastN && "Bad chain!");
+      }
+    }
+#endif // NDEBUG
+  }
+}
+
+Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
+  if (Instrs.empty())
+    return {};
+
+  Interval<Instruction> InstrsInterval(Instrs);
+  Interval<Instruction> Union = DAGInterval.getUnionInterval(InstrsInterval);
+  auto NewInterval = Union.getSingleDiff(DAGInterval);
+  if (NewInterval.empty())
+    return {};
+
+  createNewNodes(NewInterval);
+
   // Create the dependencies.
-  auto DstRange = MemDGNodeIntervalBuilder::make(InstrInterval, *this);
-  if (!DstRange.empty()) {
-    for (MemDGNode &DstN : drop_begin(DstRange)) {
-      auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode());
+  //
+  // 1. DAGInterval empty      2. New is below Old     3. New is above old
+  // ------------------------  -------------------      -------------------
+  //                                         Scan:           DstN:    Scan:
+  //                           +---+         -ScanTopN  +---+DstTopN  -ScanTopN
+  //                           |   |         |          |New|         |
+  //                           |Old|         |          +---+         -ScanBotN
+  //                           |   |         |          +---+
+  //      DstN:    Scan:       +---+DstN:    |          |   |
+  // +---+DstTopN  -ScanTopN   +---+DstTopN  |          |Old|
+  // |New|         |           |New|         |          |   |
+  // +---+DstBotN  -ScanBotN   +---+DstBotN  -ScanBotN  +---+DstBotN
+
+  // 1. This is a new DAG.
+  if (DAGInterval.empty()) {
+    assert(NewInterval == InstrsInterval && "Expected empty DAGInterval!");
+    auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this);
+    if (!DstRange.empty()) {
+      for (MemDGNode &DstN : drop_begin(DstRange)) {
+        auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode());
+        scanAndAddDeps(DstN, SrcRange);
+      }
+    }
+  }
+  // 2. The new section is below the old section.
+  else if (DAGInterval.bottom()->comesBefore(NewInterval.top())) {
+    auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this);
+    auto SrcRangeFull = MemDGNodeIntervalBuilder::make(
+        DAGInterval.getUnionInterval(NewInterval), *this);
+    for (MemDGNode &DstN : DstRange) {
+      auto SrcRange =
+          Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode());
       scanAndAddDeps(DstN, SrcRange);
     }
   }
+  // 3. The new section is above the old section.
+  else if (NewInterval.bottom()->comesBefore(DAGInterval.top())) {
+    auto DstRange = MemDGNodeIntervalBuilder::make(
+        NewInterval.getUnionInterval(DAGInterval), *this);
+    auto SrcRangeFull = MemDGNodeIntervalBuilder::make(NewInterval, *this);
+    if (!DstRange.empty()) {
+      for (MemDGNode &DstN : drop_begin(DstRange)) {
+        auto SrcRange =
+            Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode());
+        scanAndAddDeps(DstN, SrcRange);
+      }
+    }
+  } else {
+    llvm_unreachable("We don't expect extending in both directions!");
+  }
 
-  return InstrInterval;
+  DAGInterval = Union;
+  return NewInterval;
 }
 
 #ifndef NDEBUG
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index 7e2be25fa25ae6..3dbf03e4ba44e2 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -681,3 +681,70 @@ define void @foo() {
   EXPECT_FALSE(memDependency(StackSaveN, AllocaN));
   EXPECT_FALSE(memDependency(AllocaN, StackRestoreN));
 }
+
+TEST_F(DependencyGraphTest, Extend) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
+  store i8 %v1, ptr %ptr
+  store i8 %v2, ptr %ptr
+  store i8 %v3, ptr %ptr
+  store i8 %v4, ptr %ptr
+  store i8 %v5, ptr %ptr
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S2 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S3 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S4 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S5 = cast<sandboxir::StoreInst>(&*It++);
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+  {
+    // Scenario 1: Build new DAG
+    auto NewIntvl = DAG.extend({S3, S3});
+    EXPECT_EQ(NewIntvl, sandboxir::Interval<sandboxir::Instruction>(S3, S3));
+    EXPECT_EQ(DAG.getInterval().top(), S3);
+    EXPECT_EQ(DAG.getInterval().bottom(), S3);
+    [[maybe_unused]] auto *S3N = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
+  }
+  {
+    // Scenario 2: Extend below
+    auto NewIntvl = DAG.extend({S5, S5});
+    EXPECT_EQ(NewIntvl, sandboxir::Interval<sandboxir::Instruction>(S4, S5));
+    auto *S3N = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
+    auto *S4N = cast<sandboxir::MemDGNode>(DAG.getNode(S4));
+    auto *S5N = cast<sandboxir::MemDGNode>(DAG.getNode(S5));
+    EXPECT_TRUE(S4N->hasMemPred(S3N));
+    EXPECT_TRUE(S5N->hasMemPred(S4N));
+    EXPECT_TRUE(S5N->hasMemPred(S3N));
+  }
+  {
+    // Scenario 3: Extend above
+    auto NewIntvl = DAG.extend({S1, S2});
+    EXPECT_EQ(NewIntvl, sandboxir::Interval<sandboxir::Instruction>(S1, S2));
+    auto *S1N = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
+    auto *S2N = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
+    auto *S3N = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
+    auto *S4N = cast<sandboxir::MemDGNode>(DAG.getNode(S4));
+    auto *S5N = cast<sandboxir::MemDGNode>(DAG.getNode(S5));
+
+    EXPECT_TRUE(S2N->hasMemPred(S1N));
+
+    EXPECT_TRUE(S3N->hasMemPred(S2N));
+    EXPECT_TRUE(S3N->hasMemPred(S1N));
+
+    EXPECT_TRUE(S4N->hasMemPred(S3N));
+    EXPECT_TRUE(S4N->hasMemPred(S2N));
+    EXPECT_TRUE(S4N->hasMemPred(S1N));
+
+    EXPECT_TRUE(S5N->hasMemPred(S4N));
+    EXPECT_TRUE(S5N->hasMemPred(S3N));
+    EXPECT_TRUE(S5N->hasMemPred(S2N));
+    EXPECT_TRUE(S5N->hasMemPred(S1N));
+  }
+}

@llvmbot
Copy link
Member

llvmbot commented Oct 10, 2024

@llvm/pr-subscribers-llvm-transforms

Author: vporpo (vporpo)

Changes

This patch implements growing the DAG towards the top or bottom. This does the necessary dependency checks and adds new mem dependencies.


Full diff: https://github.com/llvm/llvm-project/pull/111908.diff

3 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h (+7-1)
  • (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp (+98-15)
  • (modified) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp (+67)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index 7d300ea2b60d2d..dabeec8279472c 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -280,6 +280,10 @@ class DependencyGraph {
   /// \p DstN.
   void scanAndAddDeps(MemDGNode &DstN, const Interval<MemDGNode> &SrcScanRange);
 
+  /// Create DAG nodes for instrs in \p NewInterval and update the MemNode
+  /// chain.
+  void createNewNodes(const Interval<Instruction> &NewInterval);
+
 public:
   DependencyGraph(AAResults &AA)
       : BatchAA(std::make_unique<BatchAAResults>(AA)) {}
@@ -305,8 +309,10 @@ class DependencyGraph {
     return It->second.get();
   }
   /// Build/extend the dependency graph such that it includes \p Instrs. Returns
-  /// the interval spanning \p Instrs.
+  /// the range of instructions added to the DAG.
   Interval<Instruction> extend(ArrayRef<Instruction *> Instrs);
+  /// \Returns the range of instructions included in the DAG.
+  Interval<Instruction> getInterval() const { return DAGInterval; }
 #ifndef NDEBUG
   void print(raw_ostream &OS) const;
   LLVM_DUMP_METHOD void dump() const;
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index 0cd2240e7ff1b3..bac6d19877bfb6 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -215,17 +215,11 @@ void DependencyGraph::scanAndAddDeps(MemDGNode &DstN,
   }
 }
 
-Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
-  if (Instrs.empty())
-    return {};
-
-  Interval<Instruction> InstrInterval(Instrs);
-
-  DGNode *LastN = getOrCreateNode(InstrInterval.top());
-  // Create DGNodes for all instrs in Interval to avoid future Instruction to
-  // DGNode lookups.
+void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) {
+  // Create Nodes only for the new sections of the DAG.
+  DGNode *LastN = getOrCreateNode(NewInterval.top());
   MemDGNode *LastMemN = dyn_cast<MemDGNode>(LastN);
-  for (Instruction &I : drop_begin(InstrInterval)) {
+  for (Instruction &I : drop_begin(NewInterval)) {
     auto *N = getOrCreateNode(&I);
     // Build the Mem node chain.
     if (auto *MemN = dyn_cast<MemDGNode>(N)) {
@@ -235,16 +229,105 @@ Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
       LastMemN = MemN;
     }
   }
+  // Link new MemDGNode chain with the old one, if any.
+  if (!DAGInterval.empty()) {
+    bool NewIsAbove = NewInterval.bottom()->comesBefore(DAGInterval.bottom());
+    const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval;
+    const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval;
+    MemDGNode *LinkTopN =
+        MemDGNodeIntervalBuilder::getBotMemDGNode(TopInterval, *this);
+    MemDGNode *LinkBotN =
+        MemDGNodeIntervalBuilder::getTopMemDGNode(BotInterval, *this);
+    assert(LinkTopN->comesBefore(LinkBotN) && "Wrong order!");
+    if (LinkTopN != nullptr && LinkBotN != nullptr) {
+      LinkTopN->setNextNode(LinkBotN);
+      LinkBotN->setPrevNode(LinkTopN);
+    }
+#ifndef NDEBUG
+    // TODO: Remove this once we've done enough testing.
+    // Check that the chain is well formed.
+    auto UnionIntvl = DAGInterval.getUnionInterval(NewInterval);
+    MemDGNode *ChainTopN =
+        MemDGNodeIntervalBuilder::getTopMemDGNode(UnionIntvl, *this);
+    MemDGNode *ChainBotN =
+        MemDGNodeIntervalBuilder::getBotMemDGNode(UnionIntvl, *this);
+    if (ChainTopN != nullptr && ChainBotN != nullptr) {
+      for (auto *N = ChainTopN->getNextNode(), *LastN = ChainTopN; N != nullptr;
+           LastN = N, N = N->getNextNode()) {
+        assert(N == LastN->getNextNode() && "Bad chain!");
+        assert(N->getPrevNode() == LastN && "Bad chain!");
+      }
+    }
+#endif // NDEBUG
+  }
+}
+
+Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
+  if (Instrs.empty())
+    return {};
+
+  Interval<Instruction> InstrsInterval(Instrs);
+  Interval<Instruction> Union = DAGInterval.getUnionInterval(InstrsInterval);
+  auto NewInterval = Union.getSingleDiff(DAGInterval);
+  if (NewInterval.empty())
+    return {};
+
+  createNewNodes(NewInterval);
+
   // Create the dependencies.
-  auto DstRange = MemDGNodeIntervalBuilder::make(InstrInterval, *this);
-  if (!DstRange.empty()) {
-    for (MemDGNode &DstN : drop_begin(DstRange)) {
-      auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode());
+  //
+  // 1. DAGInterval empty      2. New is below Old     3. New is above old
+  // ------------------------  -------------------      -------------------
+  //                                         Scan:           DstN:    Scan:
+  //                           +---+         -ScanTopN  +---+DstTopN  -ScanTopN
+  //                           |   |         |          |New|         |
+  //                           |Old|         |          +---+         -ScanBotN
+  //                           |   |         |          +---+
+  //      DstN:    Scan:       +---+DstN:    |          |   |
+  // +---+DstTopN  -ScanTopN   +---+DstTopN  |          |Old|
+  // |New|         |           |New|         |          |   |
+  // +---+DstBotN  -ScanBotN   +---+DstBotN  -ScanBotN  +---+DstBotN
+
+  // 1. This is a new DAG.
+  if (DAGInterval.empty()) {
+    assert(NewInterval == InstrsInterval && "Expected empty DAGInterval!");
+    auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this);
+    if (!DstRange.empty()) {
+      for (MemDGNode &DstN : drop_begin(DstRange)) {
+        auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode());
+        scanAndAddDeps(DstN, SrcRange);
+      }
+    }
+  }
+  // 2. The new section is below the old section.
+  else if (DAGInterval.bottom()->comesBefore(NewInterval.top())) {
+    auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this);
+    auto SrcRangeFull = MemDGNodeIntervalBuilder::make(
+        DAGInterval.getUnionInterval(NewInterval), *this);
+    for (MemDGNode &DstN : DstRange) {
+      auto SrcRange =
+          Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode());
       scanAndAddDeps(DstN, SrcRange);
     }
   }
+  // 3. The new section is above the old section.
+  else if (NewInterval.bottom()->comesBefore(DAGInterval.top())) {
+    auto DstRange = MemDGNodeIntervalBuilder::make(
+        NewInterval.getUnionInterval(DAGInterval), *this);
+    auto SrcRangeFull = MemDGNodeIntervalBuilder::make(NewInterval, *this);
+    if (!DstRange.empty()) {
+      for (MemDGNode &DstN : drop_begin(DstRange)) {
+        auto SrcRange =
+            Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode());
+        scanAndAddDeps(DstN, SrcRange);
+      }
+    }
+  } else {
+    llvm_unreachable("We don't expect extending in both directions!");
+  }
 
-  return InstrInterval;
+  DAGInterval = Union;
+  return NewInterval;
 }
 
 #ifndef NDEBUG
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index 7e2be25fa25ae6..3dbf03e4ba44e2 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -681,3 +681,70 @@ define void @foo() {
   EXPECT_FALSE(memDependency(StackSaveN, AllocaN));
   EXPECT_FALSE(memDependency(AllocaN, StackRestoreN));
 }
+
+TEST_F(DependencyGraphTest, Extend) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
+  store i8 %v1, ptr %ptr
+  store i8 %v2, ptr %ptr
+  store i8 %v3, ptr %ptr
+  store i8 %v4, ptr %ptr
+  store i8 %v5, ptr %ptr
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S2 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S3 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S4 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S5 = cast<sandboxir::StoreInst>(&*It++);
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+  {
+    // Scenario 1: Build new DAG
+    auto NewIntvl = DAG.extend({S3, S3});
+    EXPECT_EQ(NewIntvl, sandboxir::Interval<sandboxir::Instruction>(S3, S3));
+    EXPECT_EQ(DAG.getInterval().top(), S3);
+    EXPECT_EQ(DAG.getInterval().bottom(), S3);
+    [[maybe_unused]] auto *S3N = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
+  }
+  {
+    // Scenario 2: Extend below
+    auto NewIntvl = DAG.extend({S5, S5});
+    EXPECT_EQ(NewIntvl, sandboxir::Interval<sandboxir::Instruction>(S4, S5));
+    auto *S3N = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
+    auto *S4N = cast<sandboxir::MemDGNode>(DAG.getNode(S4));
+    auto *S5N = cast<sandboxir::MemDGNode>(DAG.getNode(S5));
+    EXPECT_TRUE(S4N->hasMemPred(S3N));
+    EXPECT_TRUE(S5N->hasMemPred(S4N));
+    EXPECT_TRUE(S5N->hasMemPred(S3N));
+  }
+  {
+    // Scenario 3: Extend above
+    auto NewIntvl = DAG.extend({S1, S2});
+    EXPECT_EQ(NewIntvl, sandboxir::Interval<sandboxir::Instruction>(S1, S2));
+    auto *S1N = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
+    auto *S2N = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
+    auto *S3N = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
+    auto *S4N = cast<sandboxir::MemDGNode>(DAG.getNode(S4));
+    auto *S5N = cast<sandboxir::MemDGNode>(DAG.getNode(S5));
+
+    EXPECT_TRUE(S2N->hasMemPred(S1N));
+
+    EXPECT_TRUE(S3N->hasMemPred(S2N));
+    EXPECT_TRUE(S3N->hasMemPred(S1N));
+
+    EXPECT_TRUE(S4N->hasMemPred(S3N));
+    EXPECT_TRUE(S4N->hasMemPred(S2N));
+    EXPECT_TRUE(S4N->hasMemPred(S1N));
+
+    EXPECT_TRUE(S5N->hasMemPred(S4N));
+    EXPECT_TRUE(S5N->hasMemPred(S3N));
+    EXPECT_TRUE(S5N->hasMemPred(S2N));
+    EXPECT_TRUE(S5N->hasMemPred(S1N));
+  }
+}

Copy link
Member

@tmsri tmsri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with minor comments.

@@ -235,16 +229,105 @@ Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
LastMemN = MemN;
}
}
// Link new MemDGNode chain with the old one, if any.
if (!DAGInterval.empty()) {
bool NewIsAbove = NewInterval.bottom()->comesBefore(DAGInterval.bottom());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it be?

NewInterval.bottom()->comesBefore(DAGInterval.top());

to explicitly check that there was no intervleaving? Also, if NewIsAbove is false, assert that

DAGInterval.bottom()->comesBefore(NewInterval.top());

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, yes it should be DAGInterval.top().
Yes I will add that assertion too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I should add a member function Interval::comesBefore() to avoid these kind of issues. I will implement this in a follow-up patch.

auto DstRange = MemDGNodeIntervalBuilder::make(
NewInterval.getUnionInterval(DAGInterval), *this);
auto SrcRangeFull = MemDGNodeIntervalBuilder::make(NewInterval, *this);
if (!DstRange.empty()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why will DstRange ever be empty here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be empty if the intervals contain non-memory instructions. The .empty() check is mainly used because drop_begin() doesn't work with empty ranges.

This patch implements growing the DAG towards the top or bottom.
This does the necessary dependency checks and adds new mem dependencies.
@vporpo vporpo merged commit e8dd95e into llvm:main Oct 11, 2024
8 checks passed
DanielCChen pushed a commit to DanielCChen/llvm-project that referenced this pull request Oct 16, 2024
This patch implements growing the DAG towards the top or bottom. This
does the necessary dependency checks and adds new mem dependencies.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants