Skip to content

Commit d6f394e

Browse files
authored
[mlir][Vector] Move vector.mask canonicalization to folder (#140324)
This MR moves the canonicalization that elides empty `vector.mask` ops to folders.
1 parent 12c62eb commit d6f394e

File tree

3 files changed

+35
-38
lines changed

3 files changed

+35
-38
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2559,7 +2559,6 @@ def Vector_MaskOp : Vector_Op<"mask", [
25592559
Location loc);
25602560
}];
25612561

2562-
let hasCanonicalizer = 1;
25632562
let hasFolder = 1;
25642563
let hasCustomAssemblyFormat = 1;
25652564
let hasVerifier = 1;

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6650,13 +6650,40 @@ LogicalResult MaskOp::verify() {
66506650
return success();
66516651
}
66526652

6653-
/// Folds vector.mask ops with an all-true mask.
6653+
/// Folds empty `vector.mask` with no passthru operand and with or without
6654+
/// return values. For example:
6655+
///
6656+
/// %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
6657+
/// vector<8xi1> -> vector<8xf32>
6658+
/// %1 = user_op %0 : vector<8xf32>
6659+
///
6660+
/// becomes:
6661+
///
6662+
/// %0 = user_op %a : vector<8xf32>
6663+
///
6664+
static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
6665+
SmallVectorImpl<OpFoldResult> &results) {
6666+
if (!maskOp.isEmpty() || maskOp.hasPassthru())
6667+
return failure();
6668+
6669+
Block *block = maskOp.getMaskBlock();
6670+
auto terminator = cast<vector::YieldOp>(block->front());
6671+
if (terminator.getNumOperands() == 0) {
6672+
// `vector.mask` has no results, just remove the `vector.mask`.
6673+
return success();
6674+
}
6675+
6676+
// `vector.mask` has results, propagate the results.
6677+
llvm::append_range(results, terminator.getOperands());
6678+
return success();
6679+
}
6680+
66546681
LogicalResult MaskOp::fold(FoldAdaptor adaptor,
66556682
SmallVectorImpl<OpFoldResult> &results) {
6656-
MaskFormat maskFormat = getMaskFormat(getMask());
6657-
if (isEmpty())
6658-
return failure();
6683+
if (succeeded(foldEmptyMaskOp(*this, adaptor, results)))
6684+
return success();
66596685

6686+
MaskFormat maskFormat = getMaskFormat(getMask());
66606687
if (maskFormat != MaskFormat::AllTrue)
66616688
return failure();
66626689

@@ -6669,37 +6696,6 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
66696696
return success();
66706697
}
66716698

6672-
// Elides empty vector.mask operations with or without return values. Propagates
6673-
// the yielded values by the vector.yield terminator, if any, or erases the op,
6674-
// otherwise.
6675-
class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {
6676-
using OpRewritePattern::OpRewritePattern;
6677-
6678-
LogicalResult matchAndRewrite(MaskOp maskOp,
6679-
PatternRewriter &rewriter) const override {
6680-
auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
6681-
if (maskingOp.getMaskableOp())
6682-
return failure();
6683-
6684-
if (!maskOp.isEmpty())
6685-
return failure();
6686-
6687-
Block *block = maskOp.getMaskBlock();
6688-
auto terminator = cast<vector::YieldOp>(block->front());
6689-
if (terminator.getNumOperands() == 0)
6690-
rewriter.eraseOp(maskOp);
6691-
else
6692-
rewriter.replaceOp(maskOp, terminator.getOperands());
6693-
6694-
return success();
6695-
}
6696-
};
6697-
6698-
void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
6699-
MLIRContext *context) {
6700-
results.add<ElideEmptyMaskOp>(context);
6701-
}
6702-
67036699
// MaskingOpInterface definitions.
67046700

67056701
/// Returns the operation masked by this 'vector.mask'.

mlir/test/Conversion/GPUCommon/lower-vector.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
// RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s
22

33
module {
4+
// CHECK-LABEL: func @func
5+
// CHECK-SAME: %[[IN:.*]]: vector<11xf32>
46
func.func @func(%arg: vector<11xf32>) -> vector<11xf32> {
57
%cst_41 = arith.constant dense<true> : vector<11xi1>
6-
// CHECK: vector.mask
7-
// CHECK-SAME: vector.yield %arg0
8+
// CHECK-NOT: vector.mask
9+
// CHECK: return %[[IN]] : vector<11xf32>
810
%127 = vector.mask %cst_41 { vector.yield %arg : vector<11xf32> } : vector<11xi1> -> vector<11xf32>
911
return %127 : vector<11xf32>
1012
}

0 commit comments

Comments
 (0)