@@ -618,6 +618,66 @@ struct VectorInterleaveOpConvert final
618
618
}
619
619
};
620
620
621
+ struct VectorDeinterleaveOpConvert final
622
+ : public OpConversionPattern<vector::DeinterleaveOp> {
623
+ using OpConversionPattern::OpConversionPattern;
624
+
625
+ LogicalResult
626
+ matchAndRewrite (vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
627
+ ConversionPatternRewriter &rewriter) const override {
628
+
629
+ // Check the result vector type.
630
+ VectorType oldResultType = deinterleaveOp.getResultVectorType ();
631
+ Type newResultType = getTypeConverter ()->convertType (oldResultType);
632
+ if (!newResultType)
633
+ return rewriter.notifyMatchFailure (deinterleaveOp,
634
+ " unsupported result vector type" );
635
+
636
+ Location loc = deinterleaveOp->getLoc ();
637
+
638
+ // Deinterleave the indices.
639
+ Value sourceVector = adaptor.getSource ();
640
+ VectorType sourceType = deinterleaveOp.getSourceVectorType ();
641
+ int n = sourceType.getNumElements ();
642
+
643
+ // Output vectors of size 1 are converted to scalars by the type converter.
644
+ // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
645
+ // use `spirv::CompositeExtractOp`.
646
+ if (n == 2 ) {
647
+ auto elem0 = rewriter.create <spirv::CompositeExtractOp>(
648
+ loc, newResultType, sourceVector, rewriter.getI32ArrayAttr ({0 }));
649
+
650
+ auto elem1 = rewriter.create <spirv::CompositeExtractOp>(
651
+ loc, newResultType, sourceVector, rewriter.getI32ArrayAttr ({1 }));
652
+
653
+ rewriter.replaceOp (deinterleaveOp, {elem0, elem1});
654
+ return success ();
655
+ }
656
+
657
+ // Indices for `shuffleEven` (result 0).
658
+ auto seqEven = llvm::seq<int64_t >(n / 2 );
659
+ auto indicesEven =
660
+ llvm::map_to_vector (seqEven, [](int i) { return i * 2 ; });
661
+
662
+ // Indices for `shuffleOdd` (result 1).
663
+ auto seqOdd = llvm::seq<int64_t >(n / 2 );
664
+ auto indicesOdd =
665
+ llvm::map_to_vector (seqOdd, [](int i) { return i * 2 + 1 ; });
666
+
667
+ // Create two SPIR-V shuffles.
668
+ auto shuffleEven = rewriter.create <spirv::VectorShuffleOp>(
669
+ loc, newResultType, sourceVector, sourceVector,
670
+ rewriter.getI32ArrayAttr (indicesEven));
671
+
672
+ auto shuffleOdd = rewriter.create <spirv::VectorShuffleOp>(
673
+ loc, newResultType, sourceVector, sourceVector,
674
+ rewriter.getI32ArrayAttr (indicesOdd));
675
+
676
+ rewriter.replaceOp (deinterleaveOp, {shuffleEven, shuffleOdd});
677
+ return success ();
678
+ }
679
+ };
680
+
621
681
struct VectorLoadOpConverter final
622
682
: public OpConversionPattern<vector::LoadOp> {
623
683
using OpConversionPattern::OpConversionPattern;
@@ -862,9 +922,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
862
922
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
863
923
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
864
924
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
865
- VectorInterleaveOpConvert, VectorSplatPattern, VectorLoadOpConverter ,
866
- VectorStoreOpConverter>(typeConverter, patterns. getContext (),
867
- PatternBenefit (1 ));
925
+ VectorInterleaveOpConvert, VectorDeinterleaveOpConvert ,
926
+ VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
927
+ typeConverter, patterns. getContext (), PatternBenefit (1 ));
868
928
869
929
// Make sure that the more specialized dot product pattern has higher benefit
870
930
// than the generic one that extracts all elements.
0 commit comments