Skip to content

Commit 1e9c352

Browse files
committed
[Task] : Add back memref.dim canonicalization with dominator fix.
1 parent e9e298d commit 1e9c352

File tree

4 files changed

+144
-2
lines changed

4 files changed

+144
-2
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,7 @@ def MemRef_DimOp : MemRef_Op<"dim", [
629629
Speculation::Speculatability getSpeculatability();
630630
}];
631631

632+
let hasCanonicalizer = 1;
632633
let hasFolder = 1;
633634
}
634635

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,6 +1069,52 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
10691069
return {};
10701070
}
10711071

1072+
namespace {
1073+
/// Fold dim of a memref reshape operation to a load into the reshape's shape
1074+
/// operand.
1075+
struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1076+
using OpRewritePattern<DimOp>::OpRewritePattern;
1077+
1078+
LogicalResult matchAndRewrite(DimOp dim,
1079+
PatternRewriter &rewriter) const override {
1080+
auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1081+
1082+
if (!reshape)
1083+
return rewriter.notifyMatchFailure(
1084+
dim, "Dim op is not defined by a reshape op.");
1085+
1086+
if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1087+
if (auto *definingOp = dim.getIndex().getDefiningOp()) {
1088+
if (reshape->isBeforeInBlock(definingOp))
1089+
return rewriter.notifyMatchFailure(
1090+
dim,
1091+
"dim.getIndex is not defined before reshape in the same block.");
1092+
} // else dim.getIndex is a block argument to reshape->getBlock
1093+
} else if (!dim.getIndex().getParentRegion()->isProperAncestor(
1094+
reshape->getParentRegion()))
1095+
return rewriter.notifyMatchFailure(
1096+
dim, "dim.getIndex does not dominate reshape.");
1097+
1098+
// Place the load directly after the reshape to ensure that the shape memref
1099+
// was not mutated.
1100+
rewriter.setInsertionPointAfter(reshape);
1101+
Location loc = dim.getLoc();
1102+
Value load =
1103+
rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
1104+
if (load.getType() != dim.getType())
1105+
load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
1106+
rewriter.replaceOp(dim, load);
1107+
return success();
1108+
}
1109+
};
1110+
1111+
} // namespace
1112+
1113+
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1114+
MLIRContext *context) {
1115+
results.add<DimOfMemRefReshape>(context);
1116+
}
1117+
10721118
// ---------------------------------------------------------------------------
10731119
// DmaStartOp
10741120
// ---------------------------------------------------------------------------

mlir/test/Dialect/MemRef/canonicalize.mlir

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,101 @@ func.func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index {
242242

243243
// -----
244244

245+
// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
246+
// CHECK-LABEL: func @dim_of_memref_reshape(
247+
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
248+
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex>
249+
// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
250+
// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
251+
// CHECK-NEXT: memref.store
252+
// CHECK-NOT: memref.dim
253+
// CHECK: return %[[DIM]] : index
254+
func.func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
255+
-> index {
256+
%c3 = arith.constant 3 : index
257+
%0 = memref.reshape %arg0(%arg1)
258+
: (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
259+
// Update the shape to test that he load ends up in the right place.
260+
memref.store %c3, %arg1[%c3] : memref<?xindex>
261+
%1 = memref.dim %0, %c3 : memref<*xf32>
262+
return %1 : index
263+
}
264+
265+
// -----
266+
267+
// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
268+
// CHECK-LABEL: func @dim_of_memref_reshape_i32(
269+
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
270+
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xi32>
271+
// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
272+
// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
273+
// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast %[[DIM]]
274+
// CHECK-NOT: memref.dim
275+
// CHECK: return %[[CAST]] : index
276+
func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
277+
-> index {
278+
%c3 = arith.constant 3 : index
279+
%0 = memref.reshape %arg0(%arg1)
280+
: (memref<*xf32>, memref<?xi32>) -> memref<*xf32>
281+
%1 = memref.dim %0, %c3 : memref<*xf32>
282+
return %1 : index
283+
}
284+
285+
// -----
286+
287+
// Test case: memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
288+
// CHECK-LABEL: func @dim_of_memref_reshape_block_arg_index(
289+
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
290+
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex>,
291+
// CHECK-SAME: %[[IDX:[0-9a-z]+]]: index
292+
// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
293+
// CHECK-NOT: memref.dim
294+
// CHECK: return %[[DIM]] : index
295+
func.func @dim_of_memref_reshape_block_arg_index(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
296+
%reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
297+
%dim = memref.dim %reshape, %arg2 : memref<*xf32>
298+
return %dim : index
299+
}
300+
301+
// -----
302+
303+
// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx]
304+
// CHECK-LABEL: func @dim_of_memref_reshape_for(
305+
// CHECK: memref.reshape
306+
// CHECK: memref.dim
307+
// CHECK-NOT: memref.load
308+
func.func @dim_of_memref_reshape_for( %arg0: memref<*xf32>, %arg1: memref<?xindex>) -> index {
309+
%c0 = arith.constant 0 : index
310+
%c1 = arith.constant 1 : index
311+
%c4 = arith.constant 4 : index
312+
313+
%0 = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
314+
315+
%1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
316+
%2 = memref.dim %0, %arg2 : memref<*xf32>
317+
%3 = arith.muli %arg3, %2 : index
318+
scf.yield %3 : index
319+
}
320+
return %1 : index
321+
}
322+
323+
// -----
324+
325+
// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx]
326+
// CHECK-LABEL: func @dim_of_memref_reshape_undominated(
327+
// CHECK: memref.reshape
328+
// CHECK: memref.dim
329+
// CHECK-NOT: memref.load
330+
func.func @dim_of_memref_reshape_undominated(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
331+
%c4 = arith.constant 4 : index
332+
%reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
333+
%0 = arith.muli %arg2, %c4 : index
334+
%dim = memref.dim %reshape, %0 : memref<*xf32>
335+
return %dim : index
336+
}
337+
338+
// -----
339+
245340
// CHECK-LABEL: func @alloc_const_fold
246341
func.func @alloc_const_fold() -> memref<?xf32> {
247342
// CHECK-NEXT: memref.alloc() : memref<4xf32>

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2294,7 +2294,7 @@ func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor<?xi32>)
22942294

22952295
// -----
22962296

2297-
// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx]
2297+
// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
22982298
// CHECK-LABEL: func @dim_of_reshape_for(
22992299
// CHECK: scf.for
23002300
// CHECK-NEXT: tensor.extract
@@ -2317,7 +2317,7 @@ func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor<?xindex>) ->
23172317

23182318
// -----
23192319

2320-
// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx]
2320+
// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
23212321
// CHECK-LABEL: func @dim_of_reshape_undominated(
23222322
// CHECK: arith.muli
23232323
// CHECK-NEXT: tensor.extract

0 commit comments

Comments
 (0)