@@ -537,6 +537,10 @@ class ConversionPattern : public RewritePattern {
537
537
ConversionPatternRewriter &rewriter) const {
538
538
llvm_unreachable (" unimplemented rewrite" );
539
539
}
540
+ virtual void rewrite (Operation *op, ArrayRef<ValueRange> operands,
541
+ ConversionPatternRewriter &rewriter) const {
542
+ rewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
543
+ }
540
544
541
545
// / Hook for derived classes to implement combined matching and rewriting.
542
546
virtual LogicalResult
@@ -547,6 +551,11 @@ class ConversionPattern : public RewritePattern {
547
551
rewrite (op, operands, rewriter);
548
552
return success ();
549
553
}
554
+ virtual LogicalResult
555
+ matchAndRewrite (Operation *op, ArrayRef<ValueRange> operands,
556
+ ConversionPatternRewriter &rewriter) const {
557
+ return matchAndRewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
558
+ }
550
559
551
560
// / Attempt to match and rewrite the IR root at the specified operation.
552
561
LogicalResult matchAndRewrite (Operation *op,
@@ -574,6 +583,15 @@ class ConversionPattern : public RewritePattern {
574
583
: RewritePattern(std::forward<Args>(args)...),
575
584
typeConverter (&typeConverter) {}
576
585
586
+ // / Given an array of value ranges, which are the inputs to a 1:N adaptor,
587
+ // / try to extract the single value of each range to construct a the inputs
588
+ // / for a 1:1 adaptor.
589
+ // /
590
+ // / This function produces a fatal error if at least one range has 0 or
591
+ // / more than 1 value: "pattern 'name' does not support 1:N conversion"
592
+ SmallVector<Value>
593
+ getOneToOneAdaptorOperands (ArrayRef<ValueRange> operands) const ;
594
+
577
595
protected:
578
596
// / An optional type converter for use by this pattern.
579
597
const TypeConverter *typeConverter = nullptr ;
@@ -589,6 +607,8 @@ template <typename SourceOp>
589
607
class OpConversionPattern : public ConversionPattern {
590
608
public:
591
609
using OpAdaptor = typename SourceOp::Adaptor;
610
+ using OneToNOpAdaptor =
611
+ typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
592
612
593
613
OpConversionPattern (MLIRContext *context, PatternBenefit benefit = 1 )
594
614
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -607,12 +627,24 @@ class OpConversionPattern : public ConversionPattern {
607
627
auto sourceOp = cast<SourceOp>(op);
608
628
rewrite (sourceOp, OpAdaptor (operands, sourceOp), rewriter);
609
629
}
630
+ void rewrite (Operation *op, ArrayRef<ValueRange> operands,
631
+ ConversionPatternRewriter &rewriter) const final {
632
+ auto sourceOp = cast<SourceOp>(op);
633
+ rewrite (sourceOp, OneToNOpAdaptor (operands, sourceOp), rewriter);
634
+ }
610
635
LogicalResult
611
636
matchAndRewrite (Operation *op, ArrayRef<Value> operands,
612
637
ConversionPatternRewriter &rewriter) const final {
613
638
auto sourceOp = cast<SourceOp>(op);
614
639
return matchAndRewrite (sourceOp, OpAdaptor (operands, sourceOp), rewriter);
615
640
}
641
+ LogicalResult
642
+ matchAndRewrite (Operation *op, ArrayRef<ValueRange> operands,
643
+ ConversionPatternRewriter &rewriter) const final {
644
+ auto sourceOp = cast<SourceOp>(op);
645
+ return matchAndRewrite (sourceOp, OneToNOpAdaptor (operands, sourceOp),
646
+ rewriter);
647
+ }
616
648
617
649
// / Rewrite and Match methods that operate on the SourceOp type. These must be
618
650
// / overridden by the derived pattern class.
@@ -623,6 +655,12 @@ class OpConversionPattern : public ConversionPattern {
623
655
ConversionPatternRewriter &rewriter) const {
624
656
llvm_unreachable (" must override matchAndRewrite or a rewrite method" );
625
657
}
658
+ virtual void rewrite (SourceOp op, OneToNOpAdaptor adaptor,
659
+ ConversionPatternRewriter &rewriter) const {
660
+ SmallVector<Value> oneToOneOperands =
661
+ getOneToOneAdaptorOperands (adaptor.getOperands ());
662
+ rewrite (op, OpAdaptor (oneToOneOperands, adaptor), rewriter);
663
+ }
626
664
virtual LogicalResult
627
665
matchAndRewrite (SourceOp op, OpAdaptor adaptor,
628
666
ConversionPatternRewriter &rewriter) const {
@@ -631,6 +669,13 @@ class OpConversionPattern : public ConversionPattern {
631
669
rewrite (op, adaptor, rewriter);
632
670
return success ();
633
671
}
672
+ virtual LogicalResult
673
+ matchAndRewrite (SourceOp op, OneToNOpAdaptor adaptor,
674
+ ConversionPatternRewriter &rewriter) const {
675
+ SmallVector<Value> oneToOneOperands =
676
+ getOneToOneAdaptorOperands (adaptor.getOperands ());
677
+ return matchAndRewrite (op, OpAdaptor (oneToOneOperands, adaptor), rewriter);
678
+ }
634
679
635
680
private:
636
681
using ConversionPattern::matchAndRewrite;
@@ -656,18 +701,31 @@ class OpInterfaceConversionPattern : public ConversionPattern {
656
701
ConversionPatternRewriter &rewriter) const final {
657
702
rewrite (cast<SourceOp>(op), operands, rewriter);
658
703
}
704
+ void rewrite (Operation *op, ArrayRef<ValueRange> operands,
705
+ ConversionPatternRewriter &rewriter) const final {
706
+ rewrite (cast<SourceOp>(op), operands, rewriter);
707
+ }
659
708
LogicalResult
660
709
matchAndRewrite (Operation *op, ArrayRef<Value> operands,
661
710
ConversionPatternRewriter &rewriter) const final {
662
711
return matchAndRewrite (cast<SourceOp>(op), operands, rewriter);
663
712
}
713
+ LogicalResult
714
+ matchAndRewrite (Operation *op, ArrayRef<ValueRange> operands,
715
+ ConversionPatternRewriter &rewriter) const final {
716
+ return matchAndRewrite (cast<SourceOp>(op), operands, rewriter);
717
+ }
664
718
665
719
// / Rewrite and Match methods that operate on the SourceOp type. These must be
666
720
// / overridden by the derived pattern class.
667
721
virtual void rewrite (SourceOp op, ArrayRef<Value> operands,
668
722
ConversionPatternRewriter &rewriter) const {
669
723
llvm_unreachable (" must override matchAndRewrite or a rewrite method" );
670
724
}
725
+ virtual void rewrite (SourceOp op, ArrayRef<ValueRange> operands,
726
+ ConversionPatternRewriter &rewriter) const {
727
+ rewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
728
+ }
671
729
virtual LogicalResult
672
730
matchAndRewrite (SourceOp op, ArrayRef<Value> operands,
673
731
ConversionPatternRewriter &rewriter) const {
@@ -676,6 +734,11 @@ class OpInterfaceConversionPattern : public ConversionPattern {
676
734
rewrite (op, operands, rewriter);
677
735
return success ();
678
736
}
737
+ virtual LogicalResult
738
+ matchAndRewrite (SourceOp op, ArrayRef<ValueRange> operands,
739
+ ConversionPatternRewriter &rewriter) const {
740
+ return matchAndRewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
741
+ }
679
742
680
743
private:
681
744
using ConversionPattern::matchAndRewrite;
0 commit comments