@@ -6631,13 +6631,42 @@ LogicalResult MaskOp::verify() {
6631
6631
return success ();
6632
6632
}
6633
6633
6634
- // / Folds vector.mask ops with an all-true mask.
6634
+ // / Folds empty `vector.mask` with no passthru operand and with or without
6635
+ // / return values. For example:
6636
+ // /
6637
+ // / %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
6638
+ // / vector<8xi1> -> vector<8xf32>
6639
+ // / %1 = user_op %0 : vector<8xf32>
6640
+ // /
6641
+ // / becomes:
6642
+ // /
6643
+ // / %0 = user_op %a : vector<8xf32>
6644
+ // /
6645
+ // / `vector.mask` with a passthru is handled by the canonicalizer.
6646
+ // /
6647
+ static LogicalResult foldEmptyMaskOp (MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
6648
+ SmallVectorImpl<OpFoldResult> &results) {
6649
+ if (!maskOp.isEmpty () || maskOp.hasPassthru ())
6650
+ return failure ();
6651
+
6652
+ Block *block = maskOp.getMaskBlock ();
6653
+ auto terminator = cast<vector::YieldOp>(block->front ());
6654
+ if (terminator.getNumOperands () == 0 ) {
6655
+ // `vector.mask` has no results, just remove the `vector.mask`.
6656
+ return success ();
6657
+ }
6658
+
6659
+ // `vector.mask` has results, propagate the results.
6660
+ llvm::append_range (results, terminator.getOperands ());
6661
+ return success ();
6662
+ }
6663
+
6635
6664
LogicalResult MaskOp::fold (FoldAdaptor adaptor,
6636
6665
SmallVectorImpl<OpFoldResult> &results) {
6637
- MaskFormat maskFormat = getMaskFormat (getMask ());
6638
- if (isEmpty ())
6639
- return failure ();
6666
+ if (succeeded (foldEmptyMaskOp (*this , adaptor, results)))
6667
+ return success ();
6640
6668
6669
+ MaskFormat maskFormat = getMaskFormat (getMask ());
6641
6670
if (maskFormat != MaskFormat::AllTrue)
6642
6671
return failure ();
6643
6672
@@ -6650,37 +6679,6 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
6650
6679
return success ();
6651
6680
}
6652
6681
6653
- // Elides empty vector.mask operations with or without return values. Propagates
6654
- // the yielded values by the vector.yield terminator, if any, or erases the op,
6655
- // otherwise.
6656
- class ElideEmptyMaskOp : public OpRewritePattern <MaskOp> {
6657
- using OpRewritePattern::OpRewritePattern;
6658
-
6659
- LogicalResult matchAndRewrite (MaskOp maskOp,
6660
- PatternRewriter &rewriter) const override {
6661
- auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation ());
6662
- if (maskingOp.getMaskableOp ())
6663
- return failure ();
6664
-
6665
- if (!maskOp.isEmpty ())
6666
- return failure ();
6667
-
6668
- Block *block = maskOp.getMaskBlock ();
6669
- auto terminator = cast<vector::YieldOp>(block->front ());
6670
- if (terminator.getNumOperands () == 0 )
6671
- rewriter.eraseOp (maskOp);
6672
- else
6673
- rewriter.replaceOp (maskOp, terminator.getOperands ());
6674
-
6675
- return success ();
6676
- }
6677
- };
6678
-
6679
- void MaskOp::getCanonicalizationPatterns (RewritePatternSet &results,
6680
- MLIRContext *context) {
6681
- results.add <ElideEmptyMaskOp>(context);
6682
- }
6683
-
6684
6682
// MaskingOpInterface definitions.
6685
6683
6686
6684
// / Returns the operation masked by this 'vector.mask'.
0 commit comments