Skip to content

Commit fdd245a

Browse files
authored
[MLIR][Vector] Implement transferXXPermutationLowering as MaskableOpRewritePattern (#91987)
* Implements `TransferWritePermutationLowering`, `TransferReadPermutationLowering` and `TransferWriteNonPermutationLowering` as a MaskableOpRewritePattern. Allowing to exit gracefully when such use of a xferOp is inside a `vector::MaskOp` * Updates MaskableOpRewritePattern to handle MemRefs and buffer semantics providing empty `Value()` as a return value for `matchAndRewriteMaskableOp` now represents successful rewriting without value to replace the original op. Split of #90835
1 parent e1c06c3 commit fdd245a

File tree

3 files changed

+130
-25
lines changed

3 files changed

+130
-25
lines changed

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,14 @@ struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
157157
if (failed(newOp))
158158
return failure();
159159

160-
rewriter.replaceOp(rootOp, *newOp);
160+
// Rewriting succeeded but there are no values to replace.
161+
if (rootOp->getNumResults() == 0) {
162+
rewriter.eraseOp(rootOp);
163+
} else {
164+
assert(*newOp != Value() &&
165+
"Cannot replace an op's use with an empty value.");
166+
rewriter.replaceOp(rootOp, *newOp);
167+
}
161168
return success();
162169
}
163170

mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,19 @@ namespace {
9090
/// Note that an alternative is to transform it to linalg.transpose +
9191
/// vector.transfer_read to do the transpose in memory instead.
9292
struct TransferReadPermutationLowering
93-
: public OpRewritePattern<vector::TransferReadOp> {
94-
using OpRewritePattern::OpRewritePattern;
93+
: public MaskableOpRewritePattern<vector::TransferReadOp> {
94+
using MaskableOpRewritePattern::MaskableOpRewritePattern;
9595

96-
LogicalResult matchAndRewrite(vector::TransferReadOp op,
97-
PatternRewriter &rewriter) const override {
96+
FailureOr<mlir::Value>
97+
matchAndRewriteMaskableOp(vector::TransferReadOp op,
98+
MaskingOpInterface maskOp,
99+
PatternRewriter &rewriter) const override {
98100
// TODO: support 0-d corner case.
99101
if (op.getTransferRank() == 0)
100102
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
103+
// TODO: Support transfer_read inside MaskOp case.
104+
if (maskOp)
105+
return rewriter.notifyMatchFailure(op, "Masked case not supported");
101106

102107
SmallVector<unsigned> permutation;
103108
AffineMap map = op.getPermutationMap();
@@ -142,9 +147,9 @@ struct TransferReadPermutationLowering
142147

143148
// Transpose result of transfer_read.
144149
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
145-
rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
146-
transposePerm);
147-
return success();
150+
return rewriter
151+
.create<vector::TransposeOp>(op.getLoc(), newRead, transposePerm)
152+
.getResult();
148153
}
149154
};
150155

@@ -165,14 +170,19 @@ struct TransferReadPermutationLowering
165170
/// %v = vector.transfer_write %tmp ...
166171
/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
167172
struct TransferWritePermutationLowering
168-
: public OpRewritePattern<vector::TransferWriteOp> {
169-
using OpRewritePattern::OpRewritePattern;
173+
: public MaskableOpRewritePattern<vector::TransferWriteOp> {
174+
using MaskableOpRewritePattern::MaskableOpRewritePattern;
170175

171-
LogicalResult matchAndRewrite(vector::TransferWriteOp op,
172-
PatternRewriter &rewriter) const override {
176+
FailureOr<mlir::Value>
177+
matchAndRewriteMaskableOp(vector::TransferWriteOp op,
178+
MaskingOpInterface maskOp,
179+
PatternRewriter &rewriter) const override {
173180
// TODO: support 0-d corner case.
174181
if (op.getTransferRank() == 0)
175182
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
183+
// TODO: Support transfer_write inside MaskOp case.
184+
if (maskOp)
185+
return rewriter.notifyMatchFailure(op, "Masked case not supported");
176186

177187
SmallVector<unsigned> permutation;
178188
AffineMap map = op.getPermutationMap();
@@ -207,11 +217,14 @@ struct TransferWritePermutationLowering
207217
op.getLoc(), op.getVector(), indices);
208218
auto newMap = AffineMap::getMinorIdentityMap(
209219
map.getNumDims(), map.getNumResults(), rewriter.getContext());
210-
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
211-
op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
212-
op.getMask(), newInBoundsAttr);
213-
214-
return success();
220+
auto newWrite = rewriter.create<vector::TransferWriteOp>(
221+
op.getLoc(), newVec, op.getSource(), op.getIndices(),
222+
AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr);
223+
if (newWrite.hasPureTensorSemantics())
224+
return newWrite.getResult();
225+
// In the memref case there's no return value. Use empty value to signal
226+
// success.
227+
return Value();
215228
}
216229
};
217230

@@ -231,14 +244,19 @@ struct TransferWritePermutationLowering
231244
/// vector<1x8x16xf32>
232245
/// ```
233246
struct TransferWriteNonPermutationLowering
234-
: public OpRewritePattern<vector::TransferWriteOp> {
235-
using OpRewritePattern::OpRewritePattern;
247+
: public MaskableOpRewritePattern<vector::TransferWriteOp> {
248+
using MaskableOpRewritePattern::MaskableOpRewritePattern;
236249

237-
LogicalResult matchAndRewrite(vector::TransferWriteOp op,
238-
PatternRewriter &rewriter) const override {
250+
FailureOr<mlir::Value>
251+
matchAndRewriteMaskableOp(vector::TransferWriteOp op,
252+
MaskingOpInterface maskOp,
253+
PatternRewriter &rewriter) const override {
239254
// TODO: support 0-d corner case.
240255
if (op.getTransferRank() == 0)
241256
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
257+
// TODO: Support transfer_write inside MaskOp case.
258+
if (maskOp)
259+
return rewriter.notifyMatchFailure(op, "Masked case not supported");
242260

243261
SmallVector<unsigned> permutation;
244262
AffineMap map = op.getPermutationMap();
@@ -285,10 +303,14 @@ struct TransferWriteNonPermutationLowering
285303
newInBoundsValues.push_back(op.isDimInBounds(i));
286304
}
287305
ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
288-
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
289-
op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
290-
newMask, newInBoundsAttr);
291-
return success();
306+
auto newWrite = rewriter.create<vector::TransferWriteOp>(
307+
op.getLoc(), newVec, op.getSource(), op.getIndices(),
308+
AffineMapAttr::get(newMap), newMask, newInBoundsAttr);
309+
if (newWrite.hasPureTensorSemantics())
310+
return newWrite.getResult();
311+
// In the memref case there's no return value. Use empty value to signal
312+
// success.
313+
return Value();
292314
}
293315
};
294316

mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,51 @@ func.func @permutation_with_mask_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %
4646
return
4747
}
4848

49+
// transfer_write in MaskOp case not supported.
50+
// CHECK-LABEL: func @masked_permutation_xfer_write_fixed_width
51+
// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>,
52+
// CHECK-SAME: %[[ARG_1:.*]]: vector<16xf32>,
53+
// CHECK-SAME: %[[IDX:.*]]: index,
54+
// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>
55+
// CHECK-NOT: vector.transpose
56+
// CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_1]], %[[ARG_0]]{{.*}} vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
57+
func.func @masked_permutation_xfer_write_fixed_width(%t: tensor<?x?xf32>, %val: vector<16xf32>, %idx: index, %mask: vector<16xi1>) -> tensor<?x?xf32> {
58+
%r = vector.mask %mask { vector.transfer_write %val, %t[%idx, %idx] {permutation_map = affine_map<(d0, d1) -> (d0)>} : vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
59+
return %r : tensor<?x?xf32>
60+
}
61+
62+
// CHECK-LABEL: func.func @masked_permutation_xfer_write_scalable(
63+
// CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>,
64+
// CHECK-SAME: %[[ARG_1:.*]]: tensor<?x?x?x?xf32>,
65+
// CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>)
66+
// CHECK-SAME: -> tensor<?x?x?x?xf32> {
67+
// CHECK-NOT: vector.transpose
68+
// CHECK: %[[R:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_0]], %[[ARG_1]]{{.*}} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
69+
func.func @masked_permutation_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %t: tensor<?x?x?x?xf32>, %mask: vector<4x[8]xi1>) -> tensor<?x?x?x?xf32> {
70+
%c0 = arith.constant 0 : index
71+
%r = vector.mask %mask { vector.transfer_write %arg0, %t[%c0, %c0, %c0, %c0] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
72+
} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
73+
74+
return %r : tensor<?x?x?x?xf32>
75+
}
76+
77+
// transfer_write in MaskOp case not supported.
78+
// CHECK-LABEL: func @masked_non_permutation_xfer_write_fixed_width
79+
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x?x?xf32>
80+
// CHECK-SAME: %[[ARG1:.*]]: vector<14x8x16xf32>
81+
// CHECK-SAME: %[[IDX:.*]]: index) -> tensor<?x?x?x?xf32>
82+
// CHECK-NOT: vector.broadcast
83+
// CHECK: %[[masked1:.*]] = vector.mask %0 { vector.transfer_write %[[ARG1]], %[[ARG0]]{{.*}} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
84+
func.func @masked_non_permutation_xfer_write_fixed_width(
85+
%arg0 : tensor<?x?x?x?xf32>,
86+
%v1 : vector<14x8x16xf32>, %dim : index) -> tensor<?x?x?x?xf32> {
87+
%c0 = arith.constant 0 : index
88+
%mask = vector.create_mask %dim, %dim, %dim : vector<14x8x16xi1>
89+
%0 = vector.mask %mask { vector.transfer_write %v1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
90+
91+
return %0 : tensor<?x?x?x?xf32>
92+
}
93+
4994
///----------------------------------------------------------------------------------------
5095
/// vector.transfer_read
5196
///----------------------------------------------------------------------------------------
@@ -101,6 +146,37 @@ func.func @permutation_with_mask_xfer_read_scalable(%mem: memref<?x?xf32>, %dim_
101146
return %1 : vector<8x[4]x2xf32>
102147
}
103148

149+
// transfer_read in MaskOp case not supported.
150+
// CHECK-LABEL: func @masked_permutation_xfer_read_fixed_width
151+
// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x1xf32>,
152+
// CHECK-SAME: %[[ARG_1:.*]]: vector<4x1xi1>
153+
// CHECK-NOT: vector.transpose
154+
// CHECK: vector.mask %[[ARG_1]] { vector.transfer_read %[[ARG_0]]{{.*}}: tensor<?x1xf32>, vector<1x4x4xf32> } : vector<4x1xi1> -> vector<1x4x4xf32>
155+
func.func @masked_permutation_xfer_read_fixed_width(%arg0: tensor<?x1xf32>, %mask : vector<4x1xi1>) {
156+
%cst = arith.constant 0.000000e+00 : f32
157+
%c0 = arith.constant 0 : index
158+
%3 = vector.mask %mask { vector.transfer_read %arg0[%c0, %c0], %cst {permutation_map = affine_map<(d0, d1) -> (d1, 0, d0)>} : tensor<?x1xf32>, vector<1x4x4xf32> } : vector<4x1xi1> -> vector<1x4x4xf32>
159+
call @test.some_use(%3) : (vector<1x4x4xf32>) -> ()
160+
return
161+
}
162+
func.func private @test.some_use(vector<1x4x4xf32>)
163+
164+
// CHECK-LABEL: func.func @masked_permutation_xfer_read_scalable(
165+
// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>,
166+
// CHECK-SAME: %[[MASK:.*]]: vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
167+
// CHECK-NOT: vector.transpose
168+
// CHECK: %[[T_READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[ARG_0]]{{.*}} : tensor<?x?xf32>, vector<8x[4]x2xf32> } : vector<2x[4]xi1> -> vector<8x[4]x2xf32>
169+
func.func @masked_permutation_xfer_read_scalable(%t: tensor<?x?xf32>, %mask : vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
170+
171+
%c0 = arith.constant 0 : index
172+
%cst_0 = arith.constant 0.000000e+00 : f32
173+
174+
%1 = vector.mask %mask { vector.transfer_read %t[%c0, %c0], %cst_0
175+
{in_bounds = [true, true, true], permutation_map = affine_map<(d0, d1) -> (0, d1, d0)>}
176+
: tensor<?x?xf32>, vector<8x[4]x2xf32> } :vector<2x[4]xi1> -> vector<8x[4]x2xf32>
177+
return %1 : vector<8x[4]x2xf32>
178+
}
179+
104180
module attributes {transform.with_named_sequence} {
105181
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
106182
%f = transform.structured.match ops{["func.func"]} in %module_op

0 commit comments

Comments
 (0)