Skip to content

Commit bdb9ace

Browse files
committed
[mlir][python] enable memref.subview
1 parent 72ce629 commit bdb9ace

File tree

5 files changed

+226
-0
lines changed

5 files changed

+226
-0
lines changed

mlir/include/mlir-c/BuiltinTypes.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,13 @@ MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type);
408408
/// Returns the memory space of the given MemRef type.
409409
MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type);
410410

411+
/// Returns the strides of the MemRef if the layout map is in strided form.
412+
/// Both strides and offset are out params. strides must point to pre-allocated
413+
/// memory of length equal to the rank of the memref.
414+
MLIR_CAPI_EXPORTED void mlirMemRefTypeGetStridesAndOffset(MlirType type,
415+
int64_t *strides,
416+
int64_t *offset);
417+
411418
/// Returns the memory spcae of the given Unranked MemRef type.
412419
MLIR_CAPI_EXPORTED MlirAttribute
413420
mlirUnrankedMemrefGetMemorySpace(MlirType type);

mlir/lib/Bindings/Python/IRTypes.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,15 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
618618
return mlirMemRefTypeGetLayout(self);
619619
},
620620
"The layout of the MemRef type.")
621+
.def_property_readonly(
622+
"strides_and_offset",
623+
[](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
624+
std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
625+
int64_t offset;
626+
mlirMemRefTypeGetStridesAndOffset(self, strides.data(), &offset);
627+
return {strides, offset};
628+
},
629+
"The strides and offset of the MemRef type.")
621630
.def_property_readonly(
622631
"affine_map",
623632
[](PyMemRefType &self) -> PyAffineMap {

mlir/lib/CAPI/IR/BuiltinTypes.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#include "mlir/IR/BuiltinTypes.h"
1717
#include "mlir/IR/Types.h"
1818

19+
#include <algorithm>
20+
1921
using namespace mlir;
2022

2123
//===----------------------------------------------------------------------===//
@@ -426,6 +428,18 @@ MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
426428
return wrap(llvm::cast<MemRefType>(unwrap(type)).getMemorySpace());
427429
}
428430

431+
void mlirMemRefTypeGetStridesAndOffset(MlirType type, int64_t *strides,
432+
int64_t *offset) {
433+
MemRefType memrefType = llvm::cast<MemRefType>(unwrap(type));
434+
std::pair<SmallVector<int64_t>, int64_t> stridesOffsets =
435+
getStridesAndOffset(memrefType);
436+
assert(stridesOffsets.first.size() == memrefType.getRank() &&
437+
"Strides and rank don't match for memref");
438+
(void)std::copy(stridesOffsets.first.begin(), stridesOffsets.first.end(),
439+
strides);
440+
*offset = stridesOffsets.second;
441+
}
442+
429443
MlirTypeID mlirUnrankedMemRefTypeGetTypeID() {
430444
return wrap(UnrankedMemRefType::getTypeID());
431445
}

mlir/python/mlir/dialects/memref.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,105 @@
11
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
from typing import Optional
45

56
from ._memref_ops_gen import *
7+
from .arith import ConstantOp, _is_integer_like_type
8+
from .transform.structured import _dispatch_mixed_values, MixedValues
9+
from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType
10+
11+
12+
def _infer_memref_subview_result_type(
13+
source_memref_type, static_offsets, static_sizes, static_strides
14+
):
15+
source_strides, source_offset = source_memref_type.strides_and_offset
16+
assert all(
17+
all(
18+
(isinstance(i, int) and not ShapedType.is_dynamic_size(i))
19+
or (isinstance(i, Value) and isinstance(i.owner.opview, ConstantOp))
20+
and _is_integer_like_type(i.type)
21+
for i in s
22+
)
23+
for s in [
24+
static_offsets,
25+
static_sizes,
26+
static_strides,
27+
source_strides,
28+
[source_offset],
29+
]
30+
), f"Only inferring from python or mlir integer constant is supported"
31+
for s in [static_offsets, static_sizes, static_strides]:
32+
for idx, i in enumerate(s):
33+
if isinstance(i, Value):
34+
s[idx] = i.owner.opview.literal_value
35+
36+
target_offset = source_offset
37+
for static_offset, target_stride in zip(static_offsets, source_strides):
38+
target_offset += static_offset * target_stride
39+
40+
target_strides = []
41+
for source_stride, static_stride in zip(source_strides, static_strides):
42+
target_strides.append(source_stride * static_stride)
43+
44+
layout = StridedLayoutAttr.get(target_offset, target_strides)
45+
return MemRefType.get(
46+
static_sizes,
47+
source_memref_type.element_type,
48+
layout,
49+
source_memref_type.memory_space,
50+
)
51+
52+
53+
_generated_subview = subview
54+
55+
56+
def subview(
57+
source: Value,
58+
offsets: MixedValues,
59+
sizes: MixedValues,
60+
strides: MixedValues,
61+
*,
62+
result_type: Optional[MemRefType] = None,
63+
loc=None,
64+
ip=None,
65+
):
66+
if offsets is None:
67+
offsets = []
68+
if sizes is None:
69+
sizes = []
70+
if strides is None:
71+
strides = []
72+
73+
source_strides, source_offset = source.type.strides_and_offset
74+
if all(
75+
all(
76+
(isinstance(i, int) and not ShapedType.is_dynamic_size(i))
77+
or (isinstance(i, Value) and isinstance(i.owner.opview, ConstantOp))
78+
for i in s
79+
)
80+
for s in [offsets, sizes, strides, source_strides, [source_offset]]
81+
):
82+
result_type = _infer_memref_subview_result_type(
83+
source.type, offsets, sizes, strides
84+
)
85+
else:
86+
assert (
87+
result_type is not None
88+
), "mixed static/dynamic offset/sizes/strides requires explicit result type"
89+
90+
offsets, _packed_offsets, static_offsets = _dispatch_mixed_values(offsets)
91+
sizes, _packed_sizes, static_sizes = _dispatch_mixed_values(sizes)
92+
strides, _packed_strides, static_strides = _dispatch_mixed_values(strides)
93+
94+
return _generated_subview(
95+
result_type,
96+
source,
97+
offsets,
98+
sizes,
99+
strides,
100+
static_offsets,
101+
static_sizes,
102+
static_strides,
103+
loc=loc,
104+
ip=ip,
105+
)

mlir/test/python/dialects/memref.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from mlir.ir import *
44
import mlir.dialects.func as func
55
import mlir.dialects.memref as memref
6+
from mlir.dialects.memref import _infer_memref_subview_result_type
7+
import mlir.dialects.arith as arith
68
import mlir.extras.types as T
79

810

@@ -88,3 +90,97 @@ def testMemRefAttr():
8890
memref.global_("objFifo_in0", T.memref(16, T.i32()))
8991
# CHECK: memref.global @objFifo_in0 : memref<16xi32>
9092
print(module)
93+
94+
95+
# CHECK-LABEL: TEST: testSubViewOpInferReturnType
96+
@run
97+
def testSubViewOpInferReturnType():
98+
with Context() as ctx, Location.unknown(ctx):
99+
module = Module.create()
100+
with InsertionPoint(module.body):
101+
x = memref.alloc(T.memref(10, 10, T.i32()), [], [])
102+
# CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<10x10xi32>
103+
print(x.owner)
104+
105+
y = memref.subview(x, [1, 1], [3, 3], [1, 1])
106+
# CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>>
107+
print(y.owner)
108+
109+
z = memref.subview(
110+
x,
111+
[arith.constant(T.index(), 1), 1],
112+
[3, 3],
113+
[1, 1],
114+
)
115+
# CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>>
116+
print(z.owner)
117+
118+
z = memref.subview(
119+
x,
120+
[arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
121+
[3, 3],
122+
[1, 1],
123+
)
124+
# CHECK: %{{.*}} = memref.subview %[[ALLOC]][3, 4] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 34>>
125+
print(z.owner)
126+
127+
try:
128+
memref.subview(
129+
x,
130+
[
131+
arith.addi(
132+
arith.constant(T.index(), 3), arith.constant(T.index(), 4)
133+
),
134+
0,
135+
],
136+
[3, 3],
137+
[1, 1],
138+
)
139+
except AssertionError as e:
140+
# CHECK: mixed static/dynamic offset/sizes/strides requires explicit result type
141+
print(e)
142+
143+
try:
144+
_infer_memref_subview_result_type(
145+
x.type,
146+
[arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
147+
[ShapedType.get_dynamic_size(), 3],
148+
[1, 1],
149+
)
150+
except AssertionError as e:
151+
# CHECK: Only inferring from python or mlir integer constant is supported
152+
print(e)
153+
154+
try:
155+
memref.subview(
156+
x,
157+
[arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
158+
[ShapedType.get_dynamic_size(), 3],
159+
[1, 1],
160+
)
161+
except AssertionError as e:
162+
# CHECK: mixed static/dynamic offset/sizes/strides requires explicit result type
163+
print(e)
164+
165+
layout = StridedLayoutAttr.get(ShapedType.get_dynamic_size(), [10, 1])
166+
x = memref.alloc(
167+
T.memref(
168+
10,
169+
10,
170+
T.i32(),
171+
layout=layout,
172+
),
173+
[],
174+
[arith.constant(T.index(), 42)],
175+
)
176+
# CHECK: %[[DYNAMICALLOC:.*]] = memref.alloc()[%c42] : memref<10x10xi32, strided<[10, 1], offset: ?>>
177+
print(x.owner)
178+
y = memref.subview(
179+
x,
180+
[1, 1],
181+
[3, 3],
182+
[1, 1],
183+
result_type=T.memref(3, 3, T.i32(), layout=layout),
184+
)
185+
# CHECK: %subview_9 = memref.subview %[[DYNAMICALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32, strided<[10, 1], offset: ?>> to memref<3x3xi32, strided<[10, 1], offset: ?>>
186+
print(y.owner)

0 commit comments

Comments
 (0)