Skip to content

Commit 6de7022

Browse files
committed
[MLIR][Vector] Implement transferXXPermutationLowering as MaskableOpRewritePattern
1 parent 2b15c4a commit 6de7022

File tree

2 files changed

+122
-24
lines changed

2 files changed

+122
-24
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,19 @@ namespace {
9090
/// Note that an alternative is to transform it to linalg.transpose +
9191
/// vector.transfer_read to do the transpose in memory instead.
9292
struct TransferReadPermutationLowering
93-
: public OpRewritePattern<vector::TransferReadOp> {
94-
using OpRewritePattern::OpRewritePattern;
93+
: public MaskableOpRewritePattern<vector::TransferReadOp> {
94+
using MaskableOpRewritePattern::MaskableOpRewritePattern;
9595

96-
LogicalResult matchAndRewrite(vector::TransferReadOp op,
97-
PatternRewriter &rewriter) const override {
96+
FailureOr<mlir::Value>
97+
matchAndRewriteMaskableOp(vector::TransferReadOp op,
98+
MaskingOpInterface maskOp,
99+
PatternRewriter &rewriter) const override {
98100
// TODO: support 0-d corner case.
99101
if (op.getTransferRank() == 0)
100102
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
103+
// TODO: Support transfer_read inside MaskOp case.
104+
if (maskOp)
105+
return rewriter.notifyMatchFailure(op, "Masked case not supported");
101106

102107
SmallVector<unsigned> permutation;
103108
AffineMap map = op.getPermutationMap();
@@ -142,9 +147,9 @@ struct TransferReadPermutationLowering
142147

143148
// Transpose result of transfer_read.
144149
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
145-
rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
146-
transposePerm);
147-
return success();
150+
return rewriter
151+
.create<vector::TransposeOp>(op.getLoc(), newRead, transposePerm)
152+
.getResult();
148153
}
149154
};
150155

@@ -165,14 +170,19 @@ struct TransferReadPermutationLowering
165170
/// %v = vector.transfer_write %tmp ...
166171
/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
167172
struct TransferWritePermutationLowering
168-
: public OpRewritePattern<vector::TransferWriteOp> {
169-
using OpRewritePattern::OpRewritePattern;
173+
: public MaskableOpRewritePattern<vector::TransferWriteOp> {
174+
using MaskableOpRewritePattern::MaskableOpRewritePattern;
170175

171-
LogicalResult matchAndRewrite(vector::TransferWriteOp op,
172-
PatternRewriter &rewriter) const override {
176+
FailureOr<mlir::Value>
177+
matchAndRewriteMaskableOp(vector::TransferWriteOp op,
178+
MaskingOpInterface maskOp,
179+
PatternRewriter &rewriter) const override {
173180
// TODO: support 0-d corner case.
174181
if (op.getTransferRank() == 0)
175182
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
183+
// TODO: Support transfer_write inside MaskOp case.
184+
if (maskOp)
185+
return rewriter.notifyMatchFailure(op, "Masked case not supported");
176186

177187
SmallVector<unsigned> permutation;
178188
AffineMap map = op.getPermutationMap();
@@ -207,11 +217,11 @@ struct TransferWritePermutationLowering
207217
op.getLoc(), op.getVector(), indices);
208218
auto newMap = AffineMap::getMinorIdentityMap(
209219
map.getNumDims(), map.getNumResults(), rewriter.getContext());
210-
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
211-
op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
212-
op.getMask(), newInBoundsAttr);
213-
214-
return success();
220+
return rewriter
221+
.create<vector::TransferWriteOp>(
222+
op.getLoc(), newVec, op.getSource(), op.getIndices(),
223+
AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr)
224+
.getResult();
215225
}
216226
};
217227

@@ -231,14 +241,19 @@ struct TransferWritePermutationLowering
231241
/// vector<1x8x16xf32>
232242
/// ```
233243
struct TransferWriteNonPermutationLowering
234-
: public OpRewritePattern<vector::TransferWriteOp> {
235-
using OpRewritePattern::OpRewritePattern;
244+
: public MaskableOpRewritePattern<vector::TransferWriteOp> {
245+
using MaskableOpRewritePattern::MaskableOpRewritePattern;
236246

237-
LogicalResult matchAndRewrite(vector::TransferWriteOp op,
238-
PatternRewriter &rewriter) const override {
247+
FailureOr<mlir::Value>
248+
matchAndRewriteMaskableOp(vector::TransferWriteOp op,
249+
MaskingOpInterface maskOp,
250+
PatternRewriter &rewriter) const override {
239251
// TODO: support 0-d corner case.
240252
if (op.getTransferRank() == 0)
241253
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
254+
// TODO: Support transfer_write inside MaskOp case.
255+
if (maskOp)
256+
return rewriter.notifyMatchFailure(op, "Masked case not supported");
242257

243258
SmallVector<unsigned> permutation;
244259
AffineMap map = op.getPermutationMap();
@@ -285,10 +300,11 @@ struct TransferWriteNonPermutationLowering
285300
newInBoundsValues.push_back(op.isDimInBounds(i));
286301
}
287302
ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
288-
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
289-
op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
290-
newMask, newInBoundsAttr);
291-
return success();
303+
return rewriter
304+
.create<vector::TransferWriteOp>(
305+
op.getLoc(), newVec, op.getSource(), op.getIndices(),
306+
AffineMapAttr::get(newMap), newMask, newInBoundsAttr)
307+
.getResult();
292308
}
293309
};
294310

mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,55 @@ func.func @permutation_with_mask_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %
4646
return
4747
}
4848

49+
// transfer_write in MaskOp case not supported.
50+
// CHECK-LABEL: func @masked_permutation_xfer_write_fixed_width
51+
// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>,
52+
// CHECK-SAME: %[[ARG_1:.*]]: vector<16xf32>,
53+
// CHECK-SAME: %[[IDX:.*]]: index,
54+
// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>
55+
// CHECK-NOT: vector.transpose
56+
// CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_1]], %[[ARG_0]][%[[IDX]], %[[IDX]]] {{.*}} vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
57+
// CHECK: return %[[RES]]
58+
func.func @masked_permutation_xfer_write_fixed_width(%t: tensor<?x?xf32>, %val: vector<16xf32>, %idx: index, %mask: vector<16xi1>) -> tensor<?x?xf32> {
59+
%r = vector.mask %mask { vector.transfer_write %val, %t[%idx, %idx] {permutation_map = affine_map<(d0, d1) -> (d0)>} : vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
60+
return %r : tensor<?x?xf32>
61+
}
62+
63+
// CHECK-LABEL: func.func @masked_permutation_xfer_write_scalable(
64+
// CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>,
65+
// CHECK-SAME: %[[ARG_1:.*]]: tensor<?x?x?x?xf32>,
66+
// CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>)
67+
// CHECK-SAME: -> tensor<?x?x?x?xf32> {
68+
// CHECK-NOT: vector.transpose
69+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
70+
// CHECK: %[[R:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_0]], %[[ARG_1]][%c0, %c0, %c0, %c0] {in_bounds = [true, true], permutation_map = #[[MAP:.*]]} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
71+
// CHECK: return %[[R]] : tensor<?x?x?x?xf32>
72+
func.func @masked_permutation_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %t: tensor<?x?x?x?xf32>, %mask: vector<4x[8]xi1>) -> tensor<?x?x?x?xf32> {
73+
%c0 = arith.constant 0 : index
74+
%r = vector.mask %mask { vector.transfer_write %arg0, %t[%c0, %c0, %c0, %c0] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
75+
} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
76+
77+
return %r : tensor<?x?x?x?xf32>
78+
}
79+
80+
// transfer_write in MaskOp case not supported.
81+
// CHECK-LABEL: func @masked_non_permutation_xfer_write_fixed_width
82+
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x?x?xf32>
83+
// CHECK-SAME: %[[ARG1:.*]]: vector<14x8x16xf32>
84+
// CHECK-SAME: %[[IDX:.*]]: index) -> tensor<?x?x?x?xf32>
85+
// CHECK-NOT: vector.broadcast
86+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
87+
// CHECK: %[[masked1:.*]] = vector.mask %0 { vector.transfer_write %[[ARG1]], %[[ARG0]]{{.*}}permutation_map = #[[MAP:.*]]} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
88+
func.func @masked_non_permutation_xfer_write_fixed_width(
89+
%arg0 : tensor<?x?x?x?xf32>,
90+
%v1 : vector<14x8x16xf32>, %dim : index) -> tensor<?x?x?x?xf32> {
91+
%c0 = arith.constant 0 : index
92+
%mask = vector.create_mask %dim, %dim, %dim : vector<14x8x16xi1>
93+
%0 = vector.mask %mask { vector.transfer_write %v1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
94+
95+
return %0 : tensor<?x?x?x?xf32>
96+
}
97+
4998
///----------------------------------------------------------------------------------------
5099
/// vector.transfer_read
51100
///----------------------------------------------------------------------------------------
@@ -101,6 +150,39 @@ func.func @permutation_with_mask_xfer_read_scalable(%mem: memref<?x?xf32>, %dim_
101150
return %1 : vector<8x[4]x2xf32>
102151
}
103152

153+
// transfer_read in MaskOp case not supported.
154+
// CHECK-LABEL: func @masked_permutation_xfer_read_fixed_width
155+
// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x1xf32>,
156+
// CHECK-SAME: %[[ARG_1:.*]]: vector<4x1xi1>
157+
// CHECK-NOT: vector.transpose
158+
// CHECK: vector.mask %[[ARG_1]] { vector.transfer_read %[[ARG_0]]{{.*}}: tensor<?x1xf32>, vector<1x4x4xf32> } : vector<4x1xi1> -> vector<1x4x4xf32>
159+
func.func @masked_permutation_xfer_read_fixed_width(%arg0: tensor<?x1xf32>, %mask : vector<4x1xi1>) {
160+
%cst = arith.constant 0.000000e+00 : f32
161+
%c0 = arith.constant 0 : index
162+
%3 = vector.mask %mask { vector.transfer_read %arg0[%c0, %c0], %cst {permutation_map = affine_map<(d0, d1) -> (d1, 0, d0)>} : tensor<?x1xf32>, vector<1x4x4xf32> } : vector<4x1xi1> -> vector<1x4x4xf32>
163+
call @test.some_use(%3) : (vector<1x4x4xf32>) -> ()
164+
return
165+
}
166+
func.func private @test.some_use(vector<1x4x4xf32>)
167+
168+
// CHECK-LABEL: func.func @masked_permutation_xfer_read_scalable(
169+
// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>,
170+
// CHECK-SAME: %[[MASK:.*]]: vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
171+
// CHECK-NOT: vector.transpose
172+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
173+
// CHECK: %[[T_READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[ARG_0]][%[[C0]], %[[C0]]], %cst {in_bounds = [true, true, true], permutation_map = #[[MAP:.*]]} : tensor<?x?xf32>, vector<8x[4]x2xf32> } : vector<2x[4]xi1> -> vector<8x[4]x2xf32>
174+
// CHECK: return %[[T_READ]] : vector<8x[4]x2xf32>
175+
func.func @masked_permutation_xfer_read_scalable(%t: tensor<?x?xf32>, %mask : vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
176+
177+
%c0 = arith.constant 0 : index
178+
%cst_0 = arith.constant 0.000000e+00 : f32
179+
180+
%1 = vector.mask %mask { vector.transfer_read %t[%c0, %c0], %cst_0
181+
{in_bounds = [true, true, true], permutation_map = affine_map<(d0, d1) -> (0, d1, d0)>}
182+
: tensor<?x?xf32>, vector<8x[4]x2xf32> } :vector<2x[4]xi1> -> vector<8x[4]x2xf32>
183+
return %1 : vector<8x[4]x2xf32>
184+
}
185+
104186
module attributes {transform.with_named_sequence} {
105187
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
106188
%f = transform.structured.match ops{["func.func"]} in %module_op

0 commit comments

Comments
 (0)