@@ -6646,13 +6646,42 @@ LogicalResult MaskOp::verify() {
6646
6646
return success ();
6647
6647
}
6648
6648
6649
- // / Folds vector.mask ops with an all-true mask.
6649
+ // / Folds empty `vector.mask` with no passthru operand and with or without
6650
+ // / return values. For example:
6651
+ // /
6652
+ // / %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
6653
+ // / vector<8xi1> -> vector<8xf32>
6654
+ // / %1 = user_op %0 : vector<8xf32>
6655
+ // /
6656
+ // / becomes:
6657
+ // /
6658
+ // / %0 = user_op %a : vector<8xf32>
6659
+ // /
6660
+ // / `vector.mask` with a passthru is handled by the canonicalizer.
6661
+ // /
6662
+ static LogicalResult foldEmptyMaskOp (MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
6663
+ SmallVectorImpl<OpFoldResult> &results) {
6664
+ if (!maskOp.isEmpty () || maskOp.hasPassthru ())
6665
+ return failure ();
6666
+
6667
+ Block *block = maskOp.getMaskBlock ();
6668
+ auto terminator = cast<vector::YieldOp>(block->front ());
6669
+ if (terminator.getNumOperands () == 0 ) {
6670
+ // `vector.mask` has no results, just remove the `vector.mask`.
6671
+ return success ();
6672
+ }
6673
+
6674
+ // `vector.mask` has results, propagate the results.
6675
+ llvm::append_range (results, terminator.getOperands ());
6676
+ return success ();
6677
+ }
6678
+
6650
6679
LogicalResult MaskOp::fold (FoldAdaptor adaptor,
6651
6680
SmallVectorImpl<OpFoldResult> &results) {
6652
- MaskFormat maskFormat = getMaskFormat (getMask ());
6653
- if (isEmpty ())
6654
- return failure ();
6681
+ if (succeeded (foldEmptyMaskOp (*this , adaptor, results)))
6682
+ return success ();
6655
6683
6684
+ MaskFormat maskFormat = getMaskFormat (getMask ());
6656
6685
if (maskFormat != MaskFormat::AllTrue)
6657
6686
return failure ();
6658
6687
@@ -6665,37 +6694,6 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
6665
6694
return success ();
6666
6695
}
6667
6696
6668
- // Elides empty vector.mask operations with or without return values. Propagates
6669
- // the yielded values by the vector.yield terminator, if any, or erases the op,
6670
- // otherwise.
6671
- class ElideEmptyMaskOp : public OpRewritePattern <MaskOp> {
6672
- using OpRewritePattern::OpRewritePattern;
6673
-
6674
- LogicalResult matchAndRewrite (MaskOp maskOp,
6675
- PatternRewriter &rewriter) const override {
6676
- auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation ());
6677
- if (maskingOp.getMaskableOp ())
6678
- return failure ();
6679
-
6680
- if (!maskOp.isEmpty ())
6681
- return failure ();
6682
-
6683
- Block *block = maskOp.getMaskBlock ();
6684
- auto terminator = cast<vector::YieldOp>(block->front ());
6685
- if (terminator.getNumOperands () == 0 )
6686
- rewriter.eraseOp (maskOp);
6687
- else
6688
- rewriter.replaceOp (maskOp, terminator.getOperands ());
6689
-
6690
- return success ();
6691
- }
6692
- };
6693
-
6694
- void MaskOp::getCanonicalizationPatterns (RewritePatternSet &results,
6695
- MLIRContext *context) {
6696
- results.add <ElideEmptyMaskOp>(context);
6697
- }
6698
-
6699
6697
// MaskingOpInterface definitions.
6700
6698
6701
6699
// / Returns the operation masked by this 'vector.mask'.
0 commit comments