Skip to content

Commit 5261216

Browse files
committed
[mlir][Vector] Update VectorEmulateNarrowType.cpp (4/N)
This is PR 4 in a series of N patches aimed at improving "VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no major functional changes are made/added. 1. Update `alignedConversionPrecondition` (1): This method didn't require the vector type for the "destination" argument. The underlying element type is sufficient. The corresponding argument has been renamed as `multiByteScalarTy` - this is meant as the multi-byte emulated type (`i8`, `i16`, `i32`, etc). 2. Update `alignedConversionPrecondition` (2): In #121298, we replaced `dstElemBitwidt` in this calculation: ```cpp const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth; ``` with the hard-coded value of 8: ```cpp const int numSrcElemsPerDestElem = 8 / srcElemBitwidth; ``` That was correct as for the patterns for which this hook was/is used: * `RewriteAlignedSubByteIntExt`, * `RewriteAlignedSubByteIntTrunc`. The destination type (or, more precisely, the emulated type) was always `i8`. In this PR, I am switching back to a more generic approach - the calculation should take into account the bit-width of the emulated type. Note that at the call sites I am passing `i8` as the emulated type, so the end-result is effectively identical. However, the intent is clearer, i.e., the underlying value is 8 because the emulated type happens to be `i8` (as opposed using a magic number). 3. Update alignedConversionPrecondition (3): The final check has been replaced with a new helper method, `isSubByteVecFittable`. This new method is also re-used within the code and hopefully will allow us more code re-use moving forward (to avoid re-implementing the same condition). 4. Update alignedConversionPrecondition (4): NEXT STEPS: We need to clarify the meaning of "source" and "destination" types. Currently the usage is ambiguous. For example, for this `arith.extsi` Op, `vector<8xi2>` and `vector<8xi32>` are the "source" and "destination" types, respectively: ```mlir %0 = arith.extsi %arg0 : vector<8xi2> to vector<8xi32> } ``` However, patterns like `RewriteAlignedSubByteIntExt` introduce `vector.bitcast` Ops like this: ```mlir %bitcast = vector.bitcast %arg0 : vector<8xi2> to vector<2xi8> ``` I've noticed that we tend to mix `vector<2xi8>` and `vector<8xi32>` as the destination types and that should be clarified.
1 parent 78f690b commit 5261216

File tree

1 file changed

+102
-32
lines changed

1 file changed

+102
-32
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 102 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,6 +1020,38 @@ struct ConvertVectorMaskedLoad final
10201020
}
10211021
};
10221022

1023+
/// Check whether `subByteVecTy` fits wthin a vector of `multiByteScalarTy`
1024+
///
1025+
/// "Fitting" means that `subByteVecTy` (a vector of sub-byte elements, e.g.
1026+
/// vector<4xi4>), can fit within N scalar elements of type `multiByteScalarTy`
1027+
/// (a multi-byte scalar, e.g. i16), where N is some integer.
1028+
///
1029+
/// Put differently, this method checks whether this would be valid:
1030+
///
1031+
/// vector.bitcast subByteVecTy into vector<N x multiByteScalarTy>
1032+
///
1033+
/// EXAMPLES:
1034+
/// * vector<4xi4> -> i16 - yes (N = 1)
1035+
/// * vector<4xi4> -> i8 - yes (N = 2)
1036+
/// * vector<3xi4> -> i8 - no (N would have to be 1.5)
1037+
/// * vector<3xi2> -> i16 - no (N would have to be 0.5)
1038+
static bool isSubByteVecFittable(VectorType subByteVecTy,
1039+
Type multiByteScalarTy) {
1040+
assert((isa<IntegerType, FloatType>(multiByteScalarTy)) && "Not scalar!");
1041+
1042+
int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
1043+
int multiByteBits = multiByteScalarTy.getIntOrFloatBitWidth();
1044+
1045+
assert(subByteBits < 8 && "Not a sub-byte scalar type!");
1046+
assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
1047+
assert(multiByteBits % subByteBits == 0 && "Unalagined element types!");
1048+
1049+
int elemsPerMultiByte = multiByteBits / subByteBits;
1050+
1051+
// TODO: This is a bit too restrictive for vectors rank > 1.
1052+
return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
1053+
}
1054+
10231055
//===----------------------------------------------------------------------===//
10241056
// ConvertVectorTransferRead
10251057
//===----------------------------------------------------------------------===//
@@ -1056,7 +1088,8 @@ struct ConvertVectorTransferRead final
10561088
auto origElements = op.getVectorType().getNumElements();
10571089

10581090
// Note, per-element-alignment was already verified above.
1059-
bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
1091+
bool isFullyAligned =
1092+
isSubByteVecFittable(op.getVectorType(), containerElemTy);
10601093

10611094
auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
10621095
adaptor.getPadding());
@@ -1357,41 +1390,76 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
13571390
return commonConversionPrecondition(rewriter, preconditionType, op);
13581391
}
13591392

1360-
/// Verify that `subByteVecType` and `dstType` are aligned. Alignment
1361-
/// means that:
1362-
/// 1. The `dstType` element type is a multiple of the
1363-
/// `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8
1364-
/// is not supported). Let this multiple be `N`.
1365-
/// 2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a
1366-
/// multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is
1367-
/// not supported).
1393+
/// Verify that `subByteVecTy` (vector) and `containerTy` (scalar) are aligned.
1394+
///
1395+
/// Alignment means that `subByteVecTy` can be packed into a vector of
1396+
/// `containerTy` elements. More specifically:
1397+
/// 1. The bit-width of `containerTy` is a multiple of the
1398+
/// bit-width of `subByteVecTy` elements. For example, for `i4` and `i16`
1399+
/// this multiple is 4.
1400+
/// 2. The multiple from 1. above divides evenly the number of the (trailing)
1401+
/// elements in `subByteVecTy`.
1402+
///
1403+
/// EXAMPLE 1:
1404+
/// `subByteVecTy = vector<2xi4>`, and
1405+
/// `containerTy = i16`
1406+
///
1407+
/// 2 divides evenly 4 ( = 16 / 4), hence both conditions are _met_.
1408+
///
1409+
/// EXAMPLE 2:
1410+
/// `subByteVecTy = vector<3xi4>`, and
1411+
/// `containerTy = i16`
1412+
///
1413+
/// 3 _does not_ divide evenly 4 (= 16/4), hence the conditions are _not met_.
1414+
///
1415+
/// EXAMPLE 3:
1416+
/// `subByteVecTy = vector<3xi3>`, and
1417+
/// `containerTy = i16`
1418+
///
1419+
/// 16 _is not_ a multiple of 3, hence the conditions are _not met_.
13681420
///
13691421
/// NOTE: This method assumes that common conversion preconditions are met. In
1370-
/// particular, the element type of `dstType` is assumed to be a multi-byte
1371-
/// type (e.g. i8, i16, i32).
1422+
/// particular, `containerTy` is assumed to be a
1423+
/// multi-byte scalar type (e.g., i8, i16, i32).
13721424
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
1373-
VectorType subByteVecType,
1374-
VectorType dstType,
1425+
VectorType subByteVecTy,
1426+
Type containerTy,
13751427
Operation *op) {
1376-
if (!subByteVecType || !dstType)
1377-
return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
1378-
unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth();
1379-
unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
1428+
// TODO: This is validating the inputs rather than checking the conditions
1429+
// documented above. Replace with an assert.
1430+
if (!subByteVecTy)
1431+
return rewriter.notifyMatchFailure(op, "not a vector!");
13801432

1381-
if (dstElemBitwidth < 8)
1382-
return rewriter.notifyMatchFailure(
1383-
op, "the bitwidth of dstType must be greater than or equal to 8");
1384-
if (dstElemBitwidth % srcElemBitwidth != 0)
1385-
return rewriter.notifyMatchFailure(op, "unaligned cases are not supported");
1386-
if (srcElemBitwidth != 2 && srcElemBitwidth != 4)
1433+
// TODO: This is validating the inputs rather than checking the conditions
1434+
// documented above. Replace with an assert.
1435+
if (!containerTy.isIntOrFloat())
1436+
return rewriter.notifyMatchFailure(op, "not a scalar!");
1437+
1438+
unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
1439+
unsigned multiByteBits = containerTy.getIntOrFloatBitWidth();
1440+
1441+
// Enforced by the common pre-conditions.
1442+
assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
1443+
1444+
// TODO: Remove this condition - the assert above (and
1445+
// commonConversionPrecondtion) takes care of that.
1446+
if (multiByteBits < 8)
1447+
return rewriter.notifyMatchFailure(op, "not a multi-byte scalar type!");
1448+
1449+
// TODO: Add support other widths (when/if needed)
1450+
if (subByteBits != 2 && subByteBits != 4)
13871451
return rewriter.notifyMatchFailure(
1388-
op, "only src bitwidth of 2 or 4 is supported at this moment");
1452+
op, "only 2-bit and 4-bit sub-byte type is supported at this moment");
1453+
1454+
// Condition 1.
1455+
if (multiByteBits % subByteBits != 0)
1456+
return rewriter.notifyMatchFailure(op, "unalagined element types");
13891457

1390-
const int numSrcElemsPerByte = 8 / srcElemBitwidth;
1391-
if ((subByteVecType.getShape().back() % numSrcElemsPerByte) != 0)
1458+
// Condition 2.
1459+
if (!isSubByteVecFittable(subByteVecTy, containerTy))
13921460
return rewriter.notifyMatchFailure(
1393-
op, "the trailing dimension of the input vector of sub-bytes must be a "
1394-
"multiple of 8 / <sub-byte-width>");
1461+
op, "not possible to fit this sub-byte vector type into a vector of "
1462+
"the given multi-byte type");
13951463

13961464
return success();
13971465
}
@@ -1828,8 +1896,9 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
18281896
return failure();
18291897

18301898
// Check general alignment preconditions.
1831-
if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
1832-
conversionOp)))
1899+
Type containerType = rewriter.getI8Type();
1900+
if (failed(alignedConversionPrecondition(rewriter, srcVecType,
1901+
containerType, conversionOp)))
18331902
return failure();
18341903

18351904
// Perform the rewrite.
@@ -1893,8 +1962,9 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
18931962

18941963
// Check general alignment preconditions. We invert the src/dst type order
18951964
// to reuse the existing precondition logic.
1896-
if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
1897-
truncOp)))
1965+
Type containerType = rewriter.getI8Type();
1966+
if (failed(alignedConversionPrecondition(rewriter, dstVecType,
1967+
containerType, truncOp)))
18981968
return failure();
18991969

19001970
// Create a new iX -> i8 truncation op.

0 commit comments

Comments
 (0)