Skip to content

Commit a72616d

Browse files
authored
[mlir][python] fix linalg.pack/unpack (#127729)
This PR #123902 broke python bindings for `tensor.pack`/`unpack`. This PR fixes that. It also 1. adds convenience wrappers for pack/unpack 2. cleans up matmul-like ops in the linalg bindings 3. fixes linalg docs missing pack/unpack
1 parent 60c6202 commit a72616d

File tree

4 files changed

+125
-7
lines changed

4 files changed

+125
-7
lines changed

mlir/docs/Dialects/Linalg/_index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,3 +695,4 @@ the same IR.
695695
## Operations
696696

697697
[include "Dialects/LinalgOps.md"]
698+
[include "Dialects/LinalgRelayoutOps.td"]

mlir/python/mlir/dialects/LinalgOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@
1111

1212
include "mlir/Dialect/Linalg/IR/LinalgOps.td"
1313
include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td"
14+
include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td"
1415

1516
#endif

mlir/python/mlir/dialects/linalg/__init__.py

Lines changed: 83 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,11 @@
5858
from .opdsl.ops.core_named_ops import *
5959

6060
from ...ir import *
61-
from .._ods_common import get_op_result_or_value as _get_op_result_or_value
61+
from .._ods_common import (
62+
get_op_result_or_value as _get_op_result_or_value,
63+
get_op_result_or_op_results as _get_op_result_or_op_results,
64+
_dispatch_mixed_values,
65+
)
6266
from ...extras.meta import region_op
6367

6468

@@ -149,7 +153,7 @@ def __init__(
149153
generic = region_op(GenericOp_, terminator=YieldOp)
150154

151155

152-
def create_op(
156+
def _create_matmul_like_op(
153157
op_type,
154158
*ins: Union[Operation, OpView, Value],
155159
outs: Sequence[Union[Operation, OpView, Value]],
@@ -179,7 +183,11 @@ def matmul(
179183
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
180184
cast: Optional[Union[TypeFn, Attribute]] = None,
181185
):
182-
return create_op(MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast)
186+
return _get_op_result_or_op_results(
187+
_create_matmul_like_op(
188+
MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
189+
)
190+
)
183191

184192

185193
def batch_matmul(
@@ -188,8 +196,10 @@ def batch_matmul(
188196
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
189197
cast: Optional[Union[TypeFn, Attribute]] = None,
190198
):
191-
return create_op(
192-
BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
199+
return _get_op_result_or_op_results(
200+
_create_matmul_like_op(
201+
BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
202+
)
193203
)
194204

195205

@@ -199,6 +209,72 @@ def contract(
199209
indexing_maps: Sequence[AffineMapAttr],
200210
cast: Optional[Union[TypeFn, Attribute]] = None,
201211
):
202-
return create_op(
203-
ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
212+
return _get_op_result_or_op_results(
213+
_create_matmul_like_op(
214+
ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
215+
)
216+
)
217+
218+
219+
def pack(
220+
source,
221+
dest,
222+
inner_dims_pos,
223+
inner_tiles,
224+
*,
225+
padding_value=None,
226+
outer_dims_perm=None,
227+
loc=None,
228+
ip=None,
229+
) -> ir.Value:
230+
(
231+
dynamic_inner_tiles,
232+
# packed here means %1:2 packing (results packing)
233+
_inner_tiles,
234+
static_inner_tiles,
235+
) = _dispatch_mixed_values(inner_tiles)
236+
237+
return _get_op_result_or_op_results(
238+
PackOp(
239+
source=source,
240+
dest=dest,
241+
inner_dims_pos=inner_dims_pos,
242+
inner_tiles=dynamic_inner_tiles,
243+
static_inner_tiles=static_inner_tiles,
244+
padding_value=padding_value,
245+
outer_dims_perm=outer_dims_perm,
246+
loc=loc,
247+
ip=ip,
248+
)
249+
)
250+
251+
252+
def unpack(
253+
source,
254+
dest,
255+
inner_dims_pos,
256+
inner_tiles,
257+
*,
258+
outer_dims_perm=None,
259+
loc=None,
260+
ip=None,
261+
) -> ir.Value:
262+
(
263+
dynamic_inner_tiles,
264+
# packed here means %1:2 packing (results packing)
265+
_inner_tiles,
266+
static_inner_tiles,
267+
) = _dispatch_mixed_values(inner_tiles)
268+
269+
return _get_op_result_or_op_results(
270+
UnPackOp(
271+
source=source,
272+
dest=dest,
273+
inner_dims_pos=inner_dims_pos,
274+
inner_tiles=dynamic_inner_tiles,
275+
static_inner_tiles=static_inner_tiles,
276+
outer_dims_perm=outer_dims_perm,
277+
loc=loc,
278+
ip=ip,
279+
)
204280
)

mlir/test/python/dialects/linalg/ops.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,3 +566,43 @@ def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
566566
)
567567

568568
print(module)
569+
570+
571+
# CHECK-LABEL: TEST: testPackUnPackOp
572+
@run
573+
def testPackUnPackOp():
574+
with Context(), Location.unknown():
575+
module = Module.create()
576+
f32 = F32Type.get()
577+
with InsertionPoint(module.body):
578+
579+
@func.FuncOp.from_py_func(
580+
RankedTensorType.get((128, 128), f32),
581+
RankedTensorType.get((16, 16, 8, 8), f32),
582+
)
583+
def tensor_pack(src, dst):
584+
packed = linalg.pack(
585+
src,
586+
dst,
587+
inner_dims_pos=[1, 0],
588+
inner_tiles=[8, 8],
589+
padding_value=arith.constant(f32, 0.0),
590+
)
591+
592+
unpacked = linalg.unpack(
593+
packed,
594+
src,
595+
inner_dims_pos=[0, 1],
596+
inner_tiles=[8, 8],
597+
)
598+
599+
return unpacked
600+
601+
# CHECK-LABEL: func.func @tensor_pack(
602+
# CHECK-SAME: %[[VAL_0:.*]]: tensor<128x128xf32>, %[[VAL_1:.*]]: tensor<16x16x8x8xf32>) -> tensor<128x128xf32> {
603+
# CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
604+
# CHECK: %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %[[VAL_1]] : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
605+
# CHECK: %[[VAL_4:.*]] = linalg.unpack %[[VAL_3]] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %[[VAL_0]] : tensor<16x16x8x8xf32> -> tensor<128x128xf32>
606+
# CHECK: return %[[VAL_4]] : tensor<128x128xf32>
607+
# CHECK: }
608+
print(module)

0 commit comments

Comments
 (0)