Skip to content

Commit 5767e4d

Browse files
authored
[MLIR][NFC] Return MemRefType in memref.subview return type inference functions (#120024)
Avoids the need for cast, and matches the extra build functions, which take a `MemRefType`
1 parent 23cb0de commit 5767e4d

File tree

7 files changed

+70
-77
lines changed

7 files changed

+70
-77
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2081,14 +2081,14 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
20812081
/// A subview result type can be fully inferred from the source type and the
20822082
/// static representation of offsets, sizes and strides. Special sentinels
20832083
/// encode the dynamic case.
2084-
static Type inferResultType(MemRefType sourceMemRefType,
2085-
ArrayRef<int64_t> staticOffsets,
2086-
ArrayRef<int64_t> staticSizes,
2087-
ArrayRef<int64_t> staticStrides);
2088-
static Type inferResultType(MemRefType sourceMemRefType,
2089-
ArrayRef<OpFoldResult> staticOffsets,
2090-
ArrayRef<OpFoldResult> staticSizes,
2091-
ArrayRef<OpFoldResult> staticStrides);
2084+
static MemRefType inferResultType(MemRefType sourceMemRefType,
2085+
ArrayRef<int64_t> staticOffsets,
2086+
ArrayRef<int64_t> staticSizes,
2087+
ArrayRef<int64_t> staticStrides);
2088+
static MemRefType inferResultType(MemRefType sourceMemRefType,
2089+
ArrayRef<OpFoldResult> staticOffsets,
2090+
ArrayRef<OpFoldResult> staticSizes,
2091+
ArrayRef<OpFoldResult> staticStrides);
20922092

20932093
/// A rank-reducing result type can be inferred from the desired result
20942094
/// shape. Only the layout map is inferred.
@@ -2097,16 +2097,16 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
20972097
/// and the desired sizes. In case there are more "ones" among the sizes
20982098
/// than the difference in source/result rank, it is not clear which dims of
20992099
/// size one should be dropped.
2100-
static Type inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2101-
MemRefType sourceMemRefType,
2102-
ArrayRef<int64_t> staticOffsets,
2103-
ArrayRef<int64_t> staticSizes,
2104-
ArrayRef<int64_t> staticStrides);
2105-
static Type inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2106-
MemRefType sourceMemRefType,
2107-
ArrayRef<OpFoldResult> staticOffsets,
2108-
ArrayRef<OpFoldResult> staticSizes,
2109-
ArrayRef<OpFoldResult> staticStrides);
2100+
static MemRefType inferRankReducedResultType(
2101+
ArrayRef<int64_t> resultShape, MemRefType sourceMemRefType,
2102+
ArrayRef<int64_t> staticOffsets,
2103+
ArrayRef<int64_t> staticSizes,
2104+
ArrayRef<int64_t> staticStrides);
2105+
static MemRefType inferRankReducedResultType(
2106+
ArrayRef<int64_t> resultShape, MemRefType sourceMemRefType,
2107+
ArrayRef<OpFoldResult> staticOffsets,
2108+
ArrayRef<OpFoldResult> staticSizes,
2109+
ArrayRef<OpFoldResult> staticStrides);
21102110

21112111
/// Return the expected rank of each of the`static_offsets`, `static_sizes`
21122112
/// and `static_strides` attributes.

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2702,10 +2702,10 @@ void SubViewOp::getAsmResultNames(
27022702
/// A subview result type can be fully inferred from the source type and the
27032703
/// static representation of offsets, sizes and strides. Special sentinels
27042704
/// encode the dynamic case.
2705-
Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2706-
ArrayRef<int64_t> staticOffsets,
2707-
ArrayRef<int64_t> staticSizes,
2708-
ArrayRef<int64_t> staticStrides) {
2705+
MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2706+
ArrayRef<int64_t> staticOffsets,
2707+
ArrayRef<int64_t> staticSizes,
2708+
ArrayRef<int64_t> staticStrides) {
27092709
unsigned rank = sourceMemRefType.getRank();
27102710
(void)rank;
27112711
assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
@@ -2744,10 +2744,10 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
27442744
sourceMemRefType.getMemorySpace());
27452745
}
27462746

2747-
Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2748-
ArrayRef<OpFoldResult> offsets,
2749-
ArrayRef<OpFoldResult> sizes,
2750-
ArrayRef<OpFoldResult> strides) {
2747+
MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2748+
ArrayRef<OpFoldResult> offsets,
2749+
ArrayRef<OpFoldResult> sizes,
2750+
ArrayRef<OpFoldResult> strides) {
27512751
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
27522752
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
27532753
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -2763,13 +2763,12 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
27632763
staticSizes, staticStrides);
27642764
}
27652765

2766-
Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2767-
MemRefType sourceRankedTensorType,
2768-
ArrayRef<int64_t> offsets,
2769-
ArrayRef<int64_t> sizes,
2770-
ArrayRef<int64_t> strides) {
2771-
auto inferredType = llvm::cast<MemRefType>(
2772-
inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2766+
MemRefType SubViewOp::inferRankReducedResultType(
2767+
ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2768+
ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2769+
ArrayRef<int64_t> strides) {
2770+
MemRefType inferredType =
2771+
inferResultType(sourceRankedTensorType, offsets, sizes, strides);
27732772
assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
27742773
"expected ");
27752774
if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
@@ -2795,11 +2794,10 @@ Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
27952794
inferredType.getMemorySpace());
27962795
}
27972796

2798-
Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2799-
MemRefType sourceRankedTensorType,
2800-
ArrayRef<OpFoldResult> offsets,
2801-
ArrayRef<OpFoldResult> sizes,
2802-
ArrayRef<OpFoldResult> strides) {
2797+
MemRefType SubViewOp::inferRankReducedResultType(
2798+
ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2799+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
2800+
ArrayRef<OpFoldResult> strides) {
28032801
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
28042802
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
28052803
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -2826,8 +2824,8 @@ void SubViewOp::build(OpBuilder &b, OperationState &result,
28262824
auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
28272825
// Structuring implementation this way avoids duplication between builders.
28282826
if (!resultType) {
2829-
resultType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
2830-
sourceMemRefType, staticOffsets, staticSizes, staticStrides));
2827+
resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2828+
staticSizes, staticStrides);
28312829
}
28322830
result.addAttributes(attrs);
28332831
build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
@@ -2992,8 +2990,8 @@ LogicalResult SubViewOp::verify() {
29922990

29932991
// Compute the expected result type, assuming that there are no rank
29942992
// reductions.
2995-
auto expectedType = cast<MemRefType>(SubViewOp::inferResultType(
2996-
baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides()));
2993+
MemRefType expectedType = SubViewOp::inferResultType(
2994+
baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides());
29972995

29982996
// Verify all properties of a shaped type: rank, element type and dimension
29992997
// sizes. This takes into account potential rank reductions.
@@ -3075,8 +3073,8 @@ static MemRefType getCanonicalSubViewResultType(
30753073
MemRefType currentResultType, MemRefType currentSourceType,
30763074
MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
30773075
ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
3078-
auto nonRankReducedType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
3079-
sourceType, mixedOffsets, mixedSizes, mixedStrides));
3076+
MemRefType nonRankReducedType = SubViewOp::inferResultType(
3077+
sourceType, mixedOffsets, mixedSizes, mixedStrides);
30803078
FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
30813079
currentSourceType, currentResultType, mixedSizes);
30823080
if (failed(unusedDims))
@@ -3110,9 +3108,8 @@ Value mlir::memref::createCanonicalRankReducingSubViewOp(
31103108
SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
31113109
SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, memref);
31123110
SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3113-
auto targetType =
3114-
llvm::cast<MemRefType>(SubViewOp::inferRankReducedResultType(
3115-
targetShape, memrefType, offsets, sizes, strides));
3111+
MemRefType targetType = SubViewOp::inferRankReducedResultType(
3112+
targetShape, memrefType, offsets, sizes, strides);
31163113
return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
31173114
sizes, strides);
31183115
}
@@ -3256,11 +3253,11 @@ struct SubViewReturnTypeCanonicalizer {
32563253
ArrayRef<OpFoldResult> mixedSizes,
32573254
ArrayRef<OpFoldResult> mixedStrides) {
32583255
// Infer a memref type without taking into account any rank reductions.
3259-
auto resTy = SubViewOp::inferResultType(op.getSourceType(), mixedOffsets,
3260-
mixedSizes, mixedStrides);
3256+
MemRefType resTy = SubViewOp::inferResultType(
3257+
op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
32613258
if (!resTy)
32623259
return {};
3263-
MemRefType nonReducedType = cast<MemRefType>(resTy);
3260+
MemRefType nonReducedType = resTy;
32643261

32653262
// Directly return the non-rank reduced type if there are no dropped dims.
32663263
llvm::SmallBitVector droppedDims = op.getDroppedDims();

mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ propagateSubViewOp(RewriterBase &rewriter,
7070
UnrealizedConversionCastOp conversionOp, SubViewOp op) {
7171
OpBuilder::InsertionGuard g(rewriter);
7272
rewriter.setInsertionPoint(op);
73-
auto newResultType = cast<MemRefType>(SubViewOp::inferRankReducedResultType(
73+
MemRefType newResultType = SubViewOp::inferRankReducedResultType(
7474
op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(),
75-
op.getMixedSizes(), op.getMixedStrides()));
75+
op.getMixedSizes(), op.getMixedStrides());
7676
Value newSubview = rewriter.create<SubViewOp>(
7777
op.getLoc(), newResultType, conversionOp.getOperand(0),
7878
op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides());

mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,13 @@ static void replaceUsesAndPropagateType(RewriterBase &rewriter,
6060
// `subview(old_op)` is replaced by a new `subview(val)`.
6161
OpBuilder::InsertionGuard g(rewriter);
6262
rewriter.setInsertionPoint(subviewUse);
63-
Type newType = memref::SubViewOp::inferRankReducedResultType(
63+
MemRefType newType = memref::SubViewOp::inferRankReducedResultType(
6464
subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
6565
subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
6666
subviewUse.getStaticStrides());
6767
Value newSubview = rewriter.create<memref::SubViewOp>(
68-
subviewUse->getLoc(), cast<MemRefType>(newType), val,
69-
subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
70-
subviewUse.getMixedStrides());
68+
subviewUse->getLoc(), newType, val, subviewUse.getMixedOffsets(),
69+
subviewUse.getMixedSizes(), subviewUse.getMixedStrides());
7170

7271
// Ouch recursion ... is this really necessary?
7372
replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);
@@ -211,9 +210,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
211210
for (int64_t i = 0, e = originalShape.size(); i != e; ++i)
212211
sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]);
213212
// Strides is [1, 1 ... 1 ].
214-
auto dstMemref =
215-
cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
216-
originalShape, mbMemRefType, offsets, sizes, strides));
213+
MemRefType dstMemref = memref::SubViewOp::inferRankReducedResultType(
214+
originalShape, mbMemRefType, offsets, sizes, strides);
217215
Value subview = rewriter.create<memref::SubViewOp>(loc, dstMemref, mbAlloc,
218216
offsets, sizes, strides);
219217
LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n");

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -407,10 +407,10 @@ struct ExtractSliceOpInterface
407407
SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
408408
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
409409
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
410-
return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
410+
return memref::SubViewOp::inferRankReducedResultType(
411411
extractSliceOp.getType().getShape(),
412412
llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
413-
mixedStrides));
413+
mixedStrides);
414414
}
415415
};
416416

@@ -692,10 +692,10 @@ struct InsertSliceOpInterface
692692

693693
// Take a subview of the destination buffer.
694694
auto dstMemrefType = cast<MemRefType>(dstMemref->getType());
695-
auto subviewMemRefType =
696-
cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
695+
MemRefType subviewMemRefType =
696+
memref::SubViewOp::inferRankReducedResultType(
697697
insertSliceOp.getSourceType().getShape(), dstMemrefType,
698-
mixedOffsets, mixedSizes, mixedStrides));
698+
mixedOffsets, mixedSizes, mixedStrides);
699699
Value subView = rewriter.create<memref::SubViewOp>(
700700
loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
701701
mixedStrides);
@@ -960,12 +960,12 @@ struct ParallelInsertSliceOpInterface
960960

961961
// Take a subview of the destination buffer.
962962
auto destBufferType = cast<MemRefType>(destBuffer->getType());
963-
auto subviewMemRefType =
964-
cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
963+
MemRefType subviewMemRefType =
964+
memref::SubViewOp::inferRankReducedResultType(
965965
parallelInsertSliceOp.getSourceType().getShape(), destBufferType,
966966
parallelInsertSliceOp.getMixedOffsets(),
967967
parallelInsertSliceOp.getMixedSizes(),
968-
parallelInsertSliceOp.getMixedStrides()));
968+
parallelInsertSliceOp.getMixedStrides());
969969
Value subview = rewriter.create<memref::SubViewOp>(
970970
parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer,
971971
parallelInsertSliceOp.getMixedOffsets(),

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,9 @@ static MemRefType dropUnitDims(MemRefType inputType,
265265
ArrayRef<OpFoldResult> sizes,
266266
ArrayRef<OpFoldResult> strides) {
267267
auto targetShape = getReducedShape(sizes);
268-
Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
268+
MemRefType rankReducedType = memref::SubViewOp::inferRankReducedResultType(
269269
targetShape, inputType, offsets, sizes, strides);
270-
return cast<MemRefType>(rankReducedType).canonicalizeStridedLayout();
270+
return rankReducedType.canonicalizeStridedLayout();
271271
}
272272

273273
/// Creates a rank-reducing memref.subview op that drops unit dims from its

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,10 +1326,9 @@ class DropInnerMostUnitDimsTransferRead
13261326
rewriter.getIndexAttr(0));
13271327
SmallVector<OpFoldResult> strides(srcType.getRank(),
13281328
rewriter.getIndexAttr(1));
1329-
auto resultMemrefType =
1330-
cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
1331-
srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1332-
strides));
1329+
MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1330+
srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1331+
strides);
13331332
ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
13341333
readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
13351334
Value rankedReducedView = rewriter.create<memref::SubViewOp>(
@@ -1417,10 +1416,9 @@ class DropInnerMostUnitDimsTransferWrite
14171416
rewriter.getIndexAttr(0));
14181417
SmallVector<OpFoldResult> strides(srcType.getRank(),
14191418
rewriter.getIndexAttr(1));
1420-
auto resultMemrefType =
1421-
cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
1422-
srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1423-
strides));
1419+
MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1420+
srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1421+
strides);
14241422
ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
14251423
writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
14261424

0 commit comments

Comments
 (0)