Skip to content

Commit 404af14

Browse files
authored
[mlir][python] enable memref.subview (#79393)
1 parent a356e6c commit 404af14

File tree

7 files changed

+516
-159
lines changed

7 files changed

+516
-159
lines changed

mlir/include/mlir-c/BuiltinTypes.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,12 @@ MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type);
408408
/// Returns the memory space of the given MemRef type.
409409
MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type);
410410

411+
/// Returns the strides of the MemRef if the layout map is in strided form.
412+
/// Both strides and offset are out params. strides must point to pre-allocated
413+
/// memory of length equal to the rank of the memref.
414+
MLIR_CAPI_EXPORTED MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(
415+
MlirType type, int64_t *strides, int64_t *offset);
416+
411417
/// Returns the memory spcae of the given Unranked MemRef type.
412418
MLIR_CAPI_EXPORTED MlirAttribute
413419
mlirUnrankedMemrefGetMemorySpace(MlirType type);

mlir/lib/Bindings/Python/IRTypes.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
#include "mlir-c/BuiltinAttributes.h"
1414
#include "mlir-c/BuiltinTypes.h"
15+
#include "mlir-c/Support.h"
16+
1517
#include <optional>
1618

1719
namespace py = pybind11;
@@ -618,6 +620,18 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
618620
return mlirMemRefTypeGetLayout(self);
619621
},
620622
"The layout of the MemRef type.")
623+
.def(
624+
"get_strides_and_offset",
625+
[](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
626+
std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
627+
int64_t offset;
628+
if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset(
629+
self, strides.data(), &offset)))
630+
throw std::runtime_error(
631+
"Failed to extract strides and offset from memref.");
632+
return {strides, offset};
633+
},
634+
"The strides and offset of the MemRef type.")
621635
.def_property_readonly(
622636
"affine_map",
623637
[](PyMemRefType &self) -> PyAffineMap {

mlir/lib/CAPI/IR/BuiltinTypes.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,16 @@
99
#include "mlir-c/BuiltinTypes.h"
1010
#include "mlir-c/AffineMap.h"
1111
#include "mlir-c/IR.h"
12+
#include "mlir-c/Support.h"
1213
#include "mlir/CAPI/AffineMap.h"
1314
#include "mlir/CAPI/IR.h"
1415
#include "mlir/CAPI/Support.h"
1516
#include "mlir/IR/AffineMap.h"
1617
#include "mlir/IR/BuiltinTypes.h"
1718
#include "mlir/IR/Types.h"
19+
#include "mlir/Support/LogicalResult.h"
20+
21+
#include <algorithm>
1822

1923
using namespace mlir;
2024

@@ -426,6 +430,18 @@ MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
426430
return wrap(llvm::cast<MemRefType>(unwrap(type)).getMemorySpace());
427431
}
428432

433+
MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type,
434+
int64_t *strides,
435+
int64_t *offset) {
436+
MemRefType memrefType = llvm::cast<MemRefType>(unwrap(type));
437+
SmallVector<int64_t> strides_;
438+
if (failed(getStridesAndOffset(memrefType, strides_, *offset)))
439+
return mlirLogicalResultFailure();
440+
441+
(void)std::copy(strides_.begin(), strides_.end(), strides);
442+
return mlirLogicalResultSuccess();
443+
}
444+
429445
MlirTypeID mlirUnrankedMemRefTypeGetTypeID() {
430446
return wrap(UnrankedMemRefType::getTypeID());
431447
}

mlir/python/mlir/dialects/_ods_common.py

Lines changed: 171 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,30 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5-
# Provide a convenient name for sub-packages to resolve the main C-extension
6-
# with a relative import.
7-
from .._mlir_libs import _mlir as _cext
85
from typing import (
6+
List as _List,
7+
Optional as _Optional,
98
Sequence as _Sequence,
9+
Tuple as _Tuple,
1010
Type as _Type,
1111
TypeVar as _TypeVar,
1212
Union as _Union,
1313
)
1414

15+
from .._mlir_libs import _mlir as _cext
16+
from ..ir import (
17+
ArrayAttr,
18+
Attribute,
19+
BoolAttr,
20+
DenseI64ArrayAttr,
21+
IntegerAttr,
22+
IntegerType,
23+
OpView,
24+
Operation,
25+
ShapedType,
26+
Value,
27+
)
28+
1529
__all__ = [
1630
"equally_sized_accessor",
1731
"get_default_loc_context",
@@ -138,3 +152,157 @@ def get_op_result_or_op_results(
138152
ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
139153
ResultValueT = _Union[ResultValueTypeTuple]
140154
VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]
155+
156+
StaticIntLike = _Union[int, IntegerAttr]
157+
ValueLike = _Union[Operation, OpView, Value]
158+
MixedInt = _Union[StaticIntLike, ValueLike]
159+
160+
IntOrAttrList = _Sequence[_Union[IntegerAttr, int]]
161+
OptionalIntList = _Optional[_Union[ArrayAttr, IntOrAttrList]]
162+
163+
BoolOrAttrList = _Sequence[_Union[BoolAttr, bool]]
164+
OptionalBoolList = _Optional[_Union[ArrayAttr, BoolOrAttrList]]
165+
166+
MixedValues = _Union[_Sequence[_Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike]
167+
168+
DynamicIndexList = _Sequence[_Union[MixedInt, _Sequence[MixedInt]]]
169+
170+
171+
def _dispatch_dynamic_index_list(
172+
indices: _Union[DynamicIndexList, ArrayAttr],
173+
) -> _Tuple[_List[ValueLike], _Union[_List[int], ArrayAttr], _List[bool]]:
174+
"""Dispatches a list of indices to the appropriate form.
175+
176+
This is similar to the custom `DynamicIndexList` directive upstream:
177+
provided indices may be in the form of dynamic SSA values or static values,
178+
and they may be scalable (i.e., as a singleton list) or not. This function
179+
dispatches each index into its respective form. It also extracts the SSA
180+
values and static indices from various similar structures, respectively.
181+
"""
182+
dynamic_indices = []
183+
static_indices = [ShapedType.get_dynamic_size()] * len(indices)
184+
scalable_indices = [False] * len(indices)
185+
186+
# ArrayAttr: Extract index values.
187+
if isinstance(indices, ArrayAttr):
188+
indices = [idx for idx in indices]
189+
190+
def process_nonscalable_index(i, index):
191+
"""Processes any form of non-scalable index.
192+
193+
Returns False if the given index was scalable and thus remains
194+
unprocessed; True otherwise.
195+
"""
196+
if isinstance(index, int):
197+
static_indices[i] = index
198+
elif isinstance(index, IntegerAttr):
199+
static_indices[i] = index.value # pytype: disable=attribute-error
200+
elif isinstance(index, (Operation, Value, OpView)):
201+
dynamic_indices.append(index)
202+
else:
203+
return False
204+
return True
205+
206+
# Process each index at a time.
207+
for i, index in enumerate(indices):
208+
if not process_nonscalable_index(i, index):
209+
# If it wasn't processed, it must be a scalable index, which is
210+
# provided as a _Sequence of one value, so extract and process that.
211+
scalable_indices[i] = True
212+
assert len(index) == 1
213+
ret = process_nonscalable_index(i, index[0])
214+
assert ret
215+
216+
return dynamic_indices, static_indices, scalable_indices
217+
218+
219+
# Dispatches `MixedValues` that all represents integers in various forms into
220+
# the following three categories:
221+
# - `dynamic_values`: a list of `Value`s, potentially from op results;
222+
# - `packed_values`: a value handle, potentially from an op result, associated
223+
# to one or more payload operations of integer type;
224+
# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python
225+
# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`.
226+
# The input is in the form for `packed_values`, only that result is set and the
227+
# other two are empty. Otherwise, the input can be a mix of the other two forms,
228+
# and for each dynamic value, a special value is added to the `static_values`.
229+
def _dispatch_mixed_values(
230+
values: MixedValues,
231+
) -> _Tuple[_List[Value], _Union[Operation, Value, OpView], DenseI64ArrayAttr]:
232+
dynamic_values = []
233+
packed_values = None
234+
static_values = None
235+
if isinstance(values, ArrayAttr):
236+
static_values = values
237+
elif isinstance(values, (Operation, Value, OpView)):
238+
packed_values = values
239+
else:
240+
static_values = []
241+
for size in values or []:
242+
if isinstance(size, int):
243+
static_values.append(size)
244+
else:
245+
static_values.append(ShapedType.get_dynamic_size())
246+
dynamic_values.append(size)
247+
static_values = DenseI64ArrayAttr.get(static_values)
248+
249+
return (dynamic_values, packed_values, static_values)
250+
251+
252+
def _get_value_or_attribute_value(
253+
value_or_attr: _Union[any, Attribute, ArrayAttr]
254+
) -> any:
255+
if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"):
256+
return value_or_attr.value
257+
if isinstance(value_or_attr, ArrayAttr):
258+
return _get_value_list(value_or_attr)
259+
return value_or_attr
260+
261+
262+
def _get_value_list(
263+
sequence_or_array_attr: _Union[_Sequence[any], ArrayAttr]
264+
) -> _Sequence[any]:
265+
return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr]
266+
267+
268+
def _get_int_array_attr(
269+
values: _Optional[_Union[ArrayAttr, IntOrAttrList]]
270+
) -> ArrayAttr:
271+
if values is None:
272+
return None
273+
274+
# Turn into a Python list of Python ints.
275+
values = _get_value_list(values)
276+
277+
# Make an ArrayAttr of IntegerAttrs out of it.
278+
return ArrayAttr.get(
279+
[IntegerAttr.get(IntegerType.get_signless(64), v) for v in values]
280+
)
281+
282+
283+
def _get_int_array_array_attr(
284+
values: _Optional[_Union[ArrayAttr, _Sequence[_Union[ArrayAttr, IntOrAttrList]]]]
285+
) -> ArrayAttr:
286+
"""Creates an ArrayAttr of ArrayAttrs of IntegerAttrs.
287+
288+
The input has to be a collection of a collection of integers, where any
289+
Python _Sequence and ArrayAttr are admissible collections and Python ints and
290+
any IntegerAttr are admissible integers. Both levels of collections are
291+
turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s.
292+
If the input is None, an empty ArrayAttr is returned.
293+
"""
294+
if values is None:
295+
return None
296+
297+
# Make sure the outer level is a list.
298+
values = _get_value_list(values)
299+
300+
# The inner level is now either invalid or a mixed sequence of ArrayAttrs and
301+
# Sequences. Make sure the nested values are all lists.
302+
values = [_get_value_list(nested) for nested in values]
303+
304+
# Turn each nested list into an ArrayAttr.
305+
values = [_get_int_array_attr(nested) for nested in values]
306+
307+
# Turn the outer list into an ArrayAttr.
308+
return ArrayAttr.get(values)

0 commit comments

Comments
 (0)