Skip to content

Commit ce8c1a7

Browse files
committed
Add missing linalg.batch_vecmat named op (llvm#70218)
Linalg currently has these named ops: * `matmul` * `matvec` * `vecmat` * `batch_matmul` * `batch_matvec` But it does not have: * `batch_vecmat` This PRs adds that for consistency, and I have a short-term need for it ( iree-org/iree#15158 ), so not having this would cause some contortion on my end.
1 parent 6aaa03a commit ce8c1a7

File tree

3 files changed

+111
-0
lines changed

3 files changed

+111
-0
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1796,6 +1796,74 @@ structured_op: !LinalgStructuredOpConfig
17961796
- !ScalarExpression
17971797
scalar_arg: B
17981798
--- !LinalgOpConfig
1799+
metadata: !LinalgOpMetadata
1800+
name: batch_vecmat
1801+
cpp_class_name: BatchVecmatOp
1802+
doc: |-
1803+
Performs a batched matrix-vector multiplication.
1804+
1805+
Numeric casting is performed on the operands to the inner multiply, promoting
1806+
them to the same data type as the accumulator/output.
1807+
implements:
1808+
- LinalgContractionOpInterface
1809+
structured_op: !LinalgStructuredOpConfig
1810+
args:
1811+
- !LinalgOperandDefConfig
1812+
name: A
1813+
kind: input_tensor
1814+
type_var: T1
1815+
shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
1816+
- !LinalgOperandDefConfig
1817+
name: B
1818+
kind: input_tensor
1819+
type_var: T2
1820+
shape_map: affine_map<()[s0, s1, s2] -> (s0, s1, s2)>
1821+
- !LinalgOperandDefConfig
1822+
name: C
1823+
kind: output_tensor
1824+
type_var: U
1825+
shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
1826+
indexing_maps: !LinalgIndexingMapsConfig
1827+
static_indexing_maps:
1828+
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
1829+
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2, d1)>
1830+
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
1831+
iterator_types:
1832+
- parallel
1833+
- parallel
1834+
- reduction
1835+
assignments:
1836+
- !ScalarAssign
1837+
arg: C
1838+
value: !ScalarExpression
1839+
scalar_fn:
1840+
kind: binary
1841+
fn_name: add
1842+
operands:
1843+
- !ScalarExpression
1844+
scalar_arg: C
1845+
- !ScalarExpression
1846+
scalar_fn:
1847+
kind: binary
1848+
fn_name: mul
1849+
operands:
1850+
- !ScalarExpression
1851+
scalar_fn:
1852+
kind: type
1853+
fn_name: cast_signed
1854+
type_var: U
1855+
operands:
1856+
- !ScalarExpression
1857+
scalar_arg: A
1858+
- !ScalarExpression
1859+
scalar_fn:
1860+
kind: type
1861+
fn_name: cast_signed
1862+
type_var: U
1863+
operands:
1864+
- !ScalarExpression
1865+
scalar_arg: B
1866+
--- !LinalgOpConfig
17991867
metadata: !LinalgOpMetadata
18001868
name: dot
18011869
cpp_class_name: DotOp

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,24 @@ def batch_matvec(
517517
)
518518

519519

520+
@linalg_structured_op
521+
def batch_vecmat(
522+
A=TensorDef(T1, Batch, S.K),
523+
B=TensorDef(T2, Batch, S.K, S.N),
524+
C=TensorDef(U, Batch, S.N, output=True),
525+
):
526+
"""Performs a batched matrix-vector multiplication.
527+
528+
Numeric casting is performed on the operands to the inner multiply, promoting
529+
them to the same data type as the accumulator/output.
530+
"""
531+
domain(D.b, D.n, D.k)
532+
implements(ContractionOpInterface)
533+
C[D.b, D.n] += TypeFn.cast_signed(U, A[D.b, D.k]) * TypeFn.cast_signed(
534+
U, B[D.b, D.k, D.n]
535+
)
536+
537+
520538
@linalg_structured_op
521539
def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, output=True)):
522540
"""Performs a dot product of two vectors to a scalar result.

mlir/test/Dialect/Linalg/generalize-named-ops.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,31 @@ func.func @generalize_batch_matm_vec(%lhs : memref<?x?x?xi8>, %rhs: memref<?x?xi
251251

252252
// -----
253253

254+
func.func @generalize_batch_vecmat(%lhs : memref<?x?xi8>, %rhs: memref<?x?x?xi8>, %out: memref<?x?xf32>) {
255+
linalg.batch_vecmat ins(%lhs, %rhs: memref<?x?xi8>, memref<?x?x?xi8>)
256+
outs(%out: memref<?x?xf32>)
257+
return
258+
}
259+
// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
260+
// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
261+
// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
262+
263+
// CHECK: @generalize_batch_vecmat
264+
265+
// CHECK: linalg.generic
266+
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
267+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
268+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?xi8>, memref<?x?x?xi8>)
269+
// CHECK-SAME: outs(%{{.+}} : memref<?x?xf32>)
270+
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: f32)
271+
// CHECK: %[[BBARG0_F32:.+]] = arith.sitofp %[[BBARG0]] : i8 to f32
272+
// CHECK: %[[BBARG1_F32:.+]] = arith.sitofp %[[BBARG1]] : i8 to f32
273+
// CHECK: %[[MUL:.+]] = arith.mulf %[[BBARG0_F32]], %[[BBARG1_F32]]
274+
// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]]
275+
// CHECK: linalg.yield %[[ADD]] : f32
276+
277+
// -----
278+
254279
func.func @batch_reduce_gemm(%lhs: memref<7x8x9xf32>, %rhs: memref<7x9x8xf32>, %out: memref<8x8xf32>) {
255280
linalg.batch_reduce_matmul ins(%lhs, %rhs: memref<7x8x9xf32>, memref<7x9x8xf32>)
256281
outs(%out: memref<8x8xf32>)

0 commit comments

Comments
 (0)