@@ -187,3 +187,49 @@ module attributes {transform.with_named_sequence} {
187
187
transform.yield
188
188
}
189
189
}
190
+
191
+ // -----
192
+
193
+
194
+ // CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>
195
+ // CHECK: func.func @transfer_read_reduce_rank_scalable(
196
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<?x?x?x?xf32>) -> vector<8x[4]x2x3xf32> {
197
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
198
+ // CHECK: %[[TFR:.*]] = vector.transfer_read %arg0[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]{{.*}} permutation_map = #[[MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
199
+ // CHECK: %[[BC:.*]] = vector.broadcast %[[TFR]] : vector<[4]x2x3xf32> to vector<8x[4]x2x3xf32>
200
+ // CHECK: return %[[BC]] : vector<8x[4]x2x3xf32>
201
+ func.func @transfer_read_reduce_rank_scalable (%mem: memref <?x?x?x?xf32 >) -> vector <8 x[4 ]x2 x3 xf32 > {
202
+ %c0 = arith.constant 0 : index
203
+ %cst_0 = arith.constant 0.000000e+00 : f32
204
+ %1 = vector.transfer_read %mem [%c0 , %c0 , %c0 , %c0 ], %cst_0
205
+ {in_bounds = [true , true , true , true ], permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (0 , d1 , 0 , d3 )>}
206
+ : memref <?x?x?x?xf32 >, vector <8 x[4 ]x2 x3 xf32 >
207
+ return %1 : vector <8 x[4 ]x2 x3 xf32 >
208
+ }
209
+
210
+ // Masked case not supported.
211
+ // CHECK-LABEL: func.func @masked_transfer_read_reduce_rank(
212
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<?x?x?x?xf32>,
213
+ // CHECK-SAME: %[[DIM:.*]]: index) -> vector<8x[4]x2x3xf32> {
214
+ // CHECK-NOT: vector.broadcast
215
+ // CHECK: %[[MASK:.*]] = vector.mask %0 { vector.transfer_read %arg0{{.*}} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
216
+ func.func @masked_transfer_read_reduce_rank (%mem: memref <?x?x?x?xf32 >, %dim: index ) -> vector <8 x[4 ]x2 x3 xf32 > {
217
+ %c0 = arith.constant 0 : index
218
+ %cst_0 = arith.constant 0.000000e+00 : f32
219
+ %mask = vector.create_mask %dim , %dim: vector <[4 ]x3 xi1 >
220
+ %res = vector.mask %mask { vector.transfer_read %mem [%c0 , %c0 , %c0 , %c0 ], %cst_0
221
+ {in_bounds = [true , true , true , true ], permutation_map = affine_map <(d0 , d1 , d2 , d3 ) -> (0 , d1 , 0 , d3 )>}
222
+ : memref <?x?x?x?xf32 >, vector <8 x[4 ]x2 x3 xf32 > } : vector <[4 ]x3 xi1 > -> vector <8 x[4 ]x2 x3 xf32 >
223
+ return %res : vector <8 x[4 ]x2 x3 xf32 >
224
+ }
225
+
226
+ module attributes {transform.with_named_sequence } {
227
+ transform.named_sequence @__transform_main (%module_op: !transform.any_op {transform.readonly }) {
228
+ %f = transform.structured.match ops {[" func.func" ]} in %module_op
229
+ : (!transform.any_op ) -> !transform.any_op
230
+ transform.apply_patterns to %f {
231
+ transform.apply_patterns.vector.transfer_permutation_patterns
232
+ } : !transform.any_op
233
+ transform.yield
234
+ }
235
+ }
0 commit comments