Skip to content

Commit a642f31

Browse files
committed
[mlir][python] fix linalg.pack
1 parent 3430bc3 commit a642f31

File tree

3 files changed

+60
-1
lines changed

3 files changed

+60
-1
lines changed

mlir/python/mlir/dialects/LinalgOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@
1111

1212
include "mlir/Dialect/Linalg/IR/LinalgOps.td"
1313
include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td"
14+
include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td"
1415

1516
#endif

mlir/python/mlir/dialects/linalg/__init__.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,10 @@
5858
from .opdsl.ops.core_named_ops import *
5959

6060
from ...ir import *
61-
from .._ods_common import get_op_result_or_value as _get_op_result_or_value
61+
from .._ods_common import (
62+
get_op_result_or_value as _get_op_result_or_value,
63+
_dispatch_mixed_values,
64+
)
6265
from ...extras.meta import region_op
6366

6467

@@ -193,3 +196,35 @@ def contract(
193196
)
194197
fill_builtin_region(op.operation)
195198
return op
199+
200+
201+
def pack(
202+
source,
203+
dest,
204+
inner_dims_pos,
205+
inner_tiles,
206+
*,
207+
padding_value=None,
208+
outer_dims_perm=None,
209+
loc=None,
210+
ip=None,
211+
) -> ir.Value:
212+
213+
(
214+
dynamic_inner_tiles,
215+
# packed here means %1:2 packing (results packing)
216+
_inner_tiles,
217+
static_inner_tiles,
218+
) = _dispatch_mixed_values(inner_tiles)
219+
220+
return PackOp(
221+
source=source,
222+
dest=dest,
223+
inner_dims_pos=inner_dims_pos,
224+
inner_tiles=dynamic_inner_tiles,
225+
static_inner_tiles=static_inner_tiles,
226+
padding_value=padding_value,
227+
outer_dims_perm=outer_dims_perm,
228+
loc=loc,
229+
ip=ip,
230+
).result

mlir/test/python/dialects/linalg/ops.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,3 +466,26 @@ def matmul_as_contract_op(
466466
)
467467

468468
print(module)
469+
470+
471+
# CHECK-LABEL: TEST: testPackOp
472+
@run
473+
def testPackOp():
474+
with Context(), Location.unknown():
475+
module = Module.create()
476+
f32 = F32Type.get()
477+
478+
@func.FuncOp.from_py_func(
479+
RankedTensorType.get((129, 47, 16, 16), f32),
480+
RankedTensorType.get((17, 2, 16, 16, 32, 8), f32)[
481+
RankedTensorType.get((17, 2, 16, 16, 32, 8), f32)
482+
],
483+
)
484+
def tensor_pack(src, dst):
485+
return linalg.pack(
486+
src,
487+
dst,
488+
inner_dims_pos=[1, 0],
489+
inner_tiles=[32, 8],
490+
padding_value=arith.constant(0.0),
491+
)

0 commit comments

Comments
 (0)