Skip to content

Commit 9be1956

Browse files
committed
[mlir][vector] Add tests xfer-permute-lowering (nfc)(2/n)
Adds more tests to: * vector-transfer-permutation-lowering.mlir Specifically, adds tests for: * out-of-bounds access for the `TransferWritePermutationLowering` pattern * in-bounds access for `TransferWriteNonPermutationLowering` + `TransferWritePermutationLowering` Also renames `@permutation_with_mask_xfer_write_fixed_width` as `@xfer_write_non_transposing_permutation_map`. This is a part of a larger effort to make sure that all key cases for patterns under populateVectorTransferPermutationMapLoweringPatterns (*) are tested. I also want to make sure that tests use consistent function and variable names. (*) transform.apply_patterns.vector.transfer_permutation_patterns in TD parlance) Depends on #96031
1 parent 2ba3fe7 commit 9be1956

File tree

1 file changed

+61
-13
lines changed

1 file changed

+61
-13
lines changed

mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
22

3+
// TODO: Replace %arg0 with %vec
4+
35
///----------------------------------------------------------------------------------------
46
/// vector.transfer_write -> vector.transpose + vector.transfer_write
57
/// [Pattern: TransferWritePermutationLowering]
@@ -31,6 +33,27 @@ func.func @xfer_write_transposing_permutation_map(
3133
return
3234
}
3335

36+
// Even with out-of-bounds, it is safe to apply this pattern
37+
// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_out_of_bounds
38+
// CHECK-SAME: %[[ARG_0:.*]]: vector<4x8xi16>,
39+
// CHECK-SAME: %[[MEM:.*]]: memref<2x2x?x?xi16>) {
40+
// CHECK: %[[TR:.*]] = vector.transpose %[[ARG_0]], [1, 0] : vector<4x8xi16> to vector<8x4xi16>
41+
// CHECK: vector.transfer_write
42+
// CHECK-NOT: permutation_map
43+
// CHECK-SAME: %[[TR]], %[[MEM]]{{.*}} {in_bounds = [false, false]} : vector<8x4xi16>, memref<2x2x?x?xi16>
44+
func.func @xfer_write_transposing_permutation_map_out_of_bounds(
45+
%arg0: vector<4x8xi16>,
46+
%mem: memref<2x2x?x?xi16>) {
47+
48+
%c0 = arith.constant 0 : index
49+
vector.transfer_write %arg0, %mem[%c0, %c0, %c0, %c0] {
50+
in_bounds = [false, false],
51+
permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
52+
} : vector<4x8xi16>, memref<2x2x?x?xi16>
53+
54+
return
55+
}
56+
3457
// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_with_mask_scalable
3558
// CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>,
3659
// CHECK-SAME: %[[MEM:.*]]: memref<2x2x?x4xi16>,
@@ -83,19 +106,44 @@ func.func @xfer_write_transposing_permutation_map_masked(
83106
/// * vector.broadcast + vector.transpose + vector.transfer_write with a map
84107
/// which _is_ a permutation of a minor identity
85108

86-
// CHECK-LABEL: func @permutation_with_mask_xfer_write_fixed_width(
87-
// CHECK: %[[vec:.*]] = arith.constant dense<-2.000000e+00> : vector<7x1xf32>
88-
// CHECK: %[[mask:.*]] = arith.constant dense<[true, false, true, false, true, true, true]> : vector<7xi1>
89-
// CHECK: %[[b:.*]] = vector.broadcast %[[mask]] : vector<7xi1> to vector<1x7xi1>
90-
// CHECK: %[[tp:.*]] = vector.transpose %[[b]], [1, 0] : vector<1x7xi1> to vector<7x1xi1>
91-
// CHECK: vector.transfer_write %[[vec]], %{{.*}}[%{{.*}}, %{{.*}}], %[[tp]] {in_bounds = [false, true]} : vector<7x1xf32>, memref<?x?xf32>
92-
func.func @permutation_with_mask_xfer_write_fixed_width(%mem : memref<?x?xf32>, %base1 : index,
93-
%base2 : index) {
94-
95-
%fn1 = arith.constant -2.0 : f32
96-
%vf0 = vector.splat %fn1 : vector<7xf32>
97-
%mask = arith.constant dense<[1, 0, 1, 0, 1, 1, 1]> : vector<7xi1>
98-
vector.transfer_write %vf0, %mem[%base1, %base2], %mask
109+
// CHECK-LABEL: func.func @xfer_write_non_transposing_permutation_map(
110+
// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
111+
// CHECK-SAME: %[[VEC:.*]]: vector<7xf32>,
112+
// CHECK-SAME: %[[BASE_1:.*]]: index, %[[BASE_2:.*]]: index) {
113+
// CHECK: %[[BC:.*]] = vector.broadcast %[[VEC]] : vector<7xf32> to vector<1x7xf32>
114+
// CHECK: %[[TR:.*]] = vector.transpose %[[BC]], [1, 0] : vector<1x7xf32> to vector<7x1xf32>
115+
// CHECK: vector.transfer_write %[[TR]], %[[MEM]]{{\[}}%[[BASE_1]], %[[BASE_2]]] {in_bounds = [false, true]} : vector<7x1xf32>, memref<?x?xf32>
116+
func.func @xfer_write_non_transposing_permutation_map(
117+
%mem : memref<?x?xf32>,
118+
%arg0 : vector<7xf32>,
119+
%base1 : index,
120+
%base2 : index) {
121+
122+
vector.transfer_write %arg0, %mem[%base1, %base2]
123+
{permutation_map = affine_map<(d0, d1) -> (d0)>}
124+
: vector<7xf32>, memref<?x?xf32>
125+
return
126+
}
127+
128+
// The broadcast dimension is in bounds, so the transformation is safe
129+
// CHECK-LABEL: func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds(
130+
// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
131+
// CHECK-SAME: %[[VEC:.*]]: vector<7xf32>,
132+
// CHECK-SAME: %[[BASE_1:.*]]: index, %[[BASE_2:.*]]: index,
133+
// CHECK-SAME: %[[MASK:.*]]: vector<7xi1>) {
134+
// CHECK: %[[BC_VEC:.*]] = vector.broadcast %[[VEC]] : vector<7xf32> to vector<1x7xf32>
135+
// CHECK: %[[BC_MASK:.*]] = vector.broadcast %[[MASK]] : vector<7xi1> to vector<1x7xi1>
136+
// CHECK: %[[TR_MASK:.*]] = vector.transpose %[[BC_MASK]], [1, 0] : vector<1x7xi1> to vector<7x1xi1>
137+
// CHECK: %[[TR_VEC:.*]] = vector.transpose %[[BC_VEC]], [1, 0] : vector<1x7xf32> to vector<7x1xf32>
138+
// CHECK: vector.transfer_write %[[TR_VEC]], %[[MEM]]{{\[}}%[[BASE_1]], %[[BASE_2]]], %[[TR_MASK]] {in_bounds = [false, true]} : vector<7x1xf32>, memref<?x?xf32>
139+
func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds(
140+
%mem : memref<?x?xf32>,
141+
%arg0 : vector<7xf32>,
142+
%base1 : index,
143+
%base2 : index,
144+
%mask : vector<7xi1>) {
145+
146+
vector.transfer_write %arg0, %mem[%base1, %base2], %mask
99147
{permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [false]}
100148
: vector<7xf32>, memref<?x?xf32>
101149
return

0 commit comments

Comments
 (0)