Skip to content

Commit 23b3a7f

Browse files
support unroll by the gpu.launchOp.
1 parent ecb7f5a commit 23b3a7f

File tree

7 files changed

+285
-25
lines changed

7 files changed

+285
-25
lines changed

mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ void getTripCountMapAndOperands(AffineForOp forOp, AffineMap *map,
4343
/// constant trip count in non-trivial cases.
4444
std::optional<uint64_t> getConstantTripCount(AffineForOp forOp);
4545

46+
/// In the GPU, the number of trip of each thread in the loop is inconsistent.
47+
/// This function returns the maximum number of trip.
48+
std::optional<uint64_t> getMaxConstantTripCount(AffineForOp forOp);
49+
4650
/// Returns the greatest known integral divisor of the trip count. Affine
4751
/// expression analysis is used (indirectly through getTripCount), and
4852
/// this method is thus able to determine non-trivial divisors.

mlir/include/mlir/Dialect/Affine/LoopUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ LogicalResult loopUnrollJamUpToFactor(AffineForOp forOp,
8686
/// was known to have a single iteration.
8787
LogicalResult promoteIfSingleIteration(AffineForOp forOp);
8888

89+
/// Eliminate loops that will never actually execute.
90+
LogicalResult removeInvalidLoop(AffineForOp forOp);
91+
8992
/// Promotes all single iteration AffineForOp's in the Function, i.e., moves
9093
/// their body into the containing Block.
9194
void promoteSingleIterationLoops(func::FuncOp f);

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,6 +1035,12 @@ def GPU_LaunchOp : GPU_Op<"launch", [
10351035
static StringRef getNumWorkgroupAttributionsAttrName() {
10361036
return "workgroup_attributions";
10371037
}
1038+
1039+
/// Find BlockSize via the BlockArgument of gpu.launch.
1040+
Value getBlockSizeOnAxis(Value threadId);
1041+
1042+
/// Find BlockSize via the Dimension Information.
1043+
Value getBlockSizeOnAxis(Dimension dimension);
10381044
}];
10391045

10401046
let hasCanonicalizer = 1;

mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp

Lines changed: 96 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
1919
#include "mlir/Dialect/Affine/IR/AffineOps.h"
2020
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
21+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
2122
#include "llvm/Support/MathExtras.h"
2223

2324
#include "llvm/ADT/DenseSet.h"
@@ -84,6 +85,67 @@ void mlir::affine::getTripCountMapAndOperands(
8485
tripCountValueMap.getOperands().end());
8586
}
8687

88+
/// Replace thread_id with its maximum value, if `replaceWithZero` is true,
89+
/// thread_id will be replaced by its minimum value 0.
90+
static void replaceGPUOperands(AffineForOp forOp,
91+
SmallVectorImpl<Value> &operands,
92+
SmallVectorImpl<AffineExpr> &symReplacements,
93+
unsigned numDim, bool replaceWithZero = false) {
94+
auto launchOp = forOp->getParentOfType<gpu::LaunchOp>();
95+
if (!launchOp)
96+
return;
97+
98+
// `b` is only used to create `AffineExpr`.
99+
Builder b(forOp.getContext());
100+
unsigned idx = 0;
101+
102+
for (unsigned i = numDim, e = operands.size(); i < e; ++i) {
103+
Value operand = operands[i];
104+
if (Value blockSize = launchOp.getBlockSizeOnAxis(operand)) {
105+
operands[i] = blockSize;
106+
if (!replaceWithZero)
107+
symReplacements.push_back(b.getAffineSymbolExpr(idx++) - 1);
108+
else
109+
symReplacements.push_back(b.getAffineConstantExpr(0));
110+
continue;
111+
}
112+
113+
Operation *defOp = operand.getDefiningOp();
114+
if (!defOp) {
115+
++idx;
116+
continue;
117+
}
118+
119+
if (auto threadIdOp = mlir::dyn_cast<gpu::ThreadIdOp>(defOp)) {
120+
gpu::Dimension dimension = threadIdOp.getDimension();
121+
operands[i] = launchOp.getBlockSizeOnAxis(dimension);
122+
if (!replaceWithZero)
123+
symReplacements.push_back(b.getAffineSymbolExpr(idx++) - 1);
124+
else
125+
symReplacements.push_back(b.getAffineConstantExpr(0));
126+
continue;
127+
}
128+
++idx;
129+
}
130+
}
131+
132+
/// Take the min if all trip counts are constant.
133+
static std::optional<uint64_t>
134+
getConstantTripCountFromAffineMap(AffineMap map) {
135+
std::optional<uint64_t> tripCount;
136+
for (auto resultExpr : map.getResults()) {
137+
auto constExpr = dyn_cast<AffineConstantExpr>(resultExpr);
138+
if (!constExpr)
139+
return std::nullopt;
140+
if (tripCount.has_value())
141+
tripCount =
142+
std::min(*tripCount, static_cast<uint64_t>(constExpr.getValue()));
143+
else
144+
tripCount = constExpr.getValue();
145+
}
146+
return tripCount;
147+
}
148+
87149
/// Returns the trip count of the loop if it's a constant, std::nullopt
88150
/// otherwise. This method uses affine expression analysis (in turn using
89151
/// getTripCount) and is able to determine constant trip count in non-trivial
@@ -95,20 +157,34 @@ std::optional<uint64_t> mlir::affine::getConstantTripCount(AffineForOp forOp) {
95157

96158
if (!map)
97159
return std::nullopt;
160+
SmallVector<AffineExpr, 4> symReplacements;
161+
replaceGPUOperands(forOp, operands, symReplacements, map.getNumDims());
162+
map = map.replaceDimsAndSymbols({}, symReplacements, map.getNumDims(),
163+
map.getNumSymbols());
164+
affine::AffineValueMap valueMap(map, operands);
165+
(void)valueMap.canonicalize();
166+
map = valueMap.getAffineMap();
167+
return getConstantTripCountFromAffineMap(map);
168+
}
98169

99-
// Take the min if all trip counts are constant.
100-
std::optional<uint64_t> tripCount;
101-
for (auto resultExpr : map.getResults()) {
102-
if (auto constExpr = dyn_cast<AffineConstantExpr>(resultExpr)) {
103-
if (tripCount.has_value())
104-
tripCount =
105-
std::min(*tripCount, static_cast<uint64_t>(constExpr.getValue()));
106-
else
107-
tripCount = constExpr.getValue();
108-
} else
109-
return std::nullopt;
110-
}
111-
return tripCount;
170+
/// In some scenarios, such as GPU, the number of trip of each thread in the
171+
/// loop is inconsistent. This function returns the maximum number of trip.
172+
std::optional<uint64_t>
173+
mlir::affine::getMaxConstantTripCount(AffineForOp forOp) {
174+
SmallVector<Value, 4> operands;
175+
AffineMap map;
176+
getTripCountMapAndOperands(forOp, &map, &operands);
177+
178+
if (!map)
179+
return std::nullopt;
180+
SmallVector<AffineExpr, 4> symReplacements;
181+
replaceGPUOperands(forOp, operands, symReplacements, map.getNumDims(), true);
182+
map = map.replaceDimsAndSymbols({}, symReplacements, map.getNumDims(),
183+
map.getNumSymbols());
184+
affine::AffineValueMap valueMap(map, operands);
185+
(void)valueMap.canonicalize();
186+
map = valueMap.getAffineMap();
187+
return getConstantTripCountFromAffineMap(map);
112188
}
113189

114190
/// Returns the greatest known integral divisor of the trip count. Affine
@@ -121,7 +197,13 @@ uint64_t mlir::affine::getLargestDivisorOfTripCount(AffineForOp forOp) {
121197

122198
if (!map)
123199
return 1;
124-
200+
SmallVector<AffineExpr, 4> symReplacements;
201+
replaceGPUOperands(forOp, operands, symReplacements, map.getNumDims());
202+
map = map.replaceDimsAndSymbols({}, symReplacements, map.getNumDims(),
203+
map.getNumSymbols());
204+
affine::AffineValueMap valueMap(map, operands);
205+
(void)valueMap.canonicalize();
206+
map = valueMap.getAffineMap();
125207
// The largest divisor of the trip count is the GCD of the individual largest
126208
// divisors.
127209
assert(map.getNumResults() >= 1 && "expected one or more results");

mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
1818
#include "mlir/Dialect/Affine/Utils.h"
1919
#include "mlir/Dialect/Func/IR/FuncOps.h"
20+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
2021
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2122
#include "mlir/Dialect/SCF/IR/SCF.h"
2223
#include "mlir/IR/IRMapping.h"
@@ -113,11 +114,29 @@ static void replaceIterArgsAndYieldResults(AffineForOp forOp) {
113114
std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
114115
}
115116

117+
/// Eliminate loops that will never actually execute
118+
LogicalResult mlir::affine::removeInvalidLoop(AffineForOp forOp) {
119+
std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
120+
std::optional<uint64_t> maxTripCount = getMaxConstantTripCount(forOp);
121+
if (!tripCount || *tripCount > 0 || !maxTripCount || *maxTripCount > 0)
122+
return failure();
123+
124+
auto iterOperands = forOp.getInits();
125+
auto results = forOp.getResults();
126+
for (auto [result, operand] : llvm::zip(results, iterOperands))
127+
result.replaceAllUsesWith(operand);
128+
129+
IRRewriter b(forOp);
130+
b.eraseOp(forOp);
131+
return success();
132+
}
133+
116134
/// Promotes the loop body of a forOp to its containing block if the forOp
117135
/// was known to have a single iteration.
118136
LogicalResult mlir::affine::promoteIfSingleIteration(AffineForOp forOp) {
119137
std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
120-
if (!tripCount || *tripCount != 1)
138+
std::optional<uint64_t> maxTripCount = getMaxConstantTripCount(forOp);
139+
if (!tripCount || *tripCount != 1 || !maxTripCount || *maxTripCount != 1)
121140
return failure();
122141

123142
// TODO: extend this for arbitrary affine bounds.
@@ -160,7 +179,8 @@ LogicalResult mlir::affine::promoteIfSingleIteration(AffineForOp forOp) {
160179
forOp.getBody()->back().erase();
161180
parentBlock->getOperations().splice(Block::iterator(forOp),
162181
forOp.getBody()->getOperations());
163-
forOp.erase();
182+
IRRewriter b(forOp.getContext());
183+
b.eraseOp(forOp);
164184
return success();
165185
}
166186

@@ -884,15 +904,27 @@ void mlir::affine::getTileableBands(
884904
/// Unrolls this loop completely.
885905
LogicalResult mlir::affine::loopUnrollFull(AffineForOp forOp) {
886906
std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
887-
if (mayBeConstantTripCount.has_value()) {
888-
uint64_t tripCount = *mayBeConstantTripCount;
889-
if (tripCount == 0)
890-
return success();
891-
if (tripCount == 1)
892-
return promoteIfSingleIteration(forOp);
893-
return loopUnrollByFactor(forOp, tripCount);
894-
}
895-
return failure();
907+
std::optional<uint64_t> maxMayBeConstantTripCount =
908+
getMaxConstantTripCount(forOp);
909+
910+
if (!mayBeConstantTripCount.has_value() &&
911+
!maxMayBeConstantTripCount.has_value())
912+
return failure();
913+
914+
uint64_t tripCount = *mayBeConstantTripCount;
915+
uint64_t maxTripCount = *maxMayBeConstantTripCount;
916+
917+
// The values of Trip are all 0, and the invalid loop is deleted.
918+
if (tripCount <= 0 && maxTripCount <= 0)
919+
return removeInvalidLoop(forOp);
920+
921+
// In special cases, such as in a GPU, only some threads execute this loop.
922+
if (tripCount == 0 && maxTripCount == 1)
923+
return success();
924+
925+
if (tripCount == 1 && maxTripCount == 1)
926+
return promoteIfSingleIteration(forOp);
927+
return loopUnrollByFactor(forOp, tripCount);
896928
}
897929

898930
/// Unrolls this loop by the specified factor or by the trip count (if constant)
@@ -1013,8 +1045,11 @@ LogicalResult mlir::affine::loopUnrollByFactor(
10131045
assert(unrollFactor > 0 && "unroll factor should be positive");
10141046

10151047
std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
1048+
std::optional<uint64_t> maxMayBeConstantTripCount =
1049+
getMaxConstantTripCount(forOp);
10161050
if (unrollFactor == 1) {
10171051
if (mayBeConstantTripCount && *mayBeConstantTripCount == 1 &&
1052+
maxMayBeConstantTripCount && *maxMayBeConstantTripCount == 1 &&
10181053
failed(promoteIfSingleIteration(forOp)))
10191054
return failure();
10201055
return success();

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,26 @@ std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
799799
return KernelDim3{operands[6], operands[7], operands[8]};
800800
}
801801

802+
Value LaunchOp::getBlockSizeOnAxis(Dimension dimension) {
803+
if (dimension == Dimension::x)
804+
return getBlockSizeX();
805+
else if (dimension == Dimension::y)
806+
return getBlockSizeY();
807+
else
808+
return getBlockSizeZ();
809+
}
810+
811+
Value LaunchOp::getBlockSizeOnAxis(Value threadId) {
812+
KernelDim3 threadIds = getThreadIds();
813+
if (threadIds.x == threadId)
814+
return getBlockSizeX();
815+
else if (threadIds.y == threadId)
816+
return getBlockSizeY();
817+
else if (threadIds.z == threadId)
818+
return getBlockSizeZ();
819+
return {};
820+
}
821+
802822
LogicalResult LaunchOp::verify() {
803823
if (!(hasClusterSize()) &&
804824
(getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))

0 commit comments

Comments
 (0)