Skip to content

[MLIR][Linalg] Introduce Python API for linalg.batch_matmul Ops. #127614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,11 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
let arguments = (ins
Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
DefaultValuedOptionalAttr<
AffineMapArrayAttr,
"BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext())"
>:$indexing_maps,
DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
);
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
let regions = (region AnyRegion:$region);
Expand All @@ -884,9 +888,10 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
}]>,
OpBuilder<
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
"Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
$_state.addOperands(operands);
$_state.addAttribute("cast", cast);
$_state.addAttributes(attributes);
$_state.addTypes(resultTensorTypes);
(void)$_state.addRegion(),
Expand Down
25 changes: 16 additions & 9 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3951,11 +3951,18 @@ void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
RegionBuilderHelper helper(b, block);
SmallVector<Value> yields;

TypeFn castVal = TypeFn::cast_signed;
auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
return attr.getName() == "cast";
});
if (castIter != attrs.end()) {
if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
castVal = attr.getValue();
}

auto toType = block.getArgument(2).getType();
Value castValA =
helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
Value castValB =
helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
Value addVal =
helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
Expand Down Expand Up @@ -4004,11 +4011,6 @@ ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
}

void BatchMatmulOp::print(OpAsmPrinter &p) {
SmallVector<StringRef, 3> elidedAttrs = {
"operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
elidedAttrs);

SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
BatchMatmulOp::getDefaultIndexingMaps(getContext()),
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
Expand All @@ -4018,6 +4020,11 @@ void BatchMatmulOp::print(OpAsmPrinter &p) {
[&](Attribute attr) { p.printAttribute(attr); });
p << "]";
}

SmallVector<StringRef, 3> elidedAttrs = {
"operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
elidedAttrs);
}

/// Verify the user defined indexing maps.
Expand Down
41 changes: 25 additions & 16 deletions mlir/python/mlir/dialects/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ def __init__(
generic = region_op(GenericOp_, terminator=YieldOp)


def matmul(
def create_op(
op_type,
*ins: Union[Operation, OpView, Value],
outs: Sequence[Union[Operation, OpView, Value]],
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
Expand All @@ -161,7 +162,7 @@ def matmul(
init = _get_op_result_or_value(outs[0])
result_types = [init.type] if isinstance(init.type, RankedTensorType) else []

op = MatmulOp(
op = op_type(
result_tensors=result_types,
inputs=ins,
outputs=[init],
Expand All @@ -172,24 +173,32 @@ def matmul(
return op


def matmul(
*ins: Union[Operation, OpView, Value],
outs: Sequence[Union[Operation, OpView, Value]],
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
cast: Optional[Union[TypeFn, Attribute]] = None,
):
return create_op(MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast)


def batch_matmul(
*ins: Union[Operation, OpView, Value],
outs: Sequence[Union[Operation, OpView, Value]],
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
cast: Optional[Union[TypeFn, Attribute]] = None,
):
return create_op(
BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
)


def contract(
*ins: Union[Operation, OpView, Value],
outs: Sequence[Union[Operation, OpView, Value]],
indexing_maps: Sequence[AffineMapAttr],
cast: Optional[Union[TypeFn, Attribute]] = None,
):
ins = [_get_op_result_or_value(input) for input in ins]
if len(outs) > 1:
raise ValueError(f"{outs=} must have length 1.")
init = _get_op_result_or_value(outs[0])
result_types = [init.type] if isinstance(init.type, RankedTensorType) else []

op = ContractOp(
result_tensors=result_types,
inputs=ins,
outputs=[init],
indexing_maps=indexing_maps,
cast=cast,
return create_op(
ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
)
fill_builtin_region(op.operation)
return op
10 changes: 5 additions & 5 deletions mlir/test/Dialect/Linalg/named-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1497,7 +1497,7 @@ func.func @matmul_transpose_b(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %a
// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<2x5x7xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) {
// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK: linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
// CHECK: return
// CHECK: }
func.func @batch_matmul_bcast_k_to_fill_missing_dims_A(%arg0: memref<5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
Expand All @@ -1520,7 +1520,7 @@ func.func @batch_matmul_bcast_k_to_fill_missing_dims_A(%arg0: memref<5xf32>, %ar
// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<2x5x7xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) {
// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK: linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
// CHECK: return
// CHECK: }
func.func @batch_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
Expand All @@ -1543,7 +1543,7 @@ func.func @batch_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1: memref<
// CHECK-SAME: %[[VAL_0:.*]]: memref<2x3x5xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<5xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) {
// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK: linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
// CHECK: return
// CHECK: }
func.func @batch_matmul_bcast_batch_and_n_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<2x3x7xf32>) {
Expand All @@ -1566,7 +1566,7 @@ func.func @batch_matmul_bcast_batch_and_n_dim_B(%arg0: memref<2x3x5xf32>, %arg1:
// CHECK-SAME: %[[VAL_0:.*]]: memref<2x3x5xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<5x7xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) {
// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK: linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
// CHECK: return
// CHECK: }

Expand Down Expand Up @@ -1622,7 +1622,7 @@ func.func @batch_matmul_explicit_transpose_B(%arg0: memref<2x3x5xf32>, %arg1: me
// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<2x7x5xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) {
// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK: linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
// CHECK: return
// CHECK: }
func.func @batch_matmul_bcast_A_transpose_B(%arg0: memref<3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) {
Expand Down
100 changes: 100 additions & 0 deletions mlir/test/python/dialects/linalg/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,3 +466,103 @@ def matmul_as_contract_op(
)

print(module)


# CHECK-LABEL: TEST: testBatchMatmulOp
@run
def testBatchMatmulOp():
with Context(), Location.unknown():
module = Module.create()
f32 = F32Type.get()
with InsertionPoint(module.body):
a_shape = (2, 4, 8)
b_shape = (2, 8, 12)
b_transposed_shape = (2, 12, 8)
c_shape = (2, 4, 12)

dimBatch = ir.AffineDimExpr.get(0)
dimM = ir.AffineDimExpr.get(1)
dimN = ir.AffineDimExpr.get(2)
dimK = ir.AffineDimExpr.get(3)

# CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
# CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
# CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>

a_map = ir.AffineMap.get(4, 0, [dimBatch, dimM, dimK])
b_transposed_map = ir.AffineMap.get(4, 0, [dimBatch, dimN, dimK])
c_map = ir.AffineMap.get(4, 0, [dimBatch, dimM, dimN])

# CHECK: func.func @batch_matmul_op(
@func.FuncOp.from_py_func(
# CHECK-SAME: %[[A:.*]]: tensor<2x4x8xf32>,
RankedTensorType.get(a_shape, f32),
# CHECK-SAME: %[[Amem:.*]]: memref<2x4x8xf32>,
MemRefType.get(a_shape, f32),
# CHECK-SAME: %[[B:.*]]: tensor<2x8x12xf32>,
RankedTensorType.get(b_shape, f32),
# CHECK-SAME: %[[Bmem:.*]]: memref<2x8x12xf32>,
MemRefType.get(b_shape, f32),
# CHECK-SAME: %[[BTrans:.*]]: tensor<2x12x8xf32>,
RankedTensorType.get(b_transposed_shape, f32),
# CHECK-SAME: %[[BTransmem:.*]]: memref<2x12x8xf32>,
MemRefType.get(b_transposed_shape, f32),
# CHECK-SAME: %[[C:.*]]: tensor<2x4x12xf32>,
RankedTensorType.get(c_shape, f32),
# CHECK-SAME: %[[Cmem:.*]]: memref<2x4x12xf32>)
MemRefType.get(c_shape, f32),
)
def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
# CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x4x8xf32>, tensor<2x8x12xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
res = linalg.BatchMatmulOp(
result_tensors=(C.type,),
inputs=(A, B),
outputs=(C,),
)
linalg.fill_builtin_region(res.operation)
# CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x4x8xf32>, tensor<2x8x12xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
res = linalg.batch_matmul(A, B, outs=(C,))

# CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<2x4x8xf32>, tensor<2x12x8xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
res = linalg.BatchMatmulOp(
result_tensors=(C.type,),
inputs=(A, Btransposed),
outputs=(C,),
indexing_maps=[a_map, b_transposed_map, c_map],
)
linalg.fill_builtin_region(res.operation)
# CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<2x4x8xf32>, tensor<2x12x8xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
res = linalg.batch_matmul(
A,
Btransposed,
outs=(C,),
indexing_maps=[a_map, b_transposed_map, c_map],
)

# CHECK: linalg.batch_matmul ins(%[[Amem]], %[[Bmem]] : memref<2x4x8xf32>, memref<2x8x12xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
res = linalg.BatchMatmulOp(
result_tensors=[],
inputs=(Amem, Bmem),
outputs=(Cmem,),
)
linalg.fill_builtin_region(res.operation)
# CHECK: linalg.batch_matmul ins(%[[Amem]], %[[Bmem]] : memref<2x4x8xf32>, memref<2x8x12xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
linalg.batch_matmul(Amem, Bmem, outs=(Cmem,))

# CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<2x4x8xf32>, memref<2x12x8xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
res = linalg.BatchMatmulOp(
result_tensors=[],
inputs=(Amem, Btransposedmem),
outputs=(Cmem,),
indexing_maps=[a_map, b_transposed_map, c_map],
)
linalg.fill_builtin_region(res.operation)
# CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<2x4x8xf32>, memref<2x12x8xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
linalg.batch_matmul(
Amem,
Btransposedmem,
outs=(Cmem,),
indexing_maps=[a_map, b_transposed_map, c_map],
)

print(module)