Skip to content

Commit 6e772de

Browse files
committed
fixup
1 parent d18ca56 commit 6e772de

File tree

3 files changed

+65
-16
lines changed

3 files changed

+65
-16
lines changed

mlir/docs/Dialects/Linalg/_index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,3 +695,4 @@ the same IR.
695695
## Operations
696696

697697
[include "Dialects/LinalgOps.md"]
698+
[include "Dialects/LinalgRelayoutOps.td"]

mlir/python/mlir/dialects/linalg/__init__.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def __init__(
153153
generic = region_op(GenericOp_, terminator=YieldOp)
154154

155155

156-
def create_op(
156+
def _create_matmul_like_op(
157157
op_type,
158158
*ins: Union[Operation, OpView, Value],
159159
outs: Sequence[Union[Operation, OpView, Value]],
@@ -183,7 +183,11 @@ def matmul(
183183
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
184184
cast: Optional[Union[TypeFn, Attribute]] = None,
185185
):
186-
return create_op(MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast)
186+
return _get_op_result_or_op_results(
187+
_create_matmul_like_op(
188+
MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
189+
)
190+
)
187191

188192

189193
def batch_matmul(
@@ -192,8 +196,10 @@ def batch_matmul(
192196
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
193197
cast: Optional[Union[TypeFn, Attribute]] = None,
194198
):
195-
return create_op(
196-
BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
199+
return _get_op_result_or_op_results(
200+
_create_matmul_like_op(
201+
BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
202+
)
197203
)
198204

199205

@@ -203,8 +209,10 @@ def contract(
203209
indexing_maps: Sequence[AffineMapAttr],
204210
cast: Optional[Union[TypeFn, Attribute]] = None,
205211
):
206-
return create_op(
207-
ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
212+
return _get_op_result_or_op_results(
213+
_create_matmul_like_op(
214+
ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
215+
)
208216
)
209217

210218

@@ -239,3 +247,34 @@ def pack(
239247
ip=ip,
240248
)
241249
)
250+
251+
252+
def unpack(
253+
source,
254+
dest,
255+
inner_dims_pos,
256+
inner_tiles,
257+
*,
258+
outer_dims_perm=None,
259+
loc=None,
260+
ip=None,
261+
) -> ir.Value:
262+
(
263+
dynamic_inner_tiles,
264+
# packed here means %1:2 packing (results packing)
265+
_inner_tiles,
266+
static_inner_tiles,
267+
) = _dispatch_mixed_values(inner_tiles)
268+
269+
return _get_op_result_or_op_results(
270+
UnPackOp(
271+
source=source,
272+
dest=dest,
273+
inner_dims_pos=inner_dims_pos,
274+
inner_tiles=dynamic_inner_tiles,
275+
static_inner_tiles=static_inner_tiles,
276+
outer_dims_perm=outer_dims_perm,
277+
loc=loc,
278+
ip=ip,
279+
)
280+
)

mlir/test/python/dialects/linalg/ops.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -568,32 +568,41 @@ def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
568568
print(module)
569569

570570

571-
# CHECK-LABEL: TEST: testPackOp
571+
# CHECK-LABEL: TEST: testPackUnPackOp
572572
@run
573-
def testPackOp():
573+
def testPackUnPackOp():
574574
with Context(), Location.unknown():
575575
module = Module.create()
576576
f32 = F32Type.get()
577577
with InsertionPoint(module.body):
578578

579579
@func.FuncOp.from_py_func(
580-
RankedTensorType.get((129, 47, 16, 16), f32),
581-
RankedTensorType.get((17, 2, 16, 16, 32, 8), f32),
580+
RankedTensorType.get((128, 128), f32),
581+
RankedTensorType.get((16, 16, 8, 8), f32),
582582
)
583583
def tensor_pack(src, dst):
584-
return linalg.pack(
584+
packed = linalg.pack(
585585
src,
586586
dst,
587587
inner_dims_pos=[1, 0],
588-
inner_tiles=[32, 8],
588+
inner_tiles=[8, 8],
589589
padding_value=arith.constant(f32, 0.0),
590590
)
591591

592+
unpacked = linalg.unpack(
593+
packed,
594+
src,
595+
inner_dims_pos=[0, 1],
596+
inner_tiles=[8, 8],
597+
)
598+
599+
return unpacked
600+
592601
# CHECK-LABEL: func.func @tensor_pack(
593-
# CHECK-SAME: %[[VAL_0:.*]]: tensor<129x47x16x16xf32>,
594-
# CHECK-SAME: %[[VAL_1:.*]]: tensor<17x2x16x16x32x8xf32>) -> tensor<17x2x16x16x32x8xf32> {
602+
# CHECK-SAME: %[[VAL_0:.*]]: tensor<128x128xf32>, %[[VAL_1:.*]]: tensor<16x16x8x8xf32>) -> tensor<128x128xf32> {
595603
# CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
596-
# CHECK: %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %[[VAL_1]] : tensor<129x47x16x16xf32> -> tensor<17x2x16x16x32x8xf32>
597-
# CHECK: return %[[VAL_3]] : tensor<17x2x16x16x32x8xf32>
604+
# CHECK: %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %[[VAL_1]] : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
605+
# CHECK: %[[VAL_4:.*]] = linalg.unpack %[[VAL_3]] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %[[VAL_0]] : tensor<16x16x8x8xf32> -> tensor<128x128xf32>
606+
# CHECK: return %[[VAL_4]] : tensor<128x128xf32>
598607
# CHECK: }
599608
print(module)

0 commit comments

Comments
 (0)