@@ -6650,13 +6650,40 @@ LogicalResult MaskOp::verify() {
6650
6650
return success ();
6651
6651
}
6652
6652
6653
- // / Folds vector.mask ops with an all-true mask.
6653
+ // / Folds empty `vector.mask` with no passthru operand and with or without
6654
+ // / return values. For example:
6655
+ // /
6656
+ // / %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
6657
+ // / vector<8xi1> -> vector<8xf32>
6658
+ // / %1 = user_op %0 : vector<8xf32>
6659
+ // /
6660
+ // / becomes:
6661
+ // /
6662
+ // / %0 = user_op %a : vector<8xf32>
6663
+ // /
6664
+ static LogicalResult foldEmptyMaskOp (MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
6665
+ SmallVectorImpl<OpFoldResult> &results) {
6666
+ if (!maskOp.isEmpty () || maskOp.hasPassthru ())
6667
+ return failure ();
6668
+
6669
+ Block *block = maskOp.getMaskBlock ();
6670
+ auto terminator = cast<vector::YieldOp>(block->front ());
6671
+ if (terminator.getNumOperands () == 0 ) {
6672
+ // `vector.mask` has no results, just remove the `vector.mask`.
6673
+ return success ();
6674
+ }
6675
+
6676
+ // `vector.mask` has results, propagate the results.
6677
+ llvm::append_range (results, terminator.getOperands ());
6678
+ return success ();
6679
+ }
6680
+
6654
6681
LogicalResult MaskOp::fold (FoldAdaptor adaptor,
6655
6682
SmallVectorImpl<OpFoldResult> &results) {
6656
- MaskFormat maskFormat = getMaskFormat (getMask ());
6657
- if (isEmpty ())
6658
- return failure ();
6683
+ if (succeeded (foldEmptyMaskOp (*this , adaptor, results)))
6684
+ return success ();
6659
6685
6686
+ MaskFormat maskFormat = getMaskFormat (getMask ());
6660
6687
if (maskFormat != MaskFormat::AllTrue)
6661
6688
return failure ();
6662
6689
@@ -6669,37 +6696,6 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
6669
6696
return success ();
6670
6697
}
6671
6698
6672
- // Elides empty vector.mask operations with or without return values. Propagates
6673
- // the yielded values by the vector.yield terminator, if any, or erases the op,
6674
- // otherwise.
6675
- class ElideEmptyMaskOp : public OpRewritePattern <MaskOp> {
6676
- using OpRewritePattern::OpRewritePattern;
6677
-
6678
- LogicalResult matchAndRewrite (MaskOp maskOp,
6679
- PatternRewriter &rewriter) const override {
6680
- auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation ());
6681
- if (maskingOp.getMaskableOp ())
6682
- return failure ();
6683
-
6684
- if (!maskOp.isEmpty ())
6685
- return failure ();
6686
-
6687
- Block *block = maskOp.getMaskBlock ();
6688
- auto terminator = cast<vector::YieldOp>(block->front ());
6689
- if (terminator.getNumOperands () == 0 )
6690
- rewriter.eraseOp (maskOp);
6691
- else
6692
- rewriter.replaceOp (maskOp, terminator.getOperands ());
6693
-
6694
- return success ();
6695
- }
6696
- };
6697
-
6698
- void MaskOp::getCanonicalizationPatterns (RewritePatternSet &results,
6699
- MLIRContext *context) {
6700
- results.add <ElideEmptyMaskOp>(context);
6701
- }
6702
-
6703
6699
// MaskingOpInterface definitions.
6704
6700
6705
6701
// / Returns the operation masked by this 'vector.mask'.
0 commit comments