@@ -46,6 +46,69 @@ static void setInsertionPointToStart(OpBuilder &builder, Value val) {
46
46
}
47
47
}
48
48
49
+ OpFoldResult computeMemRefSpan (Value memref, OpBuilder &builder) {
50
+ Location loc = memref.getLoc ();
51
+ MemRefType type = cast<MemRefType>(memref.getType ());
52
+ ArrayRef<int64_t > shape = type.getShape ();
53
+
54
+ // Check for empty memref
55
+ if (type.hasStaticShape () &&
56
+ llvm::any_of (shape, [](int64_t dim) { return dim == 0 ; })) {
57
+ return builder.getIndexAttr (0 );
58
+ }
59
+
60
+ // Get strides of the memref
61
+ SmallVector<int64_t , 4 > strides;
62
+ int64_t offset;
63
+ if (failed (type.getStridesAndOffset (strides, offset))) {
64
+ // Cannot extract strides, return a dynamic value
65
+ return Value ();
66
+ }
67
+
68
+ // Static case: compute at compile time if possible
69
+ if (type.hasStaticShape ()) {
70
+ int64_t span = 0 ;
71
+ for (unsigned i = 0 ; i < type.getRank (); ++i) {
72
+ span += (shape[i] - 1 ) * strides[i];
73
+ }
74
+ return builder.getIndexAttr (span);
75
+ }
76
+
77
+ // Dynamic case: emit IR to compute at runtime
78
+ Value result = builder.create <arith::ConstantIndexOp>(loc, 0 );
79
+
80
+ for (unsigned i = 0 ; i < type.getRank (); ++i) {
81
+ // Get dimension size
82
+ Value dimSize;
83
+ if (shape[i] == ShapedType::kDynamic ) {
84
+ dimSize = builder.create <memref::DimOp>(loc, memref, i);
85
+ } else {
86
+ dimSize = builder.create <arith::ConstantIndexOp>(loc, shape[i]);
87
+ }
88
+
89
+ // Compute (dim - 1)
90
+ Value one = builder.create <arith::ConstantIndexOp>(loc, 1 );
91
+ Value dimMinusOne = builder.create <arith::SubIOp>(loc, dimSize, one);
92
+
93
+ // Get stride
94
+ Value stride;
95
+ if (strides[i] == ShapedType::kDynamicStrideOrOffset ) {
96
+ // For dynamic strides, need to extract from memref descriptor
97
+ // This would require runtime support, possibly using extractStride
98
+ // As a placeholder, return a dynamic value
99
+ return Value ();
100
+ } else {
101
+ stride = builder.create <arith::ConstantIndexOp>(loc, strides[i]);
102
+ }
103
+
104
+ // Add (dim - 1) * stride to result
105
+ Value term = builder.create <arith::MulIOp>(loc, dimMinusOne, stride);
106
+ result = builder.create <arith::AddIOp>(loc, result, term);
107
+ }
108
+
109
+ return result;
110
+ }
111
+
49
112
static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>, OpFoldResult,
50
113
OpFoldResult>
51
114
getFlatOffsetAndStrides (OpBuilder &rewriter, Location loc, Value source,
@@ -102,8 +165,9 @@ getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
102
165
affine::makeComposedFoldedAffineApply (rewriter, loc, expr, values);
103
166
104
167
// Compute collapsed size: (the outmost stride * outmost dimension).
105
- SmallVector<OpFoldResult> ops{origStrides.front (), outmostDim};
106
- OpFoldResult collapsedSize = affine::computeProduct (loc, rewriter, ops);
168
+ // SmallVector<OpFoldResult> ops{origStrides.front(), outmostDim};
169
+ // OpFoldResult collapsedSize = affine::computeProduct(loc, rewriter, ops);
170
+ OpFoldResult collapsedSize = computeMemRefSpan (source, rewriter);
107
171
108
172
return {newExtractStridedMetadata.getBaseBuffer (), linearizedIndex,
109
173
origStrides, origOffset, collapsedSize};
0 commit comments