Skip to content

Commit 5023699

Browse files
committed
[MLIR][Linalg] Introduce Python API for linalg.batch_matmul Ops.
As linalg.batch_matmul has moved into tablegen from OpDSL, its derived python wrapper no longer exist.This patch adds the required python wrapper. Also refactors the BatchmatmulOp printer to make it consistent with its parser.
1 parent 2b71df5 commit 5023699

File tree

5 files changed

+133
-11
lines changed

5 files changed

+133
-11
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,10 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
858858
let arguments = (ins
859859
Variadic<AnyType>:$inputs,
860860
Variadic<AnyShaped>:$outputs,
861-
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
861+
DefaultValuedOptionalAttr<
862+
AffineMapArrayAttr,
863+
"BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext())"
864+
>:$indexing_maps
862865
);
863866
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
864867
let regions = (region AnyRegion:$region);

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4004,11 +4004,6 @@ ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
40044004
}
40054005

40064006
void BatchMatmulOp::print(OpAsmPrinter &p) {
4007-
SmallVector<StringRef, 3> elidedAttrs = {
4008-
"operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
4009-
::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4010-
elidedAttrs);
4011-
40124007
SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
40134008
BatchMatmulOp::getDefaultIndexingMaps(getContext()),
40144009
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
@@ -4018,6 +4013,11 @@ void BatchMatmulOp::print(OpAsmPrinter &p) {
40184013
[&](Attribute attr) { p.printAttribute(attr); });
40194014
p << "]";
40204015
}
4016+
4017+
SmallVector<StringRef, 3> elidedAttrs = {
4018+
"operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
4019+
::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4020+
elidedAttrs);
40214021
}
40224022

40234023
/// Verify the user defined indexing maps.

mlir/python/mlir/dialects/linalg/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,23 @@ def contract(
193193
)
194194
fill_builtin_region(op.operation)
195195
return op
196+
197+
def batch_matmul(
198+
*ins: Union[Operation, OpView, Value],
199+
outs: Sequence[Union[Operation, OpView, Value]],
200+
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
201+
):
202+
ins = [_get_op_result_or_value(input) for input in ins]
203+
if len(outs) > 1:
204+
raise ValueError(f"{outs=} must have length 1.")
205+
init = _get_op_result_or_value(outs[0])
206+
result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
207+
208+
op = BatchMatmulOp(
209+
result_tensors=result_types,
210+
inputs=ins,
211+
outputs=[init],
212+
indexing_maps=indexing_maps,
213+
)
214+
fill_builtin_region(op.operation)
215+
return op

mlir/test/Dialect/Linalg/named-ops.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1497,7 +1497,7 @@ func.func @matmul_transpose_b(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %a
14971497
// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>,
14981498
// CHECK-SAME: %[[VAL_1:.*]]: memref<2x5x7xf32>,
14991499
// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) {
1500-
// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
1500+
// CHECK: linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
15011501
// CHECK: return
15021502
// CHECK: }
15031503
func.func @batch_matmul_bcast_k_to_fill_missing_dims_A(%arg0: memref<5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
@@ -1520,7 +1520,7 @@ func.func @batch_matmul_bcast_k_to_fill_missing_dims_A(%arg0: memref<5xf32>, %ar
15201520
// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>,
15211521
// CHECK-SAME: %[[VAL_1:.*]]: memref<2x5x7xf32>,
15221522
// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) {
1523-
// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
1523+
// CHECK: linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
15241524
// CHECK: return
15251525
// CHECK: }
15261526
func.func @batch_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
@@ -1543,7 +1543,7 @@ func.func @batch_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1: memref<
15431543
// CHECK-SAME: %[[VAL_0:.*]]: memref<2x3x5xf32>,
15441544
// CHECK-SAME: %[[VAL_1:.*]]: memref<5xf32>,
15451545
// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) {
1546-
// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
1546+
// CHECK: linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
15471547
// CHECK: return
15481548
// CHECK: }
15491549
func.func @batch_matmul_bcast_batch_and_n_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<2x3x7xf32>) {
@@ -1566,7 +1566,7 @@ func.func @batch_matmul_bcast_batch_and_n_dim_B(%arg0: memref<2x3x5xf32>, %arg1:
15661566
// CHECK-SAME: %[[VAL_0:.*]]: memref<2x3x5xf32>,
15671567
// CHECK-SAME: %[[VAL_1:.*]]: memref<5x7xf32>,
15681568
// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) {
1569-
// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
1569+
// CHECK: linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
15701570
// CHECK: return
15711571
// CHECK: }
15721572

@@ -1622,7 +1622,7 @@ func.func @batch_matmul_explicit_transpose_B(%arg0: memref<2x3x5xf32>, %arg1: me
16221622
// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>,
16231623
// CHECK-SAME: %[[VAL_1:.*]]: memref<2x7x5xf32>,
16241624
// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) {
1625-
// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
1625+
// CHECK: linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
16261626
// CHECK: return
16271627
// CHECK: }
16281628
func.func @batch_matmul_bcast_A_transpose_B(%arg0: memref<3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) {

mlir/test/python/dialects/linalg/ops.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,3 +466,102 @@ def matmul_as_contract_op(
466466
)
467467

468468
print(module)
469+
470+
# CHECK-LABEL: TEST: testBatchMatmulOp
471+
@run
472+
def testBatchMatmulOp():
473+
with Context(), Location.unknown():
474+
module = Module.create()
475+
f32 = F32Type.get()
476+
with InsertionPoint(module.body):
477+
a_shape = (2, 4, 8)
478+
b_shape = (2, 8, 12)
479+
b_transposed_shape = (2, 12, 8)
480+
c_shape = (2, 4, 12)
481+
482+
dimBatch = ir.AffineDimExpr.get(0)
483+
dimM = ir.AffineDimExpr.get(1)
484+
dimN = ir.AffineDimExpr.get(2)
485+
dimK = ir.AffineDimExpr.get(3)
486+
487+
# CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
488+
# CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
489+
# CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
490+
491+
a_map = ir.AffineMap.get(4, 0, [dimBatch, dimM, dimK])
492+
b_transposed_map = ir.AffineMap.get(4, 0, [dimBatch, dimN, dimK])
493+
c_map = ir.AffineMap.get(4, 0, [dimBatch, dimM, dimN])
494+
495+
# CHECK: func.func @batch_matmul_op(
496+
@func.FuncOp.from_py_func(
497+
# CHECK-SAME: %[[A:.*]]: tensor<2x4x8xf32>,
498+
RankedTensorType.get(a_shape, f32),
499+
# CHECK-SAME: %[[Amem:.*]]: memref<2x4x8xf32>,
500+
MemRefType.get(a_shape, f32),
501+
# CHECK-SAME: %[[B:.*]]: tensor<2x8x12xf32>,
502+
RankedTensorType.get(b_shape, f32),
503+
# CHECK-SAME: %[[Bmem:.*]]: memref<2x8x12xf32>,
504+
MemRefType.get(b_shape, f32),
505+
# CHECK-SAME: %[[BTrans:.*]]: tensor<2x12x8xf32>,
506+
RankedTensorType.get(b_transposed_shape, f32),
507+
# CHECK-SAME: %[[BTransmem:.*]]: memref<2x12x8xf32>,
508+
MemRefType.get(b_transposed_shape, f32),
509+
# CHECK-SAME: %[[C:.*]]: tensor<2x4x12xf32>,
510+
RankedTensorType.get(c_shape, f32),
511+
# CHECK-SAME: %[[Cmem:.*]]: memref<2x4x12xf32>)
512+
MemRefType.get(c_shape, f32),
513+
)
514+
def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
515+
# CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x4x8xf32>, tensor<2x8x12xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
516+
res = linalg.BatchMatmulOp(
517+
result_tensors=(C.type,),
518+
inputs=(A, B),
519+
outputs=(C,),
520+
)
521+
linalg.fill_builtin_region(res.operation)
522+
# CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x4x8xf32>, tensor<2x8x12xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
523+
res = linalg.batch_matmul(A, B, outs=(C,))
524+
525+
# CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<2x4x8xf32>, tensor<2x12x8xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
526+
res = linalg.BatchMatmulOp(
527+
result_tensors=(C.type,),
528+
inputs=(A, Btransposed),
529+
outputs=(C,),
530+
indexing_maps=[a_map, b_transposed_map, c_map],
531+
)
532+
linalg.fill_builtin_region(res.operation)
533+
# CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<2x4x8xf32>, tensor<2x12x8xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
534+
res = linalg.batch_matmul(
535+
A,
536+
Btransposed,
537+
outs=(C,),
538+
indexing_maps=[a_map, b_transposed_map, c_map],
539+
)
540+
541+
# CHECK: linalg.batch_matmul ins(%[[Amem]], %[[Bmem]] : memref<2x4x8xf32>, memref<2x8x12xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
542+
res = linalg.BatchMatmulOp(
543+
result_tensors=[],
544+
inputs=(Amem, Bmem),
545+
outputs=(Cmem,),
546+
)
547+
linalg.fill_builtin_region(res.operation)
548+
# CHECK: linalg.batch_matmul ins(%[[Amem]], %[[Bmem]] : memref<2x4x8xf32>, memref<2x8x12xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
549+
linalg.batch_matmul(Amem, Bmem, outs=(Cmem,))
550+
551+
# CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<2x4x8xf32>, memref<2x12x8xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
552+
res = linalg.BatchMatmulOp(
553+
result_tensors=[],
554+
inputs=(Amem, Btransposedmem),
555+
outputs=(Cmem,),
556+
indexing_maps=[a_map, b_transposed_map, c_map],
557+
)
558+
linalg.fill_builtin_region(res.operation)
559+
# CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<2x4x8xf32>, memref<2x12x8xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
560+
linalg.batch_matmul(
561+
Amem,
562+
Btransposedmem,
563+
outs=(Cmem,),
564+
indexing_maps=[a_map, b_transposed_map, c_map],
565+
)
566+
567+
print(module)

0 commit comments

Comments
 (0)