Skip to content

Commit 9f858c7

Browse files
authored
[mlir][vector][test] Update tests for vector.xfter_{read|write} (#91943)
Updates tests in "vector-transfer-permutation-lowering.mlir" to make a clearer split into cases for : * xfer_read vs xfer_write * fixed-width vs scalable tests A new test case is added for fixed-width vectors for vector.transfer_read. This is to complement an existing test for scalable vectors. This is in preparation for #90835 and also for adding more tests for scalable vectors.
1 parent 999fb09 commit 9f858c7

File tree

1 file changed

+68
-26
lines changed

1 file changed

+68
-26
lines changed

mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir

Lines changed: 68 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,84 @@
11
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
22

3-
// CHECK-LABEL: func @lower_permutation_with_mask_fixed_width(
3+
///----------------------------------------------------------------------------------------
4+
/// vector.transfer_write
5+
///----------------------------------------------------------------------------------------
6+
/// Input:
7+
/// * vector.transfer_write op with a map which _is not_ the permutation of a
8+
/// minor identity
9+
/// Output:
10+
/// * vector.broadcast + vector.transfer_write with a map which _is_ the permutation of a
11+
/// minor identity
12+
13+
// CHECK-LABEL: func @permutation_with_mask_xfer_write_fixed_width(
414
// CHECK: %[[vec:.*]] = arith.constant dense<-2.000000e+00> : vector<7x1xf32>
515
// CHECK: %[[mask:.*]] = arith.constant dense<[true, false, true, false, true, true, true]> : vector<7xi1>
616
// CHECK: %[[b:.*]] = vector.broadcast %[[mask]] : vector<7xi1> to vector<1x7xi1>
717
// CHECK: %[[tp:.*]] = vector.transpose %[[b]], [1, 0] : vector<1x7xi1> to vector<7x1xi1>
818
// CHECK: vector.transfer_write %[[vec]], %{{.*}}[%{{.*}}, %{{.*}}], %[[tp]] {in_bounds = [false, true]} : vector<7x1xf32>, memref<?x?xf32>
9-
func.func @lower_permutation_with_mask_fixed_width(%A : memref<?x?xf32>, %base1 : index,
10-
%base2 : index) {
19+
func.func @permutation_with_mask_xfer_write_fixed_width(%mem : memref<?x?xf32>, %base1 : index,
20+
%base2 : index) {
21+
1122
%fn1 = arith.constant -2.0 : f32
1223
%vf0 = vector.splat %fn1 : vector<7xf32>
1324
%mask = arith.constant dense<[1, 0, 1, 0, 1, 1, 1]> : vector<7xi1>
14-
vector.transfer_write %vf0, %A[%base1, %base2], %mask
25+
vector.transfer_write %vf0, %mem[%base1, %base2], %mask
1526
{permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [false]}
1627
: vector<7xf32>, memref<?x?xf32>
1728
return
1829
}
1930

20-
// CHECK-LABEL: func.func @permutation_with_mask_scalable(
31+
// CHECK: func.func @permutation_with_mask_xfer_write_scalable(
32+
// CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>,
33+
// CHECK-SAME: %[[ARG_1:.*]]: memref<1x4x?x1xi16>,
34+
// CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>) {
35+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
36+
// CHECK: %[[BCAST_1:.*]] = vector.broadcast %[[ARG_0]] : vector<4x[8]xi16> to vector<1x4x[8]xi16>
37+
// CHECK: %[[BCAST_2:.*]] = vector.broadcast %[[MASK]] : vector<4x[8]xi1> to vector<1x4x[8]xi1>
38+
// CHECK: %[[TRANSPOSE_1:.*]] = vector.transpose %[[BCAST_2]], [1, 2, 0] : vector<1x4x[8]xi1> to vector<4x[8]x1xi1>
39+
// CHECK: %[[TRANSPOSE_2:.*]] = vector.transpose %[[BCAST_1]], [1, 2, 0] : vector<1x4x[8]xi16> to vector<4x[8]x1xi16>
40+
// CHECK: vector.transfer_write %[[TRANSPOSE_2]], %[[ARG_1]]{{.*}}, %[[TRANSPOSE_1]] {in_bounds = [true, true, true]} : vector<4x[8]x1xi16>, memref<1x4x?x1xi16>
41+
func.func @permutation_with_mask_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %mem: memref<1x4x?x1xi16>, %mask: vector<4x[8]xi1>){
42+
%c0 = arith.constant 0 : index
43+
vector.transfer_write %arg0, %mem[%c0, %c0, %c0, %c0], %mask {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
44+
} : vector<4x[8]xi16>, memref<1x4x?x1xi16>
45+
46+
return
47+
}
48+
49+
///----------------------------------------------------------------------------------------
50+
/// vector.transfer_read
51+
///----------------------------------------------------------------------------------------
52+
/// Input:
53+
/// * vector.transfer_read op with a permutation map
54+
/// Output:
55+
/// * vector.transfer_read with a permutation map composed of leading zeros followed by a minor identiy +
56+
/// vector.transpose op
57+
58+
// CHECK-LABEL: func.func @permutation_with_mask_xfer_read_fixed_width(
59+
// CHECK-SAME: %[[ARG_0:.*]]: memref<?x?xf32>,
60+
// CHECK-SAME: %[[IDX_1:.*]]: index,
61+
// CHECK-SAME: %[[IDX_2:.*]]: index) -> vector<8x4x2xf32> {
62+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
63+
// CHECK: %[[PASS_THROUGH:.*]] = arith.constant 0.000000e+00 : f32
64+
// CHECK: %[[MASK:.*]] = vector.create_mask %[[IDX_2]], %[[IDX_1]] : vector<2x4xi1>
65+
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[ARG_0]]{{\[}}%[[C0]], %[[C0]]], %[[PASS_THROUGH]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x4xf32>
66+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[T_READ]] : vector<2x4xf32> to vector<8x2x4xf32>
67+
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[BCAST]], [0, 2, 1] : vector<8x2x4xf32> to vector<8x4x2xf32>
68+
// CHECK: return %[[TRANSPOSE]] : vector<8x4x2xf32>
69+
func.func @permutation_with_mask_xfer_read_fixed_width(%mem: memref<?x?xf32>, %dim_1: index, %dim_2: index) -> (vector<8x4x2xf32>) {
70+
71+
%c0 = arith.constant 0 : index
72+
%cst_0 = arith.constant 0.000000e+00 : f32
73+
74+
%mask = vector.create_mask %dim_2, %dim_1 : vector<2x4xi1>
75+
%1 = vector.transfer_read %mem[%c0, %c0], %cst_0, %mask
76+
{in_bounds = [true, true, true], permutation_map = affine_map<(d0, d1) -> (0, d1, d0)>}
77+
: memref<?x?xf32>, vector<8x4x2xf32>
78+
return %1 : vector<8x4x2xf32>
79+
}
80+
81+
// CHECK-LABEL: func.func @permutation_with_mask_xfer_read_scalable(
2182
// CHECK-SAME: %[[ARG_0:.*]]: memref<?x?xf32>,
2283
// CHECK-SAME: %[[IDX_1:.*]]: index,
2384
// CHECK-SAME: %[[IDX_2:.*]]: index) -> vector<8x[4]x2xf32> {
@@ -28,37 +89,18 @@ func.func @lower_permutation_with_mask_fixed_width(%A : memref<?x?xf32>, %base1
2889
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[T_READ]] : vector<2x[4]xf32> to vector<8x2x[4]xf32>
2990
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[BCAST]], [0, 2, 1] : vector<8x2x[4]xf32> to vector<8x[4]x2xf32>
3091
// CHECK: return %[[TRANSPOSE]] : vector<8x[4]x2xf32>
31-
// CHECK: }
32-
func.func @permutation_with_mask_scalable(%2: memref<?x?xf32>, %dim_1: index, %dim_2: index) -> (vector<8x[4]x2xf32>) {
92+
func.func @permutation_with_mask_xfer_read_scalable(%mem: memref<?x?xf32>, %dim_1: index, %dim_2: index) -> (vector<8x[4]x2xf32>) {
3393

3494
%c0 = arith.constant 0 : index
3595
%cst_0 = arith.constant 0.000000e+00 : f32
3696

3797
%mask = vector.create_mask %dim_2, %dim_1 : vector<2x[4]xi1>
38-
%1 = vector.transfer_read %2[%c0, %c0], %cst_0, %mask
98+
%1 = vector.transfer_read %mem[%c0, %c0], %cst_0, %mask
3999
{in_bounds = [true, true, true], permutation_map = affine_map<(d0, d1) -> (0, d1, d0)>}
40100
: memref<?x?xf32>, vector<8x[4]x2xf32>
41101
return %1 : vector<8x[4]x2xf32>
42102
}
43103

44-
// CHECK: func.func @permutation_with_mask_transfer_write_scalable(
45-
// CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>,
46-
// CHECK-SAME: %[[ARG_1:.*]]: memref<1x4x?x1x1x1x1xi16>,
47-
// CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>) {
48-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
49-
// CHECK: %[[BCAST_1:.*]] = vector.broadcast %[[ARG_0]] : vector<4x[8]xi16> to vector<1x1x1x1x4x[8]xi16>
50-
// CHECK: %[[BCAST_2:.*]] = vector.broadcast %[[MASK]] : vector<4x[8]xi1> to vector<1x1x1x1x4x[8]xi1>
51-
// CHECK: %[[TRANSPOSE_1:.*]] = vector.transpose %[[BCAST_2]], [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi1> to vector<4x[8]x1x1x1x1xi1>
52-
// CHECK: %[[TRANSPOSE_2:.*]] = vector.transpose %[[BCAST_1]], [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi16> to vector<4x[8]x1x1x1x1xi16>
53-
// CHECK: vector.transfer_write %[[TRANSPOSE_2]], %[[ARG_1]]{{\[}}%[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[TRANSPOSE_1]] {in_bounds = [true, true, true, true, true, true]} : vector<4x[8]x1x1x1x1xi16>, memref<1x4x?x1x1x1x1xi16>
54-
// CHECK: return
55-
func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %mask: vector<4x[8]xi1>){
56-
%c0 = arith.constant 0 : index
57-
vector.transfer_write %arg0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %mask {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2)>
58-
} : vector<4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>
59-
60-
return
61-
}
62104
module attributes {transform.with_named_sequence} {
63105
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
64106
%f = transform.structured.match ops{["func.func"]} in %module_op

0 commit comments

Comments
 (0)