Skip to content

Commit 57863a4

Browse files
committed
[MLIR][Vector] Fix transferOps optimization inside maskOp
1 parent b247776 commit 57863a4

File tree

4 files changed

+61
-0
lines changed

4 files changed

+61
-0
lines changed

mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,17 @@ LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite(
100100
SmallVector<Value> indices(readOp.getIndices().begin(),
101101
readOp.getIndices().end());
102102
SmallVector<Value> sourceIndices;
103+
// In case transfer_read is located inside a MaskOp we want to avoid creating
104+
// more ops inside it.
105+
if (isa<vector::MaskOp>(readOp->getParentOp()))
106+
rewriter.setInsertionPoint(readOp->getParentOp());
103107
affine::resolveIndicesIntoOpWithOffsetsAndStrides(
104108
rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(),
105109
extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(),
106110
indices, sourceIndices);
107111

112+
// Reset the insertion point.
113+
rewriter.setInsertionPoint(readOp);
108114
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
109115
readOp, readOp.getVectorType(), extractSliceOp.getSource(), sourceIndices,
110116
AffineMapAttr::get(expandDimsToRank(

mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ struct TransferReadPermutationLowering
9898
// TODO: support 0-d corner case.
9999
if (op.getTransferRank() == 0)
100100
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
101+
if (isa<vector::MaskOp>(op->getParentOp()))
102+
return rewriter.notifyMatchFailure(
103+
op, "Cannot expand transfer read inside a Mask Op");
101104

102105
SmallVector<unsigned> permutation;
103106
AffineMap map = op.getPermutationMap();
@@ -173,6 +176,9 @@ struct TransferWritePermutationLowering
173176
// TODO: support 0-d corner case.
174177
if (op.getTransferRank() == 0)
175178
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
179+
if (isa<vector::MaskOp>(op->getParentOp()))
180+
return rewriter.notifyMatchFailure(
181+
op, "Cannot expand transfer write inside a Mask Op");
176182

177183
SmallVector<unsigned> permutation;
178184
AffineMap map = op.getPermutationMap();
@@ -239,6 +245,9 @@ struct TransferWriteNonPermutationLowering
239245
// TODO: support 0-d corner case.
240246
if (op.getTransferRank() == 0)
241247
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
248+
if (isa<vector::MaskOp>(op->getParentOp()))
249+
return rewriter.notifyMatchFailure(
250+
op, "Cannot expand transfer write inside a Mask Op");
242251

243252
SmallVector<unsigned> permutation;
244253
AffineMap map = op.getPermutationMap();

mlir/test/Dialect/Tensor/fold-tensor-subset-ops-into-vector-transfers.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,18 @@ func.func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor<?x?x12xf32
111111
%1 = tensor.insert_slice %0 into %t1[4, 3, %s] [1, 5, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
112112
return %1 : tensor<?x?x12xf32>
113113
}
114+
115+
// CHECK-LABEL: func @masked_transfer_read_of_extract_slice
116+
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
117+
// CHECK-DAG: %[[m:.*]] = vector.create_mask{{.*}} : vector<5x6xi1>
118+
// CHECK-DAG: %[[a:.*]] = affine.apply {{.*}}[[s1]]
119+
// CHECK: vector.mask %[[m]] { vector.transfer_read %[[t]]{{.*}}: tensor<?x?xf32>, vector<5x6xf32> } : vector<5x6xi1> -> vector<5x6xf32>
120+
func.func @masked_transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
121+
%c3 = arith.constant 3 : index
122+
%c4 = arith.constant 4 : index
123+
%cst = arith.constant 0.0 : f32
124+
%0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor<?x?xf32> to tensor<10x?xf32>
125+
%mask = vector.create_mask %c3, %c4 : vector<5x6xi1>
126+
%1 = vector.mask %mask {vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<10x?xf32>, vector<5x6xf32>} : vector<5x6xi1> -> vector<5x6xf32>
127+
return %1 : vector<5x6xf32>
128+
}

mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,37 @@ func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16
5959

6060
return
6161
}
62+
63+
64+
#map = affine_map<(d0)[s0] -> (-d0 + s0, 4)>
65+
#map1 = affine_map<(d0, d1) -> (d0, 0, d1)>
66+
// CHECK-LABEL: func @masked_permutation_transfer_read
67+
// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x1xf32>,
68+
// CHECK-SAME: %[[ARG_1:.*]]: vector<4x1xi1>
69+
// CHECK: vector.transfer_read %[[ARG_0]]{{.*}}: tensor<?x1xf32>, vector<4x4x1xf32> } : vector<4x1xi1> -> vector<4x4x1xf32>
70+
func.func @masked_permutation_transfer_read(%arg0: tensor<?x1xf32>, %mask : vector<4x1xi1>) {
71+
%cst = arith.constant 0.000000e+00 : f32
72+
%c0 = arith.constant 0 : index
73+
%3 = vector.mask %mask { vector.transfer_read %arg0[%c0, %c0], %cst {permutation_map = #map1} : tensor<?x1xf32>, vector<4x4x1xf32> } : vector<4x1xi1> -> vector<4x4x1xf32>
74+
call @dostuff(%3) : (vector<4x4x1xf32>) -> ()
75+
return
76+
}
77+
func.func private @dostuff(vector<4x4x1xf32>)
78+
79+
80+
// CHECK-LABEL: func @masked_permutation_transfer_write
81+
// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>,
82+
// CHECK-SAME: %[[ARG_1:.*]]: vector<16xf32>,
83+
// CHECK-SAME: %[[IDX:.*]]: index,
84+
// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>
85+
// CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_1]], %[[ARG_0]][%[[IDX]], %[[IDX]]] {{.*}} vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
86+
// CHECK: return %[[RES]]
87+
func.func @masked_permutation_transfer_write(%t: tensor<?x?xf32>, %val: vector<16xf32>, %idx: index, %m0: vector<16xi1>) -> tensor<?x?xf32> {
88+
%r = vector.mask %m0 { vector.transfer_write %val, %t[%idx, %idx] {permutation_map = affine_map<(d0, d1) -> (d0)>} : vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
89+
return %r : tensor<?x?xf32>
90+
}
91+
92+
6293
module attributes {transform.with_named_sequence} {
6394
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
6495
%f = transform.structured.match ops{["func.func"]} in %module_op

0 commit comments

Comments
 (0)