-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir] [vector] Add linearization pattern for vector.create_mask #138214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
3a83e2d
203ec82
24b0739
2b8a653
8e8de7a
528f913
c2c1a22
c5b2e81
8fca9c1
2b7e06e
db484c9
7b714ce
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -445,6 +445,64 @@ struct LinearizeVectorSplat final | |
} | ||
}; | ||
|
||
/// This pattern converts the CreateMaskOp to work on a | ||
/// linearized vector. The pattern currently | ||
/// supports only 2D masks with a unit outer dimension. | ||
nbpatel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
/// Following, | ||
/// vector.create_mask %dims : vector<1x4xi1> | ||
/// is converted to: | ||
/// %out_1d = vector.create_mask %dims : vector<4xi1> | ||
/// %out_nd = vector.shape_cast %out_1d : vector<4xi1> to vector<1x4xi1> | ||
struct LinearizeVectorCreateMask final | ||
: OpConversionPattern<vector::CreateMaskOp> { | ||
using OpConversionPattern::OpConversionPattern; | ||
|
||
LinearizeVectorCreateMask(const TypeConverter &typeConverter, | ||
MLIRContext *context, PatternBenefit benefit = 1) | ||
: OpConversionPattern(typeConverter, context, benefit) {} | ||
|
||
LogicalResult | ||
matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
auto srcTy = createMaskOp.getType(); | ||
nbpatel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
auto srcShape = srcTy.getShape(); | ||
if (srcShape.size() != 2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider adding |
||
return rewriter.notifyMatchFailure(createMaskOp, | ||
"only 2D mask is supported."); | ||
|
||
if (srcShape[0] != 1) | ||
return rewriter.notifyMatchFailure( | ||
createMaskOp, "only unit outer dimension is supported."); | ||
|
||
auto dstTy = getTypeConverter()->convertType(srcTy); | ||
if (!dstTy) | ||
return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type."); | ||
|
||
// Compare the first operand with 0. If it's less than or equal to 0, | ||
// create a zero mask, else strip the first operand and create a mask | ||
// using the second operand. | ||
auto firstOperand = adaptor.getOperands().front(); | ||
auto zero = | ||
rewriter.create<mlir::arith::ConstantIndexOp>(createMaskOp.getLoc(), 0); | ||
nbpatel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
auto isZeroOrNegative = rewriter.create<mlir::arith::CmpIOp>( | ||
createMaskOp.getLoc(), mlir::arith::CmpIPredicate::sle, firstOperand, | ||
zero); | ||
auto isZeroOrNegativeSplat = rewriter.create<mlir::vector::SplatOp>( | ||
createMaskOp.getLoc(), dstTy, isZeroOrNegative); | ||
|
||
// Use a select operation to choose between the masks. | ||
auto zeroMask = rewriter.create<mlir::arith::ConstantOp>( | ||
nbpatel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
createMaskOp.getLoc(), dstTy, rewriter.getZeroAttr(dstTy)); | ||
auto newMask = rewriter.create<mlir::vector::CreateMaskOp>( | ||
createMaskOp.getLoc(), dstTy, adaptor.getOperands().back()); | ||
auto result = rewriter.create<mlir::arith::SelectOp>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We may want to use some There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for updating the test @nbpatel to use non-constant operands. nit: my guess is that most of the time, the unit dimension will have extent which is the constant 1, and so |
||
createMaskOp.getLoc(), isZeroOrNegativeSplat, zeroMask, newMask); | ||
|
||
rewriter.replaceOp(createMaskOp, result.getResult()); | ||
nbpatel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
/// Return true if the operation `op` does not support scalable vectors and | ||
|
@@ -530,9 +588,10 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter, | |
void mlir::vector::populateVectorLinearizeBasePatterns( | ||
const TypeConverter &typeConverter, const ConversionTarget &target, | ||
RewritePatternSet &patterns) { | ||
patterns.add<LinearizeConstantLike, LinearizeVectorizable, | ||
LinearizeVectorBitCast, LinearizeVectorSplat>( | ||
typeConverter, patterns.getContext()); | ||
patterns | ||
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast, | ||
LinearizeVectorSplat, LinearizeVectorCreateMask>( | ||
typeConverter, patterns.getContext()); | ||
} | ||
|
||
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -345,3 +345,41 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> { | |
%0 = vector.splat %arg0 : vector<4x[2]xi32> | ||
return %0 : vector<4x[2]xi32> | ||
} | ||
|
||
// ----- | ||
// ALL-LABEL: linearize_create_mask | ||
func.func @linearize_create_mask() -> vector<1x16xi1> { | ||
// CHECK: %[[C0:.*]] = arith.constant 0 : index | ||
// CHECK: %[[C20:.*]] = arith.constant 20 : index | ||
// CHECK: %[[C0_0:.*]] = arith.constant 0 : index | ||
// CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index | ||
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1> | ||
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<16xi1> | ||
// CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<16xi1> | ||
// CHECK: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1>, vector<16xi1> | ||
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1> | ||
// CHECK: return %[[CAST]] : vector<1x16xi1> | ||
%c0 = arith.constant 0 : index | ||
%c20 = arith.constant 20 : index | ||
nbpatel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
%0 = vector.create_mask %c0, %c20 : vector<1x16xi1> | ||
return %0 : vector<1x16xi1> | ||
} | ||
|
||
// ----- | ||
// ALL-LABEL: linearize_scalable_create_mask | ||
func.func @linearize_scalable_create_mask() -> vector<1x[16]xi1> { | ||
nbpatel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// CHECK: %[[C0:.*]] = arith.constant 0 : index | ||
// CHECK: %[[C20:.*]] = arith.constant 20 : index | ||
// CHECK: %[[C0_0:.*]] = arith.constant 0 : index | ||
// CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index | ||
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<[16]xi1> | ||
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<[16]xi1> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. most of these checks should be folded with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alternatively, use an |
||
// CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<[16]xi1> | ||
// CHECK: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<[16]xi1>, vector<[16]xi1> | ||
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<[16]xi1> to vector<1x[16]xi1> | ||
// CHECK: return %[[CAST]] : vector<1x[16]xi1> | ||
%c0 = arith.constant 0 : index | ||
%c20 = arith.constant 20 : index | ||
%0 = vector.create_mask %c0, %c20 : vector<1x[16]xi1> | ||
return %0 : vector<1x[16]xi1> | ||
} |
Uh oh!
There was an error while loading. Please reload this page.