Skip to content

Commit c9db9bf

Browse files
committed
Some updates
1 parent 7b097f4 commit c9db9bf

File tree

2 files changed

+69
-164
lines changed

2 files changed

+69
-164
lines changed

mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp

Lines changed: 54 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1717
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
1818
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
19+
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1920
#include "mlir/Dialect/Utils/IndexingUtils.h"
2021
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2122
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -38,141 +39,6 @@ namespace memref {
3839

3940
using namespace mlir;
4041

41-
static void setInsertionPointToStart(OpBuilder &builder, Value val) {
42-
if (auto *parentOp = val.getDefiningOp()) {
43-
builder.setInsertionPointAfter(parentOp);
44-
} else {
45-
builder.setInsertionPointToStart(val.getParentBlock());
46-
}
47-
}
48-
49-
OpFoldResult computeMemRefSpan(Value memref, OpBuilder &builder) {
50-
Location loc = memref.getLoc();
51-
MemRefType type = cast<MemRefType>(memref.getType());
52-
ArrayRef<int64_t> shape = type.getShape();
53-
54-
// Check for empty memref
55-
if (type.hasStaticShape() &&
56-
llvm::any_of(shape, [](int64_t dim) { return dim == 0; })) {
57-
return builder.getIndexAttr(0);
58-
}
59-
60-
// Get strides of the memref
61-
SmallVector<int64_t, 4> strides;
62-
int64_t offset;
63-
if (failed(type.getStridesAndOffset(strides, offset))) {
64-
// Cannot extract strides, return a dynamic value
65-
return Value();
66-
}
67-
68-
// Static case: compute at compile time if possible
69-
if (type.hasStaticShape()) {
70-
int64_t span = 0;
71-
for (unsigned i = 0; i < type.getRank(); ++i) {
72-
span += (shape[i] - 1) * strides[i];
73-
}
74-
return builder.getIndexAttr(span);
75-
}
76-
77-
// Dynamic case: emit IR to compute at runtime
78-
Value result = builder.create<arith::ConstantIndexOp>(loc, 0);
79-
80-
for (unsigned i = 0; i < type.getRank(); ++i) {
81-
// Get dimension size
82-
Value dimSize;
83-
if (shape[i] == ShapedType::kDynamic) {
84-
dimSize = builder.create<memref::DimOp>(loc, memref, i);
85-
} else {
86-
dimSize = builder.create<arith::ConstantIndexOp>(loc, shape[i]);
87-
}
88-
89-
// Compute (dim - 1)
90-
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
91-
Value dimMinusOne = builder.create<arith::SubIOp>(loc, dimSize, one);
92-
93-
// Get stride
94-
Value stride;
95-
if (strides[i] == ShapedType::kDynamicStrideOrOffset) {
96-
// For dynamic strides, need to extract from memref descriptor
97-
// This would require runtime support, possibly using extractStride
98-
// As a placeholder, return a dynamic value
99-
return Value();
100-
} else {
101-
stride = builder.create<arith::ConstantIndexOp>(loc, strides[i]);
102-
}
103-
104-
// Add (dim - 1) * stride to result
105-
Value term = builder.create<arith::MulIOp>(loc, dimMinusOne, stride);
106-
result = builder.create<arith::AddIOp>(loc, result, term);
107-
}
108-
109-
return result;
110-
}
111-
112-
static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>, OpFoldResult,
113-
OpFoldResult>
114-
getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
115-
ArrayRef<OpFoldResult> subOffsets,
116-
ArrayRef<OpFoldResult> subStrides = std::nullopt) {
117-
auto sourceType = cast<MemRefType>(source.getType());
118-
auto sourceRank = static_cast<unsigned>(sourceType.getRank());
119-
120-
memref::ExtractStridedMetadataOp newExtractStridedMetadata;
121-
{
122-
OpBuilder::InsertionGuard g(rewriter);
123-
setInsertionPointToStart(rewriter, source);
124-
newExtractStridedMetadata =
125-
rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
126-
}
127-
128-
auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
129-
130-
auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult {
131-
return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal)
132-
: rewriter.getIndexAttr(dim);
133-
};
134-
135-
OpFoldResult origOffset =
136-
getDim(sourceOffset, newExtractStridedMetadata.getOffset());
137-
ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides();
138-
OpFoldResult outmostDim =
139-
getDim(sourceType.getShape().front(),
140-
newExtractStridedMetadata.getSizes().front());
141-
142-
SmallVector<OpFoldResult> origStrides;
143-
origStrides.reserve(sourceRank);
144-
145-
SmallVector<OpFoldResult> strides;
146-
strides.reserve(sourceRank);
147-
148-
AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
149-
AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
150-
for (auto i : llvm::seq(0u, sourceRank)) {
151-
OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]);
152-
153-
if (!subStrides.empty()) {
154-
strides.push_back(affine::makeComposedFoldedAffineApply(
155-
rewriter, loc, s0 * s1, {subStrides[i], origStride}));
156-
}
157-
158-
origStrides.emplace_back(origStride);
159-
}
160-
161-
// Compute linearized index:
162-
auto &&[expr, values] =
163-
computeLinearIndex(rewriter.getIndexAttr(0), origStrides, subOffsets);
164-
OpFoldResult linearizedIndex =
165-
affine::makeComposedFoldedAffineApply(rewriter, loc, expr, values);
166-
167-
// Compute collapsed size: (the outmost stride * outmost dimension).
168-
//SmallVector<OpFoldResult> ops{origStrides.front(), outmostDim};
169-
//OpFoldResult collapsedSize = affine::computeProduct(loc, rewriter, ops);
170-
OpFoldResult collapsedSize = computeMemRefSpan(source, rewriter);
171-
172-
return {newExtractStridedMetadata.getBaseBuffer(), linearizedIndex,
173-
origStrides, origOffset, collapsedSize};
174-
}
175-
17642
static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
17743
OpFoldResult in) {
17844
if (Attribute offsetAttr = dyn_cast<Attribute>(in)) {
@@ -188,17 +54,36 @@ static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
18854
Location loc,
18955
Value source,
19056
ValueRange indices) {
191-
auto &&[base, index, strides, offset, collapsedShape] =
192-
getFlatOffsetAndStrides(rewriter, loc, source,
193-
getAsOpFoldResult(indices));
57+
int64_t sourceOffset;
58+
SmallVector<int64_t, 4> sourceStrides;
59+
auto sourceType = cast<MemRefType>(source.getType());
60+
if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset))) {
61+
assert(false);
62+
}
63+
64+
memref::ExtractStridedMetadataOp stridedMetadata =
65+
rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
66+
67+
auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth();
68+
OpFoldResult linearizedIndices;
69+
memref::LinearizedMemRefInfo linearizedInfo;
70+
std::tie(linearizedInfo, linearizedIndices) =
71+
memref::getLinearizedMemRefOffsetAndSize(
72+
rewriter, loc, typeBit, typeBit,
73+
stridedMetadata.getConstifiedMixedOffset(),
74+
stridedMetadata.getConstifiedMixedSizes(),
75+
stridedMetadata.getConstifiedMixedStrides(),
76+
getAsOpFoldResult(indices));
19477

19578
return std::make_pair(
19679
rewriter.create<memref::ReinterpretCastOp>(
19780
loc, source,
198-
/* offset = */ offset,
199-
/* shapes = */ ArrayRef<OpFoldResult>{collapsedShape},
200-
/* strides = */ ArrayRef<OpFoldResult>{strides.back()}),
201-
getValueFromOpFoldResult(rewriter, loc, index));
81+
/* offset = */ linearizedInfo.linearizedOffset,
82+
/* shapes = */ ArrayRef<OpFoldResult>{linearizedInfo.linearizedSize},
83+
/* strides = */
84+
ArrayRef<OpFoldResult>{
85+
stridedMetadata.getConstifiedMixedStrides().back()}),
86+
getValueFromOpFoldResult(rewriter, loc, linearizedIndices));
20287
}
20388

20489
static bool needFlattening(Value val) {
@@ -313,8 +198,23 @@ struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
313198
SmallVector<OpFoldResult> subOffsets = op.getMixedOffsets();
314199
SmallVector<OpFoldResult> subSizes = op.getMixedSizes();
315200
SmallVector<OpFoldResult> subStrides = op.getMixedStrides();
316-
auto &&[base, finalOffset, strides, _, __] =
317-
getFlatOffsetAndStrides(rewriter, loc, memref, subOffsets, subStrides);
201+
202+
// base, finalOffset, strides
203+
memref::ExtractStridedMetadataOp stridedMetadata =
204+
rewriter.create<memref::ExtractStridedMetadataOp>(loc, memref);
205+
206+
auto sourceType = cast<MemRefType>(memref.getType());
207+
auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth();
208+
OpFoldResult linearizedIndices;
209+
memref::LinearizedMemRefInfo linearizedInfo;
210+
std::tie(linearizedInfo, linearizedIndices) =
211+
memref::getLinearizedMemRefOffsetAndSize(
212+
rewriter, loc, typeBit, typeBit,
213+
stridedMetadata.getConstifiedMixedOffset(),
214+
stridedMetadata.getConstifiedMixedSizes(),
215+
stridedMetadata.getConstifiedMixedStrides(), op.getMixedOffsets());
216+
auto finalOffset = linearizedInfo.linearizedOffset;
217+
auto strides = stridedMetadata.getConstifiedMixedStrides();
318218

319219
auto srcType = cast<MemRefType>(memref.getType());
320220
auto resultType = cast<MemRefType>(op.getType());
@@ -337,7 +237,7 @@ struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
337237
}
338238

339239
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
340-
op, resultType, base, finalOffset, finalSizes, finalStrides);
240+
op, resultType, memref, finalOffset, finalSizes, finalStrides);
341241
return success();
342242
}
343243
};
@@ -364,12 +264,13 @@ struct FlattenMemrefsPass
364264
} // namespace
365265

366266
void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
367-
patterns
368-
.insert<MemRefRewritePattern<memref::LoadOp>,
369-
MemRefRewritePattern<memref::StoreOp>,
370-
MemRefRewritePattern<vector::LoadOp>,
371-
MemRefRewritePattern<vector::StoreOp>,
372-
MemRefRewritePattern<vector::TransferReadOp>,
373-
MemRefRewritePattern<vector::TransferWriteOp>, FlattenSubview>(
374-
patterns.getContext());
267+
patterns.insert<MemRefRewritePattern<memref::LoadOp>,
268+
MemRefRewritePattern<memref::StoreOp>,
269+
MemRefRewritePattern<vector::LoadOp>,
270+
MemRefRewritePattern<vector::StoreOp>,
271+
MemRefRewritePattern<vector::TransferReadOp>,
272+
MemRefRewritePattern<vector::TransferWriteOp>,
273+
MemRefRewritePattern<vector::MaskedLoadOp>,
274+
MemRefRewritePattern<vector::MaskedStoreOp>, FlattenSubview>(
275+
patterns.getContext());
375276
}

mlir/test/Dialect/MemRef/flatten_memref.mlir

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ func.func @load_scalar_from_memref(%input: memref<4x8xf32, strided<[8, 1], offse
66
%value = memref.load %input[%c1, %c2] : memref<4x8xf32, strided<[8, 1], offset: 100>>
77
return %value : f32
88
}
9-
// CHECK: func @load_scalar_from_memref
9+
// CHECK-LABEL: func @load_scalar_from_memref
1010
// CHECK: %[[C10:.*]] = arith.constant 10 : index
1111
// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [100], sizes: [32], strides: [1]
1212
// CHECK-SAME: memref<4x8xf32, strided<[8, 1], offset: 100>> to memref<32xf32, strided<[1], offset: 100>>
@@ -18,6 +18,7 @@ func.func @load_scalar_from_memref_static_dim_2(%input: memref<4x8xf32, strided<
1818
%value = memref.load %input[%col, %row] : memref<4x8xf32, strided<[8, 12], offset: 100>>
1919
return %value : f32
2020
}
21+
2122
// CHECK: [[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 8 + s1 * 12)>
2223
// CHECK: func @load_scalar_from_memref_static_dim_2
2324
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[8, 12], offset: 100>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
@@ -39,7 +40,7 @@ func.func @load_scalar_from_memref_dynamic_dim(%input: memref<?x?xf32, strided<[
3940
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32, strided<[?, ?], offset: ?>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
4041
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]]
4142
// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1]
42-
// CHECK: %[[SIZE:.*]] = affine.apply #[[MAP1]]()[%[[STRIDES]]#0, %[[SIZES]]#0]
43+
// CHECK: %[[SIZE:.*]] = affine.apply #[[MAP1]]()[%[[SIZES]]#0, %[[SIZES]]#1]
4344
// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [%[[OFFSET]]], sizes: [%[[SIZE]]], strides: [%[[STRIDES]]#1]
4445
// CHECK: memref.load %[[REINT]][%[[IDX]]]
4546

@@ -49,7 +50,9 @@ func.func @load_scalar_from_memref_subview(%input: memref<4x8xf32>, %row: index,
4950
%subview = memref.subview %input[%col, %row] [1, 1] [1, 1] : memref<4x8xf32> to memref<1x1xf32, strided<[8, 1], offset: ?>>
5051
return %subview : memref<1x1xf32, strided<[8, 1], offset: ?>>
5152
}
52-
// CHECK: func @load_scalar_from_memref_subview
53+
// CHECK-LABEL: func @load_scalar_from_memref_subview
54+
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
55+
// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [1, 1], strides: [8, 1]
5356

5457
// -----
5558

@@ -76,7 +79,7 @@ func.func @store_scalar_from_memref_dynamic_dim(%input: memref<?x?xf32, strided<
7679
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32, strided<[?, ?], offset: ?>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: f32)
7780
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]]
7881
// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1]
79-
// CHECK: %[[SIZE:.*]] = affine.apply #[[MAP1]]()[%[[STRIDES]]#0, %[[SIZES]]#0]
82+
// CHECK: %[[SIZE:.*]] = affine.apply #[[MAP1]]()[%[[SIZES]]#0, %[[SIZES]]#1]
8083
// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [%[[OFFSET]]], sizes: [%[[SIZE]]], strides: [%[[STRIDES]]#1]
8184
// CHECK: memref.store %[[ARG3]], %[[REINT]][%[[IDX]]]
8285

@@ -88,7 +91,7 @@ func.func @load_vector_from_memref(%input: memref<4x8xf32>) -> vector<8xf32> {
8891
%value = vector.load %input[%c3, %c6] : memref<4x8xf32>, vector<8xf32>
8992
return %value : vector<8xf32>
9093
}
91-
// CHECK: func @load_vector_from_memref
94+
// CHECK-LABEL: func @load_vector_from_memref
9295
// CHECK: %[[C30:.*]] = arith.constant 30
9396
// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [32], strides: [1]
9497
// CHECK-NEXT: vector.load %[[REINT]][%[[C30]]]
@@ -101,7 +104,7 @@ func.func @load_vector_from_memref_odd(%input: memref<3x7xi2>) -> vector<3xi2> {
101104
%value = vector.load %input[%c1, %c3] : memref<3x7xi2>, vector<3xi2>
102105
return %value : vector<3xi2>
103106
}
104-
// CHECK: func @load_vector_from_memref_odd
107+
// CHECK-LABEL: func @load_vector_from_memref_odd
105108
// CHECK: %[[C10:.*]] = arith.constant 10 : index
106109
// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast
107110
// CHECK-NEXT: vector.load %[[REINT]][%[[C10]]]
@@ -126,10 +129,11 @@ func.func @store_vector_to_memref_odd(%input: memref<3x7xi2>, %value: vector<3xi
126129
vector.store %value, %input[%c1, %c3] : memref<3x7xi2>, vector<3xi2>
127130
return
128131
}
129-
// CHECK: func @store_vector_to_memref_odd
132+
// CHECK-LABEL: func @store_vector_to_memref_odd
133+
// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[ARG1:.*]]: vector<3xi2>)
130134
// CHECK: %[[C10:.*]] = arith.constant 10 : index
131135
// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast
132-
// CHECK-NEXT: vector.store %arg1, %[[REINT]][%[[C10]]] : memref<21xi2, strided<[1]>
136+
// CHECK-NEXT: vector.store %[[ARG1]], %[[REINT]][%[[C10]]] : memref<21xi2, strided<[1]>
133137

134138
// -----
135139

@@ -152,7 +156,7 @@ func.func @mask_store_vector_to_memref_odd(%input: memref<3x7xi2>, %value: vecto
152156
vector.maskedstore %input[%c1, %c3], %mask, %value : memref<3x7xi2>, vector<3xi1>, vector<3xi2>
153157
return
154158
}
155-
// CHECK: func @mask_store_vector_to_memref_odd
159+
// CHECK-LABEL: func @mask_store_vector_to_memref_odd
156160
// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[ARG1:.*]]: vector<3xi2>, %[[ARG2:.*]]: vector<3xi1>)
157161
// CHECK: %[[C10:.*]] = arith.constant 10 : index
158162
// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast
@@ -178,7 +182,7 @@ func.func @mask_load_vector_from_memref_odd(%input: memref<3x7xi2>, %mask: vecto
178182
%result = vector.maskedload %input[%c1, %c3], %mask, %passthru : memref<3x7xi2>, vector<3xi1>, vector<3xi2> into vector<3xi2>
179183
return %result : vector<3xi2>
180184
}
181-
// CHECK: func @mask_load_vector_from_memref_odd
185+
// CHECK-LABEL: func @mask_load_vector_from_memref_odd
182186
// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[MASK:.*]]: vector<3xi1>, %[[PASSTHRU:.*]]: vector<3xi2>)
183187
// CHECK: %[[C10:.*]] = arith.constant 10 : index
184188
// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [21], strides: [1]
@@ -204,7 +208,7 @@ func.func @transfer_read_memref(%input: memref<4x8xi2>, %value: vector<8xi2>, %r
204208
%0 = vector.transfer_read %input[%col, %row], %c0 : memref<4x8xi2>, vector<8xi2>
205209
return %0 : vector<8xi2>
206210
}
207-
// CHECK: func @transfer_read_memref
211+
// CHECK-LABEL: func @transfer_read_memref
208212
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xi2>, %[[ARG1:.*]]: vector<8xi2>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
209213
// CHECK: %[[C0:.*]] = arith.constant 0 : i2
210214
// CHECK: %[[IDX:.*]] = affine.apply #map()[%[[ARG3]], %[[ARG2]]]

0 commit comments

Comments
 (0)