@@ -322,14 +322,19 @@ struct TransferWriteNonPermutationLowering
322
322
// / %v = vector.transfer_read ...
323
323
// / permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
324
324
// / vector.broadcast %v
325
- struct TransferOpReduceRank : public OpRewritePattern <vector::TransferReadOp> {
326
- using OpRewritePattern::OpRewritePattern;
325
+ struct TransferOpReduceRank
326
+ : public MaskableOpRewritePattern<vector::TransferReadOp> {
327
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
327
328
328
- LogicalResult matchAndRewrite (vector::TransferReadOp op,
329
- PatternRewriter &rewriter) const override {
329
+ FailureOr<mlir::Value>
330
+ matchAndRewriteMaskableOp (vector::TransferReadOp op,
331
+ MaskingOpInterface maskOp,
332
+ PatternRewriter &rewriter) const override {
330
333
// TODO: support 0-d corner case.
331
334
if (op.getTransferRank () == 0 )
332
335
return rewriter.notifyMatchFailure (op, " 0-d corner case not supported" );
336
+ if (maskOp)
337
+ return rewriter.notifyMatchFailure (op, " Masked case not supported" );
333
338
334
339
AffineMap map = op.getPermutationMap ();
335
340
unsigned numLeadingBroadcast = 0 ;
@@ -369,9 +374,9 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
369
374
op.getLoc (), originalVecType.getElementType (), op.getSource (),
370
375
op.getIndices ());
371
376
}
372
- rewriter. replaceOpWithNewOp <vector::BroadcastOp>(op, originalVecType,
373
- newRead);
374
- return success ();
377
+ return rewriter
378
+ . create <vector::BroadcastOp>(op. getLoc (), originalVecType, newRead)
379
+ . getVector ();
375
380
}
376
381
377
382
SmallVector<int64_t > newShape (
@@ -393,9 +398,9 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
393
398
op.getLoc (), newReadType, op.getSource (), op.getIndices (),
394
399
AffineMapAttr::get (newMap), op.getPadding (), op.getMask (),
395
400
newInBoundsAttr);
396
- rewriter. replaceOpWithNewOp <vector::BroadcastOp>(op, originalVecType,
397
- newRead);
398
- return success ();
401
+ return rewriter
402
+ . create <vector::BroadcastOp>(op. getLoc (), originalVecType, newRead)
403
+ . getVector ();
399
404
}
400
405
};
401
406
0 commit comments