Skip to content

Commit d18ca56

Browse files
committed
[mlir][python] fix linalg.pack
1 parent 3e61c1a commit d18ca56

File tree

3 files changed

+70
-1
lines changed

3 files changed

+70
-1
lines changed

mlir/python/mlir/dialects/LinalgOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@
1111

1212
include "mlir/Dialect/Linalg/IR/LinalgOps.td"
1313
include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td"
14+
include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td"
1415

1516
#endif

mlir/python/mlir/dialects/linalg/__init__.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,11 @@
5858
from .opdsl.ops.core_named_ops import *
5959

6060
from ...ir import *
61-
from .._ods_common import get_op_result_or_value as _get_op_result_or_value
61+
from .._ods_common import (
62+
get_op_result_or_value as _get_op_result_or_value,
63+
get_op_result_or_op_results as _get_op_result_or_op_results,
64+
_dispatch_mixed_values,
65+
)
6266
from ...extras.meta import region_op
6367

6468

@@ -202,3 +206,36 @@ def contract(
202206
return create_op(
203207
ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
204208
)
209+
210+
211+
def pack(
212+
source,
213+
dest,
214+
inner_dims_pos,
215+
inner_tiles,
216+
*,
217+
padding_value=None,
218+
outer_dims_perm=None,
219+
loc=None,
220+
ip=None,
221+
) -> ir.Value:
222+
(
223+
dynamic_inner_tiles,
224+
# packed here means %1:2 packing (results packing)
225+
_inner_tiles,
226+
static_inner_tiles,
227+
) = _dispatch_mixed_values(inner_tiles)
228+
229+
return _get_op_result_or_op_results(
230+
PackOp(
231+
source=source,
232+
dest=dest,
233+
inner_dims_pos=inner_dims_pos,
234+
inner_tiles=dynamic_inner_tiles,
235+
static_inner_tiles=static_inner_tiles,
236+
padding_value=padding_value,
237+
outer_dims_perm=outer_dims_perm,
238+
loc=loc,
239+
ip=ip,
240+
)
241+
)

mlir/test/python/dialects/linalg/ops.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,3 +566,34 @@ def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
566566
)
567567

568568
print(module)
569+
570+
571+
# CHECK-LABEL: TEST: testPackOp
572+
@run
573+
def testPackOp():
574+
with Context(), Location.unknown():
575+
module = Module.create()
576+
f32 = F32Type.get()
577+
with InsertionPoint(module.body):
578+
579+
@func.FuncOp.from_py_func(
580+
RankedTensorType.get((129, 47, 16, 16), f32),
581+
RankedTensorType.get((17, 2, 16, 16, 32, 8), f32),
582+
)
583+
def tensor_pack(src, dst):
584+
return linalg.pack(
585+
src,
586+
dst,
587+
inner_dims_pos=[1, 0],
588+
inner_tiles=[32, 8],
589+
padding_value=arith.constant(f32, 0.0),
590+
)
591+
592+
# CHECK-LABEL: func.func @tensor_pack(
593+
# CHECK-SAME: %[[VAL_0:.*]]: tensor<129x47x16x16xf32>,
594+
# CHECK-SAME: %[[VAL_1:.*]]: tensor<17x2x16x16x32x8xf32>) -> tensor<17x2x16x16x32x8xf32> {
595+
# CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
596+
# CHECK: %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %[[VAL_1]] : tensor<129x47x16x16xf32> -> tensor<17x2x16x16x32x8xf32>
597+
# CHECK: return %[[VAL_3]] : tensor<17x2x16x16x32x8xf32>
598+
# CHECK: }
599+
print(module)

0 commit comments

Comments
 (0)