Skip to content

Commit b07ff6f

Browse files
committed
[mlir][python] enable memref.subview
1 parent 72ce629 commit b07ff6f

File tree

5 files changed

+113
-0
lines changed

5 files changed

+113
-0
lines changed

mlir/include/mlir-c/BuiltinTypes.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,13 @@ MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type);
408408
/// Returns the memory space of the given MemRef type.
409409
MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type);
410410

411+
/// Returns the strides of the MemRef if the layout map is in strided form.
412+
/// Both strides and offset are out params. strides must point to pre-allocated
413+
/// memory of length equal to the rank of the memref.
414+
MLIR_CAPI_EXPORTED void mlirMemRefTypeGetStridesAndOffset(MlirType type,
415+
int64_t *strides,
416+
int64_t *offset);
417+
411418
/// Returns the memory spcae of the given Unranked MemRef type.
412419
MLIR_CAPI_EXPORTED MlirAttribute
413420
mlirUnrankedMemrefGetMemorySpace(MlirType type);

mlir/lib/Bindings/Python/IRTypes.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,15 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
618618
return mlirMemRefTypeGetLayout(self);
619619
},
620620
"The layout of the MemRef type.")
621+
.def_property_readonly(
622+
"strides_and_offset",
623+
[](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
624+
std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
625+
int64_t offset;
626+
mlirMemRefTypeGetStridesAndOffset(self, strides.data(), &offset);
627+
return {strides, offset};
628+
},
629+
"The strides and offset of the MemRef type.")
621630
.def_property_readonly(
622631
"affine_map",
623632
[](PyMemRefType &self) -> PyAffineMap {

mlir/lib/CAPI/IR/BuiltinTypes.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#include "mlir/IR/BuiltinTypes.h"
1717
#include "mlir/IR/Types.h"
1818

19+
#include <algorithm>
20+
1921
using namespace mlir;
2022

2123
//===----------------------------------------------------------------------===//
@@ -426,6 +428,18 @@ MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
426428
return wrap(llvm::cast<MemRefType>(unwrap(type)).getMemorySpace());
427429
}
428430

431+
void mlirMemRefTypeGetStridesAndOffset(MlirType type, int64_t *strides,
432+
int64_t *offset) {
433+
MemRefType memrefType = llvm::cast<MemRefType>(unwrap(type));
434+
std::pair<SmallVector<int64_t>, int64_t> stridesOffsets =
435+
getStridesAndOffset(memrefType);
436+
assert(stridesOffsets.first.size() == memrefType.getRank() &&
437+
"Strides and rank don't match for memref");
438+
(void)std::copy(stridesOffsets.first.begin(), stridesOffsets.first.end(),
439+
strides);
440+
*offset = stridesOffsets.second;
441+
}
442+
429443
MlirTypeID mlirUnrankedMemRefTypeGetTypeID() {
430444
return wrap(UnrankedMemRefType::getTypeID());
431445
}

mlir/python/mlir/dialects/memref.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,74 @@
11
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
from typing import Optional, Sequence
45

56
from ._memref_ops_gen import *
7+
from ..ir import Value, ShapedType, MemRefType, StridedLayoutAttr
8+
9+
10+
def _infer_memref_subview_result_type(
11+
source_memref_type, static_offsets, static_sizes, static_strides
12+
):
13+
source_strides, source_offset = source_memref_type.strides_and_offset
14+
target_offset = source_offset
15+
for static_offset, target_stride in zip(static_offsets, source_strides):
16+
target_offset += static_offset * target_stride
17+
18+
target_strides = []
19+
for source_stride, static_stride in zip(source_strides, static_strides):
20+
target_strides.append(source_stride * static_stride)
21+
22+
layout = StridedLayoutAttr.get(target_offset, target_strides)
23+
return MemRefType.get(
24+
static_sizes,
25+
source_memref_type.element_type,
26+
layout,
27+
source_memref_type.memory_space,
28+
)
29+
30+
31+
_generated_subview = subview
32+
33+
34+
def subview(
35+
source: Value,
36+
offsets: Optional[Sequence[Value]] = None,
37+
strides: Optional[Sequence[Value]] = None,
38+
static_offsets: Optional[Sequence[int]] = None,
39+
static_sizes: Optional[Sequence[int]] = None,
40+
static_strides: Optional[Sequence[int]] = None,
41+
*,
42+
loc=None,
43+
ip=None,
44+
):
45+
if offsets is None:
46+
offsets = []
47+
if static_offsets is None:
48+
static_offsets = []
49+
if strides is None:
50+
strides = []
51+
if static_strides is None:
52+
static_strides = []
53+
assert static_sizes, f"this convenience method only handles static sizes"
54+
sizes = []
55+
S = ShapedType.get_dynamic_size()
56+
if offsets and static_offsets:
57+
assert all(s == S for s in static_offsets)
58+
if strides and static_strides:
59+
assert all(s == S for s in static_strides)
60+
result_type = _infer_memref_subview_result_type(
61+
source.type, static_offsets, static_sizes, static_strides
62+
)
63+
return _generated_subview(
64+
result_type,
65+
source,
66+
offsets,
67+
sizes,
68+
strides,
69+
static_offsets,
70+
static_sizes,
71+
static_strides,
72+
loc=loc,
73+
ip=ip,
74+
)

mlir/test/python/dialects/memref.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,17 @@ def testMemRefAttr():
8888
memref.global_("objFifo_in0", T.memref(16, T.i32()))
8989
# CHECK: memref.global @objFifo_in0 : memref<16xi32>
9090
print(module)
91+
92+
93+
# CHECK-LABEL: TEST: testSubViewOpInferReturnTypes
94+
@run
95+
def testSubViewOpInferReturnTypes():
96+
with Context() as ctx, Location.unknown(ctx):
97+
module = Module.create()
98+
with InsertionPoint(module.body):
99+
x = memref.alloc(T.memref(10, 10, T.i32()), [], [])
100+
# CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<10x10xi32>
101+
print(x.owner)
102+
y = memref.subview(x, [], [], [1, 1], [3, 3], [1, 1])
103+
# CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>>
104+
print(y.owner)

0 commit comments

Comments
 (0)