1
1
// RUN: gc-opt --split-input-file --deep-tile-contraction-named-op %s
2
2
3
- // -----
3
+ // // -----
4
4
5
- /// CHECK-LABEL: @blocked_matmul_f32
6
- func.func @blocked_matmul_f32 (%arg0: tensor <128 x128 x32 x32 xf32 >) -> tensor <128 x128 x32 x32 xf32 > {
7
- %cst = arith.constant dense <1.000000e+00 > : tensor <128 x128 x32 x32 xf32 >
8
- %cst_0 = arith.constant 0.000000e+00 : f32
9
- %0 = tensor.empty () : tensor <128 x128 x32 x32 xf32 >
10
- %1 = linalg.fill ins (%cst_0 : f32 ) outs (%0 : tensor <128 x128 x32 x32 xf32 >) -> tensor <128 x128 x32 x32 xf32 >
11
- %2 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d0 , d2 , d3 , d5 )>, affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d1 , d2 , d5 , d4 )>, affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d0 , d1 , d3 , d4 )>], iterator_types = [" parallel" , " parallel" , " reduction" , " parallel" , " parallel" , " reduction" ]} ins (%arg0 , %cst : tensor <128 x128 x32 x32 xf32 >, tensor <128 x128 x32 x32 xf32 >) outs (%1 : tensor <128 x128 x32 x32 xf32 >) {
12
- ^bb0 (%in: f32 , %in_1: f32 , %out: f32 ):
13
- %3 = arith.mulf %in , %in_1 : f32
14
- %4 = arith.addf %out , %3 : f32
15
- linalg.yield %4 : f32
16
- } -> tensor <128 x128 x32 x32 xf32 >
17
- return %2 : tensor <128 x128 x32 x32 xf32 >
18
- }
5
+ // /// CHECK-LABEL: @matmul_4Dx4D_f32
6
+ // func.func @matmul_4Dx4D_f32(%arg0: tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32> {
7
+ // %cst = arith.constant dense<1.000000e+00> : tensor<128x128x32x32x1xf32>
8
+ // %cst_0 = arith.constant 0.000000e+00 : f32
9
+ // %0 = tensor.empty() : tensor<128x128x32x32xf32>
10
+ // %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32>
11
+ // %2 = linalgx.mm4d_vnni ins(%arg0, %cst : tensor<128x128x32x32xf32>, tensor<128x128x32x32x1xf32>) outs(%1 : tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32>
12
+ // return %2 : tensor<128x128x32x32xf32>
13
+ // }
19
14
20
15
// -----
21
16
22
- /// CHECK-LABEL: @plain_matmul_f32
23
- func.func @plain_matmul_f32 (%arg0: tensor <4096 x4096 xf32 >) -> tensor <4096 x4096 xf32 > {
17
+ /// CHECK-LABEL: @matmul_2Dx2D_f32
18
+ func.func @matmul_2Dx2D_f32 (%arg0: tensor <4096 x4096 xf32 >) -> tensor <4096 x4096 xf32 > {
24
19
%cst = arith.constant dense <1.000000e+00 > : tensor <4096 x4096 xf32 >
25
20
%cst_0 = arith.constant 0.000000e+00 : f32
26
21
%0 = tensor.empty () : tensor <4096 x4096 xf32 >
@@ -29,20 +24,39 @@ func.func @plain_matmul_f32(%arg0: tensor<4096x4096xf32>) -> tensor<4096x4096xf3
29
24
return %2 : tensor <4096 x4096 xf32 >
30
25
}
31
26
27
+ // // -----
28
+
29
+ // /// CHECK-LABEL: @matmul_2Dx4D_f32
30
+ // func.func @matmul_4Dx4D_f32(%arg0: tensor<4096x4096xf32>) -> tensor<4096x4096xf32> {
31
+ // %cst = arith.constant dense<1.000000e+00> : tensor<128x128x32x32x1xf32>
32
+ // %cst_0 = arith.constant 0.000000e+00 : f32
33
+ // %0 = tensor.empty() : tensor<4096x4096xf32>
34
+ // %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
35
+ // %2 = linalgx.mm2d_vnni ins(%arg0, %cst : tensor<4096x4096xf32>, tensor<128x128x32x32x1xf32>) outs(%1 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
36
+ // return %2 : tensor<4096x4096xf32>
37
+ // }
38
+
32
39
// -----
33
40
34
- /// CHECK-LABEL: @blocked_matmul_bf16
35
- func.func @blocked_matmul_bf16 (%arg0: tensor <128 x128 x32 x32 xbf16 >) -> tensor <128 x128 x32 x32 xbf16 > {
41
+ /// CHECK-LABEL: @matmul_4Dx4D_bf16
42
+ func.func @matmul_4Dx4D_bf16 (%arg0: tensor <128 x128 x32 x32 xbf16 >) -> tensor <128 x128 x32 x32 xbf16 > {
36
43
%cst = arith.constant dense <1.000000e+00 > : tensor <128 x128 x16 x32 x2 xbf16 >
37
44
%cst_0 = arith.constant 0.000000e+00 : bf16
38
45
%0 = tensor.empty () : tensor <128 x128 x32 x32 xbf16 >
39
46
%1 = linalg.fill ins (%cst_0 : bf16 ) outs (%0 : tensor <128 x128 x32 x32 xbf16 >) -> tensor <128 x128 x32 x32 xbf16 >
40
- %2 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 ) -> (d0 , d2 , d4 , d6 )>, affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 ) -> (d1 , d2 , d6 floordiv 2 , d5 , d3 )>, affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 ) -> (d0 , d1 , d4 , d5 )>], iterator_types = [" parallel" , " parallel" , " reduction" , " reduction" , " parallel" , " parallel" , " reduction" ]} ins (%arg0 , %cst : tensor <128 x128 x32 x32 xbf16 >, tensor <128 x128 x16 x32 x2 xbf16 >) outs (%1 : tensor <128 x128 x32 x32 xbf16 >) {
41
- ^bb0 (%in: bf16 , %in_1: bf16 , %out: bf16 ):
42
- %3 = arith.mulf %in , %in_1 : bf16
43
- %4 = arith.addf %out , %3 : bf16
44
- linalg.yield %4 : bf16
45
- } -> tensor <128 x128 x32 x32 xbf16 >
47
+ %2 = linalgx.mm4d_vnni ins (%arg0 , %cst : tensor <128 x128 x32 x32 xbf16 >, tensor <128 x128 x16 x32 x2 xbf16 >) outs (%1 : tensor <128 x128 x32 x32 xbf16 >) -> tensor <128 x128 x32 x32 xbf16 >
46
48
return %2 : tensor <128 x128 x32 x32 xbf16 >
47
49
}
48
50
51
+ // // -----
52
+
53
+ // /// CHECK-LABEL: @matmul_2Dx4D_bf16
54
+ // func.func @matmul_4Dx4D_bf16(%arg0: tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16> {
55
+ // %cst = arith.constant dense<1.000000e+00> : tensor<128x128x16x32x2xbf16>
56
+ // %cst_0 = arith.constant 0.000000e+00 : bf16
57
+ // %0 = tensor.empty() : tensor<4096x4096xbf16>
58
+ // %1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16>
59
+ // %2 = linalgx.mm2d_vnni ins(%arg0, %cst : tensor<4096x4096xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16>
60
+ // return %2 : tensor<4096x4096xbf16>
61
+ // }
62
+
0 commit comments