Skip to content

Commit e8ffbaa

Browse files
committed
[mlir][Vector] Update VectorEmulateNarrowType.cpp (3/N)
This is PR 3 in a series of N patches aimed at improving "VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no major functional changes are made/added. 1. Replaces `isUnalignedEmulation` with `isFullyAligned` Note, `isUnalignedEmulation` is always computed following a "per-element-alignment" condition: ```cpp // Check per-element alignment. if (newBits % oldBits != 0) { return rewriter.notifyMatchFailure(op, "unalagined element types"); } // (...) bool isUnalignedEmulation = origElements % elementsPerContainerType != 0; ``` Given that `isUnalignedEmulation` captures only one of two conditions required for "full alignment", it should be re-named as `isPartiallyUnalignedEmulation`. Instead, I've flipped the condition and renamed it as `isFullyAligned`: ```cpp bool isFullyAligned = origElements % elementsPerContainerType == 0; ``` 2. In addition: * Unifies various comments throughout the file (for consistency). * Adds new comments throughout the file and adds TODOs where high-level comments are missing.
1 parent d40b31b commit e8ffbaa

File tree

1 file changed

+61
-44
lines changed

1 file changed

+61
-44
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 61 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ using namespace mlir;
4545
#define DBGSNL() (llvm::dbgs() << "\n")
4646
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
4747

48+
//===----------------------------------------------------------------------===//
49+
// Utils
50+
//===----------------------------------------------------------------------===//
51+
4852
/// Returns a compressed mask for the emulated vector. For example, when
4953
/// emulating an eight-element `i8` vector with `i32` (i.e. when the source
5054
/// elements span two dest elements), this method compresses `vector<8xi1>`
@@ -300,6 +304,7 @@ namespace {
300304
// ConvertVectorStore
301305
//===----------------------------------------------------------------------===//
302306

307+
// TODO: Document-me
303308
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
304309
using OpConversionPattern::OpConversionPattern;
305310

@@ -370,6 +375,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
370375
// ConvertVectorMaskedStore
371376
//===----------------------------------------------------------------------===//
372377

378+
// TODO: Document-me
373379
struct ConvertVectorMaskedStore final
374380
: OpConversionPattern<vector::MaskedStoreOp> {
375381
using OpConversionPattern::OpConversionPattern;
@@ -481,6 +487,7 @@ struct ConvertVectorMaskedStore final
481487
// ConvertVectorLoad
482488
//===----------------------------------------------------------------------===//
483489

490+
// TODO: Document-me
484491
struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
485492
using OpConversionPattern::OpConversionPattern;
486493

@@ -536,7 +543,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
536543
// compile time as they must be constants.
537544

538545
auto origElements = op.getVectorType().getNumElements();
539-
bool isUnalignedEmulation = origElements % elementsPerContainerType != 0;
546+
// Note, per-element-alignment was already verified above.
547+
bool isFullyAligned = origElements % elementsPerContainerType == 0;
540548

541549
auto stridedMetadata =
542550
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -552,9 +560,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
552560
getAsOpFoldResult(adaptor.getIndices()));
553561

554562
std::optional<int64_t> foldedIntraVectorOffset =
555-
isUnalignedEmulation
556-
? getConstantIntValue(linearizedInfo.intraDataOffset)
557-
: 0;
563+
isFullyAligned ? 0
564+
: getConstantIntValue(linearizedInfo.intraDataOffset);
558565

559566
// Always load enough elements which can cover the original elements.
560567
int64_t maxintraDataOffset =
@@ -571,7 +578,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
571578
result = dynamicallyExtractSubVector(
572579
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
573580
linearizedInfo.intraDataOffset, origElements);
574-
} else if (isUnalignedEmulation) {
581+
} else if (!isFullyAligned) {
575582
result =
576583
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
577584
*foldedIntraVectorOffset, origElements);
@@ -585,6 +592,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
585592
// ConvertVectorMaskedLoad
586593
//===----------------------------------------------------------------------===//
587594

595+
// TODO: Document-me
588596
struct ConvertVectorMaskedLoad final
589597
: OpConversionPattern<vector::MaskedLoadOp> {
590598
using OpConversionPattern::OpConversionPattern;
@@ -749,6 +757,7 @@ struct ConvertVectorMaskedLoad final
749757
// ConvertVectorTransferRead
750758
//===----------------------------------------------------------------------===//
751759

760+
// TODO: Document-me
752761
struct ConvertVectorTransferRead final
753762
: OpConversionPattern<vector::TransferReadOp> {
754763
using OpConversionPattern::OpConversionPattern;
@@ -777,7 +786,8 @@ struct ConvertVectorTransferRead final
777786

778787
auto origElements = op.getVectorType().getNumElements();
779788

780-
bool isUnalignedEmulation = origElements % elementsPerContainerType != 0;
789+
// Note, per-element-alignment was already verified above.
790+
bool isFullyAligned = origElements % elementsPerContainerType == 0;
781791

782792
auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
783793
adaptor.getPadding());
@@ -796,9 +806,8 @@ struct ConvertVectorTransferRead final
796806
getAsOpFoldResult(adaptor.getIndices()));
797807

798808
std::optional<int64_t> foldedIntraVectorOffset =
799-
isUnalignedEmulation
800-
? getConstantIntValue(linearizedInfo.intraDataOffset)
801-
: 0;
809+
isFullyAligned ? 0
810+
: getConstantIntValue(linearizedInfo.intraDataOffset);
802811

803812
int64_t maxIntraDataOffset =
804813
foldedIntraVectorOffset.value_or(elementsPerContainerType - 1);
@@ -822,7 +831,7 @@ struct ConvertVectorTransferRead final
822831
result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
823832
linearizedInfo.intraDataOffset,
824833
origElements);
825-
} else if (isUnalignedEmulation) {
834+
} else if (!isFullyAligned) {
826835
result =
827836
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
828837
*foldedIntraVectorOffset, origElements);
@@ -1506,33 +1515,34 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
15061515
/// LLVM to scramble with peephole optimizations. Templated to choose between
15071516
/// signed and unsigned conversions.
15081517
///
1509-
/// For example (signed):
1518+
/// EXAMPLE 1 (signed):
15101519
/// arith.extsi %in : vector<8xi4> to vector<8xi32>
1511-
/// is rewriten as
1512-
/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1513-
/// %1 = arith.shli %0, 4 : vector<4xi8>
1514-
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
1515-
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
1516-
/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1517-
/// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
1520+
/// is rewriten as:
1521+
/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1522+
/// %1 = arith.shli %0, 4 : vector<4xi8>
1523+
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
1524+
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
1525+
/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1526+
/// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
15181527
///
1528+
/// EXAMPLE 2 (fp):
15191529
/// arith.sitofp %in : vector<8xi4> to vector<8xf32>
1520-
/// is rewriten as
1521-
/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1522-
/// %1 = arith.shli %0, 4 : vector<4xi8>
1523-
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
1524-
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
1525-
/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1526-
/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
1530+
/// is rewriten as:
1531+
/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1532+
/// %1 = arith.shli %0, 4 : vector<4xi8>
1533+
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
1534+
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
1535+
/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1536+
/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
15271537
///
1528-
/// Example (unsigned):
1538+
/// EXAMPLE 3 (unsigned):
15291539
/// arith.extui %in : vector<8xi4> to vector<8xi32>
1530-
/// is rewritten as
1531-
/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1532-
/// %1 = arith.andi %0, 15 : vector<4xi8>
1533-
/// %2 = arith.shrui %0, 4 : vector<4xi8>
1534-
/// %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
1535-
/// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
1540+
/// is rewritten as:
1541+
/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1542+
/// %1 = arith.andi %0, 15 : vector<4xi8>
1543+
/// %2 = arith.shrui %0, 4 : vector<4xi8>
1544+
/// %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
1545+
/// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
15361546
///
15371547
template <typename ConversionOpType, bool isSigned>
15381548
struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
@@ -1542,8 +1552,8 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
15421552
PatternRewriter &rewriter) const override {
15431553
// Verify the preconditions.
15441554
Value srcValue = conversionOp.getIn();
1545-
auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
1546-
auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
1555+
VectorType srcVecType = dyn_cast<VectorType>(srcValue.getType());
1556+
VectorType dstVecType = dyn_cast<VectorType>(conversionOp.getType());
15471557

15481558
if (failed(
15491559
commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
@@ -1583,15 +1593,16 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
15831593
///
15841594
/// For example:
15851595
/// arith.trunci %in : vector<8xi32> to vector<8xi4>
1586-
/// is rewriten as
15871596
///
1588-
/// %cst = arith.constant dense<15> : vector<4xi8>
1589-
/// %cst_0 = arith.constant dense<4> : vector<4xi8>
1590-
/// %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
1591-
/// %2 = arith.andi %0, %cst : vector<4xi8>
1592-
/// %3 = arith.shli %1, %cst_0 : vector<4xi8>
1593-
/// %4 = arith.ori %2, %3 : vector<4xi8>
1594-
/// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
1597+
/// is rewriten as:
1598+
///
1599+
/// %cst = arith.constant dense<15> : vector<4xi8>
1600+
/// %cst_0 = arith.constant dense<4> : vector<4xi8>
1601+
/// %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
1602+
/// %2 = arith.andi %0, %cst : vector<4xi8>
1603+
/// %3 = arith.shli %1, %cst_0 : vector<4xi8>
1604+
/// %4 = arith.ori %2, %3 : vector<4xi8>
1605+
/// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
15951606
///
15961607
struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
15971608
using OpRewritePattern<arith::TruncIOp>::OpRewritePattern;
@@ -1635,10 +1646,11 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
16351646

16361647
/// Rewrite a sub-byte vector transpose into a sequence of instructions that
16371648
/// perform the transpose on wider (byte) element types.
1638-
/// For example:
1649+
///
1650+
/// EXAMPLE:
16391651
/// %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
16401652
///
1641-
/// is rewritten as:
1653+
/// is rewritten as:
16421654
///
16431655
/// %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
16441656
/// %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
@@ -1686,6 +1698,7 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
16861698
// Public Interface Definition
16871699
//===----------------------------------------------------------------------===//
16881700

1701+
// The emulated type is inferred from the converted memref type.
16891702
void vector::populateVectorNarrowTypeEmulationPatterns(
16901703
const arith::NarrowTypeEmulationConverter &typeConverter,
16911704
RewritePatternSet &patterns) {
@@ -1698,22 +1711,26 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
16981711

16991712
void vector::populateVectorNarrowTypeRewritePatterns(
17001713
RewritePatternSet &patterns, PatternBenefit benefit) {
1714+
// TODO: Document what the emulated type is.
17011715
patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
17021716
RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
17031717
benefit);
17041718

17051719
// Patterns for aligned cases. We set higher priority as they are expected to
17061720
// generate better performance for aligned cases.
1721+
// The emulated type is always i8.
17071722
patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp, /*isSigned=*/true>,
17081723
RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
17091724
RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
17101725
benefit.getBenefit() + 1);
1726+
// The emulated type is always i8.
17111727
patterns
17121728
.add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>,
17131729
RewriteAlignedSubByteIntExt<arith::UIToFPOp, /*isSigned=*/false>>(
17141730
patterns.getContext(), benefit.getBenefit() + 1);
17151731
}
17161732

1733+
// The emulated type is always i8.
17171734
void vector::populateVectorTransposeNarrowTypeRewritePatterns(
17181735
RewritePatternSet &patterns, PatternBenefit benefit) {
17191736
patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);

0 commit comments

Comments
 (0)