Skip to content

Commit 6f78ea7

Browse files
committed
[mlir][python] fix linalg.pack
1 parent 3430bc3 commit 6f78ea7

File tree

3 files changed

+67
-1
lines changed

3 files changed

+67
-1
lines changed

mlir/python/mlir/dialects/LinalgOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@
1111

1212
include "mlir/Dialect/Linalg/IR/LinalgOps.td"
1313
include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td"
14+
include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td"
1415

1516
#endif

mlir/python/mlir/dialects/linalg/__init__.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,10 @@
5858
from .opdsl.ops.core_named_ops import *
5959

6060
from ...ir import *
61-
from .._ods_common import get_op_result_or_value as _get_op_result_or_value
61+
from .._ods_common import (
62+
get_op_result_or_value as _get_op_result_or_value,
63+
_dispatch_mixed_values,
64+
)
6265
from ...extras.meta import region_op
6366

6467

@@ -193,3 +196,34 @@ def contract(
193196
)
194197
fill_builtin_region(op.operation)
195198
return op
199+
200+
201+
def pack(
202+
source,
203+
dest,
204+
inner_dims_pos,
205+
inner_tiles,
206+
*,
207+
padding_value=None,
208+
outer_dims_perm=None,
209+
loc=None,
210+
ip=None,
211+
) -> ir.Value:
212+
(
213+
dynamic_inner_tiles,
214+
# packed here means %1:2 packing (results packing)
215+
_inner_tiles,
216+
static_inner_tiles,
217+
) = _dispatch_mixed_values(inner_tiles)
218+
219+
return PackOp(
220+
source=source,
221+
dest=dest,
222+
inner_dims_pos=inner_dims_pos,
223+
inner_tiles=dynamic_inner_tiles,
224+
static_inner_tiles=static_inner_tiles,
225+
padding_value=padding_value,
226+
outer_dims_perm=outer_dims_perm,
227+
loc=loc,
228+
ip=ip,
229+
).result

mlir/test/python/dialects/linalg/ops.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,3 +466,34 @@ def matmul_as_contract_op(
466466
)
467467

468468
print(module)
469+
470+
471+
# CHECK-LABEL: TEST: testPackOp
472+
@run
473+
def testPackOp():
474+
with Context(), Location.unknown():
475+
module = Module.create()
476+
f32 = F32Type.get()
477+
with InsertionPoint(module.body):
478+
479+
@func.FuncOp.from_py_func(
480+
RankedTensorType.get((129, 47, 16, 16), f32),
481+
RankedTensorType.get((17, 2, 16, 16, 32, 8), f32),
482+
)
483+
def tensor_pack(src, dst):
484+
return linalg.pack(
485+
src,
486+
dst,
487+
inner_dims_pos=[1, 0],
488+
inner_tiles=[32, 8],
489+
padding_value=arith.constant(f32, 0.0),
490+
)
491+
492+
# CHECK-LABEL: func.func @tensor_pack(
493+
# CHECK-SAME: %[[VAL_0:.*]]: tensor<129x47x16x16xf32>,
494+
# CHECK-SAME: %[[VAL_1:.*]]: tensor<17x2x16x16x32x8xf32>) -> tensor<17x2x16x16x32x8xf32> {
495+
# CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
496+
# CHECK: %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %[[VAL_1]] : tensor<129x47x16x16xf32> -> tensor<17x2x16x16x32x8xf32>
497+
# CHECK: return %[[VAL_3]] : tensor<17x2x16x16x32x8xf32>
498+
# CHECK: }
499+
print(module)

0 commit comments

Comments
 (0)