Skip to content

Commit d050d09

Browse files
committed
fix striding calcs for memref (depends on llvm/llvm-project#79393)
1 parent 8cb90ec commit d050d09

File tree

10 files changed

+243
-184
lines changed

10 files changed

+243
-184
lines changed

mlir/extras/dialects/ext/arith.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,20 @@
55
from typing import Union, Optional, Tuple
66

77
import numpy as np
8+
9+
from ...util import get_user_code_loc, infer_mlir_type, mlir_type_to_np_dtype
10+
from ...._mlir_libs._mlir import register_value_caster
811
from ....dialects import arith as arith_dialect
9-
from ....dialects.arith import *
1012
from ....dialects import complex as complex_dialect
1113
from ....dialects._arith_enum_gen import (
1214
_arith_cmpfpredicateattr,
1315
CmpFPredicate,
1416
CmpIPredicate,
1517
_arith_cmpipredicateattr,
1618
)
17-
from ....dialects.arith import _is_integer_like_type
1819
from ....dialects._ods_common import get_op_result_or_value, get_op_result_or_op_results
20+
from ....dialects.arith import *
21+
from ....dialects.arith import _is_integer_like_type
1922
from ....dialects.linalg.opdsl.lang.emitter import (
2023
_is_floating_point_type,
2124
_is_integer_type,
@@ -44,9 +47,6 @@
4447
FloatAttr,
4548
)
4649

47-
from ...util import get_user_code_loc, infer_mlir_type, mlir_type_to_np_dtype
48-
from ...._mlir_libs._mlir import register_value_caster
49-
5050

5151
def constant(
5252
value: Union[int, float, bool, np.ndarray],

mlir/extras/dialects/ext/cf.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
from typing import Union, List
22

3-
from ....dialects.cf import *
3+
from ...util import get_user_code_loc, Successor
44
from ....dialects._cf_ops_gen import _Dialect
55
from ....dialects._ods_common import (
6-
get_op_result_or_value,
7-
get_op_results_or_values,
8-
get_default_loc_context,
9-
segmented_accessor,
106
_cext,
117
)
12-
from ....ir import Value, InsertionPoint, Block, OpView
13-
from ...util import get_user_code_loc, Successor
8+
from ....dialects.cf import *
9+
from ....ir import Value, InsertionPoint, Block
1410

1511

1612
@_cext.register_operation(_Dialect, replace=True)

mlir/extras/dialects/ext/func.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,10 @@
66
from ...meta import op_region_builder
77
from ...util import get_user_code_loc, make_maybe_no_args_decorator
88
from ....dialects.func import *
9-
from ....extras import types as T
109
from ....ir import (
1110
FlatSymbolRefAttr,
1211
FunctionType,
1312
InsertionPoint,
14-
StringAttr,
1513
Type,
1614
TypeAttr,
1715
Value,

mlir/extras/dialects/ext/gpu.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,17 @@
22
from functools import partial
33
from typing import Optional, Any, List, Tuple
44

5+
from .arith import constant
6+
from .func import FuncBase
7+
from ... import types as T
8+
from ...meta import (
9+
region_op,
10+
)
11+
from ...util import get_user_code_loc, make_maybe_no_args_decorator, ModuleMeta
12+
from ....dialects._gpu_ops_gen import _Dialect
513
from ....dialects._ods_common import get_default_loc_context, _cext
14+
from ....dialects._ods_common import get_op_result_or_op_results
615
from ....dialects.gpu import *
7-
from ....dialects._gpu_ops_gen import _Dialect
816
from ....ir import (
917
Type,
1018
Attribute,
@@ -17,15 +25,6 @@
1725
Value,
1826
)
1927

20-
from ... import types as T
21-
from .arith import constant
22-
from .func import FuncBase
23-
from ...meta import (
24-
region_op,
25-
)
26-
from ....dialects._ods_common import get_op_result_or_op_results
27-
from ...util import get_user_code_loc, make_maybe_no_args_decorator, ModuleMeta
28-
2928

3029
def block_id_x():
3130
return block_id("x")

mlir/extras/dialects/ext/linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from ....dialects import linalg
21
from ...util import get_user_code_loc
2+
from ....dialects import linalg
33

44

55
def abs(I, O, *, loc=None, ip=None):

mlir/extras/dialects/ext/llvm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from ....ir import Type
21
from ....dialects.llvm import *
2+
from ....ir import Type
33

44

55
def llvm_ptr_t():

mlir/extras/dialects/ext/memref.py

Lines changed: 11 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,16 @@
1-
import re
21
from functools import cached_property, reduce
3-
from typing import Tuple, Sequence, Optional, Union
2+
from typing import Tuple, Sequence, Union
43

5-
from ....ir import Type, Value, MemRefType, ShapedType, MLIRError
6-
7-
from ... import types as T
8-
from ....dialects.memref import *
9-
from ....dialects import memref, arith
104
from .arith import Scalar, constant
115
from .tensor import _indices_to_indexer, compute_result_shape_reassoc_list
6+
from ... import types as T
127
from ...meta import region_op
13-
from ...._mlir_libs._mlir import register_value_caster
148
from ...util import get_user_code_loc
9+
from ...._mlir_libs._mlir import register_value_caster
10+
from ....dialects import memref, arith
1511
from ....dialects._ods_common import get_op_result_or_op_results
12+
from ....dialects.memref import *
13+
from ....ir import Type, Value, MemRefType, ShapedType
1614

1715
S = ShapedType.get_dynamic_size()
1816

@@ -70,71 +68,6 @@ def store(
7068
return get_op_result_or_op_results(StoreOp(value, mem, indices, loc=loc, ip=ip))
7169

7270

73-
def subview(
74-
source: "MemRef",
75-
offsets: Optional[Sequence[Value]] = None,
76-
strides: Optional[Sequence[Value]] = None,
77-
static_offsets: Optional[Sequence[int]] = None,
78-
static_sizes: Optional[Sequence[int]] = None,
79-
static_strides: Optional[Sequence[int]] = None,
80-
*,
81-
loc=None,
82-
ip=None,
83-
):
84-
if loc is None:
85-
loc = get_user_code_loc()
86-
if offsets is None:
87-
offsets = []
88-
if static_offsets is None:
89-
static_offsets = []
90-
if strides is None:
91-
strides = []
92-
if static_strides is None:
93-
static_strides = []
94-
assert static_sizes, f"this convenience method only handles static sizes"
95-
sizes = []
96-
wrong_type = T.memref(*static_sizes, source.dtype)
97-
if offsets and static_offsets:
98-
assert all(s == S for s in static_offsets)
99-
if strides and static_strides:
100-
assert all(s == S for s in static_strides)
101-
val = memref.subview(
102-
wrong_type,
103-
source,
104-
offsets,
105-
sizes,
106-
strides,
107-
static_offsets,
108-
static_sizes,
109-
static_strides,
110-
loc=loc,
111-
ip=ip,
112-
)
113-
# dumbest hack ever - the default builder doesn't connect to inferReturnTypes
114-
# but the diag message does
115-
try:
116-
val.owner.verify()
117-
return val
118-
except MLIRError as e:
119-
diag = str(e.error_diagnostics[0])
120-
correct_type = re.findall(r"'memref<(.*)>'", diag)
121-
assert len(correct_type) == 1
122-
correct_type = Type.parse(f"memref<{correct_type[0]}>")
123-
val.owner.erase()
124-
return memref.subview(
125-
correct_type,
126-
source,
127-
offsets,
128-
sizes,
129-
strides,
130-
static_offsets,
131-
static_sizes,
132-
static_strides,
133-
loc=loc,
134-
ip=ip,
135-
)
136-
137-
13871
@register_value_caster(MemRefType.static_typeid)
13972
class MemRef(Value):
14073
def __str__(self):
@@ -266,16 +199,15 @@ def _subview(
266199
if indexer.is_constant():
267200
out = subview(
268201
out,
269-
static_offsets=indexer.static_offsets(),
270-
static_sizes=indexer.static_sizes(),
271-
static_strides=indexer.static_strides(),
202+
offsets=indexer.static_offsets(),
203+
sizes=indexer.static_sizes(),
204+
strides=indexer.static_strides(),
272205
loc=loc,
273206
ip=ip,
274207
)
275208
else:
276209
# special tile case
277210
offsets = [None] * len(indexer.in_shape)
278-
static_offsets = [None] * len(indexer.in_shape)
279211
static_sizes = [None] * len(indexer.in_shape)
280212
static_strides = [None] * len(indexer.in_shape)
281213
for i, ind in enumerate(indexer.indices):
@@ -292,15 +224,13 @@ def _subview(
292224
and ind.step.is_constant()
293225
):
294226
offsets[i] = ind.start
295-
static_offsets[i] = S
296227
static_sizes[i] = maybe_size.literal_value
297228
static_strides[i] = (
298229
ind.step.literal_value if isinstance(ind.step, Scalar) else ind.step
299230
)
300231
else:
301232
raise RuntimeError(f"indexing not supported {indexer.indices}")
302233
offsets = list(filter(None, offsets))
303-
static_offsets = list(filter(None, static_offsets))
304234
static_sizes = list(filter(None, static_sizes))
305235
static_strides = list(filter(None, static_strides))
306236
assert (
@@ -312,9 +242,8 @@ def _subview(
312242
out = subview(
313243
out,
314244
offsets=offsets,
315-
static_offsets=static_offsets,
316-
static_sizes=static_sizes,
317-
static_strides=static_strides,
245+
sizes=static_sizes,
246+
strides=static_strides,
318247
loc=loc,
319248
ip=ip,
320249
)

mlir/extras/dialects/ext/nvgpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
TensorMapOOBKind,
77
TensorMapInterleaveKind,
88
)
9-
from ....ir import Attribute, Type
9+
from ....ir import Type
1010

1111

1212
def tensormap_descriptor(

mlir/extras/testing/testing.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
from ..context import MLIRContext, mlir_mod_ctx
1313
from .generate_test_checks import main
1414
from ..runtime.refbackend import LLVMJITBackend
15+
from ...ir import Module
1516

1617

1718
def filecheck(correct: str, module):
19+
if isinstance(module, Module):
20+
assert module.operation.verify()
1821
filecheck_name = "FileCheck"
1922
if platform.system() == "Windows":
2023
filecheck_name += ".exe"

0 commit comments

Comments
 (0)