@@ -46,6 +46,55 @@ func.func @permutation_with_mask_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %
46
46
return
47
47
}
48
48
49
+ // transfer_write in MaskOp case not supported.
50
+ // CHECK-LABEL: func @masked_permutation_xfer_write_fixed_width
51
+ // CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>,
52
+ // CHECK-SAME: %[[ARG_1:.*]]: vector<16xf32>,
53
+ // CHECK-SAME: %[[IDX:.*]]: index,
54
+ // CHECK-SAME: %[[MASK:.*]]: vector<16xi1>
55
+ // CHECK-NOT: vector.transpose
56
+ // CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_1]], %[[ARG_0]][%[[IDX]], %[[IDX]]] {{.*}} vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
57
+ // CHECK: return %[[RES]]
58
+ func.func @masked_permutation_xfer_write_fixed_width (%t: tensor <?x?xf32 >, %val: vector <16 xf32 >, %idx: index , %mask: vector <16 xi1 >) -> tensor <?x?xf32 > {
59
+ %r = vector.mask %mask { vector.transfer_write %val , %t[%idx , %idx ] {permutation_map = affine_map <(d0 , d1 ) -> (d0 )>} : vector <16 xf32 >, tensor <?x?xf32 > } : vector <16 xi1 > -> tensor <?x?xf32 >
60
+ return %r : tensor <?x?xf32 >
61
+ }
62
+
63
+ // CHECK-LABEL: func.func @masked_permutation_xfer_write_scalable(
64
+ // CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>,
65
+ // CHECK-SAME: %[[ARG_1:.*]]: tensor<?x?x?x?xf32>,
66
+ // CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>)
67
+ // CHECK-SAME: -> tensor<?x?x?x?xf32> {
68
+ // CHECK-NOT: vector.transpose
69
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
70
+ // CHECK: %[[R:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_0]], %[[ARG_1]][%c0, %c0, %c0, %c0] {in_bounds = [true, true], permutation_map = #[[MAP:.*]]} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
71
+ // CHECK: return %[[R]] : tensor<?x?x?x?xf32>
72
+ func.func @masked_permutation_xfer_write_scalable (%arg0: vector <4 x[8 ]xi16 >, %t: tensor <?x?x?x?xf32 >, %mask: vector <4 x[8 ]xi1 >) -> tensor <?x?x?x?xf32 > {
73
+ %c0 = arith.constant 0 : index
74
+ %r = vector.mask %mask { vector.transfer_write %arg0 , %t [%c0 , %c0 , %c0 , %c0 ] {in_bounds = [true , true ], permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (d1 , d2 )>
75
+ } : vector <4 x[8 ]xi16 >, tensor <?x?x?x?xf32 > } : vector <4 x[8 ]xi1 > -> tensor <?x?x?x?xf32 >
76
+
77
+ return %r : tensor <?x?x?x?xf32 >
78
+ }
79
+
80
+ // transfer_write in MaskOp case not supported.
81
+ // CHECK-LABEL: func @masked_non_permutation_xfer_write_fixed_width
82
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x?x?xf32>
83
+ // CHECK-SAME: %[[ARG1:.*]]: vector<14x8x16xf32>
84
+ // CHECK-SAME: %[[IDX:.*]]: index) -> tensor<?x?x?x?xf32>
85
+ // CHECK-NOT: vector.broadcast
86
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
87
+ // CHECK: %[[masked1:.*]] = vector.mask %0 { vector.transfer_write %[[ARG1]], %[[ARG0]]{{.*}}permutation_map = #[[MAP:.*]]} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
88
+ func.func @masked_non_permutation_xfer_write_fixed_width (
89
+ %arg0 : tensor <?x?x?x?xf32 >,
90
+ %v1 : vector <14 x8 x16 xf32 >, %dim : index ) -> tensor <?x?x?x?xf32 > {
91
+ %c0 = arith.constant 0 : index
92
+ %mask = vector.create_mask %dim , %dim , %dim : vector <14 x8 x16 xi1 >
93
+ %0 = vector.mask %mask { vector.transfer_write %v1 , %arg0 [%c0 , %c0 , %c0 , %c0 ] {in_bounds = [false , false , true ], permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d3 )>} : vector <14 x8 x16 xf32 >, tensor <?x?x?x?xf32 > } : vector <14 x8 x16 xi1 > -> tensor <?x?x?x?xf32 >
94
+
95
+ return %0 : tensor <?x?x?x?xf32 >
96
+ }
97
+
49
98
///----------------------------------------------------------------------------------------
50
99
/// vector.transfer_read
51
100
///----------------------------------------------------------------------------------------
@@ -101,6 +150,39 @@ func.func @permutation_with_mask_xfer_read_scalable(%mem: memref<?x?xf32>, %dim_
101
150
return %1 : vector <8 x[4 ]x2 xf32 >
102
151
}
103
152
153
+ // transfer_read in MaskOp case not supported.
154
+ // CHECK-LABEL: func @masked_permutation_xfer_read_fixed_width
155
+ // CHECK-SAME: %[[ARG_0:.*]]: tensor<?x1xf32>,
156
+ // CHECK-SAME: %[[ARG_1:.*]]: vector<4x1xi1>
157
+ // CHECK-NOT: vector.transpose
158
+ // CHECK: vector.mask %[[ARG_1]] { vector.transfer_read %[[ARG_0]]{{.*}}: tensor<?x1xf32>, vector<1x4x4xf32> } : vector<4x1xi1> -> vector<1x4x4xf32>
159
+ func.func @masked_permutation_xfer_read_fixed_width (%arg0: tensor <?x1 xf32 >, %mask : vector <4 x1 xi1 >) {
160
+ %cst = arith.constant 0.000000e+00 : f32
161
+ %c0 = arith.constant 0 : index
162
+ %3 = vector.mask %mask { vector.transfer_read %arg0 [%c0 , %c0 ], %cst {permutation_map = affine_map <(d0 , d1 ) -> (d1 , 0 , d0 )>} : tensor <?x1 xf32 >, vector <1 x4 x4 xf32 > } : vector <4 x1 xi1 > -> vector <1 x4 x4 xf32 >
163
+ call @test.some_use (%3 ) : (vector <1 x4 x4 xf32 >) -> ()
164
+ return
165
+ }
166
+ func.func private @test.some_use (vector <1 x4 x4 xf32 >)
167
+
168
+ // CHECK-LABEL: func.func @masked_permutation_xfer_read_scalable(
169
+ // CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>,
170
+ // CHECK-SAME: %[[MASK:.*]]: vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
171
+ // CHECK-NOT: vector.transpose
172
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
173
+ // CHECK: %[[T_READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[ARG_0]][%[[C0]], %[[C0]]], %cst {in_bounds = [true, true, true], permutation_map = #[[MAP:.*]]} : tensor<?x?xf32>, vector<8x[4]x2xf32> } : vector<2x[4]xi1> -> vector<8x[4]x2xf32>
174
+ // CHECK: return %[[T_READ]] : vector<8x[4]x2xf32>
175
+ func.func @masked_permutation_xfer_read_scalable (%t: tensor <?x?xf32 >, %mask : vector <2 x[4 ]xi1 >) -> vector <8 x[4 ]x2 xf32 > {
176
+
177
+ %c0 = arith.constant 0 : index
178
+ %cst_0 = arith.constant 0.000000e+00 : f32
179
+
180
+ %1 = vector.mask %mask { vector.transfer_read %t [%c0 , %c0 ], %cst_0
181
+ {in_bounds = [true , true , true ], permutation_map = affine_map <(d0 , d1 ) -> (0 , d1 , d0 )>}
182
+ : tensor <?x?xf32 >, vector <8 x[4 ]x2 xf32 > } :vector <2 x[4 ]xi1 > -> vector <8 x[4 ]x2 xf32 >
183
+ return %1 : vector <8 x[4 ]x2 xf32 >
184
+ }
185
+
104
186
module attributes {transform.with_named_sequence } {
105
187
transform.named_sequence @__transform_main (%module_op: !transform.any_op {transform.readonly }) {
106
188
%f = transform.structured.match ops {[" func.func" ]} in %module_op
0 commit comments