1
1
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
2
2
3
- // CHECK-LABEL: func @lower_permutation_with_mask_fixed_width(
3
+ ///----------------------------------------------------------------------------------------
4
+ /// vector.transfer_write
5
+ ///----------------------------------------------------------------------------------------
6
+ /// Input:
7
+ /// * vector.transfer_write op with a map which _is not_ the permutation of a
8
+ /// minor identity
9
+ /// Output:
10
+ /// * vector.broadcast + vector.transfer_write with a map which _is_ the permutation of a
11
+ /// minor identity
12
+
13
+ // CHECK-LABEL: func @permutation_with_mask_xfer_write_fixed_width(
4
14
// CHECK: %[[vec:.*]] = arith.constant dense<-2.000000e+00> : vector<7x1xf32>
5
15
// CHECK: %[[mask:.*]] = arith.constant dense<[true, false, true, false, true, true, true]> : vector<7xi1>
6
16
// CHECK: %[[b:.*]] = vector.broadcast %[[mask]] : vector<7xi1> to vector<1x7xi1>
7
17
// CHECK: %[[tp:.*]] = vector.transpose %[[b]], [1, 0] : vector<1x7xi1> to vector<7x1xi1>
8
18
// CHECK: vector.transfer_write %[[vec]], %{{.*}}[%{{.*}}, %{{.*}}], %[[tp]] {in_bounds = [false, true]} : vector<7x1xf32>, memref<?x?xf32>
9
- func.func @lower_permutation_with_mask_fixed_width (%A : memref <?x?xf32 >, %base1 : index ,
10
- %base2 : index ) {
19
+ func.func @permutation_with_mask_xfer_write_fixed_width (%mem : memref <?x?xf32 >, %base1 : index ,
20
+ %base2 : index ) {
21
+
11
22
%fn1 = arith.constant -2.0 : f32
12
23
%vf0 = vector.splat %fn1 : vector <7 xf32 >
13
24
%mask = arith.constant dense <[1 , 0 , 1 , 0 , 1 , 1 , 1 ]> : vector <7 xi1 >
14
- vector.transfer_write %vf0 , %A [%base1 , %base2 ], %mask
25
+ vector.transfer_write %vf0 , %mem [%base1 , %base2 ], %mask
15
26
{permutation_map = affine_map <(d0 , d1 ) -> (d0 )>, in_bounds = [false ]}
16
27
: vector <7 xf32 >, memref <?x?xf32 >
17
28
return
18
29
}
19
30
20
- // CHECK-LABEL: func.func @permutation_with_mask_scalable(
31
+ // CHECK: func.func @permutation_with_mask_xfer_write_scalable(
32
+ // CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>,
33
+ // CHECK-SAME: %[[ARG_1:.*]]: memref<1x4x?x1xi16>,
34
+ // CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>) {
35
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
36
+ // CHECK: %[[BCAST_1:.*]] = vector.broadcast %[[ARG_0]] : vector<4x[8]xi16> to vector<1x4x[8]xi16>
37
+ // CHECK: %[[BCAST_2:.*]] = vector.broadcast %[[MASK]] : vector<4x[8]xi1> to vector<1x4x[8]xi1>
38
+ // CHECK: %[[TRANSPOSE_1:.*]] = vector.transpose %[[BCAST_2]], [1, 2, 0] : vector<1x4x[8]xi1> to vector<4x[8]x1xi1>
39
+ // CHECK: %[[TRANSPOSE_2:.*]] = vector.transpose %[[BCAST_1]], [1, 2, 0] : vector<1x4x[8]xi16> to vector<4x[8]x1xi16>
40
+ // CHECK: vector.transfer_write %[[TRANSPOSE_2]], %[[ARG_1]]{{.*}}, %[[TRANSPOSE_1]] {in_bounds = [true, true, true]} : vector<4x[8]x1xi16>, memref<1x4x?x1xi16>
41
+ func.func @permutation_with_mask_xfer_write_scalable (%arg0: vector <4 x[8 ]xi16 >, %mem: memref <1 x4 x?x1 xi16 >, %mask: vector <4 x[8 ]xi1 >){
42
+ %c0 = arith.constant 0 : index
43
+ vector.transfer_write %arg0 , %mem [%c0 , %c0 , %c0 , %c0 ], %mask {in_bounds = [true , true ], permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (d1 , d2 )>
44
+ } : vector <4 x[8 ]xi16 >, memref <1 x4 x?x1 xi16 >
45
+
46
+ return
47
+ }
48
+
49
+ ///----------------------------------------------------------------------------------------
50
+ /// vector.transfer_read
51
+ ///----------------------------------------------------------------------------------------
52
+ /// Input:
53
+ /// * vector.transfer_read op with a permutation map
54
+ /// Output:
55
+ /// * vector.transfer_read with a permutation map composed of leading zeros followed by a minor identiy +
56
+ /// vector.transpose op
57
+
58
+ // CHECK-LABEL: func.func @permutation_with_mask_xfer_read_fixed_width(
59
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<?x?xf32>,
60
+ // CHECK-SAME: %[[IDX_1:.*]]: index,
61
+ // CHECK-SAME: %[[IDX_2:.*]]: index) -> vector<8x4x2xf32> {
62
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
63
+ // CHECK: %[[PASS_THROUGH:.*]] = arith.constant 0.000000e+00 : f32
64
+ // CHECK: %[[MASK:.*]] = vector.create_mask %[[IDX_2]], %[[IDX_1]] : vector<2x4xi1>
65
+ // CHECK: %[[T_READ:.*]] = vector.transfer_read %[[ARG_0]]{{\[}}%[[C0]], %[[C0]]], %[[PASS_THROUGH]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x4xf32>
66
+ // CHECK: %[[BCAST:.*]] = vector.broadcast %[[T_READ]] : vector<2x4xf32> to vector<8x2x4xf32>
67
+ // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[BCAST]], [0, 2, 1] : vector<8x2x4xf32> to vector<8x4x2xf32>
68
+ // CHECK: return %[[TRANSPOSE]] : vector<8x4x2xf32>
69
+ func.func @permutation_with_mask_xfer_read_fixed_width (%mem: memref <?x?xf32 >, %dim_1: index , %dim_2: index ) -> (vector <8 x4 x2 xf32 >) {
70
+
71
+ %c0 = arith.constant 0 : index
72
+ %cst_0 = arith.constant 0.000000e+00 : f32
73
+
74
+ %mask = vector.create_mask %dim_2 , %dim_1 : vector <2 x4 xi1 >
75
+ %1 = vector.transfer_read %mem [%c0 , %c0 ], %cst_0 , %mask
76
+ {in_bounds = [true , true , true ], permutation_map = affine_map <(d0 , d1 ) -> (0 , d1 , d0 )>}
77
+ : memref <?x?xf32 >, vector <8 x4 x2 xf32 >
78
+ return %1 : vector <8 x4 x2 xf32 >
79
+ }
80
+
81
+ // CHECK-LABEL: func.func @permutation_with_mask_xfer_read_scalable(
21
82
// CHECK-SAME: %[[ARG_0:.*]]: memref<?x?xf32>,
22
83
// CHECK-SAME: %[[IDX_1:.*]]: index,
23
84
// CHECK-SAME: %[[IDX_2:.*]]: index) -> vector<8x[4]x2xf32> {
@@ -28,37 +89,18 @@ func.func @lower_permutation_with_mask_fixed_width(%A : memref<?x?xf32>, %base1
28
89
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[T_READ]] : vector<2x[4]xf32> to vector<8x2x[4]xf32>
29
90
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[BCAST]], [0, 2, 1] : vector<8x2x[4]xf32> to vector<8x[4]x2xf32>
30
91
// CHECK: return %[[TRANSPOSE]] : vector<8x[4]x2xf32>
31
- // CHECK: }
32
- func.func @permutation_with_mask_scalable (%2: memref <?x?xf32 >, %dim_1: index , %dim_2: index ) -> (vector <8 x[4 ]x2 xf32 >) {
92
+ func.func @permutation_with_mask_xfer_read_scalable (%mem: memref <?x?xf32 >, %dim_1: index , %dim_2: index ) -> (vector <8 x[4 ]x2 xf32 >) {
33
93
34
94
%c0 = arith.constant 0 : index
35
95
%cst_0 = arith.constant 0.000000e+00 : f32
36
96
37
97
%mask = vector.create_mask %dim_2 , %dim_1 : vector <2 x[4 ]xi1 >
38
- %1 = vector.transfer_read %2 [%c0 , %c0 ], %cst_0 , %mask
98
+ %1 = vector.transfer_read %mem [%c0 , %c0 ], %cst_0 , %mask
39
99
{in_bounds = [true , true , true ], permutation_map = affine_map <(d0 , d1 ) -> (0 , d1 , d0 )>}
40
100
: memref <?x?xf32 >, vector <8 x[4 ]x2 xf32 >
41
101
return %1 : vector <8 x[4 ]x2 xf32 >
42
102
}
43
103
44
- // CHECK: func.func @permutation_with_mask_transfer_write_scalable(
45
- // CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>,
46
- // CHECK-SAME: %[[ARG_1:.*]]: memref<1x4x?x1x1x1x1xi16>,
47
- // CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>) {
48
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
49
- // CHECK: %[[BCAST_1:.*]] = vector.broadcast %[[ARG_0]] : vector<4x[8]xi16> to vector<1x1x1x1x4x[8]xi16>
50
- // CHECK: %[[BCAST_2:.*]] = vector.broadcast %[[MASK]] : vector<4x[8]xi1> to vector<1x1x1x1x4x[8]xi1>
51
- // CHECK: %[[TRANSPOSE_1:.*]] = vector.transpose %[[BCAST_2]], [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi1> to vector<4x[8]x1x1x1x1xi1>
52
- // CHECK: %[[TRANSPOSE_2:.*]] = vector.transpose %[[BCAST_1]], [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi16> to vector<4x[8]x1x1x1x1xi16>
53
- // CHECK: vector.transfer_write %[[TRANSPOSE_2]], %[[ARG_1]]{{\[}}%[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[TRANSPOSE_1]] {in_bounds = [true, true, true, true, true, true]} : vector<4x[8]x1x1x1x1xi16>, memref<1x4x?x1x1x1x1xi16>
54
- // CHECK: return
55
- func.func @permutation_with_mask_transfer_write_scalable (%arg0: vector <4 x[8 ]xi16 >, %arg1: memref <1 x4 x?x1 x1 x1 x1 xi16 >, %mask: vector <4 x[8 ]xi1 >){
56
- %c0 = arith.constant 0 : index
57
- vector.transfer_write %arg0 , %arg1 [%c0 , %c0 , %c0 , %c0 , %c0 , %c0 , %c0 ], %mask {in_bounds = [true , true ], permutation_map = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 ) -> (d1 , d2 )>
58
- } : vector <4 x[8 ]xi16 >, memref <1 x4 x?x1 x1 x1 x1 xi16 >
59
-
60
- return
61
- }
62
104
module attributes {transform.with_named_sequence } {
63
105
transform.named_sequence @__transform_main (%module_op: !transform.any_op {transform.readonly }) {
64
106
%f = transform.structured.match ops {[" func.func" ]} in %module_op
0 commit comments