Skip to content

Commit d438852

Browse files
committed
[MLIR][Vector] Implement TransferOpReduceRank as MaskableOpRewritePattern
1 parent 058e445 commit d438852

File tree

2 files changed

+30
-10
lines changed

2 files changed

+30
-10
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -322,14 +322,19 @@ struct TransferWriteNonPermutationLowering
322322
/// %v = vector.transfer_read ...
323323
/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
324324
/// vector.broadcast %v
325-
struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
326-
using OpRewritePattern::OpRewritePattern;
325+
struct TransferOpReduceRank
326+
: public MaskableOpRewritePattern<vector::TransferReadOp> {
327+
using MaskableOpRewritePattern::MaskableOpRewritePattern;
327328

328-
LogicalResult matchAndRewrite(vector::TransferReadOp op,
329-
PatternRewriter &rewriter) const override {
329+
FailureOr<mlir::Value>
330+
matchAndRewriteMaskableOp(vector::TransferReadOp op,
331+
MaskingOpInterface maskOp,
332+
PatternRewriter &rewriter) const override {
330333
// TODO: support 0-d corner case.
331334
if (op.getTransferRank() == 0)
332335
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
336+
if (maskOp)
337+
return rewriter.notifyMatchFailure(op, "Masked case not supported");
333338

334339
AffineMap map = op.getPermutationMap();
335340
unsigned numLeadingBroadcast = 0;
@@ -369,9 +374,9 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
369374
op.getLoc(), originalVecType.getElementType(), op.getSource(),
370375
op.getIndices());
371376
}
372-
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
373-
newRead);
374-
return success();
377+
return rewriter
378+
.create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
379+
.getVector();
375380
}
376381

377382
SmallVector<int64_t> newShape(
@@ -393,9 +398,9 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
393398
op.getLoc(), newReadType, op.getSource(), op.getIndices(),
394399
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
395400
newInBoundsAttr);
396-
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
397-
newRead);
398-
return success();
401+
return rewriter
402+
.create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
403+
.getVector();
399404
}
400405
};
401406

mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,21 @@ func.func @transfer_broadcasting_2D(%mem : memref<8x8xf32>, %i : index) -> vecto
219219
return %res : vector<4x4xf32>
220220
}
221221

222+
// CHECK-LABEL: func @masked_transfer_read_reduce_rank_with_broadcast(
223+
// CHECK-SAME: %[[MEM:.*]]: memref<8x8x8x8xf32>,
224+
// CHECK-SAME: %[[MASK:.*]]: vector<4x4xi1>,
225+
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4x4x4x4xf32> {
226+
// CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]], %cst {in_bounds = [true, true, true, true], permutation_map = #map2} : memref<8x8x8x8xf32>, vector<4x4x4x4xf32> } : vector<4x4xi1> -> vector<4x4x4x4xf32>
227+
// CHECK-NEXT: return %[[RES]] : vector<4x4x4x4xf32>
228+
#rank_reducing = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>
229+
func.func @masked_transfer_read_reduce_rank_with_broadcast(%mem : memref<8x8x8x8xf32>, %mask : vector<4x4xi1>, %i : index) -> vector<4x4x4x4xf32> {
230+
%cf0 = arith.constant 0.0 : f32
231+
%res = vector.mask %mask {vector.transfer_read %mem[%i, %i, %i, %i], %cf0
232+
{in_bounds = [true, true, true, true], permutation_map = #rank_reducing}
233+
: memref<8x8x8x8xf32>, vector<4x4x4x4xf32>} : vector<4x4xi1> -> vector<4x4x4x4xf32>
234+
return %res : vector<4x4x4x4xf32>
235+
}
236+
222237
// More complex broadcasting case (here a `vector.load` is generated).
223238
// CHECK-LABEL: func @transfer_broadcasting_complex(
224239
// CHECK-SAME: %[[MEM:.*]]: memref<10x20x30x8x8xf32>,

0 commit comments

Comments
 (0)