-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][python] enable memref.subview #79393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][python] enable memref.subview #79393
Conversation
@llvm/pr-subscribers-mlir Author: Maksim Levental (makslevental) ChangesFull diff: https://github.com/llvm/llvm-project/pull/79393.diff 5 Files Affected:
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 1fd5691f41eec35..2523bddc475d823 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -408,6 +408,13 @@ MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type);
/// Returns the memory space of the given MemRef type.
MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type);
+/// Returns the strides of the MemRef if the layout map is in strided form.
+/// Both strides and offset are out params. strides must point to pre-allocated
+/// memory of length equal to the rank of the memref.
+MLIR_CAPI_EXPORTED void mlirMemRefTypeGetStridesAndOffset(MlirType type,
+ int64_t *strides,
+ int64_t *offset);
+
/// Returns the memory spcae of the given Unranked MemRef type.
MLIR_CAPI_EXPORTED MlirAttribute
mlirUnrankedMemrefGetMemorySpace(MlirType type);
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 56e895d3053796e..86f01a6381ae4e0 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -618,6 +618,15 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
return mlirMemRefTypeGetLayout(self);
},
"The layout of the MemRef type.")
+ .def_property_readonly(
+ "strides_and_offset",
+ [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
+ std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
+ int64_t offset;
+ mlirMemRefTypeGetStridesAndOffset(self, strides.data(), &offset);
+ return {strides, offset};
+ },
+ "The strides and offset of the MemRef type.")
.def_property_readonly(
"affine_map",
[](PyMemRefType &self) -> PyAffineMap {
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 6e645188dac8616..6a3653d8baf304a 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -16,6 +16,8 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h"
+#include <algorithm>
+
using namespace mlir;
//===----------------------------------------------------------------------===//
@@ -426,6 +428,18 @@ MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
return wrap(llvm::cast<MemRefType>(unwrap(type)).getMemorySpace());
}
+void mlirMemRefTypeGetStridesAndOffset(MlirType type, int64_t *strides,
+ int64_t *offset) {
+ MemRefType memrefType = llvm::cast<MemRefType>(unwrap(type));
+ std::pair<SmallVector<int64_t>, int64_t> stridesOffsets =
+ getStridesAndOffset(memrefType);
+ assert(stridesOffsets.first.size() == memrefType.getRank() &&
+ "Strides and rank don't match for memref");
+ (void)std::copy(stridesOffsets.first.begin(), stridesOffsets.first.end(),
+ strides);
+ *offset = stridesOffsets.second;
+}
+
MlirTypeID mlirUnrankedMemRefTypeGetTypeID() {
return wrap(UnrankedMemRefType::getTypeID());
}
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
index 3afb6a70cb9e0db..8023cbccd7a4183 100644
--- a/mlir/python/mlir/dialects/memref.py
+++ b/mlir/python/mlir/dialects/memref.py
@@ -1,5 +1,74 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from typing import Optional, Sequence
from ._memref_ops_gen import *
+from ..ir import Value, ShapedType, MemRefType, StridedLayoutAttr
+
+
+def _infer_memref_subview_result_type(
+ source_memref_type, static_offsets, static_sizes, static_strides
+):
+ source_strides, source_offset = source_memref_type.strides_and_offset
+ target_offset = source_offset
+ for static_offset, target_stride in zip(static_offsets, source_strides):
+ target_offset += static_offset * target_stride
+
+ target_strides = []
+ for source_stride, static_stride in zip(source_strides, static_strides):
+ target_strides.append(source_stride * static_stride)
+
+ layout = StridedLayoutAttr.get(target_offset, target_strides)
+ return MemRefType.get(
+ static_sizes,
+ source_memref_type.element_type,
+ layout,
+ source_memref_type.memory_space,
+ )
+
+
+_generated_subview = subview
+
+
+def subview(
+ source: Value,
+ offsets: Optional[Sequence[Value]] = None,
+ strides: Optional[Sequence[Value]] = None,
+ static_offsets: Optional[Sequence[int]] = None,
+ static_sizes: Optional[Sequence[int]] = None,
+ static_strides: Optional[Sequence[int]] = None,
+ *,
+ loc=None,
+ ip=None,
+):
+ if offsets is None:
+ offsets = []
+ if static_offsets is None:
+ static_offsets = []
+ if strides is None:
+ strides = []
+ if static_strides is None:
+ static_strides = []
+ assert static_sizes, f"this convenience method only handles static sizes"
+ sizes = []
+ S = ShapedType.get_dynamic_size()
+ if offsets and static_offsets:
+ assert all(s == S for s in static_offsets)
+ if strides and static_strides:
+ assert all(s == S for s in static_strides)
+ result_type = _infer_memref_subview_result_type(
+ source.type, static_offsets, static_sizes, static_strides
+ )
+ return _generated_subview(
+ result_type,
+ source,
+ offsets,
+ sizes,
+ strides,
+ static_offsets,
+ static_sizes,
+ static_strides,
+ loc=loc,
+ ip=ip,
+ )
diff --git a/mlir/test/python/dialects/memref.py b/mlir/test/python/dialects/memref.py
index 0c8a7ee282fe161..47c8ff86d30097c 100644
--- a/mlir/test/python/dialects/memref.py
+++ b/mlir/test/python/dialects/memref.py
@@ -88,3 +88,17 @@ def testMemRefAttr():
memref.global_("objFifo_in0", T.memref(16, T.i32()))
# CHECK: memref.global @objFifo_in0 : memref<16xi32>
print(module)
+
+
+# CHECK-LABEL: TEST: testSubViewOpInferReturnTypes
+@run
+def testSubViewOpInferReturnTypes():
+ with Context() as ctx, Location.unknown(ctx):
+ module = Module.create()
+ with InsertionPoint(module.body):
+ x = memref.alloc(T.memref(10, 10, T.i32()), [], [])
+ # CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<10x10xi32>
+ print(x.owner)
+ y = memref.subview(x, [], [], [1, 1], [3, 3], [1, 1])
+ # CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>>
+ print(y.owner)
|
mlir/lib/CAPI/IR/BuiltinTypes.cpp
Outdated
getStridesAndOffset(memrefType); | ||
assert(stridesOffsets.first.size() == memrefType.getRank() && | ||
"Strides and rank don't match for memref"); | ||
(void)std::copy(stridesOffsets.first.begin(), stridesOffsets.first.end(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think there is anyway to verify here that space provided is sufficient... not sure if we need to document it , probably C API is the raw sharp edges part. Was just thinking this could allow some rather nasty overwrites. But python or other language should be guarding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea sure no way to check (can you say "null terminated strings"...) but I don't know another C-ism for this. The header documents the requirement.
from ..ir import Value, ShapedType, MemRefType, StridedLayoutAttr | ||
|
||
|
||
def _infer_memref_subview_result_type( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems separable from the above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I can see how it makes the full use more compact, not sure if you have a simpler initial test for just this, else I don't object too strongly)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Separable meaning living elsewhere? I did add a test for the asserts...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Meaning in follow up commit. But the verify probably closer to top of mind :-)
b07ff6f
to
bdb9ace
Compare
Okay corners explored and polished. |
bdb9ace
to
b2a3b31
Compare
eb0f298
to
a192a09
Compare
mlir/python/mlir/dialects/memref.py
Outdated
def _is_constant(i): | ||
return isinstance(i, Value) and isinstance(i.owner.opview, ConstantOp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe not for this commit, but should rather expose the m_ConstantInt
functionality to C API somehow and drop the dependency on the arith dialect here. There are other constant-like operations that we may want to support, and ConstantOp may be defining a floating-point value or a vector that we don't want to support here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My thought was the verifiers would catch floats and other things. I added a _is_integer_like_type
check here.
This may well be the case. The op and its verifier were initially introduced when layouts had only affine forms and may not have been updated correctly. |
Looks like strides are currently not verified at all when there are no rank reductions. See #79865. |
3d5939e
to
ad62ed7
Compare
d8565f9
to
73be0c6
Compare
c08f7b8
to
77f4514
Compare
77f4514
to
df3c508
Compare
Right now you can't (without enormous difficulty1) emit a
memref.subview
because if you want to compute the result type you need to know the striding of the source which (unless you're lucky) is only available as anAffineMap
. This PR exposesstd::pair<...> getStridesAndOffset(MemRefType)
through the C API and also bindsMemRefType.stides_and_offset
to it. In addition a builder/convenience wrapper formemref.subview
is provided along with a python implementation ofinferReturnTypes
for static offsets/sizes/strides.Before I handle corner cases/polish, let me know if I've missed something and this is superfluous.
EDIT:
So the deal is that
SubViewOp::verify()
is either broken or incomplete:passes the verifier (feel free to check!) but (if I'm not missing something) the correct strides/offsets are
(which also passes the verifier).
Now maybe I am missing something - like maybe strides/offsets aren't printed if they're the "default" strides (
accumulate(static_sizes[::-1], *)
) but even if that's true, note that for%subview
those "default" strides are[11*333*4444, 333*4444, 4444, 1] == [32556744, 1479852, 4444, 1]
and not the correct answer[32556744, 2959704, 4444, 1] == np.array(np.zeros([7, 11, 333, 4444], dtype=np.int32)[:, 0:22:2].strides) // 4
.🤷
Added tests verify/compare against
np.strides
.Footnotes
I could be wrong - I might've missed some magic somewhere. ↩