10
10
// fundamental operations.
11
11
// ===----------------------------------------------------------------------===//
12
12
13
+ #include " mlir/Dialect/Affine/LoopUtils.h"
13
14
#include " mlir/Dialect/Affine/Passes.h"
14
15
15
16
#include " mlir/Dialect/Affine/IR/AffineOps.h"
@@ -28,6 +29,50 @@ namespace affine {
28
29
using namespace mlir ;
29
30
using namespace mlir ::affine;
30
31
32
+ // / Given a basis (in static and dynamic components), return the sequence of
33
+ // / suffix products of the basis, including the product of the entire basis,
34
+ // / which must **not** contain an outer bound.
35
+ // /
36
+ // / If excess dynamic values are provided, the values at the beginning
37
+ // / will be ignored. This allows for dropping the outer bound without
38
+ // / needing to manipulate the dynamic value array.
39
+ static SmallVector<Value> computeStrides (Location loc, RewriterBase &rewriter,
40
+ ValueRange dynamicBasis,
41
+ ArrayRef<int64_t > staticBasis) {
42
+ if (staticBasis.empty ())
43
+ return {};
44
+
45
+ SmallVector<Value> result;
46
+ result.reserve (staticBasis.size ());
47
+ size_t dynamicIndex = dynamicBasis.size ();
48
+ Value dynamicPart = nullptr ;
49
+ int64_t staticPart = 1 ;
50
+ for (int64_t elem : llvm::reverse (staticBasis)) {
51
+ if (ShapedType::isDynamic (elem)) {
52
+ if (dynamicPart)
53
+ dynamicPart = rewriter.create <arith::MulIOp>(
54
+ loc, dynamicPart, dynamicBasis[dynamicIndex - 1 ]);
55
+ else
56
+ dynamicPart = dynamicBasis[dynamicIndex - 1 ];
57
+ --dynamicIndex;
58
+ } else {
59
+ staticPart *= elem;
60
+ }
61
+
62
+ if (dynamicPart && staticPart == 1 ) {
63
+ result.push_back (dynamicPart);
64
+ } else {
65
+ Value stride =
66
+ rewriter.createOrFold <arith::ConstantIndexOp>(loc, staticPart);
67
+ if (dynamicPart)
68
+ stride = rewriter.create <arith::MulIOp>(loc, dynamicPart, stride);
69
+ result.push_back (stride);
70
+ }
71
+ }
72
+ std::reverse (result.begin (), result.end ());
73
+ return result;
74
+ }
75
+
31
76
namespace {
32
77
// / Lowers `affine.delinearize_index` into a sequence of division and remainder
33
78
// / operations.
@@ -36,18 +81,62 @@ struct LowerDelinearizeIndexOps
36
81
using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
37
82
LogicalResult matchAndRewrite (AffineDelinearizeIndexOp op,
38
83
PatternRewriter &rewriter) const override {
39
- FailureOr<SmallVector<Value>> multiIndex =
40
- delinearizeIndex (rewriter, op->getLoc (), op.getLinearIndex (),
41
- op.getEffectiveBasis (), /* hasOuterBound=*/ false );
42
- if (failed (multiIndex))
43
- return failure ();
44
- rewriter.replaceOp (op, *multiIndex);
84
+ Location loc = op.getLoc ();
85
+ Value linearIdx = op.getLinearIndex ();
86
+ unsigned numResults = op.getNumResults ();
87
+ ArrayRef<int64_t > staticBasis = op.getStaticBasis ();
88
+ if (numResults == staticBasis.size ())
89
+ staticBasis = staticBasis.drop_front ();
90
+
91
+ if (numResults == 1 ) {
92
+ rewriter.replaceOp (op, linearIdx);
93
+ return success ();
94
+ }
95
+
96
+ SmallVector<Value> results;
97
+ results.reserve (numResults);
98
+ SmallVector<Value> strides =
99
+ computeStrides (loc, rewriter, op.getDynamicBasis (), staticBasis);
100
+
101
+ Value zero = rewriter.createOrFold <arith::ConstantIndexOp>(loc, 0 );
102
+
103
+ Value initialPart =
104
+ rewriter.create <arith::FloorDivSIOp>(loc, linearIdx, strides.front ());
105
+ results.push_back (initialPart);
106
+
107
+ auto emitModTerm = [&](Value stride) -> Value {
108
+ Value remainder = rewriter.create <arith::RemSIOp>(loc, linearIdx, stride);
109
+ Value remainderNegative = rewriter.create <arith::CmpIOp>(
110
+ loc, arith::CmpIPredicate::slt, remainder , zero);
111
+ Value corrected = rewriter.create <arith::AddIOp>(loc, remainder , stride);
112
+ Value mod = rewriter.create <arith::SelectOp>(loc, remainderNegative,
113
+ corrected, remainder );
114
+ return mod;
115
+ };
116
+
117
+ // Generate all the intermediate parts
118
+ for (size_t i = 0 , e = strides.size () - 1 ; i < e; ++i) {
119
+ Value thisStride = strides[i];
120
+ Value nextStride = strides[i + 1 ];
121
+ Value modulus = emitModTerm (thisStride);
122
+ // We know both inputs are positive, so floorDiv == div.
123
+ // This could potentially be a divui, but it's not clear if that would
124
+ // cause issues.
125
+ Value divided = rewriter.create <arith::DivSIOp>(loc, modulus, nextStride);
126
+ results.push_back (divided);
127
+ }
128
+
129
+ results.push_back (emitModTerm (strides.back ()));
130
+
131
+ rewriter.replaceOp (op, results);
45
132
return success ();
46
133
}
47
134
};
48
135
49
136
// / Lowers `affine.linearize_index` into a sequence of multiplications and
50
- // / additions.
137
+ // / additions. Make a best effort to sort the input indices so that
138
+ // / the most loop-invariant terms are at the left of the additions
139
+ // / to enable loop-invariant code motion.
51
140
struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
52
141
using OpRewritePattern::OpRewritePattern;
53
142
LogicalResult matchAndRewrite (AffineLinearizeIndexOp op,
@@ -58,13 +147,44 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
58
147
return success ();
59
148
}
60
149
61
- SmallVector<OpFoldResult> multiIndex =
62
- getAsOpFoldResult (op.getMultiIndex ());
63
- OpFoldResult linearIndex =
64
- linearizeIndex (rewriter, op.getLoc (), multiIndex, op.getMixedBasis ());
65
- Value linearIndexValue =
66
- getValueOrCreateConstantIntOp (rewriter, op.getLoc (), linearIndex);
67
- rewriter.replaceOp (op, linearIndexValue);
150
+ Location loc = op.getLoc ();
151
+ ValueRange multiIndex = op.getMultiIndex ();
152
+ size_t numIndexes = multiIndex.size ();
153
+ ArrayRef<int64_t > staticBasis = op.getStaticBasis ();
154
+ if (numIndexes == staticBasis.size ())
155
+ staticBasis = staticBasis.drop_front ();
156
+
157
+ SmallVector<Value> strides =
158
+ computeStrides (loc, rewriter, op.getDynamicBasis (), staticBasis);
159
+ SmallVector<std::pair<Value, int64_t >> scaledValues;
160
+ scaledValues.reserve (numIndexes);
161
+
162
+ // Note: strides doesn't contain a value for the final element (stride 1)
163
+ // and everything else lines up. We use the "mutable" accessor so we can get
164
+ // our hands on an `OpOperand&` for the loop invariant counting function.
165
+ for (auto [stride, idxOp] :
166
+ llvm::zip_equal (strides, llvm::drop_end (op.getMultiIndexMutable ()))) {
167
+ Value scaledIdx =
168
+ rewriter.create <arith::MulIOp>(loc, idxOp.get (), stride);
169
+ int64_t numHoistableLoops = numEnclosingInvariantLoops (idxOp);
170
+ scaledValues.emplace_back (scaledIdx, numHoistableLoops);
171
+ }
172
+ scaledValues.emplace_back (
173
+ multiIndex.back (),
174
+ numEnclosingInvariantLoops (op.getMultiIndexMutable ()[numIndexes - 1 ]));
175
+
176
+ // Sort by how many enclosing loops there are, ties implicitly broken by
177
+ // size of the stride.
178
+ llvm::stable_sort (scaledValues,
179
+ [&](auto l, auto r) { return l.second > r.second ; });
180
+
181
+ Value result = scaledValues.front ().first ;
182
+ for (auto [scaledValue, numHoistableLoops] :
183
+ llvm::drop_begin (scaledValues)) {
184
+ std::ignore = numHoistableLoops;
185
+ result = rewriter.create <arith::AddIOp>(loc, result, scaledValue);
186
+ }
187
+ rewriter.replaceOp (op, result);
68
188
return success ();
69
189
}
70
190
};
0 commit comments