Skip to content

Commit 85e7428

Browse files
authored
[mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (3/n) (llvm#95745)
The main goal of this and subsequent PRs is to unify and categorize tests in: * vector-transfer-flatten.mlir This should make it easier to identify the edge cases being tested (and how they differ), remove duplicates and to add tests for scalable vectors. The main contributions of this PR: 1. For consistency with other tests, `@transfer_read_flattenable_with_dynamic_dims_and_indices` is renamed as `@transfer_read_leading_dynamic_dims`. It is also moved near other tests for `xfer_read`, variable names are updated to match other `xfer_read` tests 2. `@transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim` is renamed as `@negative_transfer_read_dynamic_dim_to_flatten` to better highlight that it's a negative test and to contrast it with `@transfer_read_leading_dynamic_dims` (and to emphasise the difference between the two). 3. Similar changes for tests for `xfer_write`. 4. Make sure that we consistently use `%idx_N` (as opposed to `%idxN`). Follow-up for llvm#95743 and llvm#95744
1 parent 46223b5 commit 85e7428

File tree

1 file changed

+82
-65
lines changed

1 file changed

+82
-65
lines changed

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 82 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -110,31 +110,64 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
110110

111111
func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
112112
%arg : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
113-
%idx0 : index,
114-
%idx1 : index) -> vector<2x2xf32> {
113+
%idx_1 : index,
114+
%idx_2 : index) -> vector<2x2xf32> {
115115

116116
%c0 = arith.constant 0 : index
117117
%cst_1 = arith.constant 0.000000e+00 : f32
118-
%8 = vector.transfer_read %arg[%c0, %idx0, %idx1, %c0], %cst_1 {in_bounds = [true, true]} :
118+
%8 = vector.transfer_read %arg[%c0, %idx_1, %idx_2, %c0], %cst_1 {in_bounds = [true, true]} :
119119
memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32>
120120
return %8 : vector<2x2xf32>
121121
}
122122

123123
// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
124124

125125
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
126-
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
126+
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]]
127+
// CHECK-SAME: : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
127128
// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]()
128129

129130
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
130131
// CHECK-128B: memref.collapse_shape
131132

132133
// -----
133134

134-
// The input memref has a dynamic trailing shape and hence is not flattened.
135-
// TODO: This case could be supported via memref.dim
135+
// The leading dynamic shapes don't affect whether this example is flattenable
136+
// or not. Indeed, those dynamic shapes are not candidates for flattening anyway.
136137

137-
func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
138+
func.func @transfer_read_leading_dynamic_dims(
139+
%arg : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>,
140+
%idx_1 : index,
141+
%idx_2 : index) -> vector<8x4xi8> {
142+
143+
%c0_i8 = arith.constant 0 : i8
144+
%c0 = arith.constant 0 : index
145+
%result = vector.transfer_read %arg[%idx_1, %idx_2, %c0, %c0], %c0_i8 {in_bounds = [true, true]} :
146+
memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, vector<8x4xi8>
147+
return %result : vector<8x4xi8>
148+
}
149+
150+
// CHECK-LABEL: func @transfer_read_leading_dynamic_dims
151+
// CHECK-SAME: %[[ARG0:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index
152+
// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
153+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
154+
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
155+
// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
156+
// CHECK: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
157+
// CHECK-SAME: [%[[ARG1]], %[[ARG2]], %[[C0]]], %[[C0_I8]]
158+
// CHECK-SAME: {in_bounds = [true]}
159+
// CHECK-SAME: : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
160+
// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
161+
// CHECK: return %[[VEC2D]] : vector<8x4xi8>
162+
163+
// CHECK-128B-LABEL: func @transfer_read_leading_dynamic_dims
164+
// CHECK-128B: memref.collapse_shape
165+
166+
// -----
167+
168+
// One of the dims to be flattened is dynamic - not supported ATM.
169+
170+
func.func @negative_transfer_read_dynamic_dim_to_flatten(
138171
%idx_1: index,
139172
%idx_2: index,
140173
%m_in: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
@@ -146,11 +179,11 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
146179
return %v : vector<1x2x6xi32>
147180
}
148181

149-
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
182+
// CHECK-LABEL: func.func @negative_transfer_read_dynamic_dim_to_flatten
150183
// CHECK-NOT: memref.collapse_shape
151184
// CHECK-NOT: vector.shape_cast
152185

153-
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
186+
// CHECK-128B-LABEL: func @negative_transfer_read_dynamic_dim_to_flatten
154187
// CHECK-128B-NOT: memref.collapse_shape
155188

156189
// -----
@@ -326,11 +359,11 @@ func.func @transfer_write_dims_mismatch_non_zero_indices(
326359
func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices(
327360
%value : vector<2x2xf32>,
328361
%subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
329-
%idx0 : index,
330-
%idx1 : index) {
362+
%idx_1 : index,
363+
%idx_2 : index) {
331364

332365
%c0 = arith.constant 0 : index
333-
vector.transfer_write %value, %subview[%c0, %idx0, %idx1, %c0] {in_bounds = [true, true]} : vector<2x2xf32>, memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>
366+
vector.transfer_write %value, %subview[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true]} : vector<2x2xf32>, memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>
334367
return
335368
}
336369

@@ -345,10 +378,40 @@ func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices(
345378

346379
// -----
347380

348-
// The input memref has a dynamic trailing shape and hence is not flattened.
349-
// TODO: This case could be supported via memref.dim
381+
// The leading dynamic shapes don't affect whether this example is flattenable
382+
// or not. Indeed, those dynamic shapes are not candidates for flattening anyway.
383+
384+
func.func @transfer_write_leading_dynamic_dims(
385+
%vec : vector<8x4xi8>,
386+
%arg : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>,
387+
%idx_1 : index,
388+
%idx_2 : index) {
389+
390+
%c0 = arith.constant 0 : index
391+
vector.transfer_write %vec, %arg[%idx_1, %idx_2, %c0, %c0] {in_bounds = [true, true]} :
392+
vector<8x4xi8>, memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>
393+
return
394+
}
395+
396+
// CHECK-LABEL: func @transfer_write_leading_dynamic_dims
397+
// CHECK-SAME: %[[ARG0:.+]]: vector<8x4xi8>, %[[ARG1:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
398+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
399+
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
400+
// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
401+
// CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[ARG0]] : vector<8x4xi8> to vector<32xi8>
402+
// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
403+
// CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]]
404+
// CHECK-SAME: {in_bounds = [true]}
405+
// CHECK-SAME: : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
406+
407+
// CHECK-128B-LABEL: func @transfer_write_leading_dynamic_dims
408+
// CHECK-128B: memref.collapse_shape
350409

351-
func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
410+
// -----
411+
412+
// One of the dims to be flattened is dynamic - not supported ATM.
413+
414+
func.func @negative_transfer_write_dynamic_to_flatten(
352415
%idx_1: index,
353416
%idx_2: index,
354417
%vec : vector<1x2x6xi32>,
@@ -361,11 +424,11 @@ func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
361424
return
362425
}
363426

364-
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
427+
// CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten
365428
// CHECK-NOT: memref.collapse_shape
366429
// CHECK-NOT: vector.shape_cast
367430

368-
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
431+
// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten
369432
// CHECK-128B-NOT: memref.collapse_shape
370433

371434
// -----
@@ -434,56 +497,10 @@ func.func @transfer_write_non_contiguous_src(
434497
// -----
435498

436499
///----------------------------------------------------------------------------------------
437-
/// TODO: Categorize + re-format
500+
/// [Pattern: DropUnitDimFromElementwiseOps]
501+
/// TODO: Move to a dedicated file - there's no "flattening" in the following tests
438502
///----------------------------------------------------------------------------------------
439503

440-
func.func @transfer_read_flattenable_with_dynamic_dims_and_indices(%arg0 : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, %arg1 : index, %arg2 : index) -> vector<8x4xi8> {
441-
%c0_i8 = arith.constant 0 : i8
442-
%c0 = arith.constant 0 : index
443-
%result = vector.transfer_read %arg0[%arg1, %arg2, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, vector<8x4xi8>
444-
return %result : vector<8x4xi8>
445-
}
446-
447-
// CHECK-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices
448-
// CHECK-SAME: %[[ARG0:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index
449-
// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
450-
// CHECK: %[[C0:.+]] = arith.constant 0 : index
451-
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
452-
// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
453-
// CHECK: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
454-
// CHECK-SAME: [%[[ARG1]], %[[ARG2]], %[[C0]]], %[[C0_I8]]
455-
// CHECK-SAME: {in_bounds = [true]}
456-
// CHECK-SAME: : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
457-
// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
458-
// CHECK: return %[[VEC2D]] : vector<8x4xi8>
459-
460-
// CHECK-128B-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices(
461-
// CHECK-128B: memref.collapse_shape
462-
463-
// -----
464-
465-
func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vector<8x4xi8>, %dst : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, %arg1 : index, %arg2 : index) {
466-
%c0 = arith.constant 0 : index
467-
vector.transfer_write %vec, %dst[%arg1, %arg2, %c0, %c0] {in_bounds = [true, true]} : vector<8x4xi8>, memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>
468-
return
469-
}
470-
471-
// CHECK-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices
472-
// CHECK-SAME: %[[ARG0:.+]]: vector<8x4xi8>, %[[ARG1:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
473-
// CHECK: %[[C0:.+]] = arith.constant 0 : index
474-
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
475-
// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
476-
// CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[ARG0]] : vector<8x4xi8> to vector<32xi8>
477-
// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
478-
// CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]]
479-
// CHECK-SAME: {in_bounds = [true]}
480-
// CHECK-SAME: : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
481-
482-
// CHECK-128B-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices(
483-
// CHECK-128B: memref.collapse_shape
484-
485-
// -----
486-
487504
func.func @fold_unit_dim_add_basic(%arg0 : vector<1x8xi32>) -> vector<1x8xi32> {
488505
%add = arith.addi %arg0, %arg0 : vector<1x8xi32>
489506
return %add : vector<1x8xi32>

0 commit comments

Comments
 (0)