Skip to content

Commit 1778d3b

Browse files
authored
[mlir] [vector] Add linearization pattern for vector.create_mask (#138214)
This PR is a breakdown [3 / 4] of the PR #136193 The PR adds linearization patterns for vector.create_mask
1 parent 1043810 commit 1778d3b

File tree

3 files changed

+88
-4
lines changed

3 files changed

+88
-4
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,64 @@ struct LinearizeVectorSplat final
566566
}
567567
};
568568

569+
/// This pattern converts the CreateMaskOp to work on a linearized vector.
570+
/// It currently supports only 2D masks with a unit outer dimension.
571+
/// Following,
572+
/// vector.create_mask %arg0, %arg1 : vector<1x4xi1>
573+
/// is converted to:
574+
/// %zero = arith.constant 0 : index
575+
/// %cmpi = arith.cmpi sgt, %arg0, %zero : index
576+
/// %index = arith.index_cast %cmpi : i1 to index
577+
/// %mul = arith.andi %index, %arg1 : index
578+
/// %mask = vector.create_mask %mul : vector<4xi1>
579+
/// %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
580+
struct LinearizeVectorCreateMask final
581+
: OpConversionPattern<vector::CreateMaskOp> {
582+
using OpConversionPattern::OpConversionPattern;
583+
584+
LinearizeVectorCreateMask(const TypeConverter &typeConverter,
585+
MLIRContext *context, PatternBenefit benefit = 1)
586+
: OpConversionPattern(typeConverter, context, benefit) {}
587+
588+
LogicalResult
589+
matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
590+
ConversionPatternRewriter &rewriter) const override {
591+
Location loc = createMaskOp.getLoc();
592+
VectorType srcTy = createMaskOp.getType();
593+
auto srcShape = srcTy.getShape();
594+
if (srcShape.size() != 2)
595+
return rewriter.notifyMatchFailure(createMaskOp,
596+
"only 2D mask is supported.");
597+
598+
if (srcShape[0] != 1)
599+
return rewriter.notifyMatchFailure(
600+
createMaskOp, "only unit outer dimension is supported.");
601+
602+
auto dstTy = getTypeConverter()->convertType(srcTy);
603+
if (!dstTy)
604+
return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
605+
606+
// Compare the first operand with 0. If it is greater than 0, the
607+
// corresponding mask element is set to true, otherwise false.
608+
// The result of the comparison is then multiplied with
609+
// the second operand of create_mask to get the 1D mask.
610+
auto firstOperand = adaptor.getOperands().front();
611+
auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
612+
auto isNonZero = rewriter.createOrFold<mlir::arith::CmpIOp>(
613+
loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero);
614+
auto isNonZeroIndex = rewriter.createOrFold<mlir::arith::IndexCastOp>(
615+
loc, rewriter.getIndexType(), isNonZero);
616+
auto secondOperand = adaptor.getOperands().back();
617+
auto maskSize = rewriter.createOrFold<mlir::arith::AndIOp>(
618+
loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand);
619+
620+
auto newMask =
621+
rewriter.create<mlir::vector::CreateMaskOp>(loc, dstTy, maskSize);
622+
rewriter.replaceOp(createMaskOp, newMask);
623+
return success();
624+
}
625+
};
626+
569627
} // namespace
570628

571629
/// Return true if the operation `op` does not support scalable vectors and
@@ -651,9 +709,10 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
651709
void mlir::vector::populateVectorLinearizeBasePatterns(
652710
const TypeConverter &typeConverter, const ConversionTarget &target,
653711
RewritePatternSet &patterns) {
654-
patterns.add<LinearizeConstantLike, LinearizeVectorizable,
655-
LinearizeVectorBitCast, LinearizeVectorSplat>(
656-
typeConverter, patterns.getContext());
712+
patterns
713+
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
714+
LinearizeVectorSplat, LinearizeVectorCreateMask>(
715+
typeConverter, patterns.getContext());
657716
}
658717

659718
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,3 +416,28 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
416416
return %0 : vector<4x[2]xi32>
417417
}
418418

419+
// -----
420+
421+
// CHECK-LABEL: linearize_create_mask
422+
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<1x16xi1>
423+
func.func @linearize_create_mask(%arg0 : index, %arg1 : index) -> vector<1x16xi1> {
424+
425+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
426+
// CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[ARG0]], %[[C0]] : index
427+
// CHECK: %[[INDEXCAST:.*]] = arith.index_cast %[[CMP]] : i1 to index
428+
// CHECK: %[[MULI:.*]] = arith.andi %[[INDEXCAST]], %[[ARG1]] : index
429+
// CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[MULI]] : vector<16xi1>
430+
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[MASK_1D]] : vector<16xi1> to vector<1x16xi1>
431+
// CHECK: return %[[CAST]] : vector<1x16xi1>
432+
%0 = vector.create_mask %arg0, %arg1 : vector<1x16xi1>
433+
return %0 : vector<1x16xi1>
434+
}
435+
436+
// -----
437+
// CHECK-LABEL: linearize_scalable_create_mask
438+
func.func @linearize_scalable_create_mask(%arg0 : index, %arg1 : index) -> vector<1x[16]xi1> {
439+
440+
// CHECK: %[[MASK_1D:.*]] = vector.create_mask {{%.*}} : vector<[16]xi1>
441+
%0 = vector.create_mask %arg0, %arg1 : vector<1x[16]xi1>
442+
return %0 : vector<1x[16]xi1>
443+
}

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -973,7 +973,7 @@ struct TestVectorLinearize final
973973
return "Linearizes ND vectors for N >= 2 into 1D vectors";
974974
}
975975
void getDependentDialects(DialectRegistry &registry) const override {
976-
registry.insert<vector::VectorDialect>();
976+
registry.insert<vector::VectorDialect, arith::ArithDialect>();
977977
}
978978

979979
void runOnOperation() override {

0 commit comments

Comments
 (0)