Skip to content

Commit dc73663

Browse files
[mlir][memref][WIP] memref.subview: Verify result strides
The strides of the result types are currently not verified for `memref.subview` ops that have no rank reductions. WIP: This is still failing some test cases. It also looks like the verification of result strides is incomplete (maybe also incorrect) for ops with rank reduction.
1 parent 4a39d08 commit dc73663

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -931,6 +931,7 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
931931

932932
// Early exit for the case where the number of unused dims matches the number
933933
// of ranks reduced.
934+
// TODO: Verify strides.
934935
if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
935936
originalType.getRank())
936937
return unusedDims;
@@ -2745,7 +2746,7 @@ void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
27452746
/// For ViewLikeOpInterface.
27462747
Value SubViewOp::getViewSource() { return getSource(); }
27472748

2748-
/// Return true if t1 and t2 have equal offsets (both dynamic or of same
2749+
/// Return true if `t1` and `t2` have equal offsets (both dynamic or of same
27492750
/// static value).
27502751
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
27512752
int64_t t1Offset, t2Offset;
@@ -2755,6 +2756,21 @@ static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
27552756
return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
27562757
}
27572758

2759+
/// Return true if `t1` and `t2` have equal strides (both dynamic or of same
2760+
/// static value).
2761+
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2) {
2762+
int64_t t1Offset, t2Offset;
2763+
SmallVector<int64_t> t1Strides, t2Strides;
2764+
auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
2765+
auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
2766+
if (failed(res1) || failed(res2))
2767+
return false;
2768+
for (auto [s1, s2] : llvm::zip_equal(t1Strides, t2Strides))
2769+
if (s1 != s2)
2770+
return false;
2771+
return true;
2772+
}
2773+
27582774
/// Checks if `original` Type type can be rank reduced to `reduced` type.
27592775
/// This function is slight variant of `is subsequence` algorithm where
27602776
/// not matching dimension must be 1.
@@ -2781,6 +2797,12 @@ isRankReducedMemRefType(MemRefType originalType,
27812797
if (!haveCompatibleOffsets(originalType, candidateRankReducedType))
27822798
return SliceVerificationResult::LayoutMismatch;
27832799

2800+
// Strides must match if there are no rank reductions. In case of rank
2801+
// reductions, the strides are checked by `computeMemRefRankReductionMask`.
2802+
if (optionalUnusedDimsMask->none() &&
2803+
!haveCompatibleStrides(originalType, candidateRankReducedType))
2804+
return SliceVerificationResult::LayoutMismatch;
2805+
27842806
return SliceVerificationResult::Success;
27852807
}
27862808

mlir/test/Dialect/MemRef/invalid.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,3 +1073,12 @@ func.func @dim_0_ranked(%arg : memref<f32>, %arg1 : index) {
10731073
memref.dim %arg, %arg1 : memref<f32> // expected-error {{'memref.dim' op operand #0 must be unranked.memref of any type values or non-0-ranked.memref of any type values, but got 'memref<f32>'}}
10741074
return
10751075
}
1076+
1077+
// -----
1078+
1079+
func.func @subview_invalid_strides(%m: memref<7x22x333x4444xi32>) {
1080+
// expected-error @below{{expected result type to be 'memref<7x11x333x4444xi32, strided<[32556744, 2959704, 4444, 1]>>' or a rank-reduced version. (mismatch of result layout)}}
1081+
%subview = memref.subview %m[0, 0, 0, 0] [7, 11, 333, 4444] [1, 2, 1, 1]
1082+
: memref<7x22x333x4444xi32> to memref<7x11x333x4444xi32>
1083+
return
1084+
}

0 commit comments

Comments
 (0)