Skip to content

Commit aa0208d

Browse files
authored
[mlir][scf] Implement getSingle... of LoopLikeOpinterface for scf::ParallelOp (#68511)
This adds implementations for `getSingleIterationVar`, `getSingleLowerBound`, `getSingleUpperBound`, `getSingleStep` of `LoopLikeOpInterface` to `scf::ParallelOp`. Until now, the implementations for these methods defaulted to returning `std::nullopt`, even in the special case where the parallel Op only has one dimension. Related: #67883
1 parent b8ad68f commit aa0208d

File tree

5 files changed

+124
-1
lines changed

5 files changed

+124
-1
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -791,7 +791,8 @@ def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
791791
def ParallelOp : SCF_Op<"parallel",
792792
[AutomaticAllocationScope,
793793
AttrSizedOperandSegments,
794-
DeclareOpInterfaceMethods<LoopLikeOpInterface>,
794+
DeclareOpInterfaceMethods<LoopLikeOpInterface, ["getSingleInductionVar",
795+
"getSingleLowerBound", "getSingleUpperBound", "getSingleStep"]>,
795796
RecursiveMemoryEffects,
796797
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
797798
SingleBlockImplicitTerminator<"scf::YieldOp">]> {

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2936,6 +2936,30 @@ void ParallelOp::print(OpAsmPrinter &p) {
29362936

29372937
SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
29382938

2939+
std::optional<Value> ParallelOp::getSingleInductionVar() {
2940+
if (getNumLoops() != 1)
2941+
return std::nullopt;
2942+
return getBody()->getArgument(0);
2943+
}
2944+
2945+
std::optional<OpFoldResult> ParallelOp::getSingleLowerBound() {
2946+
if (getNumLoops() != 1)
2947+
return std::nullopt;
2948+
return getLowerBound()[0];
2949+
}
2950+
2951+
std::optional<OpFoldResult> ParallelOp::getSingleUpperBound() {
2952+
if (getNumLoops() != 1)
2953+
return std::nullopt;
2954+
return getUpperBound()[0];
2955+
}
2956+
2957+
std::optional<OpFoldResult> ParallelOp::getSingleStep() {
2958+
if (getNumLoops() != 1)
2959+
return std::nullopt;
2960+
return getStep()[0];
2961+
}
2962+
29392963
ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) {
29402964
auto ivArg = llvm::dyn_cast<BlockArgument>(val);
29412965
if (!ivArg)

mlir/unittests/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ target_link_libraries(MLIRDialectTests
99
add_subdirectory(Index)
1010
add_subdirectory(LLVMIR)
1111
add_subdirectory(MemRef)
12+
add_subdirectory(SCF)
1213
add_subdirectory(SparseTensor)
1314
add_subdirectory(SPIRV)
1415
add_subdirectory(Transform)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
add_mlir_unittest(MLIRSCFTests
2+
LoopLikeSCFOpsTest.cpp
3+
)
4+
target_link_libraries(MLIRSCFTests
5+
PRIVATE
6+
MLIRIR
7+
MLIRSCFDialect
8+
)
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
//===- LoopLikeSCFOpsTest.cpp - SCF LoopLikeOpInterface Tests -------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Arith/IR/Arith.h"
10+
#include "mlir/Dialect/SCF/IR/SCF.h"
11+
#include "mlir/IR/Diagnostics.h"
12+
#include "mlir/IR/MLIRContext.h"
13+
#include "gtest/gtest.h"
14+
15+
using namespace mlir;
16+
using namespace mlir::scf;
17+
18+
//===----------------------------------------------------------------------===//
19+
// Test Fixture
20+
//===----------------------------------------------------------------------===//
21+
22+
class SCFLoopLikeTest : public ::testing::Test {
23+
protected:
24+
SCFLoopLikeTest() : b(&context), loc(UnknownLoc::get(&context)) {
25+
context.loadDialect<arith::ArithDialect, scf::SCFDialect>();
26+
}
27+
28+
void checkUnidimensional(LoopLikeOpInterface loopLikeOp) {
29+
std::optional<OpFoldResult> maybeLb = loopLikeOp.getSingleLowerBound();
30+
EXPECT_TRUE(maybeLb.has_value());
31+
std::optional<OpFoldResult> maybeUb = loopLikeOp.getSingleUpperBound();
32+
EXPECT_TRUE(maybeUb.has_value());
33+
std::optional<OpFoldResult> maybeStep = loopLikeOp.getSingleStep();
34+
EXPECT_TRUE(maybeStep.has_value());
35+
std::optional<OpFoldResult> maybeIndVar =
36+
loopLikeOp.getSingleInductionVar();
37+
EXPECT_TRUE(maybeIndVar.has_value());
38+
}
39+
40+
void checkMultidimensional(LoopLikeOpInterface loopLikeOp) {
41+
std::optional<OpFoldResult> maybeLb = loopLikeOp.getSingleLowerBound();
42+
EXPECT_FALSE(maybeLb.has_value());
43+
std::optional<OpFoldResult> maybeUb = loopLikeOp.getSingleUpperBound();
44+
EXPECT_FALSE(maybeUb.has_value());
45+
std::optional<OpFoldResult> maybeStep = loopLikeOp.getSingleStep();
46+
EXPECT_FALSE(maybeStep.has_value());
47+
std::optional<OpFoldResult> maybeIndVar =
48+
loopLikeOp.getSingleInductionVar();
49+
EXPECT_FALSE(maybeIndVar.has_value());
50+
}
51+
52+
MLIRContext context;
53+
OpBuilder b;
54+
Location loc;
55+
};
56+
57+
TEST_F(SCFLoopLikeTest, queryUnidimensionalLooplikes) {
58+
Value lb = b.create<arith::ConstantIndexOp>(loc, 0);
59+
Value ub = b.create<arith::ConstantIndexOp>(loc, 10);
60+
Value step = b.create<arith::ConstantIndexOp>(loc, 2);
61+
62+
auto forOp = b.create<scf::ForOp>(loc, lb, ub, step);
63+
checkUnidimensional(forOp);
64+
65+
auto forallOp = b.create<scf::ForallOp>(
66+
loc, ArrayRef<OpFoldResult>(lb), ArrayRef<OpFoldResult>(ub),
67+
ArrayRef<OpFoldResult>(step), ValueRange(), std::nullopt);
68+
checkUnidimensional(forallOp);
69+
70+
auto parallelOp = b.create<scf::ParallelOp>(
71+
loc, ValueRange(lb), ValueRange(ub), ValueRange(step), ValueRange());
72+
checkUnidimensional(parallelOp);
73+
}
74+
75+
TEST_F(SCFLoopLikeTest, queryMultidimensionalLooplikes) {
76+
Value lb = b.create<arith::ConstantIndexOp>(loc, 0);
77+
Value ub = b.create<arith::ConstantIndexOp>(loc, 10);
78+
Value step = b.create<arith::ConstantIndexOp>(loc, 2);
79+
80+
auto forallOp = b.create<scf::ForallOp>(
81+
loc, ArrayRef<OpFoldResult>({lb, lb}), ArrayRef<OpFoldResult>({ub, ub}),
82+
ArrayRef<OpFoldResult>({step, step}), ValueRange(), std::nullopt);
83+
checkMultidimensional(forallOp);
84+
85+
auto parallelOp =
86+
b.create<scf::ParallelOp>(loc, ValueRange({lb, lb}), ValueRange({ub, ub}),
87+
ValueRange({step, step}), ValueRange());
88+
checkMultidimensional(parallelOp);
89+
}

0 commit comments

Comments
 (0)