Skip to content

[OpenMPIRBuilder] Split calculation of canonical loop trip count, NFC #127820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 25, 2025

Conversation

skatrak
Copy link
Member

@skatrak skatrak commented Feb 19, 2025

This patch splits off the calculation of canonical loop trip counts from the creation of canonical loops. This makes it possible to reuse this logic to, for instance, populate the __tgt_target_kernel runtime call for SPMD kernels.

This feature is used to simplify one of the existing OpenMPIRBuilder tests.

@llvmbot
Copy link
Member

llvmbot commented Feb 19, 2025

@llvm/pr-subscribers-flang-openmp

Author: Sergio Afonso (skatrak)

Changes

This patch splits off the calculation of canonical loop trip counts from the creation of canonical loops. This makes it possible to reuse this logic to, for instance, populate the __tgt_target_kernel runtime call for SPMD kernels.

This feature is used to simplify one of the existing OpenMPIRBuilder tests.


Full diff: https://github.com/llvm/llvm-project/pull/127820.diff

3 Files Affected:

  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+31-7)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+18-9)
  • (modified) llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp (+3-13)
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 9ad85413acd34..207ca7fb05f62 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -728,13 +728,12 @@ class OpenMPIRBuilder {
                       LoopBodyGenCallbackTy BodyGenCB, Value *TripCount,
                       const Twine &Name = "loop");
 
-  /// Generator for the control flow structure of an OpenMP canonical loop.
+  /// Calculate the trip count of a canonical loop.
   ///
-  /// Instead of a logical iteration space, this allows specifying user-defined
-  /// loop counter values using increment, upper- and lower bounds. To
-  /// disambiguate the terminology when counting downwards, instead of lower
-  /// bounds we use \p Start for the loop counter value in the first body
-  /// iteration.
+  /// This allows specifying user-defined loop counter values using increment,
+  /// upper- and lower bounds. To disambiguate the terminology when counting
+  /// downwards, instead of lower bounds we use \p Start for the loop counter
+  /// value in the first body iteration.
   ///
   /// Consider the following limitations:
   ///
@@ -758,7 +757,32 @@ class OpenMPIRBuilder {
   ///
   ///      for (int i = 0; i < 42; i -= 1u)
   ///
-  //
+  /// \param Loc       The insert and source location description.
+  /// \param Start     Value of the loop counter for the first iterations.
+  /// \param Stop      Loop counter values past this will stop the loop.
+  /// \param Step      Loop counter increment after each iteration; negative
+  ///                  means counting down.
+  /// \param IsSigned  Whether Start, Stop and Step are signed integers.
+  /// \param InclusiveStop Whether \p Stop itself is a valid value for the loop
+  ///                      counter.
+  /// \param Name      Base name used to derive instruction names.
+  ///
+  /// \returns The value holding the calculated trip count.
+  Value *calculateCanonicalLoopTripCount(const LocationDescription &Loc,
+                                         Value *Start, Value *Stop, Value *Step,
+                                         bool IsSigned, bool InclusiveStop,
+                                         const Twine &Name = "loop");
+
+  /// Generator for the control flow structure of an OpenMP canonical loop.
+  ///
+  /// Instead of a logical iteration space, this allows specifying user-defined
+  /// loop counter values using increment, upper- and lower bounds. To
+  /// disambiguate the terminology when counting downwards, instead of lower
+  /// bounds we use \p Start for the loop counter value in the first body
+  ///
+  /// It calls \see calculateCanonicalLoopTripCount for trip count calculations,
+  /// so limitations of that method apply here as well.
+  ///
   /// \param Loc       The insert and source location description.
   /// \param BodyGenCB Callback that will generate the loop body code.
   /// \param Start     Value of the loop counter for the first iterations.
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 7788897fc0795..eee6e3e54d615 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -4059,10 +4059,9 @@ OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
   return CL;
 }
 
-Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
-    const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
-    Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
-    InsertPointTy ComputeIP, const Twine &Name) {
+Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount(
+    const LocationDescription &Loc, Value *Start, Value *Stop, Value *Step,
+    bool IsSigned, bool InclusiveStop, const Twine &Name) {
 
   // Consider the following difficulties (assuming 8-bit signed integers):
   //  * Adding \p Step to the loop counter which passes \p Stop may overflow:
@@ -4075,9 +4074,7 @@ Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
   assert(IndVarTy == Stop->getType() && "Stop type mismatch");
   assert(IndVarTy == Step->getType() && "Step type mismatch");
 
-  LocationDescription ComputeLoc =
-      ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
-  updateToLocation(ComputeLoc);
+  updateToLocation(Loc);
 
   ConstantInt *Zero = ConstantInt::get(IndVarTy, 0);
   ConstantInt *One = ConstantInt::get(IndVarTy, 1);
@@ -4117,8 +4114,20 @@ Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
     Value *OneCmp = Builder.CreateICmp(CmpInst::ICMP_ULE, Span, Incr);
     CountIfLooping = Builder.CreateSelect(OneCmp, One, CountIfTwo);
   }
-  Value *TripCount = Builder.CreateSelect(ZeroCmp, Zero, CountIfLooping,
-                                          "omp_" + Name + ".tripcount");
+
+  return Builder.CreateSelect(ZeroCmp, Zero, CountIfLooping,
+                              "omp_" + Name + ".tripcount");
+}
+
+Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
+    const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
+    Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
+    InsertPointTy ComputeIP, const Twine &Name) {
+  LocationDescription ComputeLoc =
+      ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
+
+  Value *TripCount = calculateCanonicalLoopTripCount(
+      ComputeLoc, Start, Stop, Step, IsSigned, InclusiveStop, Name);
 
   auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
     Builder.restoreIP(CodeGenIP);
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 83c8f7e932b2b..9f4946a32d9b1 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -1441,8 +1441,7 @@ TEST_F(OpenMPIRBuilderTest, CanonicalLoopSimple) {
   EXPECT_EQ(&Loop->getAfter()->front(), RetInst);
 }
 
-TEST_F(OpenMPIRBuilderTest, CanonicalLoopBounds) {
-  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+TEST_F(OpenMPIRBuilderTest, CanonicalLoopTripCount) {
   OpenMPIRBuilder OMPBuilder(*M);
   OMPBuilder.initialize();
   IRBuilder<> Builder(BB);
@@ -1458,17 +1457,8 @@ TEST_F(OpenMPIRBuilderTest, CanonicalLoopBounds) {
     Value *StartVal = ConstantInt::get(LCTy, Start);
     Value *StopVal = ConstantInt::get(LCTy, Stop);
     Value *StepVal = ConstantInt::get(LCTy, Step);
-    auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, llvm::Value *LC) {
-      return Error::success();
-    };
-    ASSERT_EXPECTED_INIT_RETURN(
-        CanonicalLoopInfo *, Loop,
-        OMPBuilder.createCanonicalLoop(Loc, LoopBodyGenCB, StartVal, StopVal,
-                                       StepVal, IsSigned, InclusiveStop),
-        -1);
-    Loop->assertOK();
-    Builder.restoreIP(Loop->getAfterIP());
-    Value *TripCount = Loop->getTripCount();
+    Value *TripCount = OMPBuilder.calculateCanonicalLoopTripCount(
+        Loc, StartVal, StopVal, StepVal, IsSigned, InclusiveStop);
     return cast<ConstantInt>(TripCount)->getValue().getZExtValue();
   };
 

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@skatrak skatrak force-pushed the users/skatrak/target-spmd-05-mlir-composite branch from 38ba269 to 33d5af4 Compare February 21, 2025 10:07
@skatrak skatrak force-pushed the users/skatrak/target-spmd-06-canonical-refactor branch from 5153e0d to 082d8e1 Compare February 21, 2025 10:08
@skatrak skatrak force-pushed the users/skatrak/target-spmd-05-mlir-composite branch from 33d5af4 to aad04fa Compare February 21, 2025 10:20
@skatrak skatrak force-pushed the users/skatrak/target-spmd-06-canonical-refactor branch 2 times, most recently from 033091e to f14d964 Compare February 21, 2025 10:51
Copy link
Member

@Meinersbur Meinersbur left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@skatrak skatrak force-pushed the users/skatrak/target-spmd-05-mlir-composite branch from 1172e9b to e61c797 Compare February 24, 2025 12:55
@skatrak skatrak force-pushed the users/skatrak/target-spmd-06-canonical-refactor branch from f14d964 to aa6182c Compare February 24, 2025 12:55
@skatrak skatrak force-pushed the users/skatrak/target-spmd-05-mlir-composite branch from e61c797 to ac8b967 Compare February 25, 2025 10:31
Base automatically changed from users/skatrak/target-spmd-05-mlir-composite to main February 25, 2025 10:31
This patch splits off the calculation of canonical loop trip counts from the
creation of canonical loops. This makes it possible to reuse this logic to, for
instance, populate the `__tgt_target_kernel` runtime call for SPMD kernels.

This feature is used to simplify one of the existing OpenMPIRBuilder tests.
@skatrak skatrak force-pushed the users/skatrak/target-spmd-06-canonical-refactor branch from aa6182c to 35078fd Compare February 25, 2025 10:32
@skatrak skatrak merged commit 56975b4 into main Feb 25, 2025
6 of 10 checks passed
@skatrak skatrak deleted the users/skatrak/target-spmd-06-canonical-refactor branch February 25, 2025 10:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:openmp OpenMP related changes to Clang flang:openmp
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants