Skip to content

Commit e671720

Browse files
committed
fix striding calcs for memref (depends on llvm/llvm-project#79393)
1 parent 8cb90ec commit e671720

File tree

3 files changed

+223
-157
lines changed

3 files changed

+223
-157
lines changed

mlir/extras/dialects/ext/memref.py

Lines changed: 11 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,16 @@
1-
import re
21
from functools import cached_property, reduce
3-
from typing import Tuple, Sequence, Optional, Union
2+
from typing import Tuple, Sequence, Union
43

5-
from ....ir import Type, Value, MemRefType, ShapedType, MLIRError
6-
7-
from ... import types as T
8-
from ....dialects.memref import *
9-
from ....dialects import memref, arith
104
from .arith import Scalar, constant
115
from .tensor import _indices_to_indexer, compute_result_shape_reassoc_list
6+
from ... import types as T
127
from ...meta import region_op
13-
from ...._mlir_libs._mlir import register_value_caster
148
from ...util import get_user_code_loc
9+
from ...._mlir_libs._mlir import register_value_caster
10+
from ....dialects import memref, arith
1511
from ....dialects._ods_common import get_op_result_or_op_results
12+
from ....dialects.memref import *
13+
from ....ir import Type, Value, MemRefType, ShapedType
1614

1715
S = ShapedType.get_dynamic_size()
1816

@@ -70,71 +68,6 @@ def store(
7068
return get_op_result_or_op_results(StoreOp(value, mem, indices, loc=loc, ip=ip))
7169

7270

73-
def subview(
74-
source: "MemRef",
75-
offsets: Optional[Sequence[Value]] = None,
76-
strides: Optional[Sequence[Value]] = None,
77-
static_offsets: Optional[Sequence[int]] = None,
78-
static_sizes: Optional[Sequence[int]] = None,
79-
static_strides: Optional[Sequence[int]] = None,
80-
*,
81-
loc=None,
82-
ip=None,
83-
):
84-
if loc is None:
85-
loc = get_user_code_loc()
86-
if offsets is None:
87-
offsets = []
88-
if static_offsets is None:
89-
static_offsets = []
90-
if strides is None:
91-
strides = []
92-
if static_strides is None:
93-
static_strides = []
94-
assert static_sizes, f"this convenience method only handles static sizes"
95-
sizes = []
96-
wrong_type = T.memref(*static_sizes, source.dtype)
97-
if offsets and static_offsets:
98-
assert all(s == S for s in static_offsets)
99-
if strides and static_strides:
100-
assert all(s == S for s in static_strides)
101-
val = memref.subview(
102-
wrong_type,
103-
source,
104-
offsets,
105-
sizes,
106-
strides,
107-
static_offsets,
108-
static_sizes,
109-
static_strides,
110-
loc=loc,
111-
ip=ip,
112-
)
113-
# dumbest hack ever - the default builder doesn't connect to inferReturnTypes
114-
# but the diag message does
115-
try:
116-
val.owner.verify()
117-
return val
118-
except MLIRError as e:
119-
diag = str(e.error_diagnostics[0])
120-
correct_type = re.findall(r"'memref<(.*)>'", diag)
121-
assert len(correct_type) == 1
122-
correct_type = Type.parse(f"memref<{correct_type[0]}>")
123-
val.owner.erase()
124-
return memref.subview(
125-
correct_type,
126-
source,
127-
offsets,
128-
sizes,
129-
strides,
130-
static_offsets,
131-
static_sizes,
132-
static_strides,
133-
loc=loc,
134-
ip=ip,
135-
)
136-
137-
13871
@register_value_caster(MemRefType.static_typeid)
13972
class MemRef(Value):
14073
def __str__(self):
@@ -266,16 +199,15 @@ def _subview(
266199
if indexer.is_constant():
267200
out = subview(
268201
out,
269-
static_offsets=indexer.static_offsets(),
270-
static_sizes=indexer.static_sizes(),
271-
static_strides=indexer.static_strides(),
202+
offsets=indexer.static_offsets(),
203+
sizes=indexer.static_sizes(),
204+
strides=indexer.static_strides(),
272205
loc=loc,
273206
ip=ip,
274207
)
275208
else:
276209
# special tile case
277210
offsets = [None] * len(indexer.in_shape)
278-
static_offsets = [None] * len(indexer.in_shape)
279211
static_sizes = [None] * len(indexer.in_shape)
280212
static_strides = [None] * len(indexer.in_shape)
281213
for i, ind in enumerate(indexer.indices):
@@ -292,15 +224,13 @@ def _subview(
292224
and ind.step.is_constant()
293225
):
294226
offsets[i] = ind.start
295-
static_offsets[i] = S
296227
static_sizes[i] = maybe_size.literal_value
297228
static_strides[i] = (
298229
ind.step.literal_value if isinstance(ind.step, Scalar) else ind.step
299230
)
300231
else:
301232
raise RuntimeError(f"indexing not supported {indexer.indices}")
302233
offsets = list(filter(None, offsets))
303-
static_offsets = list(filter(None, static_offsets))
304234
static_sizes = list(filter(None, static_sizes))
305235
static_strides = list(filter(None, static_strides))
306236
assert (
@@ -312,9 +242,8 @@ def _subview(
312242
out = subview(
313243
out,
314244
offsets=offsets,
315-
static_offsets=static_offsets,
316-
static_sizes=static_sizes,
317-
static_strides=static_strides,
245+
sizes=static_sizes,
246+
strides=static_strides,
318247
loc=loc,
319248
ip=ip,
320249
)

mlir/extras/testing/testing.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
from ..context import MLIRContext, mlir_mod_ctx
1313
from .generate_test_checks import main
1414
from ..runtime.refbackend import LLVMJITBackend
15+
from ...ir import Module
1516

1617

1718
def filecheck(correct: str, module):
19+
if isinstance(module, Module):
20+
assert module.operation.verify()
1821
filecheck_name = "FileCheck"
1922
if platform.system() == "Windows":
2023
filecheck_name += ".exe"

0 commit comments

Comments
 (0)