Skip to content

Commit 3bc4083

Browse files
committed
[mlir][Vector] Move vector.mask canonicalization to folders
This MR moves the canonicalization that elides empty `vector.mask` ops to folders.
1 parent 6cac792 commit 3bc4083

File tree

3 files changed

+38
-38
lines changed

3 files changed

+38
-38
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2559,7 +2559,7 @@ def Vector_MaskOp : Vector_Op<"mask", [
25592559
Location loc);
25602560
}];
25612561

2562-
let hasCanonicalizer = 1;
2562+
let hasCanonicalizer = 0;
25632563
let hasFolder = 1;
25642564
let hasCustomAssemblyFormat = 1;
25652565
let hasVerifier = 1;

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 33 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6646,13 +6646,42 @@ LogicalResult MaskOp::verify() {
66466646
return success();
66476647
}
66486648

6649-
/// Folds vector.mask ops with an all-true mask.
6649+
/// Folds empty `vector.mask` with no passthru operand and with or without
6650+
/// return values. For example:
6651+
///
6652+
/// %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
6653+
/// vector<8xi1> -> vector<8xf32>
6654+
/// %1 = user_op %0 : vector<8xf32>
6655+
///
6656+
/// becomes:
6657+
///
6658+
/// %0 = user_op %a : vector<8xf32>
6659+
///
6660+
/// `vector.mask` with a passthru is handled by the canonicalizer.
6661+
///
6662+
static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
6663+
SmallVectorImpl<OpFoldResult> &results) {
6664+
if (!maskOp.isEmpty() || maskOp.hasPassthru())
6665+
return failure();
6666+
6667+
Block *block = maskOp.getMaskBlock();
6668+
auto terminator = cast<vector::YieldOp>(block->front());
6669+
if (terminator.getNumOperands() == 0) {
6670+
// `vector.mask` has no results, just remove the `vector.mask`.
6671+
return success();
6672+
}
6673+
6674+
// `vector.mask` has results, propagate the results.
6675+
llvm::append_range(results, terminator.getOperands());
6676+
return success();
6677+
}
6678+
66506679
LogicalResult MaskOp::fold(FoldAdaptor adaptor,
66516680
SmallVectorImpl<OpFoldResult> &results) {
6652-
MaskFormat maskFormat = getMaskFormat(getMask());
6653-
if (isEmpty())
6654-
return failure();
6681+
if (succeeded(foldEmptyMaskOp(*this, adaptor, results)))
6682+
return success();
66556683

6684+
MaskFormat maskFormat = getMaskFormat(getMask());
66566685
if (maskFormat != MaskFormat::AllTrue)
66576686
return failure();
66586687

@@ -6665,37 +6694,6 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
66656694
return success();
66666695
}
66676696

6668-
// Elides empty vector.mask operations with or without return values. Propagates
6669-
// the yielded values by the vector.yield terminator, if any, or erases the op,
6670-
// otherwise.
6671-
class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {
6672-
using OpRewritePattern::OpRewritePattern;
6673-
6674-
LogicalResult matchAndRewrite(MaskOp maskOp,
6675-
PatternRewriter &rewriter) const override {
6676-
auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
6677-
if (maskingOp.getMaskableOp())
6678-
return failure();
6679-
6680-
if (!maskOp.isEmpty())
6681-
return failure();
6682-
6683-
Block *block = maskOp.getMaskBlock();
6684-
auto terminator = cast<vector::YieldOp>(block->front());
6685-
if (terminator.getNumOperands() == 0)
6686-
rewriter.eraseOp(maskOp);
6687-
else
6688-
rewriter.replaceOp(maskOp, terminator.getOperands());
6689-
6690-
return success();
6691-
}
6692-
};
6693-
6694-
void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
6695-
MLIRContext *context) {
6696-
results.add<ElideEmptyMaskOp>(context);
6697-
}
6698-
66996697
// MaskingOpInterface definitions.
67006698

67016699
/// Returns the operation masked by this 'vector.mask'.

mlir/test/Conversion/GPUCommon/lower-vector.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
// RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s
22

33
module {
4+
// CHECK-LABEL: func @func
5+
// CHECK-SAME: %[[IN:.*]]: vector<11xf32>
46
func.func @func(%arg: vector<11xf32>) -> vector<11xf32> {
57
%cst_41 = arith.constant dense<true> : vector<11xi1>
6-
// CHECK: vector.mask
7-
// CHECK-SAME: vector.yield %arg0
8+
// CHECK-NOT: vector.mask
9+
// CHECK: return %[[IN]] : vector<11xf32>
810
%127 = vector.mask %cst_41 { vector.yield %arg : vector<11xf32> } : vector<11xi1> -> vector<11xf32>
911
return %127 : vector<11xf32>
1012
}

0 commit comments

Comments
 (0)