|
1 | 1 | // RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
|
2 | 2 |
|
| 3 | +// TODO: Replace %arg0 with %vec |
| 4 | + |
3 | 5 | ///----------------------------------------------------------------------------------------
|
4 | 6 | /// vector.transfer_write -> vector.transpose + vector.transfer_write
|
5 | 7 | /// [Pattern: TransferWritePermutationLowering]
|
@@ -31,6 +33,27 @@ func.func @xfer_write_transposing_permutation_map(
|
31 | 33 | return
|
32 | 34 | }
|
33 | 35 |
|
| 36 | +// Even with out-of-bounds, it is safe to apply this pattern |
| 37 | +// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_out_of_bounds |
| 38 | +// CHECK-SAME: %[[ARG_0:.*]]: vector<4x8xi16>, |
| 39 | +// CHECK-SAME: %[[MEM:.*]]: memref<2x2x?x?xi16>) { |
| 40 | +// CHECK: %[[TR:.*]] = vector.transpose %[[ARG_0]], [1, 0] : vector<4x8xi16> to vector<8x4xi16> |
| 41 | +// CHECK: vector.transfer_write |
| 42 | +// CHECK-NOT: permutation_map |
| 43 | +// CHECK-SAME: %[[TR]], %[[MEM]]{{.*}} {in_bounds = [false, false]} : vector<8x4xi16>, memref<2x2x?x?xi16> |
| 44 | +func.func @xfer_write_transposing_permutation_map_out_of_bounds( |
| 45 | + %arg0: vector<4x8xi16>, |
| 46 | + %mem: memref<2x2x?x?xi16>) { |
| 47 | + |
| 48 | + %c0 = arith.constant 0 : index |
| 49 | + vector.transfer_write %arg0, %mem[%c0, %c0, %c0, %c0] { |
| 50 | + in_bounds = [false, false], |
| 51 | + permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)> |
| 52 | + } : vector<4x8xi16>, memref<2x2x?x?xi16> |
| 53 | + |
| 54 | + return |
| 55 | +} |
| 56 | + |
34 | 57 | // CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_with_mask_scalable
|
35 | 58 | // CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>,
|
36 | 59 | // CHECK-SAME: %[[MEM:.*]]: memref<2x2x?x4xi16>,
|
@@ -83,19 +106,44 @@ func.func @xfer_write_transposing_permutation_map_masked(
|
83 | 106 | /// * vector.broadcast + vector.transpose + vector.transfer_write with a map
|
84 | 107 | /// which _is_ a permutation of a minor identity
|
85 | 108 |
|
86 |
| -// CHECK-LABEL: func @permutation_with_mask_xfer_write_fixed_width( |
87 |
| -// CHECK: %[[vec:.*]] = arith.constant dense<-2.000000e+00> : vector<7x1xf32> |
88 |
| -// CHECK: %[[mask:.*]] = arith.constant dense<[true, false, true, false, true, true, true]> : vector<7xi1> |
89 |
| -// CHECK: %[[b:.*]] = vector.broadcast %[[mask]] : vector<7xi1> to vector<1x7xi1> |
90 |
| -// CHECK: %[[tp:.*]] = vector.transpose %[[b]], [1, 0] : vector<1x7xi1> to vector<7x1xi1> |
91 |
| -// CHECK: vector.transfer_write %[[vec]], %{{.*}}[%{{.*}}, %{{.*}}], %[[tp]] {in_bounds = [false, true]} : vector<7x1xf32>, memref<?x?xf32> |
92 |
| -func.func @permutation_with_mask_xfer_write_fixed_width(%mem : memref<?x?xf32>, %base1 : index, |
93 |
| - %base2 : index) { |
94 |
| - |
95 |
| - %fn1 = arith.constant -2.0 : f32 |
96 |
| - %vf0 = vector.splat %fn1 : vector<7xf32> |
97 |
| - %mask = arith.constant dense<[1, 0, 1, 0, 1, 1, 1]> : vector<7xi1> |
98 |
| - vector.transfer_write %vf0, %mem[%base1, %base2], %mask |
| 109 | +// CHECK-LABEL: func.func @xfer_write_non_transposing_permutation_map( |
| 110 | +// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>, |
| 111 | +// CHECK-SAME: %[[VEC:.*]]: vector<7xf32>, |
| 112 | +// CHECK-SAME: %[[BASE_1:.*]]: index, %[[BASE_2:.*]]: index) { |
| 113 | +// CHECK: %[[BC:.*]] = vector.broadcast %[[VEC]] : vector<7xf32> to vector<1x7xf32> |
| 114 | +// CHECK: %[[TR:.*]] = vector.transpose %[[BC]], [1, 0] : vector<1x7xf32> to vector<7x1xf32> |
| 115 | +// CHECK: vector.transfer_write %[[TR]], %[[MEM]]{{\[}}%[[BASE_1]], %[[BASE_2]]] {in_bounds = [false, true]} : vector<7x1xf32>, memref<?x?xf32> |
| 116 | +func.func @xfer_write_non_transposing_permutation_map( |
| 117 | + %mem : memref<?x?xf32>, |
| 118 | + %arg0 : vector<7xf32>, |
| 119 | + %base1 : index, |
| 120 | + %base2 : index) { |
| 121 | + |
| 122 | + vector.transfer_write %arg0, %mem[%base1, %base2] |
| 123 | + {permutation_map = affine_map<(d0, d1) -> (d0)>} |
| 124 | + : vector<7xf32>, memref<?x?xf32> |
| 125 | + return |
| 126 | +} |
| 127 | + |
| 128 | +// The broadcast dimension is in bounds, so the transformation is safe |
| 129 | +// CHECK-LABEL: func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds( |
| 130 | +// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>, |
| 131 | +// CHECK-SAME: %[[VEC:.*]]: vector<7xf32>, |
| 132 | +// CHECK-SAME: %[[BASE_1:.*]]: index, %[[BASE_2:.*]]: index, |
| 133 | +// CHECK-SAME: %[[MASK:.*]]: vector<7xi1>) { |
| 134 | +// CHECK: %[[BC_VEC:.*]] = vector.broadcast %[[VEC]] : vector<7xf32> to vector<1x7xf32> |
| 135 | +// CHECK: %[[BC_MASK:.*]] = vector.broadcast %[[MASK]] : vector<7xi1> to vector<1x7xi1> |
| 136 | +// CHECK: %[[TR_MASK:.*]] = vector.transpose %[[BC_MASK]], [1, 0] : vector<1x7xi1> to vector<7x1xi1> |
| 137 | +// CHECK: %[[TR_VEC:.*]] = vector.transpose %[[BC_VEC]], [1, 0] : vector<1x7xf32> to vector<7x1xf32> |
| 138 | +// CHECK: vector.transfer_write %[[TR_VEC]], %[[MEM]]{{\[}}%[[BASE_1]], %[[BASE_2]]], %[[TR_MASK]] {in_bounds = [false, true]} : vector<7x1xf32>, memref<?x?xf32> |
| 139 | +func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds( |
| 140 | + %mem : memref<?x?xf32>, |
| 141 | + %arg0 : vector<7xf32>, |
| 142 | + %base1 : index, |
| 143 | + %base2 : index, |
| 144 | + %mask : vector<7xi1>) { |
| 145 | + |
| 146 | + vector.transfer_write %arg0, %mem[%base1, %base2], %mask |
99 | 147 | {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [false]}
|
100 | 148 | : vector<7xf32>, memref<?x?xf32>
|
101 | 149 | return
|
|
0 commit comments