Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SandboxVec][Scheduler] Boilerplate and initial implementation. #112449

Merged
merged 1 commit into from
Oct 18, 2024

Conversation

vporpo
Copy link
Contributor

@vporpo vporpo commented Oct 15, 2024

This patch implements a ready-list-based scheduler that operates on DependencyGraph.
It is used by the sandbox vectorizer to test the legality of vectorizing a group of instrs.

SchedBundle is a helper container, containing all DGNodes that correspond to the instructions that we are attempting to schedule with trySchedule(Instrs).

@llvmbot
Copy link
Collaborator

llvmbot commented Oct 15, 2024

@llvm/pr-subscribers-vectorizers

Author: vporpo (vporpo)

Changes

This patch implements a ready-list-based scheduler that operates on DependencyGraph.
It is used by the sandbox vectorizer to test the legality of vectorizing a group of instrs.

SchedBundle is a helper container, containing all DGNodes that correspond to the instructions that we are attempting to schedule with trySchedule(Instrs).


Full diff: https://github.com/llvm/llvm-project/pull/112449.diff

8 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h (+7)
  • (added) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h (+124)
  • (modified) llvm/lib/Transforms/Vectorize/CMakeLists.txt (+1)
  • (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp (+3-1)
  • (added) llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp (+154)
  • (modified) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt (+1)
  • (modified) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp (+12)
  • (added) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp (+167)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index ae3ceed447c40b..5be05bc80c4925 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -113,8 +113,15 @@ class DGNode {
   virtual ~DGNode() = default;
   /// \Returns the number of unscheduled successors.
   unsigned getNumUnscheduledSuccs() const { return UnscheduledSuccs; }
+  void decrUnscheduledSuccs() {
+    assert(UnscheduledSuccs > 0 && "Counting error!");
+    --UnscheduledSuccs;
+  }
+  /// \Returns true if all dependent successors have been scheduled.
+  bool ready() const { return UnscheduledSuccs == 0; }
   /// \Returns true if this node has been scheduled.
   bool scheduled() const { return Scheduled; }
+  void setScheduled(bool NewVal) { Scheduled = NewVal; }
   /// \Returns true if this is before \p Other in program order.
   bool comesBefore(const DGNode *Other) { return I->comesBefore(Other->I); }
   using iterator = PredIterator;
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
new file mode 100644
index 00000000000000..60ebcc02e7f169
--- /dev/null
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
@@ -0,0 +1,124 @@
+//===- Scheduler.h ----------------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the bottom-up list scheduler used by the vectorizer. It is used for
+// checking the legality of vectorization and for scheduling instructions in
+// such a way that makes vectorization possible, if legal.
+//
+// The legality check is performed by `trySchedule(Instrs)`, which will try to
+// schedule the IR until all instructions in `Instrs` can be scheduled together
+// back-to-back. If this fails then it is illegal to vectorize `Instrs`.
+//
+// Internally the scheduler uses the vectorizer-specific DependencyGraph class.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_SCHEDULER_H
+#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_SCHEDULER_H
+
+#include "llvm/SandboxIR/Instruction.h"
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h"
+#include <queue>
+
+namespace llvm::sandboxir {
+
+class PriorityCmp {
+public:
+  bool operator()(const DGNode *N1, const DGNode *N2) {
+    // TODO: This should be a hierarchical comparator.
+    return N1->getInstruction()->comesBefore(N2->getInstruction());
+  }
+};
+
+/// The list holding nodes that are ready to schedule. Used by the scheduler.
+class ReadyList {
+  PriorityCmp Cmp;
+  /// Control/Other dependencies are not modeled by the DAG to save memory.
+  /// These have to be modeled in the ready list for correctness.
+  /// This means that the list will hold back nodes that need to meet such
+  /// unmodeled dependencies.
+  std::priority_queue<DGNode *, std::vector<DGNode *>, PriorityCmp> List;
+
+public:
+  ReadyList() : List(Cmp) {}
+  void insert(DGNode *N) { List.push(N); }
+  DGNode *pop() {
+    auto *Back = List.top();
+    List.pop();
+    return Back;
+  }
+  bool empty() const { return List.empty(); }
+#ifndef NDEBUG
+  void dump(raw_ostream &OS) const;
+  LLVM_DUMP_METHOD void dump() const;
+#endif // NDEBUG
+};
+
+/// The nodes that need to be scheduled back-to-back in a single scheduling
+/// cycle form a SchedBundle.
+class SchedBundle {
+public:
+  using ContainerTy = SmallVector<DGNode *, 4>;
+
+private:
+  ContainerTy Nodes;
+
+public:
+  SchedBundle() = default;
+  SchedBundle(ContainerTy &&Nodes) : Nodes(std::move(Nodes)) {}
+  using iterator = ContainerTy::iterator;
+  using const_iterator = ContainerTy::const_iterator;
+  iterator begin() { return Nodes.begin(); }
+  iterator end() { return Nodes.end(); }
+  const_iterator begin() const { return Nodes.begin(); }
+  const_iterator end() const { return Nodes.end(); }
+  /// \Returns the bundle node that comes before the others in program order.
+  DGNode *getTop() const;
+  /// \Returns the bundle node that comes after the others in program order.
+  DGNode *getBot() const;
+  /// Move all bundle instructions to \p Where back-to-back.
+  void cluster(BasicBlock::iterator Where);
+#ifndef NDEBUG
+  void dump(raw_ostream &OS) const;
+  LLVM_DUMP_METHOD void dump() const;
+#endif
+};
+
+/// The list scheduler.
+class Scheduler {
+  ReadyList ReadyList;
+  DependencyGraph DAG;
+  std::optional<BasicBlock::iterator> ScheduleTopItOpt;
+  SmallVector<std::unique_ptr<SchedBundle>> Bndls;
+
+  /// \Returns a scheduling bundle containing \p Instrs.
+  SchedBundle *createBundle(ArrayRef<Instruction *> Instrs);
+  /// Schedule nodes until we can schedule \p Instrs back-to-back.
+  bool tryScheduleUntil(ArrayRef<Instruction *> Instrs);
+
+  void scheduleAndUpdateReadyList(SchedBundle &Bndl);
+
+  /// Disable copies.
+  Scheduler(const Scheduler &) = delete;
+  Scheduler &operator=(const Scheduler &) = delete;
+
+public:
+  Scheduler(AAResults &AA) : DAG(AA) {}
+  ~Scheduler() {}
+
+  bool trySchedule(ArrayRef<Instruction *> Instrs);
+
+#ifndef NDEBUG
+  void dump(raw_ostream &OS) const;
+  LLVM_DUMP_METHOD void dump() const;
+#endif
+};
+
+} // namespace llvm::sandboxir
+
+#endif // LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_SCHEDULER_H
diff --git a/llvm/lib/Transforms/Vectorize/CMakeLists.txt b/llvm/lib/Transforms/Vectorize/CMakeLists.txt
index f4e98e576379a4..fc4355af5af6b9 100644
--- a/llvm/lib/Transforms/Vectorize/CMakeLists.txt
+++ b/llvm/lib/Transforms/Vectorize/CMakeLists.txt
@@ -9,6 +9,7 @@ add_llvm_component_library(LLVMVectorize
   SandboxVectorizer/Passes/RegionsFromMetadata.cpp
   SandboxVectorizer/SandboxVectorizer.cpp
   SandboxVectorizer/SandboxVectorizerPassBuilder.cpp
+  SandboxVectorizer/Scheduler.cpp
   SandboxVectorizer/SeedCollector.cpp
   SLPVectorizer.cpp
   Vectorize.cpp
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index 9bbeca4fc15494..07435f0fb3151d 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -286,7 +286,9 @@ void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) {
         MemDGNodeIntervalBuilder::getBotMemDGNode(TopInterval, *this);
     MemDGNode *LinkBotN =
         MemDGNodeIntervalBuilder::getTopMemDGNode(BotInterval, *this);
-    assert(LinkTopN->comesBefore(LinkBotN) && "Wrong order!");
+    assert((LinkTopN == nullptr || LinkBotN == nullptr ||
+            LinkTopN->comesBefore(LinkBotN)) &&
+           "Wrong order!");
     if (LinkTopN != nullptr && LinkBotN != nullptr) {
       LinkTopN->setNextNode(LinkBotN);
       LinkBotN->setPrevNode(LinkTopN);
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp
new file mode 100644
index 00000000000000..37c197cc859980
--- /dev/null
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp
@@ -0,0 +1,154 @@
+//===- Scheduler.cpp ------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h"
+
+namespace llvm::sandboxir {
+
+DGNode *SchedBundle::getTop() const {
+  DGNode *TopN = Nodes.front();
+  for (auto *N : drop_begin(Nodes)) {
+    if (N->getInstruction()->comesBefore(TopN->getInstruction()))
+      TopN = N;
+  }
+  return TopN;
+}
+
+DGNode *SchedBundle::getBot() const {
+  DGNode *BotN = Nodes.front();
+  for (auto *N : drop_begin(Nodes)) {
+    if (BotN->getInstruction()->comesBefore(N->getInstruction()))
+      BotN = N;
+  }
+  return BotN;
+}
+
+void SchedBundle::cluster(BasicBlock::iterator Where) {
+  for (auto *N : Nodes) {
+    auto *I = N->getInstruction();
+    if (I->getIterator() == Where)
+      ++Where; // Try to maintain bundle order.
+    I->moveBefore(*Where.getNodeParent(), Where);
+  }
+}
+
+#ifndef NDEBUG
+void SchedBundle::dump(raw_ostream &OS) const {
+  for (auto *N : Nodes)
+    OS << *N;
+}
+
+void SchedBundle::dump() const {
+  dump(dbgs());
+  dbgs() << "\n";
+}
+#endif // NDEBUG
+
+#ifndef NDEBUG
+void ReadyList::dump(raw_ostream &OS) const {
+  auto ListCopy = List;
+  while (!ListCopy.empty()) {
+    OS << *ListCopy.top() << "\n";
+    ListCopy.pop();
+  }
+}
+
+void ReadyList::dump() const {
+  dump(dbgs());
+  dbgs() << "\n";
+}
+#endif // NDEBUG
+
+void Scheduler::scheduleAndUpdateReadyList(SchedBundle &Bndl) {
+  // Find where we should schedule the instructions.
+  assert(ScheduleTopItOpt && "Should have been set by now!");
+  auto Where = *ScheduleTopItOpt;
+  // Move all instructions in `Bndl` to `Where`.
+  Bndl.cluster(Where);
+  // Update the last scheduled bundle.
+  ScheduleTopItOpt = Bndl.getTop()->getInstruction()->getIterator();
+  // Set nodes as "scheduled" and decrement the UnsceduledSuccs counter of all
+  // dependency predecessors.
+  for (DGNode *N : Bndl) {
+    N->setScheduled(true);
+    for (auto *DepN : N->preds(DAG))
+      DepN->decrUnscheduledSuccs();
+  }
+}
+
+SchedBundle *Scheduler::createBundle(ArrayRef<Instruction *> Instrs) {
+  SchedBundle::ContainerTy Nodes;
+  Nodes.reserve(Instrs.size());
+  for (auto *I : Instrs)
+    Nodes.push_back(DAG.getNode(I));
+  auto BndlPtr = std::make_unique<SchedBundle>(std::move(Nodes));
+  auto *Bndl = BndlPtr.get();
+  Bndls.push_back(std::move(BndlPtr));
+  return Bndl;
+}
+
+bool Scheduler::tryScheduleUntil(ArrayRef<Instruction *> Instrs) {
+  // Use a set for fast lookups.
+  DenseSet<Instruction *> InstrsToDefer(Instrs.begin(), Instrs.end());
+  SmallVector<DGNode *, 8> DeferredNodes;
+
+  // Keep scheduling ready nodes.
+  while (!ReadyList.empty()) {
+    auto *ReadyN = ReadyList.pop();
+    // We defer scheduling of instructions in `Instrs` until we can schedule all
+    // of them at the same time in a single scheduling bundle.
+    if (InstrsToDefer.contains(ReadyN->getInstruction())) {
+      DeferredNodes.push_back(ReadyN);
+      bool ReadyToScheduleDeferred = DeferredNodes.size() == Instrs.size();
+      if (ReadyToScheduleDeferred) {
+        scheduleAndUpdateReadyList(*createBundle(Instrs));
+        return true;
+      }
+    } else {
+      scheduleAndUpdateReadyList(*createBundle({ReadyN->getInstruction()}));
+    }
+  }
+  assert(DeferredNodes.size() != Instrs.size() &&
+         "We should have succesfully scheduled and early-returned!");
+  return false;
+}
+
+bool Scheduler::trySchedule(ArrayRef<Instruction *> Instrs) {
+  assert(all_of(drop_begin(Instrs),
+                [Instrs](Instruction *I) {
+                  return I->getParent() == (*Instrs.begin())->getParent();
+                }) &&
+         "Instrs not in the same BB!");
+  // Extend the DAG to include Instrs.
+  Interval<Instruction> Extension = DAG.extend(Instrs);
+  // TODO: Set the window of the DAG that we are interested in.
+  // We start scheduling at the bottom instr of Instrs.
+  auto getBottomI = [](ArrayRef<Instruction *> Instrs) -> Instruction * {
+    return *min_element(Instrs,
+                        [](auto *I1, auto *I2) { return I1->comesBefore(I2); });
+  };
+  ScheduleTopItOpt = std::next(getBottomI(Instrs)->getIterator());
+  // Add nodes to ready list.
+  for (auto &I : Extension) {
+    auto *N = DAG.getNode(&I);
+    if (N->ready())
+      ReadyList.insert(N);
+  }
+  // Try schedule all nodes until we can schedule Instrs back-to-back.
+  return tryScheduleUntil(Instrs);
+}
+
+#ifndef NDEBUG
+void Scheduler::dump(raw_ostream &OS) const {
+  OS << "ReadyList:\n";
+  ReadyList.dump(OS);
+}
+void Scheduler::dump() const { dump(dbgs()); }
+#endif // NDEBUG
+
+} // namespace llvm::sandboxir
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt
index dcd7232db5f60c..24512cb0225e8e 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt
@@ -11,5 +11,6 @@ add_llvm_unittest(SandboxVectorizerTests
   DependencyGraphTest.cpp
   IntervalTest.cpp
   LegalityTest.cpp
+  SchedulerTest.cpp
   SeedCollectorTest.cpp	
 )
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index 3f84ad1f731de8..c00599ae1c4ef2 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -254,6 +254,18 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
   EXPECT_EQ(N0->getNumUnscheduledSuccs(), 1u); // N1
   EXPECT_EQ(N1->getNumUnscheduledSuccs(), 0u);
   EXPECT_EQ(N2->getNumUnscheduledSuccs(), 0u);
+
+  // Check decrUnscheduledSuccs.
+  N0->decrUnscheduledSuccs();
+  EXPECT_EQ(N0->getNumUnscheduledSuccs(), 0u);
+#ifndef NDEBUG
+  EXPECT_DEATH(N0->decrUnscheduledSuccs(), ".*Counting.*");
+#endif // NDEBUG
+
+  // Check scheduled(), setScheduled().
+  EXPECT_FALSE(N0->scheduled());
+  N0->setScheduled(true);
+  EXPECT_TRUE(N0->scheduled());
 }
 
 TEST_F(DependencyGraphTest, Preds) {
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp
new file mode 100644
index 00000000000000..14f48e77a6fb65
--- /dev/null
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp
@@ -0,0 +1,167 @@
+//===- SchedulerTest.cpp --------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/BasicAliasAnalysis.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/SandboxIR/Context.h"
+#include "llvm/SandboxIR/Function.h"
+#include "llvm/SandboxIR/Instruction.h"
+#include "llvm/Support/SourceMgr.h"
+#include "gmock/gmock-matchers.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+struct SchedulerTest : public testing::Test {
+  LLVMContext C;
+  std::unique_ptr<Module> M;
+  std::unique_ptr<AssumptionCache> AC;
+  std::unique_ptr<DominatorTree> DT;
+  std::unique_ptr<BasicAAResult> BAA;
+  std::unique_ptr<AAResults> AA;
+
+  void parseIR(LLVMContext &C, const char *IR) {
+    SMDiagnostic Err;
+    M = parseAssemblyString(IR, Err, C);
+    if (!M)
+      Err.print("DependencyGraphTest", errs());
+  }
+
+  AAResults &getAA(llvm::Function &LLVMF) {
+    TargetLibraryInfoImpl TLII;
+    TargetLibraryInfo TLI(TLII);
+    AA = std::make_unique<AAResults>(TLI);
+    AC = std::make_unique<AssumptionCache>(LLVMF);
+    DT = std::make_unique<DominatorTree>(LLVMF);
+    BAA = std::make_unique<BasicAAResult>(M->getDataLayout(), LLVMF, TLI, *AC,
+                                          DT.get());
+    AA->addAAResult(*BAA);
+    return *AA;
+  }
+  /// \Returns true if there is a dependency: SrcN->DstN.
+  bool memDependency(sandboxir::DGNode *SrcN, sandboxir::DGNode *DstN) {
+    if (auto *MemDstN = dyn_cast<sandboxir::MemDGNode>(DstN))
+      return MemDstN->hasMemPred(SrcN);
+    return false;
+  }
+};
+
+TEST_F(SchedulerTest, SchedBundle) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
+  store i8 %v0, ptr %ptr
+  %other = add i8 %v0, %v1
+  store i8 %v1, ptr %ptr
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *S0 = cast<sandboxir::StoreInst>(&*It++);
+  auto *Other = &*It++;
+  auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+  auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+  DAG.extend({&*BB->begin(), BB->getTerminator()});
+  auto *SN0 = DAG.getNode(S0);
+  auto *SN1 = DAG.getNode(S1);
+  sandboxir::SchedBundle Bndl({SN0, SN1});
+
+  // Check getTop().
+  EXPECT_EQ(Bndl.getTop(), SN0);
+  // Check getBot().
+  EXPECT_EQ(Bndl.getBot(), SN1);
+  // Check cluster().
+  Bndl.cluster(S1->getIterator());
+  {
+    auto It = BB->begin();
+    EXPECT_EQ(&*It++, Other);
+    EXPECT_EQ(&*It++, S0);
+    EXPECT_EQ(&*It++, S1);
+    EXPECT_EQ(&*It++, Ret);
+    S0->moveBefore(Other);
+  }
+
+  Bndl.cluster(S0->getIterator());
+  {
+    auto It = BB->begin();
+    EXPECT_EQ(&*It++, S0);
+    EXPECT_EQ(&*It++, S1);
+    EXPECT_EQ(&*It++, Other);
+    EXPECT_EQ(&*It++, Ret);
+    S1->moveAfter(Other);
+  }
+
+  Bndl.cluster(Other->getIterator());
+  {
+    auto It = BB->begin();
+    EXPECT_EQ(&*It++, S0);
+    EXPECT_EQ(&*It++, S1);
+    EXPECT_EQ(&*It++, Other);
+    EXPECT_EQ(&*It++, Ret);
+    S1->moveAfter(Other);
+  }
+
+  Bndl.cluster(Ret->getIterator());
+  {
+    auto It = BB->begin();
+    EXPECT_EQ(&*It++, Other);
+    EXPECT_EQ(&*It++, S0);
+    EXPECT_EQ(&*It++, S1);
+    EXPECT_EQ(&*It++, Ret);
+    Other->moveBefore(S1);
+  }
+
+  Bndl.cluster(BB->end());
+  {
+    auto It = BB->begin();
+    EXPECT_EQ(&*It++, Other);
+    EXPECT_EQ(&*It++, Ret);
+    EXPECT_EQ(&*It++, S0);
+    EXPECT_EQ(&*It++, S1);
+    Ret->moveAfter(S1);
+    Other->moveAfter(S0);
+  }
+  // Check iterators.
+  EXPECT_THAT(Bndl, testing::ElementsAre(SN0, SN1));
+  EXPECT_THAT((const sandboxir::SchedBundle &)Bndl,
+              testing::ElementsAre(SN0, SN1));
+}
+
+TEST_F(SchedulerTest, Basic) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
+  store i8 %v0, ptr %ptr
+  store i8 %v1, ptr %ptr
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *S0 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+  auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+
+  sandboxir::Scheduler Sched(getAA(*LLVMF));
+  EXPECT_TRUE(Sched.trySchedule({Ret}));
+  EXPECT_TRUE(Sched.trySchedule({S1}));
+  EXPECT_TRUE(Sched.trySchedule({S0}));
+}

@llvmbot
Copy link
Collaborator

llvmbot commented Oct 15, 2024

@llvm/pr-subscribers-llvm-transforms

Author: vporpo (vporpo)

Changes

This patch implements a ready-list-based scheduler that operates on DependencyGraph.
It is used by the sandbox vectorizer to test the legality of vectorizing a group of instrs.

SchedBundle is a helper container, containing all DGNodes that correspond to the instructions that we are attempting to schedule with trySchedule(Instrs).


Full diff: https://github.com/llvm/llvm-project/pull/112449.diff

8 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h (+7)
  • (added) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h (+124)
  • (modified) llvm/lib/Transforms/Vectorize/CMakeLists.txt (+1)
  • (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp (+3-1)
  • (added) llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp (+154)
  • (modified) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt (+1)
  • (modified) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp (+12)
  • (added) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp (+167)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index ae3ceed447c40b..5be05bc80c4925 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -113,8 +113,15 @@ class DGNode {
   virtual ~DGNode() = default;
   /// \Returns the number of unscheduled successors.
   unsigned getNumUnscheduledSuccs() const { return UnscheduledSuccs; }
+  void decrUnscheduledSuccs() {
+    assert(UnscheduledSuccs > 0 && "Counting error!");
+    --UnscheduledSuccs;
+  }
+  /// \Returns true if all dependent successors have been scheduled.
+  bool ready() const { return UnscheduledSuccs == 0; }
   /// \Returns true if this node has been scheduled.
   bool scheduled() const { return Scheduled; }
+  void setScheduled(bool NewVal) { Scheduled = NewVal; }
   /// \Returns true if this is before \p Other in program order.
   bool comesBefore(const DGNode *Other) { return I->comesBefore(Other->I); }
   using iterator = PredIterator;
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
new file mode 100644
index 00000000000000..60ebcc02e7f169
--- /dev/null
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
@@ -0,0 +1,124 @@
+//===- Scheduler.h ----------------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the bottom-up list scheduler used by the vectorizer. It is used for
+// checking the legality of vectorization and for scheduling instructions in
+// such a way that makes vectorization possible, if legal.
+//
+// The legality check is performed by `trySchedule(Instrs)`, which will try to
+// schedule the IR until all instructions in `Instrs` can be scheduled together
+// back-to-back. If this fails then it is illegal to vectorize `Instrs`.
+//
+// Internally the scheduler uses the vectorizer-specific DependencyGraph class.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_SCHEDULER_H
+#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_SCHEDULER_H
+
+#include "llvm/SandboxIR/Instruction.h"
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h"
+#include <queue>
+
+namespace llvm::sandboxir {
+
+class PriorityCmp {
+public:
+  bool operator()(const DGNode *N1, const DGNode *N2) {
+    // TODO: This should be a hierarchical comparator.
+    return N1->getInstruction()->comesBefore(N2->getInstruction());
+  }
+};
+
+/// The list holding nodes that are ready to schedule. Used by the scheduler.
+class ReadyList {
+  PriorityCmp Cmp;
+  /// Control/Other dependencies are not modeled by the DAG to save memory.
+  /// These have to be modeled in the ready list for correctness.
+  /// This means that the list will hold back nodes that need to meet such
+  /// unmodeled dependencies.
+  std::priority_queue<DGNode *, std::vector<DGNode *>, PriorityCmp> List;
+
+public:
+  ReadyList() : List(Cmp) {}
+  void insert(DGNode *N) { List.push(N); }
+  DGNode *pop() {
+    auto *Back = List.top();
+    List.pop();
+    return Back;
+  }
+  bool empty() const { return List.empty(); }
+#ifndef NDEBUG
+  void dump(raw_ostream &OS) const;
+  LLVM_DUMP_METHOD void dump() const;
+#endif // NDEBUG
+};
+
+/// The nodes that need to be scheduled back-to-back in a single scheduling
+/// cycle form a SchedBundle.
+class SchedBundle {
+public:
+  using ContainerTy = SmallVector<DGNode *, 4>;
+
+private:
+  ContainerTy Nodes;
+
+public:
+  SchedBundle() = default;
+  SchedBundle(ContainerTy &&Nodes) : Nodes(std::move(Nodes)) {}
+  using iterator = ContainerTy::iterator;
+  using const_iterator = ContainerTy::const_iterator;
+  iterator begin() { return Nodes.begin(); }
+  iterator end() { return Nodes.end(); }
+  const_iterator begin() const { return Nodes.begin(); }
+  const_iterator end() const { return Nodes.end(); }
+  /// \Returns the bundle node that comes before the others in program order.
+  DGNode *getTop() const;
+  /// \Returns the bundle node that comes after the others in program order.
+  DGNode *getBot() const;
+  /// Move all bundle instructions to \p Where back-to-back.
+  void cluster(BasicBlock::iterator Where);
+#ifndef NDEBUG
+  void dump(raw_ostream &OS) const;
+  LLVM_DUMP_METHOD void dump() const;
+#endif
+};
+
+/// The list scheduler.
+class Scheduler {
+  ReadyList ReadyList;
+  DependencyGraph DAG;
+  std::optional<BasicBlock::iterator> ScheduleTopItOpt;
+  SmallVector<std::unique_ptr<SchedBundle>> Bndls;
+
+  /// \Returns a scheduling bundle containing \p Instrs.
+  SchedBundle *createBundle(ArrayRef<Instruction *> Instrs);
+  /// Schedule nodes until we can schedule \p Instrs back-to-back.
+  bool tryScheduleUntil(ArrayRef<Instruction *> Instrs);
+
+  void scheduleAndUpdateReadyList(SchedBundle &Bndl);
+
+  /// Disable copies.
+  Scheduler(const Scheduler &) = delete;
+  Scheduler &operator=(const Scheduler &) = delete;
+
+public:
+  Scheduler(AAResults &AA) : DAG(AA) {}
+  ~Scheduler() {}
+
+  bool trySchedule(ArrayRef<Instruction *> Instrs);
+
+#ifndef NDEBUG
+  void dump(raw_ostream &OS) const;
+  LLVM_DUMP_METHOD void dump() const;
+#endif
+};
+
+} // namespace llvm::sandboxir
+
+#endif // LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_SCHEDULER_H
diff --git a/llvm/lib/Transforms/Vectorize/CMakeLists.txt b/llvm/lib/Transforms/Vectorize/CMakeLists.txt
index f4e98e576379a4..fc4355af5af6b9 100644
--- a/llvm/lib/Transforms/Vectorize/CMakeLists.txt
+++ b/llvm/lib/Transforms/Vectorize/CMakeLists.txt
@@ -9,6 +9,7 @@ add_llvm_component_library(LLVMVectorize
   SandboxVectorizer/Passes/RegionsFromMetadata.cpp
   SandboxVectorizer/SandboxVectorizer.cpp
   SandboxVectorizer/SandboxVectorizerPassBuilder.cpp
+  SandboxVectorizer/Scheduler.cpp
   SandboxVectorizer/SeedCollector.cpp
   SLPVectorizer.cpp
   Vectorize.cpp
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index 9bbeca4fc15494..07435f0fb3151d 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -286,7 +286,9 @@ void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) {
         MemDGNodeIntervalBuilder::getBotMemDGNode(TopInterval, *this);
     MemDGNode *LinkBotN =
         MemDGNodeIntervalBuilder::getTopMemDGNode(BotInterval, *this);
-    assert(LinkTopN->comesBefore(LinkBotN) && "Wrong order!");
+    assert((LinkTopN == nullptr || LinkBotN == nullptr ||
+            LinkTopN->comesBefore(LinkBotN)) &&
+           "Wrong order!");
     if (LinkTopN != nullptr && LinkBotN != nullptr) {
       LinkTopN->setNextNode(LinkBotN);
       LinkBotN->setPrevNode(LinkTopN);
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp
new file mode 100644
index 00000000000000..37c197cc859980
--- /dev/null
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp
@@ -0,0 +1,154 @@
+//===- Scheduler.cpp ------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h"
+
+namespace llvm::sandboxir {
+
+DGNode *SchedBundle::getTop() const {
+  DGNode *TopN = Nodes.front();
+  for (auto *N : drop_begin(Nodes)) {
+    if (N->getInstruction()->comesBefore(TopN->getInstruction()))
+      TopN = N;
+  }
+  return TopN;
+}
+
+DGNode *SchedBundle::getBot() const {
+  DGNode *BotN = Nodes.front();
+  for (auto *N : drop_begin(Nodes)) {
+    if (BotN->getInstruction()->comesBefore(N->getInstruction()))
+      BotN = N;
+  }
+  return BotN;
+}
+
+void SchedBundle::cluster(BasicBlock::iterator Where) {
+  for (auto *N : Nodes) {
+    auto *I = N->getInstruction();
+    if (I->getIterator() == Where)
+      ++Where; // Try to maintain bundle order.
+    I->moveBefore(*Where.getNodeParent(), Where);
+  }
+}
+
+#ifndef NDEBUG
+void SchedBundle::dump(raw_ostream &OS) const {
+  for (auto *N : Nodes)
+    OS << *N;
+}
+
+void SchedBundle::dump() const {
+  dump(dbgs());
+  dbgs() << "\n";
+}
+#endif // NDEBUG
+
+#ifndef NDEBUG
+void ReadyList::dump(raw_ostream &OS) const {
+  auto ListCopy = List;
+  while (!ListCopy.empty()) {
+    OS << *ListCopy.top() << "\n";
+    ListCopy.pop();
+  }
+}
+
+void ReadyList::dump() const {
+  dump(dbgs());
+  dbgs() << "\n";
+}
+#endif // NDEBUG
+
+void Scheduler::scheduleAndUpdateReadyList(SchedBundle &Bndl) {
+  // Find where we should schedule the instructions.
+  assert(ScheduleTopItOpt && "Should have been set by now!");
+  auto Where = *ScheduleTopItOpt;
+  // Move all instructions in `Bndl` to `Where`.
+  Bndl.cluster(Where);
+  // Update the last scheduled bundle.
+  ScheduleTopItOpt = Bndl.getTop()->getInstruction()->getIterator();
+  // Set nodes as "scheduled" and decrement the UnsceduledSuccs counter of all
+  // dependency predecessors.
+  for (DGNode *N : Bndl) {
+    N->setScheduled(true);
+    for (auto *DepN : N->preds(DAG))
+      DepN->decrUnscheduledSuccs();
+  }
+}
+
+SchedBundle *Scheduler::createBundle(ArrayRef<Instruction *> Instrs) {
+  SchedBundle::ContainerTy Nodes;
+  Nodes.reserve(Instrs.size());
+  for (auto *I : Instrs)
+    Nodes.push_back(DAG.getNode(I));
+  auto BndlPtr = std::make_unique<SchedBundle>(std::move(Nodes));
+  auto *Bndl = BndlPtr.get();
+  Bndls.push_back(std::move(BndlPtr));
+  return Bndl;
+}
+
+bool Scheduler::tryScheduleUntil(ArrayRef<Instruction *> Instrs) {
+  // Use a set for fast lookups.
+  DenseSet<Instruction *> InstrsToDefer(Instrs.begin(), Instrs.end());
+  SmallVector<DGNode *, 8> DeferredNodes;
+
+  // Keep scheduling ready nodes.
+  while (!ReadyList.empty()) {
+    auto *ReadyN = ReadyList.pop();
+    // We defer scheduling of instructions in `Instrs` until we can schedule all
+    // of them at the same time in a single scheduling bundle.
+    if (InstrsToDefer.contains(ReadyN->getInstruction())) {
+      DeferredNodes.push_back(ReadyN);
+      bool ReadyToScheduleDeferred = DeferredNodes.size() == Instrs.size();
+      if (ReadyToScheduleDeferred) {
+        scheduleAndUpdateReadyList(*createBundle(Instrs));
+        return true;
+      }
+    } else {
+      scheduleAndUpdateReadyList(*createBundle({ReadyN->getInstruction()}));
+    }
+  }
+  assert(DeferredNodes.size() != Instrs.size() &&
+         "We should have succesfully scheduled and early-returned!");
+  return false;
+}
+
+bool Scheduler::trySchedule(ArrayRef<Instruction *> Instrs) {
+  assert(all_of(drop_begin(Instrs),
+                [Instrs](Instruction *I) {
+                  return I->getParent() == (*Instrs.begin())->getParent();
+                }) &&
+         "Instrs not in the same BB!");
+  // Extend the DAG to include Instrs.
+  Interval<Instruction> Extension = DAG.extend(Instrs);
+  // TODO: Set the window of the DAG that we are interested in.
+  // We start scheduling at the bottom instr of Instrs.
+  auto getBottomI = [](ArrayRef<Instruction *> Instrs) -> Instruction * {
+    return *min_element(Instrs,
+                        [](auto *I1, auto *I2) { return I1->comesBefore(I2); });
+  };
+  ScheduleTopItOpt = std::next(getBottomI(Instrs)->getIterator());
+  // Add nodes to ready list.
+  for (auto &I : Extension) {
+    auto *N = DAG.getNode(&I);
+    if (N->ready())
+      ReadyList.insert(N);
+  }
+  // Try schedule all nodes until we can schedule Instrs back-to-back.
+  return tryScheduleUntil(Instrs);
+}
+
+#ifndef NDEBUG
+void Scheduler::dump(raw_ostream &OS) const {
+  OS << "ReadyList:\n";
+  ReadyList.dump(OS);
+}
+void Scheduler::dump() const { dump(dbgs()); }
+#endif // NDEBUG
+
+} // namespace llvm::sandboxir
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt
index dcd7232db5f60c..24512cb0225e8e 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt
@@ -11,5 +11,6 @@ add_llvm_unittest(SandboxVectorizerTests
   DependencyGraphTest.cpp
   IntervalTest.cpp
   LegalityTest.cpp
+  SchedulerTest.cpp
   SeedCollectorTest.cpp	
 )
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index 3f84ad1f731de8..c00599ae1c4ef2 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -254,6 +254,18 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
   EXPECT_EQ(N0->getNumUnscheduledSuccs(), 1u); // N1
   EXPECT_EQ(N1->getNumUnscheduledSuccs(), 0u);
   EXPECT_EQ(N2->getNumUnscheduledSuccs(), 0u);
+
+  // Check decrUnscheduledSuccs.
+  N0->decrUnscheduledSuccs();
+  EXPECT_EQ(N0->getNumUnscheduledSuccs(), 0u);
+#ifndef NDEBUG
+  EXPECT_DEATH(N0->decrUnscheduledSuccs(), ".*Counting.*");
+#endif // NDEBUG
+
+  // Check scheduled(), setScheduled().
+  EXPECT_FALSE(N0->scheduled());
+  N0->setScheduled(true);
+  EXPECT_TRUE(N0->scheduled());
 }
 
 TEST_F(DependencyGraphTest, Preds) {
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp
new file mode 100644
index 00000000000000..14f48e77a6fb65
--- /dev/null
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp
@@ -0,0 +1,167 @@
+//===- SchedulerTest.cpp --------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/BasicAliasAnalysis.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/SandboxIR/Context.h"
+#include "llvm/SandboxIR/Function.h"
+#include "llvm/SandboxIR/Instruction.h"
+#include "llvm/Support/SourceMgr.h"
+#include "gmock/gmock-matchers.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+struct SchedulerTest : public testing::Test {
+  LLVMContext C;
+  std::unique_ptr<Module> M;
+  std::unique_ptr<AssumptionCache> AC;
+  std::unique_ptr<DominatorTree> DT;
+  std::unique_ptr<BasicAAResult> BAA;
+  std::unique_ptr<AAResults> AA;
+
+  void parseIR(LLVMContext &C, const char *IR) {
+    SMDiagnostic Err;
+    M = parseAssemblyString(IR, Err, C);
+    if (!M)
+      Err.print("DependencyGraphTest", errs());
+  }
+
+  AAResults &getAA(llvm::Function &LLVMF) {
+    TargetLibraryInfoImpl TLII;
+    TargetLibraryInfo TLI(TLII);
+    AA = std::make_unique<AAResults>(TLI);
+    AC = std::make_unique<AssumptionCache>(LLVMF);
+    DT = std::make_unique<DominatorTree>(LLVMF);
+    BAA = std::make_unique<BasicAAResult>(M->getDataLayout(), LLVMF, TLI, *AC,
+                                          DT.get());
+    AA->addAAResult(*BAA);
+    return *AA;
+  }
+  /// \Returns true if there is a dependency: SrcN->DstN.
+  bool memDependency(sandboxir::DGNode *SrcN, sandboxir::DGNode *DstN) {
+    if (auto *MemDstN = dyn_cast<sandboxir::MemDGNode>(DstN))
+      return MemDstN->hasMemPred(SrcN);
+    return false;
+  }
+};
+
+TEST_F(SchedulerTest, SchedBundle) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
+  store i8 %v0, ptr %ptr
+  %other = add i8 %v0, %v1
+  store i8 %v1, ptr %ptr
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *S0 = cast<sandboxir::StoreInst>(&*It++);
+  auto *Other = &*It++;
+  auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+  auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+  DAG.extend({&*BB->begin(), BB->getTerminator()});
+  auto *SN0 = DAG.getNode(S0);
+  auto *SN1 = DAG.getNode(S1);
+  sandboxir::SchedBundle Bndl({SN0, SN1});
+
+  // Check getTop().
+  EXPECT_EQ(Bndl.getTop(), SN0);
+  // Check getBot().
+  EXPECT_EQ(Bndl.getBot(), SN1);
+  // Check cluster().
+  Bndl.cluster(S1->getIterator());
+  {
+    auto It = BB->begin();
+    EXPECT_EQ(&*It++, Other);
+    EXPECT_EQ(&*It++, S0);
+    EXPECT_EQ(&*It++, S1);
+    EXPECT_EQ(&*It++, Ret);
+    S0->moveBefore(Other);
+  }
+
+  Bndl.cluster(S0->getIterator());
+  {
+    auto It = BB->begin();
+    EXPECT_EQ(&*It++, S0);
+    EXPECT_EQ(&*It++, S1);
+    EXPECT_EQ(&*It++, Other);
+    EXPECT_EQ(&*It++, Ret);
+    S1->moveAfter(Other);
+  }
+
+  Bndl.cluster(Other->getIterator());
+  {
+    auto It = BB->begin();
+    EXPECT_EQ(&*It++, S0);
+    EXPECT_EQ(&*It++, S1);
+    EXPECT_EQ(&*It++, Other);
+    EXPECT_EQ(&*It++, Ret);
+    S1->moveAfter(Other);
+  }
+
+  Bndl.cluster(Ret->getIterator());
+  {
+    auto It = BB->begin();
+    EXPECT_EQ(&*It++, Other);
+    EXPECT_EQ(&*It++, S0);
+    EXPECT_EQ(&*It++, S1);
+    EXPECT_EQ(&*It++, Ret);
+    Other->moveBefore(S1);
+  }
+
+  Bndl.cluster(BB->end());
+  {
+    auto It = BB->begin();
+    EXPECT_EQ(&*It++, Other);
+    EXPECT_EQ(&*It++, Ret);
+    EXPECT_EQ(&*It++, S0);
+    EXPECT_EQ(&*It++, S1);
+    Ret->moveAfter(S1);
+    Other->moveAfter(S0);
+  }
+  // Check iterators.
+  EXPECT_THAT(Bndl, testing::ElementsAre(SN0, SN1));
+  EXPECT_THAT((const sandboxir::SchedBundle &)Bndl,
+              testing::ElementsAre(SN0, SN1));
+}
+
+TEST_F(SchedulerTest, Basic) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
+  store i8 %v0, ptr %ptr
+  store i8 %v1, ptr %ptr
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *S0 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+  auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+
+  sandboxir::Scheduler Sched(getAA(*LLVMF));
+  EXPECT_TRUE(Sched.trySchedule({Ret}));
+  EXPECT_TRUE(Sched.trySchedule({S1}));
+  EXPECT_TRUE(Sched.trySchedule({S0}));
+}

@vporpo
Copy link
Contributor Author

vporpo commented Oct 16, 2024

Removed unused function from test.

Copy link
Member

@tmsri tmsri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with some comments.

}
} else {
scheduleAndUpdateReadyList(*createBundle({ReadyN->getInstruction()}));
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this can be structured differently. First, my understanding:

  • You are trying to schedule Instrs and ReadyList already is populated with the DGNodes for those Instrs.
  • In the while loop, you check if the ReadyNode's instruction is part of Instrs. If not, you create a singleton bundle and schedule it.

I am asking why would a DGNode's->getInstruction() not be contained in InstrsToDefer? I believe that is because when we schedule a node, it's preds() could become ready and get added to the list, in function "scheduleAndUpdateReadyList". But, that would get added only to the end since Ready list is a queue. Is this understanding correct?

If that is the case, you can split the if into two parts. You will bundle all the DeferredNodes first and schedule it and then schedule the remaining singletons.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not entirely correct. The problem is that some or all instructions in Instrs won't be ready to get scheduled. So what we are trying to do is keep scheduling program instructions individually as they become ready, until either we get all instructions in Instrs ready at the same time, or we run out of instructions to schedule. As program instructions become ready, some of them may be in Instrs, and these are the ones that are "deferred", because we don't want to schedule them (which would make their predecessors ready etc.). Once the deferred ones match Instrs, then we proceed to schedule them.

This patch implements a ready-list-based scheduler that operates on
DependencyGraph.
It is used by the sandbox vectorizer to test the legality of vectorizing
a group of instrs.

SchedBundle is a helper container, containing all DGNodes that correspond
to the instructions that we are attempting to schedule with
trySchedule(Instrs).
@vporpo
Copy link
Contributor Author

vporpo commented Oct 18, 2024

Added more comments and rebased.

Copy link
Member

@tmsri tmsri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@vporpo vporpo merged commit 1d09925 into llvm:main Oct 18, 2024
6 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants