Skip to content

Commit d928a67

Browse files
authored
[mlir][Vector] Refactor VectorEmulateNarrowType.cpp (#123529)
This is PR refactors `alignedConversionPrecondition` from VectorEmulateNarrowType.cpp and adds new helper hooks. **Update `alignedConversionPrecondition` (1)** This method doesn't require the vector type for the "container" argument. The underlying element type is sufficient. The corresponding argument has been renamed as `containerTy` - this is meant as the multi-byte container element type (`i8`, `i16`, `i32`, etc). With this change, the updated invocations of `alignedConversionPrecondition` (in e.g. `RewriteAlignedSubByteIntExt`) make it clear that the container element type is assumed to be `i8`. **Update alignedConversionPrecondition (2):** The final check in `alignedConversionPrecondition` has been replaced with a new helper method, `isSubByteVecFittable`. This helper hook is now also re-used in `ConvertVectorTransferRead` (to improve code re-use). **Other updates** Extended + unified comments. **Implements**: #123630
1 parent 4e98944 commit d928a67

File tree

1 file changed

+112
-49
lines changed

1 file changed

+112
-49
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 112 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
519519

520520
auto origElements = valueToStore.getType().getNumElements();
521521
// Note, per-element-alignment was already verified above.
522-
bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
522+
bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
523523

524524
auto stridedMetadata =
525525
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -535,8 +535,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
535535
getAsOpFoldResult(adaptor.getIndices()));
536536

537537
std::optional<int64_t> foldedNumFrontPadElems =
538-
isFullyAligned ? 0
539-
: getConstantIntValue(linearizedInfo.intraDataOffset);
538+
isDivisibleInSize ? 0
539+
: getConstantIntValue(linearizedInfo.intraDataOffset);
540540

541541
if (!foldedNumFrontPadElems) {
542542
return rewriter.notifyMatchFailure(
@@ -554,7 +554,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
554554
// need unaligned emulation because the store address is aligned and the
555555
// source is a whole byte.
556556
bool emulationRequiresPartialStores =
557-
!isFullyAligned || *foldedNumFrontPadElems != 0;
557+
!isDivisibleInSize || *foldedNumFrontPadElems != 0;
558558
if (!emulationRequiresPartialStores) {
559559
// Basic case: storing full bytes.
560560
auto numElements = origElements / emulatedPerContainerElem;
@@ -881,7 +881,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
881881

882882
auto origElements = op.getVectorType().getNumElements();
883883
// Note, per-element-alignment was already verified above.
884-
bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
884+
bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
885885

886886
auto stridedMetadata =
887887
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -897,8 +897,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
897897
getAsOpFoldResult(adaptor.getIndices()));
898898

899899
std::optional<int64_t> foldedIntraVectorOffset =
900-
isFullyAligned ? 0
901-
: getConstantIntValue(linearizedInfo.intraDataOffset);
900+
isDivisibleInSize ? 0
901+
: getConstantIntValue(linearizedInfo.intraDataOffset);
902902

903903
// Always load enough elements which can cover the original elements.
904904
int64_t maxintraDataOffset =
@@ -915,7 +915,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
915915
result = dynamicallyExtractSubVector(
916916
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
917917
linearizedInfo.intraDataOffset, origElements);
918-
} else if (!isFullyAligned) {
918+
} else if (!isDivisibleInSize) {
919919
result = staticallyExtractSubvector(
920920
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
921921
}
@@ -1002,7 +1002,7 @@ struct ConvertVectorMaskedLoad final
10021002
auto origType = op.getVectorType();
10031003
auto origElements = origType.getNumElements();
10041004
// Note, per-element-alignment was already verified above.
1005-
bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
1005+
bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
10061006

10071007
auto stridedMetadata =
10081008
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -1017,8 +1017,8 @@ struct ConvertVectorMaskedLoad final
10171017
getAsOpFoldResult(adaptor.getIndices()));
10181018

10191019
std::optional<int64_t> foldedIntraVectorOffset =
1020-
isFullyAligned ? 0
1021-
: getConstantIntValue(linearizedInfo.intraDataOffset);
1020+
isDivisibleInSize ? 0
1021+
: getConstantIntValue(linearizedInfo.intraDataOffset);
10221022

10231023
int64_t maxIntraDataOffset =
10241024
foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
@@ -1042,7 +1042,7 @@ struct ConvertVectorMaskedLoad final
10421042
passthru = dynamicallyInsertSubVector(
10431043
rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset,
10441044
origElements);
1045-
} else if (!isFullyAligned) {
1045+
} else if (!isDivisibleInSize) {
10461046
passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
10471047
*foldedIntraVectorOffset);
10481048
}
@@ -1070,7 +1070,7 @@ struct ConvertVectorMaskedLoad final
10701070
mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask,
10711071
linearizedInfo.intraDataOffset,
10721072
origElements);
1073-
} else if (!isFullyAligned) {
1073+
} else if (!isDivisibleInSize) {
10741074
mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
10751075
*foldedIntraVectorOffset);
10761076
}
@@ -1081,7 +1081,7 @@ struct ConvertVectorMaskedLoad final
10811081
result = dynamicallyExtractSubVector(
10821082
rewriter, loc, result, op.getPassThru(),
10831083
linearizedInfo.intraDataOffset, origElements);
1084-
} else if (!isFullyAligned) {
1084+
} else if (!isDivisibleInSize) {
10851085
result = staticallyExtractSubvector(
10861086
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
10871087
}
@@ -1091,6 +1091,38 @@ struct ConvertVectorMaskedLoad final
10911091
}
10921092
};
10931093

1094+
/// Check whether `subByteVecTy` fits wthin a vector of `multiByteScalarTy`
1095+
///
1096+
/// "Fitting" means that `subByteVecTy` (a vector of sub-byte elements, e.g.
1097+
/// vector<4xi4>), can fit within N scalar elements of type `multiByteScalarTy`
1098+
/// (a multi-byte scalar, e.g. i16), where N is some integer.
1099+
///
1100+
/// Put differently, this method checks whether this would be valid:
1101+
///
1102+
/// vector.bitcast subByteVecTy into vector<N x multiByteScalarTy>
1103+
///
1104+
/// EXAMPLES:
1105+
/// * vector<4xi4> -> i16 - yes (N = 1)
1106+
/// * vector<4xi4> -> i8 - yes (N = 2)
1107+
/// * vector<3xi4> -> i8 - no (N would have to be 1.5)
1108+
/// * vector<3xi2> -> i16 - no (N would have to be 0.5)
1109+
static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
1110+
Type multiByteScalarTy) {
1111+
assert((isa<IntegerType, FloatType>(multiByteScalarTy)) && "Not scalar!");
1112+
1113+
int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
1114+
int multiByteBits = multiByteScalarTy.getIntOrFloatBitWidth();
1115+
1116+
assert(subByteBits < 8 && "Not a sub-byte scalar type!");
1117+
assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
1118+
assert(multiByteBits % subByteBits == 0 && "Unalagined element types!");
1119+
1120+
int elemsPerMultiByte = multiByteBits / subByteBits;
1121+
1122+
// TODO: This is a bit too restrictive for vectors rank > 1.
1123+
return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
1124+
}
1125+
10941126
//===----------------------------------------------------------------------===//
10951127
// ConvertVectorTransferRead
10961128
//===----------------------------------------------------------------------===//
@@ -1127,7 +1159,8 @@ struct ConvertVectorTransferRead final
11271159
auto origElements = op.getVectorType().getNumElements();
11281160

11291161
// Note, per-element-alignment was already verified above.
1130-
bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
1162+
bool isDivisibleInSize =
1163+
fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
11311164

11321165
auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
11331166
adaptor.getPadding());
@@ -1146,8 +1179,8 @@ struct ConvertVectorTransferRead final
11461179
getAsOpFoldResult(adaptor.getIndices()));
11471180

11481181
std::optional<int64_t> foldedIntraVectorOffset =
1149-
isFullyAligned ? 0
1150-
: getConstantIntValue(linearizedInfo.intraDataOffset);
1182+
isDivisibleInSize ? 0
1183+
: getConstantIntValue(linearizedInfo.intraDataOffset);
11511184

11521185
int64_t maxIntraDataOffset =
11531186
foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
@@ -1171,7 +1204,7 @@ struct ConvertVectorTransferRead final
11711204
result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
11721205
linearizedInfo.intraDataOffset,
11731206
origElements);
1174-
} else if (!isFullyAligned) {
1207+
} else if (!isDivisibleInSize) {
11751208
result = staticallyExtractSubvector(
11761209
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
11771210
}
@@ -1428,41 +1461,69 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
14281461
return commonConversionPrecondition(rewriter, preconditionType, op);
14291462
}
14301463

1431-
/// Verify that `subByteVecType` and `dstType` are aligned. Alignment
1432-
/// means that:
1433-
/// 1. The `dstType` element type is a multiple of the
1434-
/// `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8
1435-
/// is not supported). Let this multiple be `N`.
1436-
/// 2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a
1437-
/// multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is
1438-
/// not supported).
1464+
/// Verify that `subByteVecTy` (vector) and `containerTy` (scalar) are aligned.
1465+
///
1466+
/// Alignment means that `subByteVecTy` can be packed into a vector of
1467+
/// `containerTy` elements. More specifically:
1468+
/// 1. The bit-width of `containerTy` is a multiple of the
1469+
/// bit-width of `subByteVecTy` elements. For example, for `i4` and `i16`
1470+
/// this multiple is 4.
1471+
/// 2. The multiple from 1. above divides evenly the number of the (trailing)
1472+
/// elements in `subByteVecTy`.
1473+
///
1474+
/// EXAMPLE 1:
1475+
/// `subByteVecTy = vector<2xi4>`, and
1476+
/// `containerTy = i16`
1477+
///
1478+
/// 2 divides evenly 4 ( = 16 / 4), hence both conditions are _met_.
1479+
///
1480+
/// EXAMPLE 2:
1481+
/// `subByteVecTy = vector<3xi4>`, and
1482+
/// `containerTy = i16`
1483+
///
1484+
/// 3 _does not_ divide evenly 4 (= 16/4), hence the conditions are _not met_.
1485+
///
1486+
/// EXAMPLE 3:
1487+
/// `subByteVecTy = vector<3xi3>`, and
1488+
/// `containerTy = i16`
1489+
///
1490+
/// 16 _is not_ a multiple of 3, hence the conditions are _not met_.
14391491
///
14401492
/// NOTE: This method assumes that common conversion preconditions are met. In
1441-
/// particular, the element type of `dstType` is assumed to be a multi-byte
1442-
/// type (e.g. i8, i16, i32).
1493+
/// particular, `containerTy` is assumed to be a
1494+
/// multi-byte scalar type (e.g., i8, i16, i32).
14431495
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
1444-
VectorType subByteVecType,
1445-
VectorType dstType,
1496+
VectorType subByteVecTy,
1497+
Type containerTy,
14461498
Operation *op) {
1447-
if (!subByteVecType || !dstType)
1448-
return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
1449-
unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth();
1450-
unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
1499+
assert(containerTy.isIntOrFloat() &&
1500+
"container element type is not a scalar");
14511501

1452-
if (dstElemBitwidth < 8)
1453-
return rewriter.notifyMatchFailure(
1454-
op, "the bitwidth of dstType must be greater than or equal to 8");
1455-
if (dstElemBitwidth % srcElemBitwidth != 0)
1456-
return rewriter.notifyMatchFailure(op, "unaligned cases are not supported");
1457-
if (srcElemBitwidth != 2 && srcElemBitwidth != 4)
1502+
// TODO: This is validating the inputs rather than checking the conditions
1503+
// documented above. Replace with an assert.
1504+
if (!subByteVecTy)
1505+
return rewriter.notifyMatchFailure(op, "not a vector!");
1506+
1507+
unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
1508+
unsigned containerBits = containerTy.getIntOrFloatBitWidth();
1509+
1510+
// Enforced by the common pre-conditions.
1511+
assert(containerBits % 8 == 0 && "Not a multi-byte scalar type!");
1512+
1513+
// TODO: Add support other widths (when/if needed)
1514+
if (subByteBits != 2 && subByteBits != 4)
14581515
return rewriter.notifyMatchFailure(
1459-
op, "only src bitwidth of 2 or 4 is supported at this moment");
1516+
op, "only 2-bit and 4-bit sub-byte type is supported at this moment");
1517+
1518+
// Condition 1 ("per-element" alignment)
1519+
if (containerBits % subByteBits != 0)
1520+
return rewriter.notifyMatchFailure(op, "unalagined element types");
14601521

1461-
const int numSrcElemsPerByte = 8 / srcElemBitwidth;
1462-
if ((subByteVecType.getShape().back() % numSrcElemsPerByte) != 0)
1522+
// Condition 2 ("full" alignment)
1523+
if (!fitsInMultiByteContainerTy(subByteVecTy, containerTy))
14631524
return rewriter.notifyMatchFailure(
1464-
op, "the trailing dimension of the input vector of sub-bytes must be a "
1465-
"multiple of 8 / <sub-byte-width>");
1525+
op, "not possible to fit this sub-byte vector type into a vector of "
1526+
"the given multi-byte type");
14661527

14671528
return success();
14681529
}
@@ -1899,8 +1960,9 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
18991960
return failure();
19001961

19011962
// Check general alignment preconditions.
1902-
if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
1903-
conversionOp)))
1963+
if (failed(alignedConversionPrecondition(
1964+
rewriter, srcVecType,
1965+
/*containerTy=*/rewriter.getI8Type(), conversionOp)))
19041966
return failure();
19051967

19061968
// Perform the rewrite.
@@ -1964,8 +2026,9 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
19642026

19652027
// Check general alignment preconditions. We invert the src/dst type order
19662028
// to reuse the existing precondition logic.
1967-
if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
1968-
truncOp)))
2029+
if (failed(alignedConversionPrecondition(
2030+
rewriter, dstVecType,
2031+
/*containerTy=*/rewriter.getI8Type(), truncOp)))
19692032
return failure();
19702033

19712034
// Create a new iX -> i8 truncation op.

0 commit comments

Comments
 (0)