Skip to content

Commit 932dc9d

Browse files
authored
[mlir][MemRef] Add a pattern to simplify `extract_strided_metadata(ca… (#68291)
…st)` `expand-strided-metadata` was missing a pattern to get rid of `memref.cast`. The pattern is straight foward: Produce a new `extract_strided_metadata` with the source of the cast and fold the static information (sizes, strides, offset) along the way.
1 parent 253ee85 commit 932dc9d

File tree

2 files changed

+213
-0
lines changed

2 files changed

+213
-0
lines changed

mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp

+88
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,92 @@ class ExtractStridedMetadataOpReinterpretCastFolder
870870
}
871871
};
872872

873+
/// Replace `base, offset, sizes, strides =
874+
/// extract_strided_metadata(
875+
/// cast(src) to dstTy)`
876+
/// With
877+
/// ```
878+
/// base, ... = extract_strided_metadata(src)
879+
/// offset = !dstTy.srcOffset.isDynamic()
880+
/// ? dstTy.srcOffset
881+
/// : extract_strided_metadata(src).offset
882+
/// sizes = for each srcSize in dstTy.srcSizes:
883+
/// !srcSize.isDynamic()
884+
/// ? srcSize
885+
// : extract_strided_metadata(src).sizes[i]
886+
/// strides = for each srcStride in dstTy.srcStrides:
887+
/// !srcStrides.isDynamic()
888+
/// ? srcStrides
889+
/// : extract_strided_metadata(src).strides[i]
890+
/// ```
891+
///
892+
/// In other words, consume the `cast` and apply its effects
893+
/// on the offset, sizes, and strides or compute them directly from `src`.
894+
class ExtractStridedMetadataOpCastFolder
895+
: public OpRewritePattern<memref::ExtractStridedMetadataOp> {
896+
using OpRewritePattern::OpRewritePattern;
897+
898+
LogicalResult
899+
matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
900+
PatternRewriter &rewriter) const override {
901+
Value source = extractStridedMetadataOp.getSource();
902+
auto castOp = source.getDefiningOp<memref::CastOp>();
903+
if (!castOp)
904+
return failure();
905+
906+
Location loc = extractStridedMetadataOp.getLoc();
907+
// Check if the source is suitable for extract_strided_metadata.
908+
SmallVector<Type> inferredReturnTypes;
909+
if (failed(extractStridedMetadataOp.inferReturnTypes(
910+
rewriter.getContext(), loc, {castOp.getSource()},
911+
/*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
912+
inferredReturnTypes)))
913+
return rewriter.notifyMatchFailure(castOp,
914+
"cast source's type is incompatible");
915+
916+
auto memrefType = cast<MemRefType>(source.getType());
917+
unsigned rank = memrefType.getRank();
918+
SmallVector<OpFoldResult> results;
919+
results.resize_for_overwrite(rank * 2 + 2);
920+
921+
auto newExtractStridedMetadata =
922+
rewriter.create<memref::ExtractStridedMetadataOp>(loc,
923+
castOp.getSource());
924+
925+
// Register the base_buffer.
926+
results[0] = newExtractStridedMetadata.getBaseBuffer();
927+
928+
auto getConstantOrValue = [&rewriter](int64_t constant,
929+
OpFoldResult ofr) -> OpFoldResult {
930+
return !ShapedType::isDynamic(constant)
931+
? OpFoldResult(rewriter.getIndexAttr(constant))
932+
: ofr;
933+
};
934+
935+
auto [sourceStrides, sourceOffset] = getStridesAndOffset(memrefType);
936+
assert(sourceStrides.size() == rank && "unexpected number of strides");
937+
938+
// Register the new offset.
939+
results[1] =
940+
getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());
941+
942+
const unsigned sizeStartIdx = 2;
943+
const unsigned strideStartIdx = sizeStartIdx + rank;
944+
ArrayRef<int64_t> sourceSizes = memrefType.getShape();
945+
946+
SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes();
947+
SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides();
948+
for (unsigned i = 0; i < rank; ++i) {
949+
results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
950+
results[strideStartIdx + i] =
951+
getConstantOrValue(sourceStrides[i], strides[i]);
952+
}
953+
rewriter.replaceOp(extractStridedMetadataOp,
954+
getValueOrCreateConstantIndexOp(rewriter, loc, results));
955+
return success();
956+
}
957+
};
958+
873959
/// Replace `base, offset =
874960
/// extract_strided_metadata(extract_strided_metadata(src)#0)`
875961
/// With
@@ -911,6 +997,7 @@ void memref::populateExpandStridedMetadataPatterns(
911997
ExtractStridedMetadataOpGetGlobalFolder,
912998
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
913999
ExtractStridedMetadataOpReinterpretCastFolder,
1000+
ExtractStridedMetadataOpCastFolder,
9141001
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
9151002
patterns.getContext());
9161003
}
@@ -923,6 +1010,7 @@ void memref::populateResolveExtractStridedMetadataPatterns(
9231010
ExtractStridedMetadataOpSubviewFolder,
9241011
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
9251012
ExtractStridedMetadataOpReinterpretCastFolder,
1013+
ExtractStridedMetadataOpCastFolder,
9261014
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
9271015
patterns.getContext());
9281016
}

mlir/test/Dialect/MemRef/expand-strided-metadata.mlir

+125
Original file line numberDiff line numberDiff line change
@@ -1369,3 +1369,128 @@ func.func @extract_strided_metadata_of_get_global_with_offset()
13691369
return %base, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
13701370
memref<i32>, index, index, index, index, index
13711371
}
1372+
1373+
// -----
1374+
1375+
// Check that we simplify extract_strided_metadata of cast
1376+
// when the source of the cast is compatible with what
1377+
// `extract_strided_metadata`s accept.
1378+
//
1379+
// When we apply the transformation the resulting offset, sizes and strides
1380+
// should come straight from the inputs of the cast.
1381+
// Additionally the folder on extract_strided_metadata should propagate the
1382+
// static information.
1383+
//
1384+
// CHECK-LABEL: func @extract_strided_metadata_of_cast
1385+
// CHECK-SAME: %[[ARG:.*]]: memref<3x?xi32, strided<[4, ?], offset: ?>>)
1386+
//
1387+
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
1388+
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
1389+
// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
1390+
//
1391+
// CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[C3]], %[[DYN_SIZES]]#1, %[[C4]], %[[DYN_STRIDES]]#1
1392+
func.func @extract_strided_metadata_of_cast(
1393+
%arg : memref<3x?xi32, strided<[4, ?], offset:?>>)
1394+
-> (memref<i32>, index,
1395+
index, index,
1396+
index, index) {
1397+
1398+
%cast =
1399+
memref.cast %arg :
1400+
memref<3x?xi32, strided<[4, ?], offset: ?>> to
1401+
memref<?x?xi32, strided<[?, ?], offset: ?>>
1402+
1403+
%base, %base_offset, %sizes:2, %strides:2 =
1404+
memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
1405+
-> memref<i32>, index,
1406+
index, index,
1407+
index, index
1408+
1409+
return %base, %base_offset,
1410+
%sizes#0, %sizes#1,
1411+
%strides#0, %strides#1 :
1412+
memref<i32>, index,
1413+
index, index,
1414+
index, index
1415+
}
1416+
1417+
// -----
1418+
1419+
// Check that we simplify extract_strided_metadata of cast
1420+
// when the source of the cast is compatible with what
1421+
// `extract_strided_metadata`s accept.
1422+
//
1423+
// Same as extract_strided_metadata_of_cast but with constant sizes and strides
1424+
// in the destination type.
1425+
//
1426+
// CHECK-LABEL: func @extract_strided_metadata_of_cast_w_csts
1427+
// CHECK-SAME: %[[ARG:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>)
1428+
//
1429+
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
1430+
// CHECK-DAG: %[[C18:.*]] = arith.constant 18 : index
1431+
// CHECK-DAG: %[[C25:.*]] = arith.constant 25 : index
1432+
// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
1433+
//
1434+
// CHECK: return %[[BASE]], %[[C25]], %[[C4]], %[[DYN_SIZES]]#1, %[[DYN_STRIDES]]#0, %[[C18]]
1435+
func.func @extract_strided_metadata_of_cast_w_csts(
1436+
%arg : memref<?x?xi32, strided<[?, ?], offset:?>>)
1437+
-> (memref<i32>, index,
1438+
index, index,
1439+
index, index) {
1440+
1441+
%cast =
1442+
memref.cast %arg :
1443+
memref<?x?xi32, strided<[?, ?], offset: ?>> to
1444+
memref<4x?xi32, strided<[?, 18], offset: 25>>
1445+
1446+
%base, %base_offset, %sizes:2, %strides:2 =
1447+
memref.extract_strided_metadata %cast:memref<4x?xi32, strided<[?, 18], offset: 25>>
1448+
-> memref<i32>, index,
1449+
index, index,
1450+
index, index
1451+
1452+
return %base, %base_offset,
1453+
%sizes#0, %sizes#1,
1454+
%strides#0, %strides#1 :
1455+
memref<i32>, index,
1456+
index, index,
1457+
index, index
1458+
}
1459+
// -----
1460+
1461+
// Check that we don't simplify extract_strided_metadata of
1462+
// cast when the source of the cast is unranked.
1463+
// Unranked memrefs cannot feed into extract_strided_metadata operations.
1464+
// Note: Technically we could still fold the sizes and strides.
1465+
//
1466+
// CHECK-LABEL: func @extract_strided_metadata_of_cast_unranked
1467+
// CHECK-SAME: %[[ARG:.*]]: memref<*xi32>)
1468+
//
1469+
// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] :
1470+
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]]
1471+
//
1472+
// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1
1473+
func.func @extract_strided_metadata_of_cast_unranked(
1474+
%arg : memref<*xi32>)
1475+
-> (memref<i32>, index,
1476+
index, index,
1477+
index, index) {
1478+
1479+
%cast =
1480+
memref.cast %arg :
1481+
memref<*xi32> to
1482+
memref<?x?xi32, strided<[?, ?], offset: ?>>
1483+
1484+
%base, %base_offset, %sizes:2, %strides:2 =
1485+
memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
1486+
-> memref<i32>, index,
1487+
index, index,
1488+
index, index
1489+
1490+
return %base, %base_offset,
1491+
%sizes#0, %sizes#1,
1492+
%strides#0, %strides#1 :
1493+
memref<i32>, index,
1494+
index, index,
1495+
index, index
1496+
}

0 commit comments

Comments
 (0)