Skip to content

[mlir][memref] memref.subview: Verify result strides with rank reductions #80158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

matthias-springer
Copy link
Member

This is a follow-up on #79865. Result strides are now also verified if the memref.subview op has rank reductions.

@llvmbot
Copy link
Member

llvmbot commented Jan 31, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-memref

Author: Matthias Springer (matthias-springer)

Changes

This is a follow-up on #79865. Result strides are now also verified if the memref.subview op has rank reductions.


Full diff: https://github.com/llvm/llvm-project/pull/80158.diff

5 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+15-8)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp (+13-4)
  • (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+3-3)
  • (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+2-2)
  • (modified) mlir/test/Dialect/MemRef/invalid.mlir (+9)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index f43217f6f27ae..a6624e9c8482c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2756,17 +2756,26 @@ static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
 }
 
 /// Return true if `t1` and `t2` have equal strides (both dynamic or of same
-/// static value).
-static bool haveCompatibleStrides(MemRefType t1, MemRefType t2) {
+/// static value). Dimensions of `t1` may be dropped in `t2`; these must be
+/// marked as dropped in `droppedDims`.
+static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
+                                  const llvm::SmallBitVector &droppedDims) {
+  assert(t1.getRank() == droppedDims.size() && "incorrect number of bits");
+  assert(t1.getRank() - t2.getRank() == droppedDims.count() &&
+         "incorrect number of dropped dims");
   int64_t t1Offset, t2Offset;
   SmallVector<int64_t> t1Strides, t2Strides;
   auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
   auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
   if (failed(res1) || failed(res2))
     return false;
-  for (auto [s1, s2] : llvm::zip_equal(t1Strides, t2Strides))
-    if (s1 != s2)
+  for (int64_t i = 0, j = 0; i < t1.getRank(); ++i) {
+    if (droppedDims[i])
+      continue;
+    if (t1Strides[i] != t2Strides[j])
       return false;
+    ++j;
+  }
   return true;
 }
 
@@ -2843,10 +2852,8 @@ LogicalResult SubViewOp::verify() {
     return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch,
                                   *this, expectedType);
 
-  // Strides must match if there are no rank reductions.
-  // TODO: Verify strides when there are rank reductions. Strides are partially
-  // checked in `computeMemRefRankReductionMask`.
-  if (unusedDims->none() && !haveCompatibleStrides(expectedType, subViewType))
+  // Strides must match.
+  if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
     return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch,
                                   *this, expectedType);
 
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index f6af0791ba756..96eb7cfd2db69 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -144,16 +144,25 @@ resolveSubviewStridedMetadata(RewriterBase &rewriter,
   SmallVector<OpFoldResult> finalStrides;
   finalStrides.reserve(subRank);
 
+#ifndef NDEBUG
+  // Iteration variable for result dimensions of the subview op.
+  int64_t j = 0;
+#endif // NDEBUG
   for (unsigned i = 0; i < sourceRank; ++i) {
     if (droppedDims.test(i))
       continue;
 
     finalSizes.push_back(subSizes[i]);
     finalStrides.push_back(strides[i]);
-    // TODO: Assert that the computed stride matches the respective stride of
-    // the result type of the subview op (if both are static), once the verifier
-    // of memref.subview verfies result strides correctly for ops with rank
-    // reductions.
+#ifndef NDEBUG
+    // Assert that the computed stride matches the stride of the result type of
+    // the subview op (if both are static).
+    std::optional<int64_t> computedStride = getConstantIntValue(strides[i]);
+    if (computedStride && !ShapedType::isDynamic(resultStrides[j]))
+      assert(*computedStride == resultStrides[j] &&
+             "mismatch between computed stride and result type stride");
+    ++j;
+#endif // NDEBUG
   }
   assert(finalSizes.size() == subRank &&
          "Should have populated all the values at this point");
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 993ef32edc9d4..a772a25da5738 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -62,13 +62,13 @@ func.func @subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
 // -----
 
 func.func @rank_reducing_subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
-  %arg2 : index) -> memref<?x?xf32, strided<[?, 1], offset: ?>>
+  %arg2 : index) -> memref<?x?xf32, strided<[?, ?], offset: ?>>
 {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c4 = arith.constant 4 : index
-  %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
-  return %0 : memref<?x?xf32, strided<[?, 1], offset: ?>>
+  %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  return %0 : memref<?x?xf32, strided<[?, ?], offset: ?>>
 }
 // CHECK-LABEL: func @rank_reducing_subview_canonicalize
 //  CHECK-SAME:   %[[ARG0:.+]]: memref<?x?x?xf32>
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 3407bdbc7c8f9..5b853a6cc5a37 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -613,9 +613,9 @@ func.func @subview_of_subview_rank_reducing(%m: memref<?x?x?xf32>,
 {
   %0 = memref.subview %m[3, 1, 8] [1, %sz, 1] [1, 1, 1]
       : memref<?x?x?xf32>
-        to memref<?xf32, strided<[1], offset: ?>>
+        to memref<?xf32, strided<[?], offset: ?>>
   %1 = memref.subview %0[6] [1] [1]
-      : memref<?xf32, strided<[1], offset: ?>>
+      : memref<?xf32, strided<[?], offset: ?>>
         to memref<f32, strided<[], offset: ?>>
   return %1 : memref<f32, strided<[], offset: ?>>
 }
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index be60a3dcb1b20..8f5ba5ea8fc78 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -1082,3 +1082,12 @@ func.func @subview_invalid_strides(%m: memref<7x22x333x4444xi32>) {
       : memref<7x22x333x4444xi32> to memref<7x11x333x4444xi32>
   return
 }
+
+// -----
+
+func.func @subview_invalid_strides_rank_reduction(%m: memref<7x22x333x4444xi32>) {
+  // expected-error @below{{expected result type to be 'memref<7x11x1x4444xi32, strided<[32556744, 2959704, 4444, 1]>>' or a rank-reduced version. (mismatch of result layout)}}
+  %subview = memref.subview %m[0, 0, 0, 0] [7, 11, 1, 4444] [1, 2, 1, 1]
+      : memref<7x22x333x4444xi32> to memref<7x11x4444xi32>
+  return
+}

Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks!

…ctions

This is a follow-up on llvm#79865. Result strides are now also verified if the `memref.subview` op has rank reductions.
@matthias-springer matthias-springer force-pushed the subview_verify_strides_rank_reductions branch from 881c580 to 17c0745 Compare February 2, 2024 08:53
@matthias-springer matthias-springer merged commit 9efdccb into llvm:main Feb 2, 2024
agozillon pushed a commit to agozillon/llvm-project that referenced this pull request Feb 5, 2024
…ctions (llvm#80158)

This is a follow-up on llvm#79865. Result strides are now also verified if
the `memref.subview` op has rank reductions.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants