@@ -110,31 +110,64 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
110
110
111
111
func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices (
112
112
%arg : memref <1 x3 x3 x2 xf32 , strided <[40 , 10 , 2 , 1 ], offset : ?>>,
113
- %idx0 : index ,
114
- %idx1 : index ) -> vector <2 x2 xf32 > {
113
+ %idx_1 : index ,
114
+ %idx_2 : index ) -> vector <2 x2 xf32 > {
115
115
116
116
%c0 = arith.constant 0 : index
117
117
%cst_1 = arith.constant 0.000000e+00 : f32
118
- %8 = vector.transfer_read %arg [%c0 , %idx0 , %idx1 , %c0 ], %cst_1 {in_bounds = [true , true ]} :
118
+ %8 = vector.transfer_read %arg [%c0 , %idx_1 , %idx_2 , %c0 ], %cst_1 {in_bounds = [true , true ]} :
119
119
memref <1 x3 x3 x2 xf32 , strided <[40 , 10 , 2 , 1 ], offset : ?>>, vector <2 x2 xf32 >
120
120
return %8 : vector <2 x2 xf32 >
121
121
}
122
122
123
123
// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
124
124
125
125
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
126
- // CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
126
+ // CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]]
127
+ // CHECK-SAME: : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
127
128
// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]()
128
129
129
130
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
130
131
// CHECK-128B: memref.collapse_shape
131
132
132
133
// -----
133
134
134
- // The input memref has a dynamic trailing shape and hence is not flattened.
135
- // TODO: This case could be supported via memref.dim
135
+ // The leading dynamic shapes don't affect whether this example is flattenable
136
+ // or not. Indeed, those dynamic shapes are not candidates for flattening anyway.
136
137
137
- func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes (
138
+ func.func @transfer_read_leading_dynamic_dims (
139
+ %arg : memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>,
140
+ %idx_1 : index ,
141
+ %idx_2 : index ) -> vector <8 x4 xi8 > {
142
+
143
+ %c0_i8 = arith.constant 0 : i8
144
+ %c0 = arith.constant 0 : index
145
+ %result = vector.transfer_read %arg [%idx_1 , %idx_2 , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} :
146
+ memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>, vector <8 x4 xi8 >
147
+ return %result : vector <8 x4 xi8 >
148
+ }
149
+
150
+ // CHECK-LABEL: func @transfer_read_leading_dynamic_dims
151
+ // CHECK-SAME: %[[ARG0:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index
152
+ // CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
153
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
154
+ // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
155
+ // CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
156
+ // CHECK: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
157
+ // CHECK-SAME: [%[[ARG1]], %[[ARG2]], %[[C0]]], %[[C0_I8]]
158
+ // CHECK-SAME: {in_bounds = [true]}
159
+ // CHECK-SAME: : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
160
+ // CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
161
+ // CHECK: return %[[VEC2D]] : vector<8x4xi8>
162
+
163
+ // CHECK-128B-LABEL: func @transfer_read_leading_dynamic_dims
164
+ // CHECK-128B: memref.collapse_shape
165
+
166
+ // -----
167
+
168
+ // One of the dims to be flattened is dynamic - not supported ATM.
169
+
170
+ func.func @negative_transfer_read_dynamic_dim_to_flatten (
138
171
%idx_1: index ,
139
172
%idx_2: index ,
140
173
%m_in: memref <1 x?x4 x6 xi32 >) -> vector <1 x2 x6 xi32 > {
@@ -146,11 +179,11 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
146
179
return %v : vector <1 x2 x6 xi32 >
147
180
}
148
181
149
- // CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
182
+ // CHECK-LABEL: func.func @negative_transfer_read_dynamic_dim_to_flatten
150
183
// CHECK-NOT: memref.collapse_shape
151
184
// CHECK-NOT: vector.shape_cast
152
185
153
- // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
186
+ // CHECK-128B-LABEL: func @negative_transfer_read_dynamic_dim_to_flatten
154
187
// CHECK-128B-NOT: memref.collapse_shape
155
188
156
189
// -----
@@ -326,11 +359,11 @@ func.func @transfer_write_dims_mismatch_non_zero_indices(
326
359
func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices (
327
360
%value : vector <2 x2 xf32 >,
328
361
%subview : memref <1 x3 x3 x2 xf32 , strided <[40 , 10 , 2 , 1 ], offset : ?>>,
329
- %idx0 : index ,
330
- %idx1 : index ) {
362
+ %idx_1 : index ,
363
+ %idx_2 : index ) {
331
364
332
365
%c0 = arith.constant 0 : index
333
- vector.transfer_write %value , %subview [%c0 , %idx0 , %idx1 , %c0 ] {in_bounds = [true , true ]} : vector <2 x2 xf32 >, memref <1 x3 x3 x2 xf32 , strided <[40 , 10 , 2 , 1 ], offset : ?>>
366
+ vector.transfer_write %value , %subview [%c0 , %idx_1 , %idx_2 , %c0 ] {in_bounds = [true , true ]} : vector <2 x2 xf32 >, memref <1 x3 x3 x2 xf32 , strided <[40 , 10 , 2 , 1 ], offset : ?>>
334
367
return
335
368
}
336
369
@@ -345,10 +378,40 @@ func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices(
345
378
346
379
// -----
347
380
348
- // The input memref has a dynamic trailing shape and hence is not flattened.
349
- // TODO: This case could be supported via memref.dim
381
+ // The leading dynamic shapes don't affect whether this example is flattenable
382
+ // or not. Indeed, those dynamic shapes are not candidates for flattening anyway.
383
+
384
+ func.func @transfer_write_leading_dynamic_dims (
385
+ %vec : vector <8 x4 xi8 >,
386
+ %arg : memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>,
387
+ %idx_1 : index ,
388
+ %idx_2 : index ) {
389
+
390
+ %c0 = arith.constant 0 : index
391
+ vector.transfer_write %vec , %arg [%idx_1 , %idx_2 , %c0 , %c0 ] {in_bounds = [true , true ]} :
392
+ vector <8 x4 xi8 >, memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>
393
+ return
394
+ }
395
+
396
+ // CHECK-LABEL: func @transfer_write_leading_dynamic_dims
397
+ // CHECK-SAME: %[[ARG0:.+]]: vector<8x4xi8>, %[[ARG1:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
398
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
399
+ // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
400
+ // CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
401
+ // CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[ARG0]] : vector<8x4xi8> to vector<32xi8>
402
+ // CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
403
+ // CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]]
404
+ // CHECK-SAME: {in_bounds = [true]}
405
+ // CHECK-SAME: : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
406
+
407
+ // CHECK-128B-LABEL: func @transfer_write_leading_dynamic_dims
408
+ // CHECK-128B: memref.collapse_shape
350
409
351
- func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes (
410
+ // -----
411
+
412
+ // One of the dims to be flattened is dynamic - not supported ATM.
413
+
414
+ func.func @negative_transfer_write_dynamic_to_flatten (
352
415
%idx_1: index ,
353
416
%idx_2: index ,
354
417
%vec : vector <1 x2 x6 xi32 >,
@@ -361,11 +424,11 @@ func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
361
424
return
362
425
}
363
426
364
- // CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
427
+ // CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten
365
428
// CHECK-NOT: memref.collapse_shape
366
429
// CHECK-NOT: vector.shape_cast
367
430
368
- // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
431
+ // CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten
369
432
// CHECK-128B-NOT: memref.collapse_shape
370
433
371
434
// -----
@@ -434,56 +497,10 @@ func.func @transfer_write_non_contiguous_src(
434
497
// -----
435
498
436
499
///----------------------------------------------------------------------------------------
437
- /// TODO: Categorize + re-format
500
+ /// [Pattern: DropUnitDimFromElementwiseOps]
501
+ /// TODO: Move to a dedicated file - there's no "flattening" in the following tests
438
502
///----------------------------------------------------------------------------------------
439
503
440
- func.func @transfer_read_flattenable_with_dynamic_dims_and_indices (%arg0 : memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>, %arg1 : index , %arg2 : index ) -> vector <8 x4 xi8 > {
441
- %c0_i8 = arith.constant 0 : i8
442
- %c0 = arith.constant 0 : index
443
- %result = vector.transfer_read %arg0 [%arg1 , %arg2 , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>, vector <8 x4 xi8 >
444
- return %result : vector <8 x4 xi8 >
445
- }
446
-
447
- // CHECK-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices
448
- // CHECK-SAME: %[[ARG0:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index
449
- // CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
450
- // CHECK: %[[C0:.+]] = arith.constant 0 : index
451
- // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
452
- // CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
453
- // CHECK: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
454
- // CHECK-SAME: [%[[ARG1]], %[[ARG2]], %[[C0]]], %[[C0_I8]]
455
- // CHECK-SAME: {in_bounds = [true]}
456
- // CHECK-SAME: : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
457
- // CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
458
- // CHECK: return %[[VEC2D]] : vector<8x4xi8>
459
-
460
- // CHECK-128B-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices(
461
- // CHECK-128B: memref.collapse_shape
462
-
463
- // -----
464
-
465
- func.func @transfer_write_flattenable_with_dynamic_dims_and_indices (%vec : vector <8 x4 xi8 >, %dst : memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>, %arg1 : index , %arg2 : index ) {
466
- %c0 = arith.constant 0 : index
467
- vector.transfer_write %vec , %dst [%arg1 , %arg2 , %c0 , %c0 ] {in_bounds = [true , true ]} : vector <8 x4 xi8 >, memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>
468
- return
469
- }
470
-
471
- // CHECK-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices
472
- // CHECK-SAME: %[[ARG0:.+]]: vector<8x4xi8>, %[[ARG1:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
473
- // CHECK: %[[C0:.+]] = arith.constant 0 : index
474
- // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
475
- // CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
476
- // CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[ARG0]] : vector<8x4xi8> to vector<32xi8>
477
- // CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
478
- // CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]]
479
- // CHECK-SAME: {in_bounds = [true]}
480
- // CHECK-SAME: : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
481
-
482
- // CHECK-128B-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices(
483
- // CHECK-128B: memref.collapse_shape
484
-
485
- // -----
486
-
487
504
func.func @fold_unit_dim_add_basic (%arg0 : vector <1 x8 xi32 >) -> vector <1 x8 xi32 > {
488
505
%add = arith.addi %arg0 , %arg0 : vector <1 x8 xi32 >
489
506
return %add : vector <1 x8 xi32 >
0 commit comments