Skip to content

Commit 7b097f4

Browse files
committed
Not working yet.
1 parent ce84995 commit 7b097f4

File tree

1 file changed

+66
-2
lines changed

1 file changed

+66
-2
lines changed

mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,69 @@ static void setInsertionPointToStart(OpBuilder &builder, Value val) {
4646
}
4747
}
4848

49+
OpFoldResult computeMemRefSpan(Value memref, OpBuilder &builder) {
50+
Location loc = memref.getLoc();
51+
MemRefType type = cast<MemRefType>(memref.getType());
52+
ArrayRef<int64_t> shape = type.getShape();
53+
54+
// Check for empty memref
55+
if (type.hasStaticShape() &&
56+
llvm::any_of(shape, [](int64_t dim) { return dim == 0; })) {
57+
return builder.getIndexAttr(0);
58+
}
59+
60+
// Get strides of the memref
61+
SmallVector<int64_t, 4> strides;
62+
int64_t offset;
63+
if (failed(type.getStridesAndOffset(strides, offset))) {
64+
// Cannot extract strides, return a dynamic value
65+
return Value();
66+
}
67+
68+
// Static case: compute at compile time if possible
69+
if (type.hasStaticShape()) {
70+
int64_t span = 0;
71+
for (unsigned i = 0; i < type.getRank(); ++i) {
72+
span += (shape[i] - 1) * strides[i];
73+
}
74+
return builder.getIndexAttr(span);
75+
}
76+
77+
// Dynamic case: emit IR to compute at runtime
78+
Value result = builder.create<arith::ConstantIndexOp>(loc, 0);
79+
80+
for (unsigned i = 0; i < type.getRank(); ++i) {
81+
// Get dimension size
82+
Value dimSize;
83+
if (shape[i] == ShapedType::kDynamic) {
84+
dimSize = builder.create<memref::DimOp>(loc, memref, i);
85+
} else {
86+
dimSize = builder.create<arith::ConstantIndexOp>(loc, shape[i]);
87+
}
88+
89+
// Compute (dim - 1)
90+
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
91+
Value dimMinusOne = builder.create<arith::SubIOp>(loc, dimSize, one);
92+
93+
// Get stride
94+
Value stride;
95+
if (strides[i] == ShapedType::kDynamicStrideOrOffset) {
96+
// For dynamic strides, need to extract from memref descriptor
97+
// This would require runtime support, possibly using extractStride
98+
// As a placeholder, return a dynamic value
99+
return Value();
100+
} else {
101+
stride = builder.create<arith::ConstantIndexOp>(loc, strides[i]);
102+
}
103+
104+
// Add (dim - 1) * stride to result
105+
Value term = builder.create<arith::MulIOp>(loc, dimMinusOne, stride);
106+
result = builder.create<arith::AddIOp>(loc, result, term);
107+
}
108+
109+
return result;
110+
}
111+
49112
static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>, OpFoldResult,
50113
OpFoldResult>
51114
getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
@@ -102,8 +165,9 @@ getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
102165
affine::makeComposedFoldedAffineApply(rewriter, loc, expr, values);
103166

104167
// Compute collapsed size: (the outmost stride * outmost dimension).
105-
SmallVector<OpFoldResult> ops{origStrides.front(), outmostDim};
106-
OpFoldResult collapsedSize = affine::computeProduct(loc, rewriter, ops);
168+
//SmallVector<OpFoldResult> ops{origStrides.front(), outmostDim};
169+
//OpFoldResult collapsedSize = affine::computeProduct(loc, rewriter, ops);
170+
OpFoldResult collapsedSize = computeMemRefSpan(source, rewriter);
107171

108172
return {newExtractStridedMetadata.getBaseBuffer(), linearizedIndex,
109173
origStrides, origOffset, collapsedSize};

0 commit comments

Comments
 (0)