16
16
#include " mlir/Dialect/MemRef/IR/MemRef.h"
17
17
#include " mlir/Dialect/MemRef/Transforms/Passes.h"
18
18
#include " mlir/Dialect/MemRef/Transforms/Transforms.h"
19
+ #include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
19
20
#include " mlir/Dialect/Utils/IndexingUtils.h"
20
21
#include " mlir/Dialect/Utils/StaticValueUtils.h"
21
22
#include " mlir/Dialect/Vector/IR/VectorOps.h"
@@ -38,141 +39,6 @@ namespace memref {
38
39
39
40
using namespace mlir ;
40
41
41
- static void setInsertionPointToStart (OpBuilder &builder, Value val) {
42
- if (auto *parentOp = val.getDefiningOp ()) {
43
- builder.setInsertionPointAfter (parentOp);
44
- } else {
45
- builder.setInsertionPointToStart (val.getParentBlock ());
46
- }
47
- }
48
-
49
- OpFoldResult computeMemRefSpan (Value memref, OpBuilder &builder) {
50
- Location loc = memref.getLoc ();
51
- MemRefType type = cast<MemRefType>(memref.getType ());
52
- ArrayRef<int64_t > shape = type.getShape ();
53
-
54
- // Check for empty memref
55
- if (type.hasStaticShape () &&
56
- llvm::any_of (shape, [](int64_t dim) { return dim == 0 ; })) {
57
- return builder.getIndexAttr (0 );
58
- }
59
-
60
- // Get strides of the memref
61
- SmallVector<int64_t , 4 > strides;
62
- int64_t offset;
63
- if (failed (type.getStridesAndOffset (strides, offset))) {
64
- // Cannot extract strides, return a dynamic value
65
- return Value ();
66
- }
67
-
68
- // Static case: compute at compile time if possible
69
- if (type.hasStaticShape ()) {
70
- int64_t span = 0 ;
71
- for (unsigned i = 0 ; i < type.getRank (); ++i) {
72
- span += (shape[i] - 1 ) * strides[i];
73
- }
74
- return builder.getIndexAttr (span);
75
- }
76
-
77
- // Dynamic case: emit IR to compute at runtime
78
- Value result = builder.create <arith::ConstantIndexOp>(loc, 0 );
79
-
80
- for (unsigned i = 0 ; i < type.getRank (); ++i) {
81
- // Get dimension size
82
- Value dimSize;
83
- if (shape[i] == ShapedType::kDynamic ) {
84
- dimSize = builder.create <memref::DimOp>(loc, memref, i);
85
- } else {
86
- dimSize = builder.create <arith::ConstantIndexOp>(loc, shape[i]);
87
- }
88
-
89
- // Compute (dim - 1)
90
- Value one = builder.create <arith::ConstantIndexOp>(loc, 1 );
91
- Value dimMinusOne = builder.create <arith::SubIOp>(loc, dimSize, one);
92
-
93
- // Get stride
94
- Value stride;
95
- if (strides[i] == ShapedType::kDynamicStrideOrOffset ) {
96
- // For dynamic strides, need to extract from memref descriptor
97
- // This would require runtime support, possibly using extractStride
98
- // As a placeholder, return a dynamic value
99
- return Value ();
100
- } else {
101
- stride = builder.create <arith::ConstantIndexOp>(loc, strides[i]);
102
- }
103
-
104
- // Add (dim - 1) * stride to result
105
- Value term = builder.create <arith::MulIOp>(loc, dimMinusOne, stride);
106
- result = builder.create <arith::AddIOp>(loc, result, term);
107
- }
108
-
109
- return result;
110
- }
111
-
112
- static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>, OpFoldResult,
113
- OpFoldResult>
114
- getFlatOffsetAndStrides (OpBuilder &rewriter, Location loc, Value source,
115
- ArrayRef<OpFoldResult> subOffsets,
116
- ArrayRef<OpFoldResult> subStrides = std::nullopt) {
117
- auto sourceType = cast<MemRefType>(source.getType ());
118
- auto sourceRank = static_cast <unsigned >(sourceType.getRank ());
119
-
120
- memref::ExtractStridedMetadataOp newExtractStridedMetadata;
121
- {
122
- OpBuilder::InsertionGuard g (rewriter);
123
- setInsertionPointToStart (rewriter, source);
124
- newExtractStridedMetadata =
125
- rewriter.create <memref::ExtractStridedMetadataOp>(loc, source);
126
- }
127
-
128
- auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset ();
129
-
130
- auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult {
131
- return ShapedType::isDynamic (dim) ? getAsOpFoldResult (dimVal)
132
- : rewriter.getIndexAttr (dim);
133
- };
134
-
135
- OpFoldResult origOffset =
136
- getDim (sourceOffset, newExtractStridedMetadata.getOffset ());
137
- ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides ();
138
- OpFoldResult outmostDim =
139
- getDim (sourceType.getShape ().front (),
140
- newExtractStridedMetadata.getSizes ().front ());
141
-
142
- SmallVector<OpFoldResult> origStrides;
143
- origStrides.reserve (sourceRank);
144
-
145
- SmallVector<OpFoldResult> strides;
146
- strides.reserve (sourceRank);
147
-
148
- AffineExpr s0 = rewriter.getAffineSymbolExpr (0 );
149
- AffineExpr s1 = rewriter.getAffineSymbolExpr (1 );
150
- for (auto i : llvm::seq (0u , sourceRank)) {
151
- OpFoldResult origStride = getDim (sourceStrides[i], sourceStridesVals[i]);
152
-
153
- if (!subStrides.empty ()) {
154
- strides.push_back (affine::makeComposedFoldedAffineApply (
155
- rewriter, loc, s0 * s1, {subStrides[i], origStride}));
156
- }
157
-
158
- origStrides.emplace_back (origStride);
159
- }
160
-
161
- // Compute linearized index:
162
- auto &&[expr, values] =
163
- computeLinearIndex (rewriter.getIndexAttr (0 ), origStrides, subOffsets);
164
- OpFoldResult linearizedIndex =
165
- affine::makeComposedFoldedAffineApply (rewriter, loc, expr, values);
166
-
167
- // Compute collapsed size: (the outmost stride * outmost dimension).
168
- // SmallVector<OpFoldResult> ops{origStrides.front(), outmostDim};
169
- // OpFoldResult collapsedSize = affine::computeProduct(loc, rewriter, ops);
170
- OpFoldResult collapsedSize = computeMemRefSpan (source, rewriter);
171
-
172
- return {newExtractStridedMetadata.getBaseBuffer (), linearizedIndex,
173
- origStrides, origOffset, collapsedSize};
174
- }
175
-
176
42
static Value getValueFromOpFoldResult (OpBuilder &rewriter, Location loc,
177
43
OpFoldResult in) {
178
44
if (Attribute offsetAttr = dyn_cast<Attribute>(in)) {
@@ -188,17 +54,36 @@ static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
188
54
Location loc,
189
55
Value source,
190
56
ValueRange indices) {
191
- auto &&[base, index , strides, offset, collapsedShape] =
192
- getFlatOffsetAndStrides (rewriter, loc, source,
193
- getAsOpFoldResult (indices));
57
+ int64_t sourceOffset;
58
+ SmallVector<int64_t , 4 > sourceStrides;
59
+ auto sourceType = cast<MemRefType>(source.getType ());
60
+ if (failed (sourceType.getStridesAndOffset (sourceStrides, sourceOffset))) {
61
+ assert (false );
62
+ }
63
+
64
+ memref::ExtractStridedMetadataOp stridedMetadata =
65
+ rewriter.create <memref::ExtractStridedMetadataOp>(loc, source);
66
+
67
+ auto typeBit = sourceType.getElementType ().getIntOrFloatBitWidth ();
68
+ OpFoldResult linearizedIndices;
69
+ memref::LinearizedMemRefInfo linearizedInfo;
70
+ std::tie (linearizedInfo, linearizedIndices) =
71
+ memref::getLinearizedMemRefOffsetAndSize (
72
+ rewriter, loc, typeBit, typeBit,
73
+ stridedMetadata.getConstifiedMixedOffset (),
74
+ stridedMetadata.getConstifiedMixedSizes (),
75
+ stridedMetadata.getConstifiedMixedStrides (),
76
+ getAsOpFoldResult (indices));
194
77
195
78
return std::make_pair (
196
79
rewriter.create <memref::ReinterpretCastOp>(
197
80
loc, source,
198
- /* offset = */ offset,
199
- /* shapes = */ ArrayRef<OpFoldResult>{collapsedShape},
200
- /* strides = */ ArrayRef<OpFoldResult>{strides.back ()}),
201
- getValueFromOpFoldResult (rewriter, loc, index ));
81
+ /* offset = */ linearizedInfo.linearizedOffset ,
82
+ /* shapes = */ ArrayRef<OpFoldResult>{linearizedInfo.linearizedSize },
83
+ /* strides = */
84
+ ArrayRef<OpFoldResult>{
85
+ stridedMetadata.getConstifiedMixedStrides ().back ()}),
86
+ getValueFromOpFoldResult (rewriter, loc, linearizedIndices));
202
87
}
203
88
204
89
static bool needFlattening (Value val) {
@@ -313,8 +198,23 @@ struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
313
198
SmallVector<OpFoldResult> subOffsets = op.getMixedOffsets ();
314
199
SmallVector<OpFoldResult> subSizes = op.getMixedSizes ();
315
200
SmallVector<OpFoldResult> subStrides = op.getMixedStrides ();
316
- auto &&[base, finalOffset, strides, _, __] =
317
- getFlatOffsetAndStrides (rewriter, loc, memref, subOffsets, subStrides);
201
+
202
+ // base, finalOffset, strides
203
+ memref::ExtractStridedMetadataOp stridedMetadata =
204
+ rewriter.create <memref::ExtractStridedMetadataOp>(loc, memref);
205
+
206
+ auto sourceType = cast<MemRefType>(memref.getType ());
207
+ auto typeBit = sourceType.getElementType ().getIntOrFloatBitWidth ();
208
+ OpFoldResult linearizedIndices;
209
+ memref::LinearizedMemRefInfo linearizedInfo;
210
+ std::tie (linearizedInfo, linearizedIndices) =
211
+ memref::getLinearizedMemRefOffsetAndSize (
212
+ rewriter, loc, typeBit, typeBit,
213
+ stridedMetadata.getConstifiedMixedOffset (),
214
+ stridedMetadata.getConstifiedMixedSizes (),
215
+ stridedMetadata.getConstifiedMixedStrides (), op.getMixedOffsets ());
216
+ auto finalOffset = linearizedInfo.linearizedOffset ;
217
+ auto strides = stridedMetadata.getConstifiedMixedStrides ();
318
218
319
219
auto srcType = cast<MemRefType>(memref.getType ());
320
220
auto resultType = cast<MemRefType>(op.getType ());
@@ -337,7 +237,7 @@ struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
337
237
}
338
238
339
239
rewriter.replaceOpWithNewOp <memref::ReinterpretCastOp>(
340
- op, resultType, base , finalOffset, finalSizes, finalStrides);
240
+ op, resultType, memref , finalOffset, finalSizes, finalStrides);
341
241
return success ();
342
242
}
343
243
};
@@ -364,12 +264,13 @@ struct FlattenMemrefsPass
364
264
} // namespace
365
265
366
266
void memref::populateFlattenMemrefsPatterns (RewritePatternSet &patterns) {
367
- patterns
368
- .insert <MemRefRewritePattern<memref::LoadOp>,
369
- MemRefRewritePattern<memref::StoreOp>,
370
- MemRefRewritePattern<vector::LoadOp>,
371
- MemRefRewritePattern<vector::StoreOp>,
372
- MemRefRewritePattern<vector::TransferReadOp>,
373
- MemRefRewritePattern<vector::TransferWriteOp>, FlattenSubview>(
374
- patterns.getContext ());
267
+ patterns.insert <MemRefRewritePattern<memref::LoadOp>,
268
+ MemRefRewritePattern<memref::StoreOp>,
269
+ MemRefRewritePattern<vector::LoadOp>,
270
+ MemRefRewritePattern<vector::StoreOp>,
271
+ MemRefRewritePattern<vector::TransferReadOp>,
272
+ MemRefRewritePattern<vector::TransferWriteOp>,
273
+ MemRefRewritePattern<vector::MaskedLoadOp>,
274
+ MemRefRewritePattern<vector::MaskedStoreOp>, FlattenSubview>(
275
+ patterns.getContext ());
375
276
}
0 commit comments