@@ -45,6 +45,10 @@ using namespace mlir;
45
45
#define DBGSNL () (llvm::dbgs() << " \n " )
46
46
#define LDBG (X ) LLVM_DEBUG(DBGS() << X << " \n " )
47
47
48
+ // ===----------------------------------------------------------------------===//
49
+ // Utils
50
+ // ===----------------------------------------------------------------------===//
51
+
48
52
// / Returns a compressed mask for the emulated vector. For example, when
49
53
// / emulating an eight-element `i8` vector with `i32` (i.e. when the source
50
54
// / elements span two dest elements), this method compresses `vector<8xi1>`
@@ -300,6 +304,7 @@ namespace {
300
304
// ConvertVectorStore
301
305
// ===----------------------------------------------------------------------===//
302
306
307
+ // TODO: Document-me
303
308
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
304
309
using OpConversionPattern::OpConversionPattern;
305
310
@@ -370,6 +375,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
370
375
// ConvertVectorMaskedStore
371
376
// ===----------------------------------------------------------------------===//
372
377
378
+ // TODO: Document-me
373
379
struct ConvertVectorMaskedStore final
374
380
: OpConversionPattern<vector::MaskedStoreOp> {
375
381
using OpConversionPattern::OpConversionPattern;
@@ -481,6 +487,7 @@ struct ConvertVectorMaskedStore final
481
487
// ConvertVectorLoad
482
488
// ===----------------------------------------------------------------------===//
483
489
490
+ // TODO: Document-me
484
491
struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
485
492
using OpConversionPattern::OpConversionPattern;
486
493
@@ -536,7 +543,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
536
543
// compile time as they must be constants.
537
544
538
545
auto origElements = op.getVectorType ().getNumElements ();
539
- bool isUnalignedEmulation = origElements % elementsPerContainerType != 0 ;
546
+ // Note, per-element-alignment was already verified above.
547
+ bool isFullyAligned = origElements % elementsPerContainerType == 0 ;
540
548
541
549
auto stridedMetadata =
542
550
rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
@@ -552,9 +560,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
552
560
getAsOpFoldResult (adaptor.getIndices ()));
553
561
554
562
std::optional<int64_t > foldedIntraVectorOffset =
555
- isUnalignedEmulation
556
- ? getConstantIntValue (linearizedInfo.intraDataOffset )
557
- : 0 ;
563
+ isFullyAligned ? 0
564
+ : getConstantIntValue (linearizedInfo.intraDataOffset );
558
565
559
566
// Always load enough elements which can cover the original elements.
560
567
int64_t maxintraDataOffset =
@@ -571,7 +578,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
571
578
result = dynamicallyExtractSubVector (
572
579
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
573
580
linearizedInfo.intraDataOffset , origElements);
574
- } else if (isUnalignedEmulation ) {
581
+ } else if (!isFullyAligned ) {
575
582
result =
576
583
staticallyExtractSubvector (rewriter, loc, op.getType (), result,
577
584
*foldedIntraVectorOffset, origElements);
@@ -585,6 +592,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
585
592
// ConvertVectorMaskedLoad
586
593
// ===----------------------------------------------------------------------===//
587
594
595
+ // TODO: Document-me
588
596
struct ConvertVectorMaskedLoad final
589
597
: OpConversionPattern<vector::MaskedLoadOp> {
590
598
using OpConversionPattern::OpConversionPattern;
@@ -749,6 +757,7 @@ struct ConvertVectorMaskedLoad final
749
757
// ConvertVectorTransferRead
750
758
// ===----------------------------------------------------------------------===//
751
759
760
+ // TODO: Document-me
752
761
struct ConvertVectorTransferRead final
753
762
: OpConversionPattern<vector::TransferReadOp> {
754
763
using OpConversionPattern::OpConversionPattern;
@@ -777,7 +786,8 @@ struct ConvertVectorTransferRead final
777
786
778
787
auto origElements = op.getVectorType ().getNumElements ();
779
788
780
- bool isUnalignedEmulation = origElements % elementsPerContainerType != 0 ;
789
+ // Note, per-element-alignment was already verified above.
790
+ bool isFullyAligned = origElements % elementsPerContainerType == 0 ;
781
791
782
792
auto newPadding = rewriter.create <arith::ExtUIOp>(loc, newElementType,
783
793
adaptor.getPadding ());
@@ -796,9 +806,8 @@ struct ConvertVectorTransferRead final
796
806
getAsOpFoldResult (adaptor.getIndices ()));
797
807
798
808
std::optional<int64_t > foldedIntraVectorOffset =
799
- isUnalignedEmulation
800
- ? getConstantIntValue (linearizedInfo.intraDataOffset )
801
- : 0 ;
809
+ isFullyAligned ? 0
810
+ : getConstantIntValue (linearizedInfo.intraDataOffset );
802
811
803
812
int64_t maxIntraDataOffset =
804
813
foldedIntraVectorOffset.value_or (elementsPerContainerType - 1 );
@@ -822,7 +831,7 @@ struct ConvertVectorTransferRead final
822
831
result = dynamicallyExtractSubVector (rewriter, loc, bitCast, zeros,
823
832
linearizedInfo.intraDataOffset ,
824
833
origElements);
825
- } else if (isUnalignedEmulation ) {
834
+ } else if (!isFullyAligned ) {
826
835
result =
827
836
staticallyExtractSubvector (rewriter, loc, op.getType (), result,
828
837
*foldedIntraVectorOffset, origElements);
@@ -1506,33 +1515,34 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
1506
1515
// / LLVM to scramble with peephole optimizations. Templated to choose between
1507
1516
// / signed and unsigned conversions.
1508
1517
// /
1509
- // / For example (signed):
1518
+ // / EXAMPLE 1 (signed):
1510
1519
// / arith.extsi %in : vector<8xi4> to vector<8xi32>
1511
- // / is rewriten as
1512
- // / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1513
- // / %1 = arith.shli %0, 4 : vector<4xi8>
1514
- // / %2 = arith.shrsi %1, 4 : vector<4xi8>
1515
- // / %3 = arith.shrsi %0, 4 : vector<4xi8>
1516
- // / %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1517
- // / %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
1520
+ // / is rewriten as:
1521
+ // / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1522
+ // / %1 = arith.shli %0, 4 : vector<4xi8>
1523
+ // / %2 = arith.shrsi %1, 4 : vector<4xi8>
1524
+ // / %3 = arith.shrsi %0, 4 : vector<4xi8>
1525
+ // / %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1526
+ // / %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
1518
1527
// /
1528
+ // / EXAMPLE 2 (fp):
1519
1529
// / arith.sitofp %in : vector<8xi4> to vector<8xf32>
1520
- // / is rewriten as
1521
- // / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1522
- // / %1 = arith.shli %0, 4 : vector<4xi8>
1523
- // / %2 = arith.shrsi %1, 4 : vector<4xi8>
1524
- // / %3 = arith.shrsi %0, 4 : vector<4xi8>
1525
- // / %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1526
- // / %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
1530
+ // / is rewriten as:
1531
+ // / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1532
+ // / %1 = arith.shli %0, 4 : vector<4xi8>
1533
+ // / %2 = arith.shrsi %1, 4 : vector<4xi8>
1534
+ // / %3 = arith.shrsi %0, 4 : vector<4xi8>
1535
+ // / %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1536
+ // / %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
1527
1537
// /
1528
- // / Example (unsigned):
1538
+ // / EXAMPLE 3 (unsigned):
1529
1539
// / arith.extui %in : vector<8xi4> to vector<8xi32>
1530
- // / is rewritten as
1531
- // / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1532
- // / %1 = arith.andi %0, 15 : vector<4xi8>
1533
- // / %2 = arith.shrui %0, 4 : vector<4xi8>
1534
- // / %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
1535
- // / %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
1540
+ // / is rewritten as:
1541
+ // / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1542
+ // / %1 = arith.andi %0, 15 : vector<4xi8>
1543
+ // / %2 = arith.shrui %0, 4 : vector<4xi8>
1544
+ // / %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
1545
+ // / %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
1536
1546
// /
1537
1547
template <typename ConversionOpType, bool isSigned>
1538
1548
struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
@@ -1542,8 +1552,8 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
1542
1552
PatternRewriter &rewriter) const override {
1543
1553
// Verify the preconditions.
1544
1554
Value srcValue = conversionOp.getIn ();
1545
- auto srcVecType = dyn_cast<VectorType>(srcValue.getType ());
1546
- auto dstVecType = dyn_cast<VectorType>(conversionOp.getType ());
1555
+ VectorType srcVecType = dyn_cast<VectorType>(srcValue.getType ());
1556
+ VectorType dstVecType = dyn_cast<VectorType>(conversionOp.getType ());
1547
1557
1548
1558
if (failed (
1549
1559
commonConversionPrecondition (rewriter, dstVecType, conversionOp)))
@@ -1583,15 +1593,16 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
1583
1593
// /
1584
1594
// / For example:
1585
1595
// / arith.trunci %in : vector<8xi32> to vector<8xi4>
1586
- // / is rewriten as
1587
1596
// /
1588
- // / %cst = arith.constant dense<15> : vector<4xi8>
1589
- // / %cst_0 = arith.constant dense<4> : vector<4xi8>
1590
- // / %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
1591
- // / %2 = arith.andi %0, %cst : vector<4xi8>
1592
- // / %3 = arith.shli %1, %cst_0 : vector<4xi8>
1593
- // / %4 = arith.ori %2, %3 : vector<4xi8>
1594
- // / %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
1597
+ // / is rewriten as:
1598
+ // /
1599
+ // / %cst = arith.constant dense<15> : vector<4xi8>
1600
+ // / %cst_0 = arith.constant dense<4> : vector<4xi8>
1601
+ // / %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
1602
+ // / %2 = arith.andi %0, %cst : vector<4xi8>
1603
+ // / %3 = arith.shli %1, %cst_0 : vector<4xi8>
1604
+ // / %4 = arith.ori %2, %3 : vector<4xi8>
1605
+ // / %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
1595
1606
// /
1596
1607
struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
1597
1608
using OpRewritePattern<arith::TruncIOp>::OpRewritePattern;
@@ -1635,10 +1646,11 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
1635
1646
1636
1647
// / Rewrite a sub-byte vector transpose into a sequence of instructions that
1637
1648
// / perform the transpose on wider (byte) element types.
1638
- // / For example:
1649
+ // /
1650
+ // / EXAMPLE:
1639
1651
// / %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
1640
1652
// /
1641
- // / is rewritten as:
1653
+ // / is rewritten as:
1642
1654
// /
1643
1655
// / %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
1644
1656
// / %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
@@ -1686,6 +1698,7 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
1686
1698
// Public Interface Definition
1687
1699
// ===----------------------------------------------------------------------===//
1688
1700
1701
+ // The emulated type is inferred from the converted memref type.
1689
1702
void vector::populateVectorNarrowTypeEmulationPatterns (
1690
1703
const arith::NarrowTypeEmulationConverter &typeConverter,
1691
1704
RewritePatternSet &patterns) {
@@ -1698,22 +1711,26 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
1698
1711
1699
1712
void vector::populateVectorNarrowTypeRewritePatterns (
1700
1713
RewritePatternSet &patterns, PatternBenefit benefit) {
1714
+ // TODO: Document what the emulated type is.
1701
1715
patterns.add <RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
1702
1716
RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext (),
1703
1717
benefit);
1704
1718
1705
1719
// Patterns for aligned cases. We set higher priority as they are expected to
1706
1720
// generate better performance for aligned cases.
1721
+ // The emulated type is always i8.
1707
1722
patterns.add <RewriteAlignedSubByteIntExt<arith::ExtSIOp, /* isSigned=*/ true >,
1708
1723
RewriteAlignedSubByteIntExt<arith::SIToFPOp, /* isSigned=*/ true >,
1709
1724
RewriteAlignedSubByteIntTrunc>(patterns.getContext (),
1710
1725
benefit.getBenefit () + 1 );
1726
+ // The emulated type is always i8.
1711
1727
patterns
1712
1728
.add <RewriteAlignedSubByteIntExt<arith::ExtUIOp, /* isSigned=*/ false >,
1713
1729
RewriteAlignedSubByteIntExt<arith::UIToFPOp, /* isSigned=*/ false >>(
1714
1730
patterns.getContext (), benefit.getBenefit () + 1 );
1715
1731
}
1716
1732
1733
+ // The emulated type is always i8.
1717
1734
void vector::populateVectorTransposeNarrowTypeRewritePatterns (
1718
1735
RewritePatternSet &patterns, PatternBenefit benefit) {
1719
1736
patterns.add <RewriteVectorTranspose>(patterns.getContext (), benefit);
0 commit comments