1
- import re
2
1
from functools import cached_property , reduce
3
- from typing import Tuple , Sequence , Optional , Union
2
+ from typing import Tuple , Sequence , Union
4
3
5
- from ....ir import Type , Value , MemRefType , ShapedType , MLIRError
6
-
7
- from ... import types as T
8
- from ....dialects .memref import *
9
- from ....dialects import memref , arith
10
4
from .arith import Scalar , constant
11
5
from .tensor import _indices_to_indexer , compute_result_shape_reassoc_list
6
+ from ... import types as T
12
7
from ...meta import region_op
13
- from ...._mlir_libs ._mlir import register_value_caster
14
8
from ...util import get_user_code_loc
9
+ from ...._mlir_libs ._mlir import register_value_caster
10
+ from ....dialects import memref , arith
15
11
from ....dialects ._ods_common import get_op_result_or_op_results
12
+ from ....dialects .memref import *
13
+ from ....ir import Type , Value , MemRefType , ShapedType
16
14
17
15
S = ShapedType .get_dynamic_size ()
18
16
@@ -70,71 +68,6 @@ def store(
70
68
return get_op_result_or_op_results (StoreOp (value , mem , indices , loc = loc , ip = ip ))
71
69
72
70
73
- def subview (
74
- source : "MemRef" ,
75
- offsets : Optional [Sequence [Value ]] = None ,
76
- strides : Optional [Sequence [Value ]] = None ,
77
- static_offsets : Optional [Sequence [int ]] = None ,
78
- static_sizes : Optional [Sequence [int ]] = None ,
79
- static_strides : Optional [Sequence [int ]] = None ,
80
- * ,
81
- loc = None ,
82
- ip = None ,
83
- ):
84
- if loc is None :
85
- loc = get_user_code_loc ()
86
- if offsets is None :
87
- offsets = []
88
- if static_offsets is None :
89
- static_offsets = []
90
- if strides is None :
91
- strides = []
92
- if static_strides is None :
93
- static_strides = []
94
- assert static_sizes , f"this convenience method only handles static sizes"
95
- sizes = []
96
- wrong_type = T .memref (* static_sizes , source .dtype )
97
- if offsets and static_offsets :
98
- assert all (s == S for s in static_offsets )
99
- if strides and static_strides :
100
- assert all (s == S for s in static_strides )
101
- val = memref .subview (
102
- wrong_type ,
103
- source ,
104
- offsets ,
105
- sizes ,
106
- strides ,
107
- static_offsets ,
108
- static_sizes ,
109
- static_strides ,
110
- loc = loc ,
111
- ip = ip ,
112
- )
113
- # dumbest hack ever - the default builder doesn't connect to inferReturnTypes
114
- # but the diag message does
115
- try :
116
- val .owner .verify ()
117
- return val
118
- except MLIRError as e :
119
- diag = str (e .error_diagnostics [0 ])
120
- correct_type = re .findall (r"'memref<(.*)>'" , diag )
121
- assert len (correct_type ) == 1
122
- correct_type = Type .parse (f"memref<{ correct_type [0 ]} >" )
123
- val .owner .erase ()
124
- return memref .subview (
125
- correct_type ,
126
- source ,
127
- offsets ,
128
- sizes ,
129
- strides ,
130
- static_offsets ,
131
- static_sizes ,
132
- static_strides ,
133
- loc = loc ,
134
- ip = ip ,
135
- )
136
-
137
-
138
71
@register_value_caster (MemRefType .static_typeid )
139
72
class MemRef (Value ):
140
73
def __str__ (self ):
@@ -266,16 +199,15 @@ def _subview(
266
199
if indexer .is_constant ():
267
200
out = subview (
268
201
out ,
269
- static_offsets = indexer .static_offsets (),
270
- static_sizes = indexer .static_sizes (),
271
- static_strides = indexer .static_strides (),
202
+ offsets = indexer .static_offsets (),
203
+ sizes = indexer .static_sizes (),
204
+ strides = indexer .static_strides (),
272
205
loc = loc ,
273
206
ip = ip ,
274
207
)
275
208
else :
276
209
# special tile case
277
210
offsets = [None ] * len (indexer .in_shape )
278
- static_offsets = [None ] * len (indexer .in_shape )
279
211
static_sizes = [None ] * len (indexer .in_shape )
280
212
static_strides = [None ] * len (indexer .in_shape )
281
213
for i , ind in enumerate (indexer .indices ):
@@ -292,15 +224,13 @@ def _subview(
292
224
and ind .step .is_constant ()
293
225
):
294
226
offsets [i ] = ind .start
295
- static_offsets [i ] = S
296
227
static_sizes [i ] = maybe_size .literal_value
297
228
static_strides [i ] = (
298
229
ind .step .literal_value if isinstance (ind .step , Scalar ) else ind .step
299
230
)
300
231
else :
301
232
raise RuntimeError (f"indexing not supported { indexer .indices } " )
302
233
offsets = list (filter (None , offsets ))
303
- static_offsets = list (filter (None , static_offsets ))
304
234
static_sizes = list (filter (None , static_sizes ))
305
235
static_strides = list (filter (None , static_strides ))
306
236
assert (
@@ -312,9 +242,8 @@ def _subview(
312
242
out = subview (
313
243
out ,
314
244
offsets = offsets ,
315
- static_offsets = static_offsets ,
316
- static_sizes = static_sizes ,
317
- static_strides = static_strides ,
245
+ sizes = static_sizes ,
246
+ strides = static_strides ,
318
247
loc = loc ,
319
248
ip = ip ,
320
249
)
0 commit comments