Skip to content

Commit 9e8200c

Browse files
authored
[mlir][Affine] Expand affine.[de]linearize_index without affine maps (#116703)
As the documentation for -affine-expand-index-ops says, affine.delinearize_index and affine.linearize_index don't need to be expanded into the affine dialect. Expanding these operations into affine.apply operations can introduce unwanted "simplifications", mainly translations of `(dN mod C + ...)` to `(dN + ... - (dN floordiv C) * C)` and similar, which create worse generated code. This commit resolves this issue by expanding out affine.delanierize_index directly. In addition, the lowering of affine.linearize_index now sorts the operands by loop-independence, allowing an increased amount of loop-invariant code motion after lowering. The old behavior is preserved as -expand-affine-index-ops-as-affine but is no longer the default
1 parent 0ac889b commit 9e8200c

File tree

11 files changed

+405
-109
lines changed

11 files changed

+405
-109
lines changed

mlir/include/mlir/Dialect/Affine/LoopUtils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,11 @@ separateFullTiles(MutableArrayRef<AffineForOp> nest,
301301
/// Walk an affine.for to find a band to coalesce.
302302
LogicalResult coalescePerfectlyNestedAffineLoops(AffineForOp op);
303303

304+
/// Count the number of loops surrounding `operand` such that operand could be
305+
/// hoisted above.
306+
/// Stop counting at the first loop over which the operand cannot be hoisted.
307+
/// This counts any LoopLikeOpInterface, not just affine.for.
308+
int64_t numEnclosingInvariantLoops(OpOperand &operand);
304309
} // namespace affine
305310
} // namespace mlir
306311

mlir/include/mlir/Dialect/Affine/Passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ std::unique_ptr<OperationPass<func::FuncOp>> createPipelineDataTransferPass();
116116
/// operations (not necessarily restricted to Affine dialect).
117117
std::unique_ptr<Pass> createAffineExpandIndexOpsPass();
118118

119+
/// Creates a pass to expand affine index operations into affine.apply
120+
/// operations.
121+
std::unique_ptr<Pass> createAffineExpandIndexOpsAsAffinePass();
122+
119123
//===----------------------------------------------------------------------===//
120124
// Registration
121125
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Affine/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,4 +408,9 @@ def AffineExpandIndexOps : Pass<"affine-expand-index-ops"> {
408408
let constructor = "mlir::affine::createAffineExpandIndexOpsPass()";
409409
}
410410

411+
def AffineExpandIndexOpsAsAffine : Pass<"affine-expand-index-ops-as-affine"> {
412+
let summary = "Lower affine operations operating on indices into affine.apply operations";
413+
let constructor = "mlir::affine::createAffineExpandIndexOpsAsAffinePass()";
414+
}
415+
411416
#endif // MLIR_DIALECT_AFFINE_PASSES

mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ class AffineApplyOp;
3737
/// operations (not necessarily restricted to Affine dialect).
3838
void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns);
3939

40+
/// Populate patterns that expand affine index operations into their equivalent
41+
/// `affine.apply` representations.
42+
void populateAffineExpandIndexOpsAsAffinePatterns(RewritePatternSet &patterns);
43+
4044
/// Helper function to rewrite `op`'s affine map and reorder its operands such
4145
/// that they are in increasing order of hoistability (i.e. the least hoistable)
4246
/// operands come first in the operand list.

mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp

Lines changed: 134 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
// fundamental operations.
1111
//===----------------------------------------------------------------------===//
1212

13+
#include "mlir/Dialect/Affine/LoopUtils.h"
1314
#include "mlir/Dialect/Affine/Passes.h"
1415

1516
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -28,6 +29,50 @@ namespace affine {
2829
using namespace mlir;
2930
using namespace mlir::affine;
3031

32+
/// Given a basis (in static and dynamic components), return the sequence of
33+
/// suffix products of the basis, including the product of the entire basis,
34+
/// which must **not** contain an outer bound.
35+
///
36+
/// If excess dynamic values are provided, the values at the beginning
37+
/// will be ignored. This allows for dropping the outer bound without
38+
/// needing to manipulate the dynamic value array.
39+
static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
40+
ValueRange dynamicBasis,
41+
ArrayRef<int64_t> staticBasis) {
42+
if (staticBasis.empty())
43+
return {};
44+
45+
SmallVector<Value> result;
46+
result.reserve(staticBasis.size());
47+
size_t dynamicIndex = dynamicBasis.size();
48+
Value dynamicPart = nullptr;
49+
int64_t staticPart = 1;
50+
for (int64_t elem : llvm::reverse(staticBasis)) {
51+
if (ShapedType::isDynamic(elem)) {
52+
if (dynamicPart)
53+
dynamicPart = rewriter.create<arith::MulIOp>(
54+
loc, dynamicPart, dynamicBasis[dynamicIndex - 1]);
55+
else
56+
dynamicPart = dynamicBasis[dynamicIndex - 1];
57+
--dynamicIndex;
58+
} else {
59+
staticPart *= elem;
60+
}
61+
62+
if (dynamicPart && staticPart == 1) {
63+
result.push_back(dynamicPart);
64+
} else {
65+
Value stride =
66+
rewriter.createOrFold<arith::ConstantIndexOp>(loc, staticPart);
67+
if (dynamicPart)
68+
stride = rewriter.create<arith::MulIOp>(loc, dynamicPart, stride);
69+
result.push_back(stride);
70+
}
71+
}
72+
std::reverse(result.begin(), result.end());
73+
return result;
74+
}
75+
3176
namespace {
3277
/// Lowers `affine.delinearize_index` into a sequence of division and remainder
3378
/// operations.
@@ -36,18 +81,62 @@ struct LowerDelinearizeIndexOps
3681
using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
3782
LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
3883
PatternRewriter &rewriter) const override {
39-
FailureOr<SmallVector<Value>> multiIndex =
40-
delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
41-
op.getEffectiveBasis(), /*hasOuterBound=*/false);
42-
if (failed(multiIndex))
43-
return failure();
44-
rewriter.replaceOp(op, *multiIndex);
84+
Location loc = op.getLoc();
85+
Value linearIdx = op.getLinearIndex();
86+
unsigned numResults = op.getNumResults();
87+
ArrayRef<int64_t> staticBasis = op.getStaticBasis();
88+
if (numResults == staticBasis.size())
89+
staticBasis = staticBasis.drop_front();
90+
91+
if (numResults == 1) {
92+
rewriter.replaceOp(op, linearIdx);
93+
return success();
94+
}
95+
96+
SmallVector<Value> results;
97+
results.reserve(numResults);
98+
SmallVector<Value> strides =
99+
computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis);
100+
101+
Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
102+
103+
Value initialPart =
104+
rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front());
105+
results.push_back(initialPart);
106+
107+
auto emitModTerm = [&](Value stride) -> Value {
108+
Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
109+
Value remainderNegative = rewriter.create<arith::CmpIOp>(
110+
loc, arith::CmpIPredicate::slt, remainder, zero);
111+
Value corrected = rewriter.create<arith::AddIOp>(loc, remainder, stride);
112+
Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
113+
corrected, remainder);
114+
return mod;
115+
};
116+
117+
// Generate all the intermediate parts
118+
for (size_t i = 0, e = strides.size() - 1; i < e; ++i) {
119+
Value thisStride = strides[i];
120+
Value nextStride = strides[i + 1];
121+
Value modulus = emitModTerm(thisStride);
122+
// We know both inputs are positive, so floorDiv == div.
123+
// This could potentially be a divui, but it's not clear if that would
124+
// cause issues.
125+
Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride);
126+
results.push_back(divided);
127+
}
128+
129+
results.push_back(emitModTerm(strides.back()));
130+
131+
rewriter.replaceOp(op, results);
45132
return success();
46133
}
47134
};
48135

49136
/// Lowers `affine.linearize_index` into a sequence of multiplications and
50-
/// additions.
137+
/// additions. Make a best effort to sort the input indices so that
138+
/// the most loop-invariant terms are at the left of the additions
139+
/// to enable loop-invariant code motion.
51140
struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
52141
using OpRewritePattern::OpRewritePattern;
53142
LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
@@ -58,13 +147,44 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
58147
return success();
59148
}
60149

61-
SmallVector<OpFoldResult> multiIndex =
62-
getAsOpFoldResult(op.getMultiIndex());
63-
OpFoldResult linearIndex =
64-
linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis());
65-
Value linearIndexValue =
66-
getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex);
67-
rewriter.replaceOp(op, linearIndexValue);
150+
Location loc = op.getLoc();
151+
ValueRange multiIndex = op.getMultiIndex();
152+
size_t numIndexes = multiIndex.size();
153+
ArrayRef<int64_t> staticBasis = op.getStaticBasis();
154+
if (numIndexes == staticBasis.size())
155+
staticBasis = staticBasis.drop_front();
156+
157+
SmallVector<Value> strides =
158+
computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis);
159+
SmallVector<std::pair<Value, int64_t>> scaledValues;
160+
scaledValues.reserve(numIndexes);
161+
162+
// Note: strides doesn't contain a value for the final element (stride 1)
163+
// and everything else lines up. We use the "mutable" accessor so we can get
164+
// our hands on an `OpOperand&` for the loop invariant counting function.
165+
for (auto [stride, idxOp] :
166+
llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
167+
Value scaledIdx =
168+
rewriter.create<arith::MulIOp>(loc, idxOp.get(), stride);
169+
int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
170+
scaledValues.emplace_back(scaledIdx, numHoistableLoops);
171+
}
172+
scaledValues.emplace_back(
173+
multiIndex.back(),
174+
numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1]));
175+
176+
// Sort by how many enclosing loops there are, ties implicitly broken by
177+
// size of the stride.
178+
llvm::stable_sort(scaledValues,
179+
[&](auto l, auto r) { return l.second > r.second; });
180+
181+
Value result = scaledValues.front().first;
182+
for (auto [scaledValue, numHoistableLoops] :
183+
llvm::drop_begin(scaledValues)) {
184+
std::ignore = numHoistableLoops;
185+
result = rewriter.create<arith::AddIOp>(loc, result, scaledValue);
186+
}
187+
rewriter.replaceOp(op, result);
68188
return success();
69189
}
70190
};
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
//===- AffineExpandIndexOpsAsAffine.cpp - Expand index ops to apply pass --===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass to expand affine index ops into one or more more
10+
// fundamental operations.
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/Affine/Passes.h"
14+
15+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
16+
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
17+
#include "mlir/Dialect/Affine/Utils.h"
18+
#include "mlir/Dialect/Arith/Utils/Utils.h"
19+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20+
21+
namespace mlir {
22+
namespace affine {
23+
#define GEN_PASS_DEF_AFFINEEXPANDINDEXOPSASAFFINE
24+
#include "mlir/Dialect/Affine/Passes.h.inc"
25+
} // namespace affine
26+
} // namespace mlir
27+
28+
using namespace mlir;
29+
using namespace mlir::affine;
30+
31+
namespace {
32+
/// Lowers `affine.delinearize_index` into a sequence of division and remainder
33+
/// operations.
34+
struct LowerDelinearizeIndexOps
35+
: public OpRewritePattern<AffineDelinearizeIndexOp> {
36+
using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
37+
LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
38+
PatternRewriter &rewriter) const override {
39+
FailureOr<SmallVector<Value>> multiIndex =
40+
delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
41+
op.getEffectiveBasis(), /*hasOuterBound=*/false);
42+
if (failed(multiIndex))
43+
return failure();
44+
rewriter.replaceOp(op, *multiIndex);
45+
return success();
46+
}
47+
};
48+
49+
/// Lowers `affine.linearize_index` into a sequence of multiplications and
50+
/// additions.
51+
struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
52+
using OpRewritePattern::OpRewritePattern;
53+
LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
54+
PatternRewriter &rewriter) const override {
55+
// Should be folded away, included here for safety.
56+
if (op.getMultiIndex().empty()) {
57+
rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
58+
return success();
59+
}
60+
61+
SmallVector<OpFoldResult> multiIndex =
62+
getAsOpFoldResult(op.getMultiIndex());
63+
OpFoldResult linearIndex =
64+
linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis());
65+
Value linearIndexValue =
66+
getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex);
67+
rewriter.replaceOp(op, linearIndexValue);
68+
return success();
69+
}
70+
};
71+
72+
class ExpandAffineIndexOpsAsAffinePass
73+
: public affine::impl::AffineExpandIndexOpsAsAffineBase<
74+
ExpandAffineIndexOpsAsAffinePass> {
75+
public:
76+
ExpandAffineIndexOpsAsAffinePass() = default;
77+
78+
void runOnOperation() override {
79+
MLIRContext *context = &getContext();
80+
RewritePatternSet patterns(context);
81+
populateAffineExpandIndexOpsAsAffinePatterns(patterns);
82+
if (failed(
83+
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
84+
return signalPassFailure();
85+
}
86+
};
87+
88+
} // namespace
89+
90+
void mlir::affine::populateAffineExpandIndexOpsAsAffinePatterns(
91+
RewritePatternSet &patterns) {
92+
patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>(
93+
patterns.getContext());
94+
}
95+
96+
std::unique_ptr<Pass> mlir::affine::createAffineExpandIndexOpsAsAffinePass() {
97+
return std::make_unique<ExpandAffineIndexOpsAsAffinePass>();
98+
}

mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_mlir_dialect_library(MLIRAffineTransforms
22
AffineDataCopyGeneration.cpp
33
AffineExpandIndexOps.cpp
4+
AffineExpandIndexOpsAsAffine.cpp
45
AffineLoopInvariantCodeMotion.cpp
56
AffineLoopNormalize.cpp
67
AffineParallelize.cpp

mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2772,3 +2772,15 @@ LogicalResult affine::coalescePerfectlyNestedAffineLoops(AffineForOp op) {
27722772
}
27732773
return result;
27742774
}
2775+
2776+
int64_t mlir::affine::numEnclosingInvariantLoops(OpOperand &operand) {
2777+
int64_t count = 0;
2778+
Operation *currentOp = operand.getOwner();
2779+
while (auto loopOp = currentOp->getParentOfType<LoopLikeOpInterface>()) {
2780+
if (!loopOp.isDefinedOutsideOfLoop(operand.get()))
2781+
break;
2782+
currentOp = loopOp;
2783+
count++;
2784+
}
2785+
return count;
2786+
}

0 commit comments

Comments
 (0)