Skip to content

[mlir][python] enable memref.subview #79393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 30, 2024

Conversation

makslevental
Copy link
Contributor

@makslevental makslevental commented Jan 25, 2024

Right now you can't (without enormous difficulty1) emit a memref.subview because if you want to compute the result type you need to know the striding of the source which (unless you're lucky) is only available as an AffineMap. This PR exposes std::pair<...> getStridesAndOffset(MemRefType) through the C API and also binds MemRefType.stides_and_offset to it. In addition a builder/convenience wrapper for memref.subview is provided along with a python implementation of inferReturnTypes for static offsets/sizes/strides.

Before I handle corner cases/polish, let me know if I've missed something and this is superfluous.

EDIT:

So the deal is that SubViewOp::verify() is either broken or incomplete:

module {
  %alloc = memref.alloc() : memref<7x22x333x4444xi32>
  %subview = memref.subview %alloc[0, 0, 0, 0] [7, 11, 333, 4444] [1, 2, 1, 1] : memref<7x22x333x4444xi32> to memref<7x11x333x4444xi32>
  %subview_0 = memref.subview %alloc[0, 0, 0, 0] [7, 11, 11, 4444] [1, 2, 30, 1] : memref<7x22x333x4444xi32> to memref<7x11x11x4444xi32>
  %subview_1 = memref.subview %alloc[0, 0, 0, 0] [7, 11, 11, 11] [1, 2, 30, 400] : memref<7x22x333x4444xi32> to memref<7x11x11x11xi32>
  %subview_2 = memref.subview %alloc[0, 0, 100, 1000] [7, 22, 20, 20] [1, 1, 5, 50] : memref<7x22x333x4444xi32> to memref<7x22x20x20xi32, strided<[32556744, 1479852, 22220, 50], offset: 445400>>
}

passes the verifier (feel free to check!) but (if I'm not missing something) the correct strides/offsets are

module {
  %alloc = memref.alloc() : memref<7x22x333x4444xi32>
  %subview = memref.subview %alloc[0, 0, 0, 0] [7, 11, 333, 4444] [1, 2, 1, 1] : memref<7x22x333x4444xi32> to memref<7x11x333x4444xi32, strided<[32556744, 2959704, 4444, 1]>>
  %subview_0 = memref.subview %alloc[0, 0, 0, 0] [7, 11, 11, 4444] [1, 2, 30, 1] : memref<7x22x333x4444xi32> to memref<7x11x11x4444xi32, strided<[32556744, 2959704, 133320, 1]>>
  %subview_1 = memref.subview %alloc[0, 0, 0, 0] [7, 11, 11, 11] [1, 2, 30, 400] : memref<7x22x333x4444xi32> to memref<7x11x11x11xi32, strided<[32556744, 2959704, 133320, 400]>>
  %subview_2 = memref.subview %alloc[0, 0, 100, 1000] [7, 22, 20, 20] [1, 1, 5, 50] : memref<7x22x333x4444xi32> to memref<7x22x20x20xi32, strided<[32556744, 1479852, 22220, 50], offset: 445400>>
}

(which also passes the verifier).

Now maybe I am missing something - like maybe strides/offsets aren't printed if they're the "default" strides (accumulate(static_sizes[::-1], *)) but even if that's true, note that for %subview those "default" strides are [11*333*4444, 333*4444, 4444, 1] == [32556744, 1479852, 4444, 1] and not the correct answer [32556744, 2959704, 4444, 1] == np.array(np.zeros([7, 11, 333, 4444], dtype=np.int32)[:, 0:22:2].strides) // 4.

🤷

Added tests verify/compare against np.strides.

Footnotes

  1. I could be wrong - I might've missed some magic somewhere.

@llvmbot llvmbot added mlir:python MLIR Python bindings mlir labels Jan 25, 2024
@llvmbot
Copy link
Member

llvmbot commented Jan 25, 2024

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/79393.diff

5 Files Affected:

  • (modified) mlir/include/mlir-c/BuiltinTypes.h (+7)
  • (modified) mlir/lib/Bindings/Python/IRTypes.cpp (+9)
  • (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+14)
  • (modified) mlir/python/mlir/dialects/memref.py (+69)
  • (modified) mlir/test/python/dialects/memref.py (+14)
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 1fd5691f41eec35..2523bddc475d823 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -408,6 +408,13 @@ MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type);
 /// Returns the memory space of the given MemRef type.
 MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type);
 
+/// Returns the strides of the MemRef if the layout map is in strided form.
+/// Both strides and offset are out params. strides must point to pre-allocated
+/// memory of length equal to the rank of the memref.
+MLIR_CAPI_EXPORTED void mlirMemRefTypeGetStridesAndOffset(MlirType type,
+                                                          int64_t *strides,
+                                                          int64_t *offset);
+
 /// Returns the memory spcae of the given Unranked MemRef type.
 MLIR_CAPI_EXPORTED MlirAttribute
 mlirUnrankedMemrefGetMemorySpace(MlirType type);
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 56e895d3053796e..86f01a6381ae4e0 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -618,6 +618,15 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
               return mlirMemRefTypeGetLayout(self);
             },
             "The layout of the MemRef type.")
+        .def_property_readonly(
+            "strides_and_offset",
+            [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
+              std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
+              int64_t offset;
+              mlirMemRefTypeGetStridesAndOffset(self, strides.data(), &offset);
+              return {strides, offset};
+            },
+            "The strides and offset of the MemRef type.")
         .def_property_readonly(
             "affine_map",
             [](PyMemRefType &self) -> PyAffineMap {
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 6e645188dac8616..6a3653d8baf304a 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -16,6 +16,8 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Types.h"
 
+#include <algorithm>
+
 using namespace mlir;
 
 //===----------------------------------------------------------------------===//
@@ -426,6 +428,18 @@ MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
   return wrap(llvm::cast<MemRefType>(unwrap(type)).getMemorySpace());
 }
 
+void mlirMemRefTypeGetStridesAndOffset(MlirType type, int64_t *strides,
+                                       int64_t *offset) {
+  MemRefType memrefType = llvm::cast<MemRefType>(unwrap(type));
+  std::pair<SmallVector<int64_t>, int64_t> stridesOffsets =
+      getStridesAndOffset(memrefType);
+  assert(stridesOffsets.first.size() == memrefType.getRank() &&
+         "Strides and rank don't match for memref");
+  (void)std::copy(stridesOffsets.first.begin(), stridesOffsets.first.end(),
+                  strides);
+  *offset = stridesOffsets.second;
+}
+
 MlirTypeID mlirUnrankedMemRefTypeGetTypeID() {
   return wrap(UnrankedMemRefType::getTypeID());
 }
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
index 3afb6a70cb9e0db..8023cbccd7a4183 100644
--- a/mlir/python/mlir/dialects/memref.py
+++ b/mlir/python/mlir/dialects/memref.py
@@ -1,5 +1,74 @@
 #  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from typing import Optional, Sequence
 
 from ._memref_ops_gen import *
+from ..ir import Value, ShapedType, MemRefType, StridedLayoutAttr
+
+
+def _infer_memref_subview_result_type(
+    source_memref_type, static_offsets, static_sizes, static_strides
+):
+    source_strides, source_offset = source_memref_type.strides_and_offset
+    target_offset = source_offset
+    for static_offset, target_stride in zip(static_offsets, source_strides):
+        target_offset += static_offset * target_stride
+
+    target_strides = []
+    for source_stride, static_stride in zip(source_strides, static_strides):
+        target_strides.append(source_stride * static_stride)
+
+    layout = StridedLayoutAttr.get(target_offset, target_strides)
+    return MemRefType.get(
+        static_sizes,
+        source_memref_type.element_type,
+        layout,
+        source_memref_type.memory_space,
+    )
+
+
+_generated_subview = subview
+
+
+def subview(
+    source: Value,
+    offsets: Optional[Sequence[Value]] = None,
+    strides: Optional[Sequence[Value]] = None,
+    static_offsets: Optional[Sequence[int]] = None,
+    static_sizes: Optional[Sequence[int]] = None,
+    static_strides: Optional[Sequence[int]] = None,
+    *,
+    loc=None,
+    ip=None,
+):
+    if offsets is None:
+        offsets = []
+    if static_offsets is None:
+        static_offsets = []
+    if strides is None:
+        strides = []
+    if static_strides is None:
+        static_strides = []
+    assert static_sizes, f"this convenience method only handles static sizes"
+    sizes = []
+    S = ShapedType.get_dynamic_size()
+    if offsets and static_offsets:
+        assert all(s == S for s in static_offsets)
+    if strides and static_strides:
+        assert all(s == S for s in static_strides)
+    result_type = _infer_memref_subview_result_type(
+        source.type, static_offsets, static_sizes, static_strides
+    )
+    return _generated_subview(
+        result_type,
+        source,
+        offsets,
+        sizes,
+        strides,
+        static_offsets,
+        static_sizes,
+        static_strides,
+        loc=loc,
+        ip=ip,
+    )
diff --git a/mlir/test/python/dialects/memref.py b/mlir/test/python/dialects/memref.py
index 0c8a7ee282fe161..47c8ff86d30097c 100644
--- a/mlir/test/python/dialects/memref.py
+++ b/mlir/test/python/dialects/memref.py
@@ -88,3 +88,17 @@ def testMemRefAttr():
             memref.global_("objFifo_in0", T.memref(16, T.i32()))
         # CHECK: memref.global @objFifo_in0 : memref<16xi32>
         print(module)
+
+
+# CHECK-LABEL: TEST: testSubViewOpInferReturnTypes
+@run
+def testSubViewOpInferReturnTypes():
+    with Context() as ctx, Location.unknown(ctx):
+        module = Module.create()
+        with InsertionPoint(module.body):
+            x = memref.alloc(T.memref(10, 10, T.i32()), [], [])
+            # CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<10x10xi32>
+            print(x.owner)
+            y = memref.subview(x, [], [], [1, 1], [3, 3], [1, 1])
+            # CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>>
+            print(y.owner)

getStridesAndOffset(memrefType);
assert(stridesOffsets.first.size() == memrefType.getRank() &&
"Strides and rank don't match for memref");
(void)std::copy(stridesOffsets.first.begin(), stridesOffsets.first.end(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there is anyway to verify here that space provided is sufficient... not sure if we need to document it , probably C API is the raw sharp edges part. Was just thinking this could allow some rather nasty overwrites. But python or other language should be guarding.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea sure no way to check (can you say "null terminated strings"...) but I don't know another C-ism for this. The header documents the requirement.

from ..ir import Value, ShapedType, MemRefType, StridedLayoutAttr


def _infer_memref_subview_result_type(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems separable from the above?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I can see how it makes the full use more compact, not sure if you have a simpler initial test for just this, else I don't object too strongly)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separable meaning living elsewhere? I did add a test for the asserts...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meaning in follow up commit. But the verify probably closer to top of mind :-)

@makslevental makslevental force-pushed the infer_return_types_subview branch from b07ff6f to bdb9ace Compare January 25, 2024 02:13
@makslevental makslevental requested a review from jpienaar January 25, 2024 02:19
@makslevental
Copy link
Contributor Author

makslevental commented Jan 25, 2024

Okay corners explored and polished.

@makslevental makslevental force-pushed the infer_return_types_subview branch from bdb9ace to b2a3b31 Compare January 25, 2024 08:25
makslevental added a commit to makslevental/mlir-python-extras that referenced this pull request Jan 25, 2024
makslevental added a commit to makslevental/mlir-python-extras that referenced this pull request Jan 26, 2024
@makslevental makslevental force-pushed the infer_return_types_subview branch 2 times, most recently from eb0f298 to a192a09 Compare January 26, 2024 00:55
makslevental added a commit to makslevental/mlir-python-extras that referenced this pull request Jan 26, 2024
Comment on lines 14 to 15
def _is_constant(i):
return isinstance(i, Value) and isinstance(i.owner.opview, ConstantOp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe not for this commit, but should rather expose the m_ConstantInt functionality to C API somehow and drop the dependency on the arith dialect here. There are other constant-like operations that we may want to support, and ConstantOp may be defining a floating-point value or a vector that we don't want to support here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thought was the verifiers would catch floats and other things. I added a _is_integer_like_type check here.

@ftynse
Copy link
Member

ftynse commented Jan 29, 2024

So the deal is that SubViewOp::verify() is either broken or incomplete

This may well be the case. The op and its verifier were initially introduced when layouts had only affine forms and may not have been updated correctly.

@matthias-springer
Copy link
Member

So the deal is that SubViewOp::verify() is either broken or incomplete

This may well be the case. The op and its verifier were initially introduced when layouts had only affine forms and may not have been updated correctly.

Looks like strides are currently not verified at all when there are no rank reductions. See #79865.

@makslevental makslevental force-pushed the infer_return_types_subview branch from 3d5939e to ad62ed7 Compare January 29, 2024 19:45
@makslevental makslevental requested a review from ftynse January 29, 2024 19:46
@makslevental makslevental force-pushed the infer_return_types_subview branch 2 times, most recently from d8565f9 to 73be0c6 Compare January 29, 2024 21:27
@makslevental makslevental force-pushed the infer_return_types_subview branch from c08f7b8 to 77f4514 Compare January 30, 2024 19:04
@makslevental makslevental force-pushed the infer_return_types_subview branch from 77f4514 to df3c508 Compare January 30, 2024 20:57
@makslevental makslevental merged commit 404af14 into llvm:main Jan 30, 2024
@makslevental makslevental deleted the infer_return_types_subview branch January 30, 2024 22:22
makslevental added a commit to Xilinx/mlir-aie that referenced this pull request Jan 30, 2024
makslevental added a commit to makslevental/mlir-python-extras that referenced this pull request Jan 31, 2024
makslevental added a commit to makslevental/mlir-python-extras that referenced this pull request Jan 31, 2024
makslevental added a commit to makslevental/mlir-python-extras that referenced this pull request Jan 31, 2024
makslevental added a commit to makslevental/mlir-python-extras that referenced this pull request Jan 31, 2024
makslevental added a commit to makslevental/mlir-python-extras that referenced this pull request Jan 31, 2024
makslevental added a commit to makslevental/mlir-python-extras that referenced this pull request Jan 31, 2024
makslevental added a commit to Xilinx/mlir-aie that referenced this pull request Jan 31, 2024
makslevental added a commit to Xilinx/mlir-aie that referenced this pull request Jan 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants