@@ -566,6 +566,64 @@ struct LinearizeVectorSplat final
566
566
}
567
567
};
568
568
569
+ // / This pattern converts the CreateMaskOp to work on a linearized vector.
570
+ // / It currently supports only 2D masks with a unit outer dimension.
571
+ // / Following,
572
+ // / vector.create_mask %arg0, %arg1 : vector<1x4xi1>
573
+ // / is converted to:
574
+ // / %zero = arith.constant 0 : index
575
+ // / %cmpi = arith.cmpi sgt, %arg0, %zero : index
576
+ // / %index = arith.index_cast %cmpi : i1 to index
577
+ // / %mul = arith.andi %index, %arg1 : index
578
+ // / %mask = vector.create_mask %mul : vector<4xi1>
579
+ // / %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
580
+ struct LinearizeVectorCreateMask final
581
+ : OpConversionPattern<vector::CreateMaskOp> {
582
+ using OpConversionPattern::OpConversionPattern;
583
+
584
+ LinearizeVectorCreateMask (const TypeConverter &typeConverter,
585
+ MLIRContext *context, PatternBenefit benefit = 1 )
586
+ : OpConversionPattern(typeConverter, context, benefit) {}
587
+
588
+ LogicalResult
589
+ matchAndRewrite (vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
590
+ ConversionPatternRewriter &rewriter) const override {
591
+ Location loc = createMaskOp.getLoc ();
592
+ VectorType srcTy = createMaskOp.getType ();
593
+ auto srcShape = srcTy.getShape ();
594
+ if (srcShape.size () != 2 )
595
+ return rewriter.notifyMatchFailure (createMaskOp,
596
+ " only 2D mask is supported." );
597
+
598
+ if (srcShape[0 ] != 1 )
599
+ return rewriter.notifyMatchFailure (
600
+ createMaskOp, " only unit outer dimension is supported." );
601
+
602
+ auto dstTy = getTypeConverter ()->convertType (srcTy);
603
+ if (!dstTy)
604
+ return rewriter.notifyMatchFailure (createMaskOp, " cannot convert type." );
605
+
606
+ // Compare the first operand with 0. If it is greater than 0, the
607
+ // corresponding mask element is set to true, otherwise false.
608
+ // The result of the comparison is then multiplied with
609
+ // the second operand of create_mask to get the 1D mask.
610
+ auto firstOperand = adaptor.getOperands ().front ();
611
+ auto zero = rewriter.create <mlir::arith::ConstantIndexOp>(loc, 0 );
612
+ auto isNonZero = rewriter.createOrFold <mlir::arith::CmpIOp>(
613
+ loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero);
614
+ auto isNonZeroIndex = rewriter.createOrFold <mlir::arith::IndexCastOp>(
615
+ loc, rewriter.getIndexType (), isNonZero);
616
+ auto secondOperand = adaptor.getOperands ().back ();
617
+ auto maskSize = rewriter.createOrFold <mlir::arith::AndIOp>(
618
+ loc, rewriter.getIndexType (), isNonZeroIndex, secondOperand);
619
+
620
+ auto newMask =
621
+ rewriter.create <mlir::vector::CreateMaskOp>(loc, dstTy, maskSize);
622
+ rewriter.replaceOp (createMaskOp, newMask);
623
+ return success ();
624
+ }
625
+ };
626
+
569
627
} // namespace
570
628
571
629
// / Return true if the operation `op` does not support scalable vectors and
@@ -651,9 +709,10 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
651
709
void mlir::vector::populateVectorLinearizeBasePatterns (
652
710
const TypeConverter &typeConverter, const ConversionTarget &target,
653
711
RewritePatternSet &patterns) {
654
- patterns.add <LinearizeConstantLike, LinearizeVectorizable,
655
- LinearizeVectorBitCast, LinearizeVectorSplat>(
656
- typeConverter, patterns.getContext ());
712
+ patterns
713
+ .add <LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
714
+ LinearizeVectorSplat, LinearizeVectorCreateMask>(
715
+ typeConverter, patterns.getContext ());
657
716
}
658
717
659
718
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns (
0 commit comments