Skip to content

Commit eb9f46c

Browse files
committed
Add comment and move test.
1 parent 3705157 commit eb9f46c

File tree

2 files changed

+20
-20
lines changed

2 files changed

+20
-20
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -646,12 +646,13 @@ struct TransferWriteToVectorStoreLowering
646646
rewriter.create<vector::MaskedStoreOp>(
647647
write.getLoc(), write.getSource(), write.getIndices(),
648648
write.getMask(), write.getVector());
649-
return Value();
650649
} else {
651650
rewriter.create<vector::StoreOp>(write.getLoc(), write.getVector(),
652651
write.getSource(), write.getIndices());
653-
return Value();
654652
}
653+
// There's no return value for StoreOps. Use Value() to signal success to
654+
// matchAndRewrite.
655+
return Value();
655656
}
656657

657658
std::optional<unsigned> maxTransferRank;

mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,23 @@ func.func @transfer_to_load(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32>
5151
return %res : vector<4xf32>
5252
}
5353

54+
// Masked transfer_read/write inside are NOT lowered to vector.load/store
55+
// CHECK-LABEL: func @masked_transfer_to_load(
56+
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
57+
// CHECK-SAME: %[[IDX:.*]]: index,
58+
// CHECK-SAME: %[[MASK:.*]]: vector<4xi1>) -> memref<8x8xf32>
59+
// CHECK-NOT: vector.load
60+
// CHECK-NOT: vector.store
61+
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %arg0[%[[IDX]], %[[IDX]]]{{.*}} : memref<8x8xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
62+
// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[MEM]][%[[IDX]], %[[IDX]]]{{.*}} : vector<4xf32>, memref<8x8xf32> } : vector<4xi1>
63+
64+
func.func @masked_transfer_to_load(%mem : memref<8x8xf32>, %i : index, %mask : vector<4xi1>) -> memref<8x8xf32> {
65+
%cf0 = arith.constant 0.0 : f32
66+
%read = vector.mask %mask {vector.transfer_read %mem[%i, %i], %cf0 {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>} : vector<4xi1> -> vector<4xf32>
67+
vector.mask %mask {vector.transfer_write %read, %mem[%i, %i] {in_bounds = [true]} : vector<4xf32>, memref<8x8xf32> } : vector<4xi1>
68+
return %mem : memref<8x8xf32>
69+
}
70+
5471
// n-D results are also supported.
5572
// CHECK-LABEL: func @transfer_2D(
5673
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
@@ -392,24 +409,6 @@ func.func @transfer_2D_masked(%mem : memref<?x?xf32>, %mask : vector<2x4xi1>) ->
392409
return %res : vector<2x4xf32>
393410
}
394411

395-
// Masked transfer_read/write inside are NOT lowered to vector.load/store
396-
// CHECK-LABEL: func @masked_transfer_to_load(
397-
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
398-
// CHECK-SAME: %[[IDX:.*]]: index,
399-
// CHECK-SAME: %[[MASK:.*]]: vector<4xi1>) -> memref<8x8xf32>
400-
// CHECK-NOT: vector.load
401-
// CHECK-NOT: vector.store
402-
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %arg0[%[[IDX]], %[[IDX]]]{{.*}} : memref<8x8xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
403-
// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[MEM]][%[[IDX]], %[[IDX]]]{{.*}} : vector<4xf32>, memref<8x8xf32> } : vector<4xi1>
404-
405-
406-
func.func @masked_transfer_to_load(%mem : memref<8x8xf32>, %i : index, %mask : vector<4xi1>) -> memref<8x8xf32> {
407-
%cf0 = arith.constant 0.0 : f32
408-
%read = vector.mask %mask { vector.transfer_read %mem[%i, %i], %cf0 {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>} : vector<4xi1> -> vector<4xf32>
409-
vector.mask %mask {vector.transfer_write %read, %mem[%i, %i] {in_bounds = [true]} : vector<4xf32>, memref<8x8xf32> } : vector<4xi1>
410-
return %mem : memref<8x8xf32>
411-
}
412-
413412
module attributes {transform.with_named_sequence} {
414413
transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
415414
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">

0 commit comments

Comments
 (0)