@@ -466,3 +466,102 @@ def matmul_as_contract_op(
466
466
)
467
467
468
468
print (module )
469
+
470
+ # CHECK-LABEL: TEST: testBatchMatmulOp
471
+ @run
472
+ def testBatchMatmulOp ():
473
+ with Context (), Location .unknown ():
474
+ module = Module .create ()
475
+ f32 = F32Type .get ()
476
+ with InsertionPoint (module .body ):
477
+ a_shape = (2 , 4 , 8 )
478
+ b_shape = (2 , 8 , 12 )
479
+ b_transposed_shape = (2 , 12 , 8 )
480
+ c_shape = (2 , 4 , 12 )
481
+
482
+ dimBatch = ir .AffineDimExpr .get (0 )
483
+ dimM = ir .AffineDimExpr .get (1 )
484
+ dimN = ir .AffineDimExpr .get (2 )
485
+ dimK = ir .AffineDimExpr .get (3 )
486
+
487
+ # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
488
+ # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
489
+ # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
490
+
491
+ a_map = ir .AffineMap .get (4 , 0 , [dimBatch , dimM , dimK ])
492
+ b_transposed_map = ir .AffineMap .get (4 , 0 , [dimBatch , dimN , dimK ])
493
+ c_map = ir .AffineMap .get (4 , 0 , [dimBatch , dimM , dimN ])
494
+
495
+ # CHECK: func.func @batch_matmul_op(
496
+ @func .FuncOp .from_py_func (
497
+ # CHECK-SAME: %[[A:.*]]: tensor<2x4x8xf32>,
498
+ RankedTensorType .get (a_shape , f32 ),
499
+ # CHECK-SAME: %[[Amem:.*]]: memref<2x4x8xf32>,
500
+ MemRefType .get (a_shape , f32 ),
501
+ # CHECK-SAME: %[[B:.*]]: tensor<2x8x12xf32>,
502
+ RankedTensorType .get (b_shape , f32 ),
503
+ # CHECK-SAME: %[[Bmem:.*]]: memref<2x8x12xf32>,
504
+ MemRefType .get (b_shape , f32 ),
505
+ # CHECK-SAME: %[[BTrans:.*]]: tensor<2x12x8xf32>,
506
+ RankedTensorType .get (b_transposed_shape , f32 ),
507
+ # CHECK-SAME: %[[BTransmem:.*]]: memref<2x12x8xf32>,
508
+ MemRefType .get (b_transposed_shape , f32 ),
509
+ # CHECK-SAME: %[[C:.*]]: tensor<2x4x12xf32>,
510
+ RankedTensorType .get (c_shape , f32 ),
511
+ # CHECK-SAME: %[[Cmem:.*]]: memref<2x4x12xf32>)
512
+ MemRefType .get (c_shape , f32 ),
513
+ )
514
+ def batch_matmul_op (A , Amem , B , Bmem , Btransposed , Btransposedmem , C , Cmem ):
515
+ # CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x4x8xf32>, tensor<2x8x12xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
516
+ res = linalg .BatchMatmulOp (
517
+ result_tensors = (C .type ,),
518
+ inputs = (A , B ),
519
+ outputs = (C ,),
520
+ )
521
+ linalg .fill_builtin_region (res .operation )
522
+ # CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x4x8xf32>, tensor<2x8x12xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
523
+ res = linalg .batch_matmul (A , B , outs = (C ,))
524
+
525
+ # CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<2x4x8xf32>, tensor<2x12x8xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
526
+ res = linalg .BatchMatmulOp (
527
+ result_tensors = (C .type ,),
528
+ inputs = (A , Btransposed ),
529
+ outputs = (C ,),
530
+ indexing_maps = [a_map , b_transposed_map , c_map ],
531
+ )
532
+ linalg .fill_builtin_region (res .operation )
533
+ # CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<2x4x8xf32>, tensor<2x12x8xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
534
+ res = linalg .batch_matmul (
535
+ A ,
536
+ Btransposed ,
537
+ outs = (C ,),
538
+ indexing_maps = [a_map , b_transposed_map , c_map ],
539
+ )
540
+
541
+ # CHECK: linalg.batch_matmul ins(%[[Amem]], %[[Bmem]] : memref<2x4x8xf32>, memref<2x8x12xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
542
+ res = linalg .BatchMatmulOp (
543
+ result_tensors = [],
544
+ inputs = (Amem , Bmem ),
545
+ outputs = (Cmem ,),
546
+ )
547
+ linalg .fill_builtin_region (res .operation )
548
+ # CHECK: linalg.batch_matmul ins(%[[Amem]], %[[Bmem]] : memref<2x4x8xf32>, memref<2x8x12xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
549
+ linalg .batch_matmul (Amem , Bmem , outs = (Cmem ,))
550
+
551
+ # CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<2x4x8xf32>, memref<2x12x8xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
552
+ res = linalg .BatchMatmulOp (
553
+ result_tensors = [],
554
+ inputs = (Amem , Btransposedmem ),
555
+ outputs = (Cmem ,),
556
+ indexing_maps = [a_map , b_transposed_map , c_map ],
557
+ )
558
+ linalg .fill_builtin_region (res .operation )
559
+ # CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<2x4x8xf32>, memref<2x12x8xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
560
+ linalg .batch_matmul (
561
+ Amem ,
562
+ Btransposedmem ,
563
+ outs = (Cmem ,),
564
+ indexing_maps = [a_map , b_transposed_map , c_map ],
565
+ )
566
+
567
+ print (module )
0 commit comments