@@ -131,25 +131,6 @@ func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
131
131
132
132
// -----
133
133
134
- func.func @transfer_read_dims_mismatch_non_contiguous (
135
- %arg : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>) -> vector <2 x1 x2 x2 xi8 > {
136
-
137
- %c0 = arith.constant 0 : index
138
- %cst = arith.constant 0 : i8
139
- %v = vector.transfer_read %arg [%c0 , %c0 , %c0 , %c0 ], %cst :
140
- memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>, vector <2 x1 x2 x2 xi8 >
141
- return %v : vector <2 x1 x2 x2 xi8 >
142
- }
143
-
144
- // CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous
145
- // CHECK-NOT: memref.collapse_shape
146
- // CHECK-NOT: vector.shape_cast
147
-
148
- // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous(
149
- // CHECK-128B-NOT: memref.collapse_shape
150
-
151
- // -----
152
-
153
134
// The input memref has a dynamic trailing shape and hence is not flattened.
154
135
// TODO: This case could be supported via memref.dim
155
136
@@ -214,6 +195,28 @@ func.func @transfer_read_0d(
214
195
215
196
// -----
216
197
198
+ // Strides make the input memref non-contiguous, hence non-flattenable.
199
+
200
+ func.func @transfer_read_non_contiguous_src (
201
+ %arg : memref <5 x4 x3 x2 xi8 , strided <[24 , 8 , 2 , 1 ], offset : ?>>) -> vector <5 x4 x3 x2 xi8 > {
202
+
203
+ %c0 = arith.constant 0 : index
204
+ %cst = arith.constant 0 : i8
205
+ %v = vector.transfer_read %arg [%c0 , %c0 , %c0 , %c0 ], %cst :
206
+ memref <5 x4 x3 x2 xi8 , strided <[24 , 8 , 2 , 1 ], offset : ?>>, vector <5 x4 x3 x2 xi8 >
207
+ return %v : vector <5 x4 x3 x2 xi8 >
208
+ }
209
+
210
+ // CHECK-LABEL: func.func @transfer_read_non_contiguous_src
211
+ // CHECK-NOT: memref.collapse_shape
212
+ // CHECK-NOT: vector.shape_cast
213
+
214
+ // CHECK-128B-LABEL: func @transfer_read_non_contiguous_src
215
+ // CHECK-128B-NOT: memref.collapse_shape
216
+ // CHECK-128B-NOT: vector.shape_cast
217
+
218
+ // -----
219
+
217
220
///----------------------------------------------------------------------------------------
218
221
/// vector.transfer_write
219
222
/// [Pattern: FlattenContiguousRowMajorTransferWritePattern]
@@ -342,25 +345,6 @@ func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices(
342
345
343
346
// -----
344
347
345
- func.func @transfer_write_dims_mismatch_non_contiguous (
346
- %arg : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>,
347
- %vec : vector <2 x1 x2 x2 xi8 >) {
348
-
349
- %c0 = arith.constant 0 : index
350
- vector.transfer_write %vec , %arg [%c0 , %c0 , %c0 , %c0 ] :
351
- vector <2 x1 x2 x2 xi8 >, memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>
352
- return
353
- }
354
-
355
- // CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_contiguous
356
- // CHECK-NOT: memref.collapse_shape
357
- // CHECK-NOT: vector.shape_cast
358
-
359
- // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_contiguous(
360
- // CHECK-128B-NOT: memref.collapse_shape
361
-
362
- // -----
363
-
364
348
// The input memref has a dynamic trailing shape and hence is not flattened.
365
349
// TODO: This case could be supported via memref.dim
366
350
@@ -427,6 +411,28 @@ func.func @transfer_write_0d(
427
411
428
412
// -----
429
413
414
+ // The strides make the input memref non-contiguous, hence non-flattenable.
415
+
416
+ func.func @transfer_write_non_contiguous_src (
417
+ %arg : memref <5 x4 x3 x2 xi8 , strided <[24 , 8 , 2 , 1 ], offset : ?>>,
418
+ %vec : vector <5 x4 x3 x2 xi8 >) {
419
+
420
+ %c0 = arith.constant 0 : index
421
+ vector.transfer_write %vec , %arg [%c0 , %c0 , %c0 , %c0 ] :
422
+ vector <5 x4 x3 x2 xi8 >, memref <5 x4 x3 x2 xi8 , strided <[24 , 8 , 2 , 1 ], offset : ?>>
423
+ return
424
+ }
425
+
426
+ // CHECK-LABEL: func.func @transfer_write_non_contiguous_src
427
+ // CHECK-NOT: memref.collapse_shape
428
+ // CHECK-NOT: vector.shape_cast
429
+
430
+ // CHECK-128B-LABEL: func @transfer_write_non_contiguous_src
431
+ // CHECK-128B-NOT: memref.collapse_shape
432
+ // CHECK-128B-NOT: vector.shape_cast
433
+
434
+ // -----
435
+
430
436
///----------------------------------------------------------------------------------------
431
437
/// TODO: Categorize + re-format
432
438
///----------------------------------------------------------------------------------------
@@ -478,40 +484,6 @@ func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vecto
478
484
479
485
// -----
480
486
481
- func.func @transfer_read_flattenable_negative (
482
- %arg : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>) -> vector <2 x2 x2 x2 xi8 > {
483
- %c0 = arith.constant 0 : index
484
- %cst = arith.constant 0 : i8
485
- %v = vector.transfer_read %arg [%c0 , %c0 , %c0 , %c0 ], %cst :
486
- memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>, vector <2 x2 x2 x2 xi8 >
487
- return %v : vector <2 x2 x2 x2 xi8 >
488
- }
489
-
490
- // CHECK-LABEL: func @transfer_read_flattenable_negative
491
- // CHECK: vector.transfer_read {{.*}} vector<2x2x2x2xi8>
492
-
493
- // CHECK-128B-LABEL: func @transfer_read_flattenable_negative(
494
- // CHECK-128B-NOT: memref.collapse_shape
495
-
496
- // -----
497
-
498
- func.func @transfer_read_flattenable_negative2 (
499
- %arg : memref <5 x4 x3 x2 xi8 , strided <[24 , 8 , 2 , 1 ], offset : ?>>) -> vector <5 x4 x3 x2 xi8 > {
500
- %c0 = arith.constant 0 : index
501
- %cst = arith.constant 0 : i8
502
- %v = vector.transfer_read %arg [%c0 , %c0 , %c0 , %c0 ], %cst :
503
- memref <5 x4 x3 x2 xi8 , strided <[24 , 8 , 2 , 1 ], offset : ?>>, vector <5 x4 x3 x2 xi8 >
504
- return %v : vector <5 x4 x3 x2 xi8 >
505
- }
506
-
507
- // CHECK-LABEL: func @transfer_read_flattenable_negative2
508
- // CHECK: vector.transfer_read {{.*}} vector<5x4x3x2xi8>
509
-
510
- // CHECK-128B-LABEL: func @transfer_read_flattenable_negative2(
511
- // CHECK-128B-NOT: memref.collapse_shape
512
-
513
- // -----
514
-
515
487
func.func @fold_unit_dim_add_basic (%arg0 : vector <1 x8 xi32 >) -> vector <1 x8 xi32 > {
516
488
%add = arith.addi %arg0 , %arg0 : vector <1 x8 xi32 >
517
489
return %add : vector <1 x8 xi32 >
0 commit comments