Skip to content

Commit 10b6a10

Browse files
banach-spaceAlexisPerry
authored andcommitted
[mlir][vector] Add tests for xfer-permute-lowering (1/n)(nfc) (llvm#95529)
Adds more tests to "vector-transfer-permutation-lowering.mlir", specifically for the `TransferWritePermutationLowering` pattern - such tests seem to be missing ATM. The following edge cases are covered: * plain fixed-width (supported) * scalable vectors with mask (supported) * plain fixed-width, masked (not supported) This is a part of a larger effort to make sure that all key cases for patterns under `populateVectorTransferPermutationMapLoweringPatterns` (*) are tested. I also want to make sure that tests use consistent function and variable names. (*) `transform.apply_patterns.vector.transfer_permutation_patterns` in TD parlance)
1 parent 4700d4e commit 10b6a10

File tree

1 file changed

+83
-6
lines changed

1 file changed

+83
-6
lines changed

mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir

Lines changed: 83 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,87 @@
11
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
22

33
///----------------------------------------------------------------------------------------
4-
/// vector.transfer_write
4+
/// vector.transfer_write -> vector.transpose + vector.transfer_write
5+
/// [Pattern: TransferWritePermutationLowering]
56
///----------------------------------------------------------------------------------------
6-
/// Input:
7-
/// * vector.transfer_write op with a map which _is not_ the permutation of a
8-
/// minor identity
7+
/// Input:
8+
/// * vector.transfer_write op with a permutation that under a transpose
9+
/// _would be_ a minor identity permutation map
910
/// Output:
10-
/// * vector.broadcast + vector.transfer_write with a map which _is_ the permutation of a
11+
/// * vector.transpose + vector.transfer_write with a permutation map which
12+
/// _is_ a minor identity
13+
14+
// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map
15+
// CHECK-SAME: %[[ARG_0:.*]]: vector<4x8xi16>,
16+
// CHECK-SAME: %[[MEM:.*]]: memref<2x2x8x4xi16>) {
17+
// CHECK: %[[TR:.*]] = vector.transpose %[[ARG_0]], [1, 0] : vector<4x8xi16> to vector<8x4xi16>
18+
// CHECK: vector.transfer_write
19+
// CHECK-NOT: permutation_map
20+
// CHECK-SAME: %[[TR]], %[[MEM]]{{.*}} {in_bounds = [true, true]} : vector<8x4xi16>, memref<2x2x8x4xi16>
21+
func.func @xfer_write_transposing_permutation_map(
22+
%arg0: vector<4x8xi16>,
23+
%mem: memref<2x2x8x4xi16>) {
24+
25+
%c0 = arith.constant 0 : index
26+
vector.transfer_write %arg0, %mem[%c0, %c0, %c0, %c0] {
27+
in_bounds = [true, true],
28+
permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
29+
} : vector<4x8xi16>, memref<2x2x8x4xi16>
30+
31+
return
32+
}
33+
34+
// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_with_mask_scalable
35+
// CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>,
36+
// CHECK-SAME: %[[MEM:.*]]: memref<2x2x?x4xi16>,
37+
// CHECK-SAME: %[[MASK:.*]]: vector<[8]x4xi1>) {
38+
// CHECK: %[[TR:.*]] = vector.transpose %[[ARG_0]], [1, 0] : vector<4x[8]xi16> to vector<[8]x4xi16>
39+
// CHECK: vector.transfer_write
40+
// CHECK-NOT: permutation_map
41+
// CHECK-SAME: %[[TR]], %[[MEM]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<[8]x4xi16>, memref<2x2x?x4xi16>
42+
func.func @xfer_write_transposing_permutation_map_with_mask_scalable(
43+
%arg0: vector<4x[8]xi16>,
44+
%mem: memref<2x2x?x4xi16>,
45+
%mask: vector<[8]x4xi1>) {
46+
47+
%c0 = arith.constant 0 : index
48+
vector.transfer_write %arg0, %mem[%c0, %c0, %c0, %c0], %mask {
49+
in_bounds = [true, true],
50+
permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
51+
} : vector<4x[8]xi16>, memref<2x2x?x4xi16>
52+
53+
return
54+
}
55+
56+
// Masked version is not supported
57+
// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_masked
58+
// CHECK-NOT: vector.transpose
59+
func.func @xfer_write_transposing_permutation_map_masked(
60+
%arg0: vector<4x8xi16>,
61+
%mem: memref<2x2x8x4xi16>,
62+
%mask: vector<8x4xi1>) {
63+
64+
%c0 = arith.constant 0 : index
65+
vector.mask %mask {
66+
vector.transfer_write %arg0, %mem[%c0, %c0, %c0, %c0] {
67+
in_bounds = [true, true],
68+
permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
69+
} : vector<4x8xi16>, memref<2x2x8x4xi16>
70+
} : vector<8x4xi1>
71+
72+
return
73+
}
74+
75+
///----------------------------------------------------------------------------------------
76+
/// vector.transfer_write -> vector.broadcast + vector.transpose + vector.transfer_write
77+
/// [Patterns: TransferWriteNonPermutationLowering + TransferWritePermutationLowering]
78+
///----------------------------------------------------------------------------------------
79+
/// Input:
80+
/// * vector.transfer_write op with a map which _is not_ a permutation of a
1181
/// minor identity
82+
/// Output:
83+
/// * vector.broadcast + vector.transpose + vector.transfer_write with a map
84+
/// which _is_ a permutation of a minor identity
1285

1386
// CHECK-LABEL: func @permutation_with_mask_xfer_write_fixed_width(
1487
// CHECK: %[[vec:.*]] = arith.constant dense<-2.000000e+00> : vector<7x1xf32>
@@ -94,7 +167,7 @@ func.func @masked_non_permutation_xfer_write_fixed_width(
94167
///----------------------------------------------------------------------------------------
95168
/// vector.transfer_read
96169
///----------------------------------------------------------------------------------------
97-
/// Input:
170+
/// Input:
98171
/// * vector.transfer_read op with a permutation map
99172
/// Output:
100173
/// * vector.transfer_read with a permutation map composed of leading zeros followed by a minor identiy +
@@ -190,6 +263,10 @@ module attributes {transform.with_named_sequence} {
190263

191264
// -----
192265

266+
///----------------------------------------------------------------------------------------
267+
/// vector.transfer_read
268+
///----------------------------------------------------------------------------------------
269+
/// TODO: Review and categorize
193270

194271
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>
195272
// CHECK: func.func @transfer_read_reduce_rank_scalable(

0 commit comments

Comments
 (0)