Skip to content

Commit 760ec2c

Browse files
authored
[MLIR][Linalg] Introduce Python API for linalg.batch_matmul Ops. (#127614)
As linalg.batch_matmul has been moved into tablegen from OpDSL, its derived python wrapper no longer exist.This patch adds the required python wrapper. Also refactors the BatchmatmulOp printer to make it consistent with its parser.
1 parent 01d0793 commit 760ec2c

File tree

5 files changed

+153
-32
lines changed

5 files changed

+153
-32
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,11 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
858858
let arguments = (ins
859859
Variadic<AnyType>:$inputs,
860860
Variadic<AnyShaped>:$outputs,
861-
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
861+
DefaultValuedOptionalAttr<
862+
AffineMapArrayAttr,
863+
"BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext())"
864+
>:$indexing_maps,
865+
DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
862866
);
863867
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
864868
let regions = (region AnyRegion:$region);
@@ -884,9 +888,10 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
884888
}]>,
885889
OpBuilder<
886890
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
887-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
891+
"Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
888892
[{
889893
$_state.addOperands(operands);
894+
$_state.addAttribute("cast", cast);
890895
$_state.addAttributes(attributes);
891896
$_state.addTypes(resultTensorTypes);
892897
(void)$_state.addRegion(),

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3951,11 +3951,18 @@ void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
39513951
RegionBuilderHelper helper(b, block);
39523952
SmallVector<Value> yields;
39533953

3954+
TypeFn castVal = TypeFn::cast_signed;
3955+
auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3956+
return attr.getName() == "cast";
3957+
});
3958+
if (castIter != attrs.end()) {
3959+
if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3960+
castVal = attr.getValue();
3961+
}
3962+
39543963
auto toType = block.getArgument(2).getType();
3955-
Value castValA =
3956-
helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
3957-
Value castValB =
3958-
helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
3964+
Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
3965+
Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
39593966
Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
39603967
Value addVal =
39613968
helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
@@ -4004,11 +4011,6 @@ ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
40044011
}
40054012

40064013
void BatchMatmulOp::print(OpAsmPrinter &p) {
4007-
SmallVector<StringRef, 3> elidedAttrs = {
4008-
"operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
4009-
::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4010-
elidedAttrs);
4011-
40124014
SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
40134015
BatchMatmulOp::getDefaultIndexingMaps(getContext()),
40144016
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
@@ -4018,6 +4020,11 @@ void BatchMatmulOp::print(OpAsmPrinter &p) {
40184020
[&](Attribute attr) { p.printAttribute(attr); });
40194021
p << "]";
40204022
}
4023+
4024+
SmallVector<StringRef, 3> elidedAttrs = {
4025+
"operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
4026+
::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4027+
elidedAttrs);
40214028
}
40224029

40234030
/// Verify the user defined indexing maps.

mlir/python/mlir/dialects/linalg/__init__.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ def __init__(
149149
generic = region_op(GenericOp_, terminator=YieldOp)
150150

151151

152-
def matmul(
152+
def create_op(
153+
op_type,
153154
*ins: Union[Operation, OpView, Value],
154155
outs: Sequence[Union[Operation, OpView, Value]],
155156
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
@@ -161,7 +162,7 @@ def matmul(
161162
init = _get_op_result_or_value(outs[0])
162163
result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
163164

164-
op = MatmulOp(
165+
op = op_type(
165166
result_tensors=result_types,
166167
inputs=ins,
167168
outputs=[init],
@@ -172,24 +173,32 @@ def matmul(
172173
return op
173174

174175

176+
def matmul(
177+
*ins: Union[Operation, OpView, Value],
178+
outs: Sequence[Union[Operation, OpView, Value]],
179+
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
180+
cast: Optional[Union[TypeFn, Attribute]] = None,
181+
):
182+
return create_op(MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast)
183+
184+
185+
def batch_matmul(
186+
*ins: Union[Operation, OpView, Value],
187+
outs: Sequence[Union[Operation, OpView, Value]],
188+
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
189+
cast: Optional[Union[TypeFn, Attribute]] = None,
190+
):
191+
return create_op(
192+
BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
193+
)
194+
195+
175196
def contract(
176197
*ins: Union[Operation, OpView, Value],
177198
outs: Sequence[Union[Operation, OpView, Value]],
178199
indexing_maps: Sequence[AffineMapAttr],
179200
cast: Optional[Union[TypeFn, Attribute]] = None,
180201
):
181-
ins = [_get_op_result_or_value(input) for input in ins]
182-
if len(outs) > 1:
183-
raise ValueError(f"{outs=} must have length 1.")
184-
init = _get_op_result_or_value(outs[0])
185-
result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
186-
187-
op = ContractOp(
188-
result_tensors=result_types,
189-
inputs=ins,
190-
outputs=[init],
191-
indexing_maps=indexing_maps,
192-
cast=cast,
202+
return create_op(
203+
ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
193204
)
194-
fill_builtin_region(op.operation)
195-
return op

mlir/test/Dialect/Linalg/named-ops.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1497,7 +1497,7 @@ func.func @matmul_transpose_b(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %a
14971497
// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>,
14981498
// CHECK-SAME: %[[VAL_1:.*]]: memref<2x5x7xf32>,
14991499
// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) {
1500-
// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
1500+
// CHECK: linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
15011501
// CHECK: return
15021502
// CHECK: }
15031503
func.func @batch_matmul_bcast_k_to_fill_missing_dims_A(%arg0: memref<5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
@@ -1520,7 +1520,7 @@ func.func @batch_matmul_bcast_k_to_fill_missing_dims_A(%arg0: memref<5xf32>, %ar
15201520
// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>,
15211521
// CHECK-SAME: %[[VAL_1:.*]]: memref<2x5x7xf32>,
15221522
// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) {
1523-
// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
1523+
// CHECK: linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
15241524
// CHECK: return
15251525
// CHECK: }
15261526
func.func @batch_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
@@ -1543,7 +1543,7 @@ func.func @batch_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1: memref<
15431543
// CHECK-SAME: %[[VAL_0:.*]]: memref<2x3x5xf32>,
15441544
// CHECK-SAME: %[[VAL_1:.*]]: memref<5xf32>,
15451545
// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) {
1546-
// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
1546+
// CHECK: linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
15471547
// CHECK: return
15481548
// CHECK: }
15491549
func.func @batch_matmul_bcast_batch_and_n_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<2x3x7xf32>) {
@@ -1566,7 +1566,7 @@ func.func @batch_matmul_bcast_batch_and_n_dim_B(%arg0: memref<2x3x5xf32>, %arg1:
15661566
// CHECK-SAME: %[[VAL_0:.*]]: memref<2x3x5xf32>,
15671567
// CHECK-SAME: %[[VAL_1:.*]]: memref<5x7xf32>,
15681568
// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) {
1569-
// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
1569+
// CHECK: linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
15701570
// CHECK: return
15711571
// CHECK: }
15721572

@@ -1622,7 +1622,7 @@ func.func @batch_matmul_explicit_transpose_B(%arg0: memref<2x3x5xf32>, %arg1: me
16221622
// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>,
16231623
// CHECK-SAME: %[[VAL_1:.*]]: memref<2x7x5xf32>,
16241624
// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) {
1625-
// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
1625+
// CHECK: linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
16261626
// CHECK: return
16271627
// CHECK: }
16281628
func.func @batch_matmul_bcast_A_transpose_B(%arg0: memref<3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) {

mlir/test/python/dialects/linalg/ops.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,3 +466,103 @@ def matmul_as_contract_op(
466466
)
467467

468468
print(module)
469+
470+
471+
# CHECK-LABEL: TEST: testBatchMatmulOp
472+
@run
473+
def testBatchMatmulOp():
474+
with Context(), Location.unknown():
475+
module = Module.create()
476+
f32 = F32Type.get()
477+
with InsertionPoint(module.body):
478+
a_shape = (2, 4, 8)
479+
b_shape = (2, 8, 12)
480+
b_transposed_shape = (2, 12, 8)
481+
c_shape = (2, 4, 12)
482+
483+
dimBatch = ir.AffineDimExpr.get(0)
484+
dimM = ir.AffineDimExpr.get(1)
485+
dimN = ir.AffineDimExpr.get(2)
486+
dimK = ir.AffineDimExpr.get(3)
487+
488+
# CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
489+
# CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
490+
# CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
491+
492+
a_map = ir.AffineMap.get(4, 0, [dimBatch, dimM, dimK])
493+
b_transposed_map = ir.AffineMap.get(4, 0, [dimBatch, dimN, dimK])
494+
c_map = ir.AffineMap.get(4, 0, [dimBatch, dimM, dimN])
495+
496+
# CHECK: func.func @batch_matmul_op(
497+
@func.FuncOp.from_py_func(
498+
# CHECK-SAME: %[[A:.*]]: tensor<2x4x8xf32>,
499+
RankedTensorType.get(a_shape, f32),
500+
# CHECK-SAME: %[[Amem:.*]]: memref<2x4x8xf32>,
501+
MemRefType.get(a_shape, f32),
502+
# CHECK-SAME: %[[B:.*]]: tensor<2x8x12xf32>,
503+
RankedTensorType.get(b_shape, f32),
504+
# CHECK-SAME: %[[Bmem:.*]]: memref<2x8x12xf32>,
505+
MemRefType.get(b_shape, f32),
506+
# CHECK-SAME: %[[BTrans:.*]]: tensor<2x12x8xf32>,
507+
RankedTensorType.get(b_transposed_shape, f32),
508+
# CHECK-SAME: %[[BTransmem:.*]]: memref<2x12x8xf32>,
509+
MemRefType.get(b_transposed_shape, f32),
510+
# CHECK-SAME: %[[C:.*]]: tensor<2x4x12xf32>,
511+
RankedTensorType.get(c_shape, f32),
512+
# CHECK-SAME: %[[Cmem:.*]]: memref<2x4x12xf32>)
513+
MemRefType.get(c_shape, f32),
514+
)
515+
def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
516+
# CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x4x8xf32>, tensor<2x8x12xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
517+
res = linalg.BatchMatmulOp(
518+
result_tensors=(C.type,),
519+
inputs=(A, B),
520+
outputs=(C,),
521+
)
522+
linalg.fill_builtin_region(res.operation)
523+
# CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x4x8xf32>, tensor<2x8x12xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
524+
res = linalg.batch_matmul(A, B, outs=(C,))
525+
526+
# CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<2x4x8xf32>, tensor<2x12x8xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
527+
res = linalg.BatchMatmulOp(
528+
result_tensors=(C.type,),
529+
inputs=(A, Btransposed),
530+
outputs=(C,),
531+
indexing_maps=[a_map, b_transposed_map, c_map],
532+
)
533+
linalg.fill_builtin_region(res.operation)
534+
# CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<2x4x8xf32>, tensor<2x12x8xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
535+
res = linalg.batch_matmul(
536+
A,
537+
Btransposed,
538+
outs=(C,),
539+
indexing_maps=[a_map, b_transposed_map, c_map],
540+
)
541+
542+
# CHECK: linalg.batch_matmul ins(%[[Amem]], %[[Bmem]] : memref<2x4x8xf32>, memref<2x8x12xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
543+
res = linalg.BatchMatmulOp(
544+
result_tensors=[],
545+
inputs=(Amem, Bmem),
546+
outputs=(Cmem,),
547+
)
548+
linalg.fill_builtin_region(res.operation)
549+
# CHECK: linalg.batch_matmul ins(%[[Amem]], %[[Bmem]] : memref<2x4x8xf32>, memref<2x8x12xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
550+
linalg.batch_matmul(Amem, Bmem, outs=(Cmem,))
551+
552+
# CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<2x4x8xf32>, memref<2x12x8xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
553+
res = linalg.BatchMatmulOp(
554+
result_tensors=[],
555+
inputs=(Amem, Btransposedmem),
556+
outputs=(Cmem,),
557+
indexing_maps=[a_map, b_transposed_map, c_map],
558+
)
559+
linalg.fill_builtin_region(res.operation)
560+
# CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<2x4x8xf32>, memref<2x12x8xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
561+
linalg.batch_matmul(
562+
Amem,
563+
Btransposedmem,
564+
outs=(Cmem,),
565+
indexing_maps=[a_map, b_transposed_map, c_map],
566+
)
567+
568+
print(module)

0 commit comments

Comments
 (0)