Skip to content

Commit f32b3e1

Browse files
authored
[mlir][memref] Fix index delinearization for CollapseShapeOp folding (#68833)
The `resolveSourceIndicesCollapseShape` method is used to compute indices into the source `MemRef` of a `CollapseShapeOp` from the collapsed indices. This method didn't check for dynamic sizes of the source shape which led to a crash. Fix #68483
1 parent 894927b commit f32b3e1

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,16 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
128128
dynamicIndices.push_back(indices[cnt++]);
129129
int64_t groupSize = groups.size();
130130

131-
// Calculate suffix product for all collapse op source dimension sizes.
132-
SmallVector<int64_t> sizes(groupSize);
133-
for (int64_t i = 0; i < groupSize; ++i)
131+
// Calculate suffix product for all collapse op source dimension sizes
132+
// except the most major one of each group.
133+
// We allow the most major source dimension to be dynamic but enforce all
134+
// others to be known statically.
135+
SmallVector<int64_t> sizes(groupSize, 1);
136+
for (int64_t i = 1; i < groupSize; ++i) {
134137
sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]);
138+
if (sizes[i] == ShapedType::kDynamic)
139+
return failure();
140+
}
135141
SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
136142

137143
// Derive the index values along all dimensions of the source corresponding

mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,22 @@ func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg
317317

318318
// -----
319319

320+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 6)>
321+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 6)>
322+
// CHECK-LABEL: @fold_dynamic_size_collapse_shape_with_affine_load
323+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
324+
func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
325+
%0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<?x6x32xf32> into memref<?x32xf32>
326+
%1 = affine.load %0[%arg1, %arg2] : memref<?x32xf32>
327+
return %1 : f32
328+
}
329+
// CHECK-NEXT: %[[MODIFIED_INDEX0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG1]]]
330+
// CHECK-NEXT: %[[MODIFIED_INDEX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
331+
// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEX0]], %[[MODIFIED_INDEX1]], %[[ARG2]]] : memref<?x6x32xf32>
332+
// CHECK-NEXT: return %[[RESULT]] : f32
333+
334+
// -----
335+
320336
// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0, s1, s2] -> (s0 * 6 + s1 * 3 + s2)>
321337
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_3d
322338
// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index) -> f32 {

0 commit comments

Comments
 (0)