Skip to content

Commit 1014ee5

Browse files
committed
[mlir][Vector] Move vector.mask canonicalization to folders
This MR moves the canonicalization that elides empty `vector.mask` ops to folders.
1 parent 286ab11 commit 1014ee5

File tree

3 files changed

+38
-38
lines changed

3 files changed

+38
-38
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2554,7 +2554,7 @@ def Vector_MaskOp : Vector_Op<"mask", [
25542554
Location loc);
25552555
}];
25562556

2557-
let hasCanonicalizer = 1;
2557+
let hasCanonicalizer = 0;
25582558
let hasFolder = 1;
25592559
let hasCustomAssemblyFormat = 1;
25602560
let hasVerifier = 1;

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 33 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6631,13 +6631,42 @@ LogicalResult MaskOp::verify() {
66316631
return success();
66326632
}
66336633

6634-
/// Folds vector.mask ops with an all-true mask.
6634+
/// Folds empty `vector.mask` with no passthru operand and with or without
6635+
/// return values. For example:
6636+
///
6637+
/// %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
6638+
/// vector<8xi1> -> vector<8xf32>
6639+
/// %1 = user_op %0 : vector<8xf32>
6640+
///
6641+
/// becomes:
6642+
///
6643+
/// %0 = user_op %a : vector<8xf32>
6644+
///
6645+
/// `vector.mask` with a passthru is handled by the canonicalizer.
6646+
///
6647+
static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
6648+
SmallVectorImpl<OpFoldResult> &results) {
6649+
if (!maskOp.isEmpty() || maskOp.hasPassthru())
6650+
return failure();
6651+
6652+
Block *block = maskOp.getMaskBlock();
6653+
auto terminator = cast<vector::YieldOp>(block->front());
6654+
if (terminator.getNumOperands() == 0) {
6655+
// `vector.mask` has no results, just remove the `vector.mask`.
6656+
return success();
6657+
}
6658+
6659+
// `vector.mask` has results, propagate the results.
6660+
llvm::append_range(results, terminator.getOperands());
6661+
return success();
6662+
}
6663+
66356664
LogicalResult MaskOp::fold(FoldAdaptor adaptor,
66366665
SmallVectorImpl<OpFoldResult> &results) {
6637-
MaskFormat maskFormat = getMaskFormat(getMask());
6638-
if (isEmpty())
6639-
return failure();
6666+
if (succeeded(foldEmptyMaskOp(*this, adaptor, results)))
6667+
return success();
66406668

6669+
MaskFormat maskFormat = getMaskFormat(getMask());
66416670
if (maskFormat != MaskFormat::AllTrue)
66426671
return failure();
66436672

@@ -6650,37 +6679,6 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
66506679
return success();
66516680
}
66526681

6653-
// Elides empty vector.mask operations with or without return values. Propagates
6654-
// the yielded values by the vector.yield terminator, if any, or erases the op,
6655-
// otherwise.
6656-
class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {
6657-
using OpRewritePattern::OpRewritePattern;
6658-
6659-
LogicalResult matchAndRewrite(MaskOp maskOp,
6660-
PatternRewriter &rewriter) const override {
6661-
auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
6662-
if (maskingOp.getMaskableOp())
6663-
return failure();
6664-
6665-
if (!maskOp.isEmpty())
6666-
return failure();
6667-
6668-
Block *block = maskOp.getMaskBlock();
6669-
auto terminator = cast<vector::YieldOp>(block->front());
6670-
if (terminator.getNumOperands() == 0)
6671-
rewriter.eraseOp(maskOp);
6672-
else
6673-
rewriter.replaceOp(maskOp, terminator.getOperands());
6674-
6675-
return success();
6676-
}
6677-
};
6678-
6679-
void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
6680-
MLIRContext *context) {
6681-
results.add<ElideEmptyMaskOp>(context);
6682-
}
6683-
66846682
// MaskingOpInterface definitions.
66856683

66866684
/// Returns the operation masked by this 'vector.mask'.

mlir/test/Conversion/GPUCommon/lower-vector.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
// RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s
22

33
module {
4+
// CHECK-LABEL: func @func
5+
// CHECK-SAME: %[[IN:.*]]: vector<11xf32>
46
func.func @func(%arg: vector<11xf32>) -> vector<11xf32> {
57
%cst_41 = arith.constant dense<true> : vector<11xi1>
6-
// CHECK: vector.mask
7-
// CHECK-SAME: vector.yield %arg0
8+
// CHECK-NOT: vector.mask
9+
// CHECK: return %[[IN]] : vector<11xf32>
810
%127 = vector.mask %cst_41 { vector.yield %arg : vector<11xf32> } : vector<11xi1> -> vector<11xf32>
911
return %127 : vector<11xf32>
1012
}

0 commit comments

Comments
 (0)