Skip to content

Commit 6326b57

Browse files
[mlir][memref] memref.subview: Verify result strides
The `memref.subview` verifier currently checks result shape, element type, memory space and offset of the result type. However, the strides of the result type are currently not verified. This commit adds verification of result strides for non-rank reducing ops and fixes invalid IR in test cases. Verification of result strides for ops with rank reductions is more complex (and there could be multiple possible result types). That is left for a separate commit. Also refactor the implementation a bit: * If `computeMemRefRankReductionMask` could not compute the dropped dimensions, there must be something wrong with the op. Return `FailureOr` instead of `std::optional`. * `isRankReducedMemRefType` did much more than just checking whether the op has rank reductions or not. Inline the implementation into the verifier and add better comments. * `produceSubViewErrorMsg` does not have to be templatized.
1 parent cbe5985 commit 6326b57

File tree

5 files changed

+90
-65
lines changed

5 files changed

+90
-65
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 68 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -917,7 +917,7 @@ static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
917917
/// This accounts for cases where there are multiple unit-dims, but only a
918918
/// subset of those are dropped. For MemRefTypes these can be disambiguated
919919
/// using the strides. If a dimension is dropped the stride must be dropped too.
920-
static std::optional<llvm::SmallBitVector>
920+
static FailureOr<llvm::SmallBitVector>
921921
computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
922922
ArrayRef<OpFoldResult> sizes) {
923923
llvm::SmallBitVector unusedDims(originalType.getRank());
@@ -941,7 +941,7 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
941941
getStridesAndOffset(originalType, originalStrides, originalOffset)) ||
942942
failed(
943943
getStridesAndOffset(reducedType, candidateStrides, candidateOffset)))
944-
return std::nullopt;
944+
return failure();
945945

946946
// For memrefs, a dimension is truly dropped if its corresponding stride is
947947
// also dropped. This is particularly important when more than one of the dims
@@ -976,22 +976,22 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
976976
candidateStridesNumOccurences[originalStride]) {
977977
// This should never happen. Cant have a stride in the reduced rank type
978978
// that wasnt in the original one.
979-
return std::nullopt;
979+
return failure();
980980
}
981981
}
982982

983983
if ((int64_t)unusedDims.count() + reducedType.getRank() !=
984984
originalType.getRank())
985-
return std::nullopt;
985+
return failure();
986986
return unusedDims;
987987
}
988988

989989
llvm::SmallBitVector SubViewOp::getDroppedDims() {
990990
MemRefType sourceType = getSourceType();
991991
MemRefType resultType = getType();
992-
std::optional<llvm::SmallBitVector> unusedDims =
992+
FailureOr<llvm::SmallBitVector> unusedDims =
993993
computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
994-
assert(unusedDims && "unable to find unused dims of subview");
994+
assert(succeeded(unusedDims) && "unable to find unused dims of subview");
995995
return *unusedDims;
996996
}
997997

@@ -2745,7 +2745,7 @@ void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
27452745
/// For ViewLikeOpInterface.
27462746
Value SubViewOp::getViewSource() { return getSource(); }
27472747

2748-
/// Return true if t1 and t2 have equal offsets (both dynamic or of same
2748+
/// Return true if `t1` and `t2` have equal offsets (both dynamic or of same
27492749
/// static value).
27502750
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
27512751
int64_t t1Offset, t2Offset;
@@ -2755,56 +2755,41 @@ static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
27552755
return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
27562756
}
27572757

2758-
/// Checks if `original` Type type can be rank reduced to `reduced` type.
2759-
/// This function is slight variant of `is subsequence` algorithm where
2760-
/// not matching dimension must be 1.
2761-
static SliceVerificationResult
2762-
isRankReducedMemRefType(MemRefType originalType,
2763-
MemRefType candidateRankReducedType,
2764-
ArrayRef<OpFoldResult> sizes) {
2765-
auto partialRes = isRankReducedType(originalType, candidateRankReducedType);
2766-
if (partialRes != SliceVerificationResult::Success)
2767-
return partialRes;
2768-
2769-
auto optionalUnusedDimsMask = computeMemRefRankReductionMask(
2770-
originalType, candidateRankReducedType, sizes);
2771-
2772-
// Sizes cannot be matched in case empty vector is returned.
2773-
if (!optionalUnusedDimsMask)
2774-
return SliceVerificationResult::LayoutMismatch;
2775-
2776-
if (originalType.getMemorySpace() !=
2777-
candidateRankReducedType.getMemorySpace())
2778-
return SliceVerificationResult::MemSpaceMismatch;
2779-
2780-
// No amount of stride dropping can reconcile incompatible offsets.
2781-
if (!haveCompatibleOffsets(originalType, candidateRankReducedType))
2782-
return SliceVerificationResult::LayoutMismatch;
2783-
2784-
return SliceVerificationResult::Success;
2758+
/// Return true if `t1` and `t2` have equal strides (both dynamic or of same
2759+
/// static value).
2760+
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2) {
2761+
int64_t t1Offset, t2Offset;
2762+
SmallVector<int64_t> t1Strides, t2Strides;
2763+
auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
2764+
auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
2765+
if (failed(res1) || failed(res2))
2766+
return false;
2767+
for (auto [s1, s2] : llvm::zip_equal(t1Strides, t2Strides))
2768+
if (s1 != s2)
2769+
return false;
2770+
return true;
27852771
}
27862772

2787-
template <typename OpTy>
27882773
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result,
2789-
OpTy op, Type expectedType) {
2774+
Operation *op, Type expectedType) {
27902775
auto memrefType = llvm::cast<ShapedType>(expectedType);
27912776
switch (result) {
27922777
case SliceVerificationResult::Success:
27932778
return success();
27942779
case SliceVerificationResult::RankTooLarge:
2795-
return op.emitError("expected result rank to be smaller or equal to ")
2780+
return op->emitError("expected result rank to be smaller or equal to ")
27962781
<< "the source rank. ";
27972782
case SliceVerificationResult::SizeMismatch:
2798-
return op.emitError("expected result type to be ")
2783+
return op->emitError("expected result type to be ")
27992784
<< expectedType
28002785
<< " or a rank-reduced version. (mismatch of result sizes) ";
28012786
case SliceVerificationResult::ElemTypeMismatch:
2802-
return op.emitError("expected result element type to be ")
2787+
return op->emitError("expected result element type to be ")
28032788
<< memrefType.getElementType();
28042789
case SliceVerificationResult::MemSpaceMismatch:
2805-
return op.emitError("expected result and source memory spaces to match.");
2790+
return op->emitError("expected result and source memory spaces to match.");
28062791
case SliceVerificationResult::LayoutMismatch:
2807-
return op.emitError("expected result type to be ")
2792+
return op->emitError("expected result type to be ")
28082793
<< expectedType
28092794
<< " or a rank-reduced version. (mismatch of result layout) ";
28102795
}
@@ -2826,13 +2811,46 @@ LogicalResult SubViewOp::verify() {
28262811
if (!isStrided(baseType))
28272812
return emitError("base type ") << baseType << " is not strided";
28282813

2829-
// Verify result type against inferred type.
2830-
auto expectedType = SubViewOp::inferResultType(
2831-
baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides());
2814+
// Compute the expected result type, assuming that there are no rank
2815+
// reductions.
2816+
auto expectedType = cast<MemRefType>(SubViewOp::inferResultType(
2817+
baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides()));
2818+
2819+
// Verify all properties of a shaped type: rank, element type and dimension
2820+
// sizes. This takes into account potential rank reductions.
2821+
auto shapedTypeVerification = isRankReducedType(
2822+
/*originalType=*/expectedType, /*candidateReducedType=*/subViewType);
2823+
if (shapedTypeVerification != SliceVerificationResult::Success)
2824+
return produceSubViewErrorMsg(shapedTypeVerification, *this, expectedType);
2825+
2826+
// Make sure that the memory space did not change.
2827+
if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
2828+
return produceSubViewErrorMsg(SliceVerificationResult::MemSpaceMismatch,
2829+
*this, expectedType);
2830+
2831+
// Verify the offset of the layout map.
2832+
if (!haveCompatibleOffsets(expectedType, subViewType))
2833+
return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch,
2834+
*this, expectedType);
2835+
2836+
// The only thing that's left to verify now are the strides. First, compute
2837+
// the unused dimensions due to rank reductions. We have to look at sizes and
2838+
// strides to decide which dimensions were dropped. This function also
2839+
// verifies strides in case of rank reductions.
2840+
auto unusedDims = computeMemRefRankReductionMask(expectedType, subViewType,
2841+
getMixedSizes());
2842+
if (failed(unusedDims))
2843+
return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch,
2844+
*this, expectedType);
2845+
2846+
// Strides must match if there are no rank reductions.
2847+
// TODO: Verify strides when there are rank reductions. Strides are partially
2848+
// checked in `computeMemRefRankReductionMask`.
2849+
if (unusedDims->none() && !haveCompatibleStrides(expectedType, subViewType))
2850+
return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch,
2851+
*this, expectedType);
28322852

2833-
auto result = isRankReducedMemRefType(llvm::cast<MemRefType>(expectedType),
2834-
subViewType, getMixedSizes());
2835-
return produceSubViewErrorMsg(result, *this, expectedType);
2853+
return success();
28362854
}
28372855

28382856
raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
@@ -2882,11 +2900,9 @@ static MemRefType getCanonicalSubViewResultType(
28822900
ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
28832901
auto nonRankReducedType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
28842902
sourceType, mixedOffsets, mixedSizes, mixedStrides));
2885-
std::optional<llvm::SmallBitVector> unusedDims =
2886-
computeMemRefRankReductionMask(currentSourceType, currentResultType,
2887-
mixedSizes);
2888-
// Return nullptr as failure mode.
2889-
if (!unusedDims)
2903+
FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
2904+
currentSourceType, currentResultType, mixedSizes);
2905+
if (failed(unusedDims))
28902906
return nullptr;
28912907

28922908
auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());

mlir/test/Dialect/GPU/decompose-memrefs.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ func.func @decompose_subview(%arg0 : memref<?x?x?xf32>) {
119119
// CHECK: %[[IDX1:.*]] = affine.apply #[[MAP1]]()[%[[STRIDES]]#1]
120120
// CHECK: %[[IDX2:.*]] = affine.apply #[[MAP2]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]]
121121
// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX2]]], sizes: [%{{.*}}, %{{.*}}, %{{.*}}], strides: [%[[IDX]], %[[IDX1]], 4]
122-
// CHECK: "test.test"(%[[PTR]]) : (memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) -> ()
122+
// CHECK: "test.test"(%[[PTR]]) : (memref<?x?x?xf32, strided<[?, ?, 4], offset: ?>>) -> ()
123123
func.func @decompose_subview_strided(%arg0 : memref<?x?x?xf32>) {
124124
%c0 = arith.constant 0 : index
125125
%c1 = arith.constant 1 : index
@@ -129,8 +129,8 @@ func.func @decompose_subview_strided(%arg0 : memref<?x?x?xf32>) {
129129
%block_dim2 = memref.dim %arg0, %c2 : memref<?x?x?xf32>
130130
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
131131
threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) {
132-
%res = memref.subview %arg0[%tx, %ty, %tz] [%c2, %c2, %c2] [2, 3, 4] : memref<?x?x?xf32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
133-
"test.test"(%res) : (memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) -> ()
132+
%res = memref.subview %arg0[%tx, %ty, %tz] [%c2, %c2, %c2] [2, 3, 4] : memref<?x?x?xf32> to memref<?x?x?xf32, strided<[?, ?, 4], offset: ?>>
133+
"test.test"(%res) : (memref<?x?x?xf32, strided<[?, ?, 4], offset: ?>>) -> ()
134134
gpu.terminator
135135
}
136136
return

mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -595,9 +595,9 @@ func.func @subview_of_subview(%m: memref<1x1024xf32, 3>, %pos: index)
595595
{
596596
%0 = memref.subview %m[3, %pos] [1, 2] [1, 1]
597597
: memref<1x1024xf32, 3>
598-
to memref<1x2xf32, strided<[1024, 2], offset: ?>, 3>
598+
to memref<1x2xf32, strided<[1024, 1], offset: ?>, 3>
599599
%1 = memref.subview %0[1, 2] [1, 1] [1, 1]
600-
: memref<1x2xf32, strided<[1024, 2], offset: ?>, 3>
600+
: memref<1x2xf32, strided<[1024, 1], offset: ?>, 3>
601601
to memref<f32, strided<[], offset: ?>, 3>
602602
return %1 : memref<f32, strided<[], offset: ?>, 3>
603603
}
@@ -675,9 +675,9 @@ func.func @fold_gpu_subgroup_mma_store_matrix_1d(%dst: memref<?xvector<4xf32>>,
675675
// CHECK-LABEL: func.func @fold_gpu_subgroup_mma_load_matrix_2d
676676
// CHECK-SAME: %[[SRC:.+]]: memref<128x128xf32>
677677
func.func @fold_gpu_subgroup_mma_load_matrix_2d(%arg0 : memref<128x128xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> !gpu.mma_matrix<16x16xf16, "COp"> {
678-
%subview = memref.subview %arg0[%arg1, %arg2][64, 32][2, 1] : memref<128x128xf32> to memref<64x32xf32, strided<[64, 1], offset: ?>>
678+
%subview = memref.subview %arg0[%arg1, %arg2][64, 32][2, 1] : memref<128x128xf32> to memref<64x32xf32, strided<[256, 1], offset: ?>>
679679
// CHECK: gpu.subgroup_mma_load_matrix %[[SRC]][{{.+}}] {leadDimension = 32 : index} : memref<128x128xf32> -> !gpu.mma_matrix<16x16xf16, "COp">
680-
%matrix = gpu.subgroup_mma_load_matrix %subview[%arg3, %arg4] {leadDimension = 32 : index} : memref<64x32xf32, strided<[64, 1], offset: ?>> -> !gpu.mma_matrix<16x16xf16, "COp">
680+
%matrix = gpu.subgroup_mma_load_matrix %subview[%arg3, %arg4] {leadDimension = 32 : index} : memref<64x32xf32, strided<[256, 1], offset: ?>> -> !gpu.mma_matrix<16x16xf16, "COp">
681681
return %matrix : !gpu.mma_matrix<16x16xf16, "COp">
682682
}
683683

@@ -686,9 +686,9 @@ func.func @fold_gpu_subgroup_mma_load_matrix_2d(%arg0 : memref<128x128xf32>, %ar
686686
// CHECK-LABEL: func.func @fold_gpu_subgroup_mma_load_matrix_2d
687687
// CHECK-SAME: %[[DST:.+]]: memref<128x128xf32>
688688
func.func @fold_gpu_subgroup_mma_load_matrix_2d(%arg0 : memref<128x128xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %matrix: !gpu.mma_matrix<16x16xf16, "COp">) {
689-
%subview = memref.subview %arg0[%arg1, %arg2][64, 32][2, 1] : memref<128x128xf32> to memref<64x32xf32, strided<[64, 1], offset: ?>>
689+
%subview = memref.subview %arg0[%arg1, %arg2][64, 32][2, 1] : memref<128x128xf32> to memref<64x32xf32, strided<[256, 1], offset: ?>>
690690
// CHECK: gpu.subgroup_mma_store_matrix %{{.+}}, %[[DST]][{{.+}}] {leadDimension = 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<128x128xf32>
691-
gpu.subgroup_mma_store_matrix %matrix, %subview[%arg3, %arg4] {leadDimension = 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<64x32xf32, strided<[64, 1], offset: ?>>
691+
gpu.subgroup_mma_store_matrix %matrix, %subview[%arg3, %arg4] {leadDimension = 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<64x32xf32, strided<[256, 1], offset: ?>>
692692
return
693693
}
694694

mlir/test/Dialect/MemRef/invalid.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,3 +1073,12 @@ func.func @dim_0_ranked(%arg : memref<f32>, %arg1 : index) {
10731073
memref.dim %arg, %arg1 : memref<f32> // expected-error {{'memref.dim' op operand #0 must be unranked.memref of any type values or non-0-ranked.memref of any type values, but got 'memref<f32>'}}
10741074
return
10751075
}
1076+
1077+
// -----
1078+
1079+
func.func @subview_invalid_strides(%m: memref<7x22x333x4444xi32>) {
1080+
// expected-error @below{{expected result type to be 'memref<7x11x333x4444xi32, strided<[32556744, 2959704, 4444, 1]>>' or a rank-reduced version. (mismatch of result layout)}}
1081+
%subview = memref.subview %m[0, 0, 0, 0] [7, 11, 333, 4444] [1, 2, 1, 1]
1082+
: memref<7x22x333x4444xi32> to memref<7x11x333x4444xi32>
1083+
return
1084+
}

mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,10 @@ module {
8888
// Prepare a buffer for x0, x1, x2, y0 and a buffer for y1.
8989
%xys = memref.alloc() : memref<20xi32>
9090
%xy = memref.cast %xys : memref<20xi32> to memref<?xi32>
91-
%x0 = memref.subview %xy[%i0][%i5][%i4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
92-
%x1 = memref.subview %xy[%i1][%i5][%i4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
93-
%x2 = memref.subview %xy[%i2][%i5][%i4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
94-
%y0 = memref.subview %xy[%i3][%i5][%i4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
91+
%x0 = memref.subview %xy[%i0][%i5][4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
92+
%x1 = memref.subview %xy[%i1][%i5][4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
93+
%x2 = memref.subview %xy[%i2][%i5][4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
94+
%y0 = memref.subview %xy[%i3][%i5][4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
9595
%y1s = memref.alloc() : memref<7xi32>
9696
%y1 = memref.cast %y1s : memref<7xi32> to memref<?xi32>
9797

0 commit comments

Comments
 (0)