Skip to content

Commit 1c85c71

Browse files
authored
[mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (2/n) (#95744)
The main goal of this and subsequent PRs is to unify and categorize tests in: * vector-transfer-flatten.mlir This should make it easier to identify the edge cases being tested (and how they differ), remove duplicates and to add tests for scalable vectors. Below are the main contributions of this PR 1. Two tests duplicated `@transfer_{read|write}_dims_mismatch_non_contiguous_slice`: * `@transfer_{read|write}_dims_mismatch_non_contiguous` and * `@transfer_read_flattenable_negative` duplicated `@transfer_{read|write}_dims_mismatch_non_contiguous_slice`. These tests are removed (the original test is preserved). 2. `@transfer_read_flattenable_negative2` is replaced with two tests with more descriptive names: * `@transfer_read_non_contiguous_src` (for `xfer_read`) and * `@transfer_write_non_contiguous_src` (for `xfer_write`)
1 parent f1075a3 commit 1c85c71

File tree

1 file changed

+44
-72
lines changed

1 file changed

+44
-72
lines changed

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 44 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -131,25 +131,6 @@ func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
131131

132132
// -----
133133

134-
func.func @transfer_read_dims_mismatch_non_contiguous(
135-
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
136-
137-
%c0 = arith.constant 0 : index
138-
%cst = arith.constant 0 : i8
139-
%v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
140-
memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8>
141-
return %v : vector<2x1x2x2xi8>
142-
}
143-
144-
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous
145-
// CHECK-NOT: memref.collapse_shape
146-
// CHECK-NOT: vector.shape_cast
147-
148-
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous(
149-
// CHECK-128B-NOT: memref.collapse_shape
150-
151-
// -----
152-
153134
// The input memref has a dynamic trailing shape and hence is not flattened.
154135
// TODO: This case could be supported via memref.dim
155136

@@ -214,6 +195,28 @@ func.func @transfer_read_0d(
214195

215196
// -----
216197

198+
// Strides make the input memref non-contiguous, hence non-flattenable.
199+
200+
func.func @transfer_read_non_contiguous_src(
201+
%arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
202+
203+
%c0 = arith.constant 0 : index
204+
%cst = arith.constant 0 : i8
205+
%v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
206+
memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
207+
return %v : vector<5x4x3x2xi8>
208+
}
209+
210+
// CHECK-LABEL: func.func @transfer_read_non_contiguous_src
211+
// CHECK-NOT: memref.collapse_shape
212+
// CHECK-NOT: vector.shape_cast
213+
214+
// CHECK-128B-LABEL: func @transfer_read_non_contiguous_src
215+
// CHECK-128B-NOT: memref.collapse_shape
216+
// CHECK-128B-NOT: vector.shape_cast
217+
218+
// -----
219+
217220
///----------------------------------------------------------------------------------------
218221
/// vector.transfer_write
219222
/// [Pattern: FlattenContiguousRowMajorTransferWritePattern]
@@ -342,25 +345,6 @@ func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices(
342345

343346
// -----
344347

345-
func.func @transfer_write_dims_mismatch_non_contiguous(
346-
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
347-
%vec : vector<2x1x2x2xi8>) {
348-
349-
%c0 = arith.constant 0 : index
350-
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
351-
vector<2x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
352-
return
353-
}
354-
355-
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_contiguous
356-
// CHECK-NOT: memref.collapse_shape
357-
// CHECK-NOT: vector.shape_cast
358-
359-
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_contiguous(
360-
// CHECK-128B-NOT: memref.collapse_shape
361-
362-
// -----
363-
364348
// The input memref has a dynamic trailing shape and hence is not flattened.
365349
// TODO: This case could be supported via memref.dim
366350

@@ -427,6 +411,28 @@ func.func @transfer_write_0d(
427411

428412
// -----
429413

414+
// The strides make the input memref non-contiguous, hence non-flattenable.
415+
416+
func.func @transfer_write_non_contiguous_src(
417+
%arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>,
418+
%vec : vector<5x4x3x2xi8>) {
419+
420+
%c0 = arith.constant 0 : index
421+
vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0] :
422+
vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>
423+
return
424+
}
425+
426+
// CHECK-LABEL: func.func @transfer_write_non_contiguous_src
427+
// CHECK-NOT: memref.collapse_shape
428+
// CHECK-NOT: vector.shape_cast
429+
430+
// CHECK-128B-LABEL: func @transfer_write_non_contiguous_src
431+
// CHECK-128B-NOT: memref.collapse_shape
432+
// CHECK-128B-NOT: vector.shape_cast
433+
434+
// -----
435+
430436
///----------------------------------------------------------------------------------------
431437
/// TODO: Categorize + re-format
432438
///----------------------------------------------------------------------------------------
@@ -478,40 +484,6 @@ func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vecto
478484

479485
// -----
480486

481-
func.func @transfer_read_flattenable_negative(
482-
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x2x2x2xi8> {
483-
%c0 = arith.constant 0 : index
484-
%cst = arith.constant 0 : i8
485-
%v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
486-
memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x2x2x2xi8>
487-
return %v : vector<2x2x2x2xi8>
488-
}
489-
490-
// CHECK-LABEL: func @transfer_read_flattenable_negative
491-
// CHECK: vector.transfer_read {{.*}} vector<2x2x2x2xi8>
492-
493-
// CHECK-128B-LABEL: func @transfer_read_flattenable_negative(
494-
// CHECK-128B-NOT: memref.collapse_shape
495-
496-
// -----
497-
498-
func.func @transfer_read_flattenable_negative2(
499-
%arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
500-
%c0 = arith.constant 0 : index
501-
%cst = arith.constant 0 : i8
502-
%v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
503-
memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
504-
return %v : vector<5x4x3x2xi8>
505-
}
506-
507-
// CHECK-LABEL: func @transfer_read_flattenable_negative2
508-
// CHECK: vector.transfer_read {{.*}} vector<5x4x3x2xi8>
509-
510-
// CHECK-128B-LABEL: func @transfer_read_flattenable_negative2(
511-
// CHECK-128B-NOT: memref.collapse_shape
512-
513-
// -----
514-
515487
func.func @fold_unit_dim_add_basic(%arg0 : vector<1x8xi32>) -> vector<1x8xi32> {
516488
%add = arith.addi %arg0, %arg0 : vector<1x8xi32>
517489
return %add : vector<1x8xi32>

0 commit comments

Comments
 (0)