-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][memref] memref.subview
: Verify result strides with rank reductions
#80158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][memref] memref.subview
: Verify result strides with rank reductions
#80158
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-memref Author: Matthias Springer (matthias-springer) ChangesThis is a follow-up on #79865. Result strides are now also verified if the Full diff: https://github.com/llvm/llvm-project/pull/80158.diff 5 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index f43217f6f27ae..a6624e9c8482c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2756,17 +2756,26 @@ static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
}
/// Return true if `t1` and `t2` have equal strides (both dynamic or of same
-/// static value).
-static bool haveCompatibleStrides(MemRefType t1, MemRefType t2) {
+/// static value). Dimensions of `t1` may be dropped in `t2`; these must be
+/// marked as dropped in `droppedDims`.
+static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
+ const llvm::SmallBitVector &droppedDims) {
+ assert(t1.getRank() == droppedDims.size() && "incorrect number of bits");
+ assert(t1.getRank() - t2.getRank() == droppedDims.count() &&
+ "incorrect number of dropped dims");
int64_t t1Offset, t2Offset;
SmallVector<int64_t> t1Strides, t2Strides;
auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
if (failed(res1) || failed(res2))
return false;
- for (auto [s1, s2] : llvm::zip_equal(t1Strides, t2Strides))
- if (s1 != s2)
+ for (int64_t i = 0, j = 0; i < t1.getRank(); ++i) {
+ if (droppedDims[i])
+ continue;
+ if (t1Strides[i] != t2Strides[j])
return false;
+ ++j;
+ }
return true;
}
@@ -2843,10 +2852,8 @@ LogicalResult SubViewOp::verify() {
return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch,
*this, expectedType);
- // Strides must match if there are no rank reductions.
- // TODO: Verify strides when there are rank reductions. Strides are partially
- // checked in `computeMemRefRankReductionMask`.
- if (unusedDims->none() && !haveCompatibleStrides(expectedType, subViewType))
+ // Strides must match.
+ if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch,
*this, expectedType);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index f6af0791ba756..96eb7cfd2db69 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -144,16 +144,25 @@ resolveSubviewStridedMetadata(RewriterBase &rewriter,
SmallVector<OpFoldResult> finalStrides;
finalStrides.reserve(subRank);
+#ifndef NDEBUG
+ // Iteration variable for result dimensions of the subview op.
+ int64_t j = 0;
+#endif // NDEBUG
for (unsigned i = 0; i < sourceRank; ++i) {
if (droppedDims.test(i))
continue;
finalSizes.push_back(subSizes[i]);
finalStrides.push_back(strides[i]);
- // TODO: Assert that the computed stride matches the respective stride of
- // the result type of the subview op (if both are static), once the verifier
- // of memref.subview verfies result strides correctly for ops with rank
- // reductions.
+#ifndef NDEBUG
+ // Assert that the computed stride matches the stride of the result type of
+ // the subview op (if both are static).
+ std::optional<int64_t> computedStride = getConstantIntValue(strides[i]);
+ if (computedStride && !ShapedType::isDynamic(resultStrides[j]))
+ assert(*computedStride == resultStrides[j] &&
+ "mismatch between computed stride and result type stride");
+ ++j;
+#endif // NDEBUG
}
assert(finalSizes.size() == subRank &&
"Should have populated all the values at this point");
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 993ef32edc9d4..a772a25da5738 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -62,13 +62,13 @@ func.func @subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
// -----
func.func @rank_reducing_subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
- %arg2 : index) -> memref<?x?xf32, strided<[?, 1], offset: ?>>
+ %arg2 : index) -> memref<?x?xf32, strided<[?, ?], offset: ?>>
{
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
- %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
- return %0 : memref<?x?xf32, strided<[?, 1], offset: ?>>
+ %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ return %0 : memref<?x?xf32, strided<[?, ?], offset: ?>>
}
// CHECK-LABEL: func @rank_reducing_subview_canonicalize
// CHECK-SAME: %[[ARG0:.+]]: memref<?x?x?xf32>
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 3407bdbc7c8f9..5b853a6cc5a37 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -613,9 +613,9 @@ func.func @subview_of_subview_rank_reducing(%m: memref<?x?x?xf32>,
{
%0 = memref.subview %m[3, 1, 8] [1, %sz, 1] [1, 1, 1]
: memref<?x?x?xf32>
- to memref<?xf32, strided<[1], offset: ?>>
+ to memref<?xf32, strided<[?], offset: ?>>
%1 = memref.subview %0[6] [1] [1]
- : memref<?xf32, strided<[1], offset: ?>>
+ : memref<?xf32, strided<[?], offset: ?>>
to memref<f32, strided<[], offset: ?>>
return %1 : memref<f32, strided<[], offset: ?>>
}
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index be60a3dcb1b20..8f5ba5ea8fc78 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -1082,3 +1082,12 @@ func.func @subview_invalid_strides(%m: memref<7x22x333x4444xi32>) {
: memref<7x22x333x4444xi32> to memref<7x11x333x4444xi32>
return
}
+
+// -----
+
+func.func @subview_invalid_strides_rank_reduction(%m: memref<7x22x333x4444xi32>) {
+ // expected-error @below{{expected result type to be 'memref<7x11x1x4444xi32, strided<[32556744, 2959704, 4444, 1]>>' or a rank-reduced version. (mismatch of result layout)}}
+ %subview = memref.subview %m[0, 0, 0, 0] [7, 11, 1, 4444] [1, 2, 1, 1]
+ : memref<7x22x333x4444xi32> to memref<7x11x4444xi32>
+ return
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks!
…ctions This is a follow-up on llvm#79865. Result strides are now also verified if the `memref.subview` op has rank reductions.
881c580
to
17c0745
Compare
…ctions (llvm#80158) This is a follow-up on llvm#79865. Result strides are now also verified if the `memref.subview` op has rank reductions.
This is a follow-up on #79865. Result strides are now also verified if the
memref.subview
op has rank reductions.