Skip to content

Commit 1c3f8c8

Browse files
committed
Move test to vector-transfer-permutation-lowering.mlir
1 parent ee51517 commit 1c3f8c8

File tree

3 files changed

+47
-15
lines changed

3 files changed

+47
-15
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ struct TransferOpReduceRank
333333
// TODO: support 0-d corner case.
334334
if (op.getTransferRank() == 0)
335335
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
336+
// TODO: support masked case.
336337
if (maskOp)
337338
return rewriter.notifyMatchFailure(op, "Masked case not supported");
338339

mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,49 @@ module attributes {transform.with_named_sequence} {
187187
transform.yield
188188
}
189189
}
190+
191+
// -----
192+
193+
194+
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>
195+
// CHECK: func.func @transfer_read_reduce_rank_scalable(
196+
// CHECK-SAME: %[[ARG_0:.*]]: memref<?x?x?x?xf32>) -> vector<8x[4]x2x3xf32> {
197+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
198+
// CHECK: %[[TFR:.*]] = vector.transfer_read %arg0[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]{{.*}} permutation_map = #[[MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
199+
// CHECK: %[[BC:.*]] = vector.broadcast %[[TFR]] : vector<[4]x2x3xf32> to vector<8x[4]x2x3xf32>
200+
// CHECK: return %[[BC]] : vector<8x[4]x2x3xf32>
201+
func.func @transfer_read_reduce_rank_scalable(%mem: memref<?x?x?x?xf32>) -> vector<8x[4]x2x3xf32> {
202+
%c0 = arith.constant 0 : index
203+
%cst_0 = arith.constant 0.000000e+00 : f32
204+
%1 = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst_0
205+
{in_bounds = [true, true, true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>}
206+
: memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32>
207+
return %1 : vector<8x[4]x2x3xf32>
208+
}
209+
210+
// Masked case not supported.
211+
// CHECK-LABEL: func.func @masked_transfer_read_reduce_rank(
212+
// CHECK-SAME: %[[ARG_0:.*]]: memref<?x?x?x?xf32>,
213+
// CHECK-SAME: %[[DIM:.*]]: index) -> vector<8x[4]x2x3xf32> {
214+
// CHECK-NOT: vector.broadcast
215+
// CHECK: %[[MASK:.*]] = vector.mask %0 { vector.transfer_read %arg0{{.*}} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
216+
func.func @masked_transfer_read_reduce_rank(%mem: memref<?x?x?x?xf32>, %dim: index) -> vector<8x[4]x2x3xf32> {
217+
%c0 = arith.constant 0 : index
218+
%cst_0 = arith.constant 0.000000e+00 : f32
219+
%mask = vector.create_mask %dim, %dim: vector<[4]x3xi1>
220+
%res = vector.mask %mask { vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst_0
221+
{in_bounds = [true, true, true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>}
222+
: memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
223+
return %res : vector<8x[4]x2x3xf32>
224+
}
225+
226+
module attributes {transform.with_named_sequence} {
227+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
228+
%f = transform.structured.match ops{["func.func"]} in %module_op
229+
: (!transform.any_op) -> !transform.any_op
230+
transform.apply_patterns to %f {
231+
transform.apply_patterns.vector.transfer_permutation_patterns
232+
} : !transform.any_op
233+
transform.yield
234+
}
235+
}

mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -219,21 +219,6 @@ func.func @transfer_broadcasting_2D(%mem : memref<8x8xf32>, %i : index) -> vecto
219219
return %res : vector<4x4xf32>
220220
}
221221

222-
// CHECK-LABEL: func @masked_transfer_read_reduce_rank_with_broadcast(
223-
// CHECK-SAME: %[[MEM:.*]]: memref<8x8x8x8xf32>,
224-
// CHECK-SAME: %[[MASK:.*]]: vector<4x4xi1>,
225-
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4x4x4x4xf32> {
226-
// CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]], %cst {in_bounds = [true, true, true, true], permutation_map = #map2} : memref<8x8x8x8xf32>, vector<4x4x4x4xf32> } : vector<4x4xi1> -> vector<4x4x4x4xf32>
227-
// CHECK-NEXT: return %[[RES]] : vector<4x4x4x4xf32>
228-
#rank_reducing = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>
229-
func.func @masked_transfer_read_reduce_rank_with_broadcast(%mem : memref<8x8x8x8xf32>, %mask : vector<4x4xi1>, %i : index) -> vector<4x4x4x4xf32> {
230-
%cf0 = arith.constant 0.0 : f32
231-
%res = vector.mask %mask {vector.transfer_read %mem[%i, %i, %i, %i], %cf0
232-
{in_bounds = [true, true, true, true], permutation_map = #rank_reducing}
233-
: memref<8x8x8x8xf32>, vector<4x4x4x4xf32>} : vector<4x4xi1> -> vector<4x4x4x4xf32>
234-
return %res : vector<4x4x4x4xf32>
235-
}
236-
237222
// More complex broadcasting case (here a `vector.load` is generated).
238223
// CHECK-LABEL: func @transfer_broadcasting_complex(
239224
// CHECK-SAME: %[[MEM:.*]]: memref<10x20x30x8x8xf32>,

0 commit comments

Comments
 (0)