|
2 | 2 | # See https://llvm.org/LICENSE.txt for license information.
|
3 | 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
4 | 4 |
|
5 |
| -# Provide a convenient name for sub-packages to resolve the main C-extension |
6 |
| -# with a relative import. |
7 |
| -from .._mlir_libs import _mlir as _cext |
8 | 5 | from typing import (
|
| 6 | + List as _List, |
| 7 | + Optional as _Optional, |
9 | 8 | Sequence as _Sequence,
|
| 9 | + Tuple as _Tuple, |
10 | 10 | Type as _Type,
|
11 | 11 | TypeVar as _TypeVar,
|
12 | 12 | Union as _Union,
|
13 | 13 | )
|
14 | 14 |
|
| 15 | +from .._mlir_libs import _mlir as _cext |
| 16 | +from ..ir import ( |
| 17 | + ArrayAttr, |
| 18 | + Attribute, |
| 19 | + BoolAttr, |
| 20 | + DenseI64ArrayAttr, |
| 21 | + IntegerAttr, |
| 22 | + IntegerType, |
| 23 | + OpView, |
| 24 | + Operation, |
| 25 | + ShapedType, |
| 26 | + Value, |
| 27 | +) |
| 28 | + |
15 | 29 | __all__ = [
|
16 | 30 | "equally_sized_accessor",
|
17 | 31 | "get_default_loc_context",
|
@@ -138,3 +152,157 @@ def get_op_result_or_op_results(
|
138 | 152 | ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
|
139 | 153 | ResultValueT = _Union[ResultValueTypeTuple]
|
140 | 154 | VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]
|
| 155 | + |
| 156 | +StaticIntLike = _Union[int, IntegerAttr] |
| 157 | +ValueLike = _Union[Operation, OpView, Value] |
| 158 | +MixedInt = _Union[StaticIntLike, ValueLike] |
| 159 | + |
| 160 | +IntOrAttrList = _Sequence[_Union[IntegerAttr, int]] |
| 161 | +OptionalIntList = _Optional[_Union[ArrayAttr, IntOrAttrList]] |
| 162 | + |
| 163 | +BoolOrAttrList = _Sequence[_Union[BoolAttr, bool]] |
| 164 | +OptionalBoolList = _Optional[_Union[ArrayAttr, BoolOrAttrList]] |
| 165 | + |
| 166 | +MixedValues = _Union[_Sequence[_Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike] |
| 167 | + |
| 168 | +DynamicIndexList = _Sequence[_Union[MixedInt, _Sequence[MixedInt]]] |
| 169 | + |
| 170 | + |
| 171 | +def _dispatch_dynamic_index_list( |
| 172 | + indices: _Union[DynamicIndexList, ArrayAttr], |
| 173 | +) -> _Tuple[_List[ValueLike], _Union[_List[int], ArrayAttr], _List[bool]]: |
| 174 | + """Dispatches a list of indices to the appropriate form. |
| 175 | +
|
| 176 | + This is similar to the custom `DynamicIndexList` directive upstream: |
| 177 | + provided indices may be in the form of dynamic SSA values or static values, |
| 178 | + and they may be scalable (i.e., as a singleton list) or not. This function |
| 179 | + dispatches each index into its respective form. It also extracts the SSA |
| 180 | + values and static indices from various similar structures, respectively. |
| 181 | + """ |
| 182 | + dynamic_indices = [] |
| 183 | + static_indices = [ShapedType.get_dynamic_size()] * len(indices) |
| 184 | + scalable_indices = [False] * len(indices) |
| 185 | + |
| 186 | + # ArrayAttr: Extract index values. |
| 187 | + if isinstance(indices, ArrayAttr): |
| 188 | + indices = [idx for idx in indices] |
| 189 | + |
| 190 | + def process_nonscalable_index(i, index): |
| 191 | + """Processes any form of non-scalable index. |
| 192 | +
|
| 193 | + Returns False if the given index was scalable and thus remains |
| 194 | + unprocessed; True otherwise. |
| 195 | + """ |
| 196 | + if isinstance(index, int): |
| 197 | + static_indices[i] = index |
| 198 | + elif isinstance(index, IntegerAttr): |
| 199 | + static_indices[i] = index.value # pytype: disable=attribute-error |
| 200 | + elif isinstance(index, (Operation, Value, OpView)): |
| 201 | + dynamic_indices.append(index) |
| 202 | + else: |
| 203 | + return False |
| 204 | + return True |
| 205 | + |
| 206 | + # Process each index at a time. |
| 207 | + for i, index in enumerate(indices): |
| 208 | + if not process_nonscalable_index(i, index): |
| 209 | + # If it wasn't processed, it must be a scalable index, which is |
| 210 | + # provided as a _Sequence of one value, so extract and process that. |
| 211 | + scalable_indices[i] = True |
| 212 | + assert len(index) == 1 |
| 213 | + ret = process_nonscalable_index(i, index[0]) |
| 214 | + assert ret |
| 215 | + |
| 216 | + return dynamic_indices, static_indices, scalable_indices |
| 217 | + |
| 218 | + |
| 219 | +# Dispatches `MixedValues` that all represents integers in various forms into |
| 220 | +# the following three categories: |
| 221 | +# - `dynamic_values`: a list of `Value`s, potentially from op results; |
| 222 | +# - `packed_values`: a value handle, potentially from an op result, associated |
| 223 | +# to one or more payload operations of integer type; |
| 224 | +# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python |
| 225 | +# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`. |
| 226 | +# The input is in the form for `packed_values`, only that result is set and the |
| 227 | +# other two are empty. Otherwise, the input can be a mix of the other two forms, |
| 228 | +# and for each dynamic value, a special value is added to the `static_values`. |
| 229 | +def _dispatch_mixed_values( |
| 230 | + values: MixedValues, |
| 231 | +) -> _Tuple[_List[Value], _Union[Operation, Value, OpView], DenseI64ArrayAttr]: |
| 232 | + dynamic_values = [] |
| 233 | + packed_values = None |
| 234 | + static_values = None |
| 235 | + if isinstance(values, ArrayAttr): |
| 236 | + static_values = values |
| 237 | + elif isinstance(values, (Operation, Value, OpView)): |
| 238 | + packed_values = values |
| 239 | + else: |
| 240 | + static_values = [] |
| 241 | + for size in values or []: |
| 242 | + if isinstance(size, int): |
| 243 | + static_values.append(size) |
| 244 | + else: |
| 245 | + static_values.append(ShapedType.get_dynamic_size()) |
| 246 | + dynamic_values.append(size) |
| 247 | + static_values = DenseI64ArrayAttr.get(static_values) |
| 248 | + |
| 249 | + return (dynamic_values, packed_values, static_values) |
| 250 | + |
| 251 | + |
| 252 | +def _get_value_or_attribute_value( |
| 253 | + value_or_attr: _Union[any, Attribute, ArrayAttr] |
| 254 | +) -> any: |
| 255 | + if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"): |
| 256 | + return value_or_attr.value |
| 257 | + if isinstance(value_or_attr, ArrayAttr): |
| 258 | + return _get_value_list(value_or_attr) |
| 259 | + return value_or_attr |
| 260 | + |
| 261 | + |
| 262 | +def _get_value_list( |
| 263 | + sequence_or_array_attr: _Union[_Sequence[any], ArrayAttr] |
| 264 | +) -> _Sequence[any]: |
| 265 | + return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr] |
| 266 | + |
| 267 | + |
| 268 | +def _get_int_array_attr( |
| 269 | + values: _Optional[_Union[ArrayAttr, IntOrAttrList]] |
| 270 | +) -> ArrayAttr: |
| 271 | + if values is None: |
| 272 | + return None |
| 273 | + |
| 274 | + # Turn into a Python list of Python ints. |
| 275 | + values = _get_value_list(values) |
| 276 | + |
| 277 | + # Make an ArrayAttr of IntegerAttrs out of it. |
| 278 | + return ArrayAttr.get( |
| 279 | + [IntegerAttr.get(IntegerType.get_signless(64), v) for v in values] |
| 280 | + ) |
| 281 | + |
| 282 | + |
| 283 | +def _get_int_array_array_attr( |
| 284 | + values: _Optional[_Union[ArrayAttr, _Sequence[_Union[ArrayAttr, IntOrAttrList]]]] |
| 285 | +) -> ArrayAttr: |
| 286 | + """Creates an ArrayAttr of ArrayAttrs of IntegerAttrs. |
| 287 | +
|
| 288 | + The input has to be a collection of a collection of integers, where any |
| 289 | + Python _Sequence and ArrayAttr are admissible collections and Python ints and |
| 290 | + any IntegerAttr are admissible integers. Both levels of collections are |
| 291 | + turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s. |
| 292 | + If the input is None, an empty ArrayAttr is returned. |
| 293 | + """ |
| 294 | + if values is None: |
| 295 | + return None |
| 296 | + |
| 297 | + # Make sure the outer level is a list. |
| 298 | + values = _get_value_list(values) |
| 299 | + |
| 300 | + # The inner level is now either invalid or a mixed sequence of ArrayAttrs and |
| 301 | + # Sequences. Make sure the nested values are all lists. |
| 302 | + values = [_get_value_list(nested) for nested in values] |
| 303 | + |
| 304 | + # Turn each nested list into an ArrayAttr. |
| 305 | + values = [_get_int_array_attr(nested) for nested in values] |
| 306 | + |
| 307 | + # Turn the outer list into an ArrayAttr. |
| 308 | + return ArrayAttr.get(values) |
0 commit comments