Skip to content

Commit b3b67eb

Browse files
committed
[mlir][Vector] Update VectorEmulateNarrowType.cpp (4/N)
This is PR 4 in a series of N patches aimed at improving "VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no major functional changes are made/added. 1. Update `alignedConversionPrecondition` (1): This method didn't require the vector type for the "destination" argument. The underlying element type is sufficient. The corresponding argument has been renamed as `multiByteScalarTy` - this is meant as the multi-byte emulated type (`i8`, `i16`, `i32`, etc). 2. Update `alignedConversionPrecondition` (2): In #121298, we replaced `dstElemBitwidt` in this calculation: ```cpp const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth; ``` with the hard-coded value of 8: ```cpp const int numSrcElemsPerDestElem = 8 / srcElemBitwidth; ``` That was correct as for the patterns for which this hook was/is used: * `RewriteAlignedSubByteIntExt`, * `RewriteAlignedSubByteIntTrunc`. The destination type (or, more precisely, the emulated type) was always `i8`. In this PR, I am switching back to a more generic approach - the calculation should take into account the bit-width of the emulated type. Note that at the call sites I am passing `i8` as the emulated type, so the end-result is effectively identical. However, the intent is clearer, i.e., the underlying value is 8 because the emulated type happens to be `i8` (as opposed using a magic number). 3. Update alignedConversionPrecondition (3): The final check has been replaced with a new helper method, `isSubByteVecFittable`. This new method is also re-used within the code and hopefully will allow us more code re-use moving forward (to avoid re-implementing the same condition). 4. Update alignedConversionPrecondition (4): NEXT STEPS: We need to clarify the meaning of "source" and "destination" types. Currently the usage is ambiguous. For example, for this `arith.extsi` Op, `vector<8xi2>` and `vector<8xi32>` are the "source" and "destination" types, respectively: ```mlir %0 = arith.extsi %arg0 : vector<8xi2> to vector<8xi32> } ``` However, patterns like `RewriteAlignedSubByteIntExt` introduce `vector.bitcast` Ops like this: ```mlir %bitcast = vector.bitcast %arg0 : vector<8xi2> to vector<2xi8> ``` I've noticed that we tend to mix `vector<2xi8>` and `vector<8xi32>` as the destination types and that should be clarified.
1 parent a522c22 commit b3b67eb

File tree

1 file changed

+102
-32
lines changed

1 file changed

+102
-32
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 102 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,6 +1091,38 @@ struct ConvertVectorMaskedLoad final
10911091
}
10921092
};
10931093

1094+
/// Check whether `subByteVecTy` fits wthin a vector of `multiByteScalarTy`
1095+
///
1096+
/// "Fitting" means that `subByteVecTy` (a vector of sub-byte elements, e.g.
1097+
/// vector<4xi4>), can fit within N scalar elements of type `multiByteScalarTy`
1098+
/// (a multi-byte scalar, e.g. i16), where N is some integer.
1099+
///
1100+
/// Put differently, this method checks whether this would be valid:
1101+
///
1102+
/// vector.bitcast subByteVecTy into vector<N x multiByteScalarTy>
1103+
///
1104+
/// EXAMPLES:
1105+
/// * vector<4xi4> -> i16 - yes (N = 1)
1106+
/// * vector<4xi4> -> i8 - yes (N = 2)
1107+
/// * vector<3xi4> -> i8 - no (N would have to be 1.5)
1108+
/// * vector<3xi2> -> i16 - no (N would have to be 0.5)
1109+
static bool isSubByteVecFittable(VectorType subByteVecTy,
1110+
Type multiByteScalarTy) {
1111+
assert((isa<IntegerType, FloatType>(multiByteScalarTy)) && "Not scalar!");
1112+
1113+
int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
1114+
int multiByteBits = multiByteScalarTy.getIntOrFloatBitWidth();
1115+
1116+
assert(subByteBits < 8 && "Not a sub-byte scalar type!");
1117+
assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
1118+
assert(multiByteBits % subByteBits == 0 && "Unalagined element types!");
1119+
1120+
int elemsPerMultiByte = multiByteBits / subByteBits;
1121+
1122+
// TODO: This is a bit too restrictive for vectors rank > 1.
1123+
return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
1124+
}
1125+
10941126
//===----------------------------------------------------------------------===//
10951127
// ConvertVectorTransferRead
10961128
//===----------------------------------------------------------------------===//
@@ -1127,7 +1159,8 @@ struct ConvertVectorTransferRead final
11271159
auto origElements = op.getVectorType().getNumElements();
11281160

11291161
// Note, per-element-alignment was already verified above.
1130-
bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
1162+
bool isFullyAligned =
1163+
isSubByteVecFittable(op.getVectorType(), containerElemTy);
11311164

11321165
auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
11331166
adaptor.getPadding());
@@ -1428,41 +1461,76 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
14281461
return commonConversionPrecondition(rewriter, preconditionType, op);
14291462
}
14301463

1431-
/// Verify that `subByteVecType` and `dstType` are aligned. Alignment
1432-
/// means that:
1433-
/// 1. The `dstType` element type is a multiple of the
1434-
/// `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8
1435-
/// is not supported). Let this multiple be `N`.
1436-
/// 2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a
1437-
/// multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is
1438-
/// not supported).
1464+
/// Verify that `subByteVecTy` (vector) and `containerTy` (scalar) are aligned.
1465+
///
1466+
/// Alignment means that `subByteVecTy` can be packed into a vector of
1467+
/// `containerTy` elements. More specifically:
1468+
/// 1. The bit-width of `containerTy` is a multiple of the
1469+
/// bit-width of `subByteVecTy` elements. For example, for `i4` and `i16`
1470+
/// this multiple is 4.
1471+
/// 2. The multiple from 1. above divides evenly the number of the (trailing)
1472+
/// elements in `subByteVecTy`.
1473+
///
1474+
/// EXAMPLE 1:
1475+
/// `subByteVecTy = vector<2xi4>`, and
1476+
/// `containerTy = i16`
1477+
///
1478+
/// 2 divides evenly 4 ( = 16 / 4), hence both conditions are _met_.
1479+
///
1480+
/// EXAMPLE 2:
1481+
/// `subByteVecTy = vector<3xi4>`, and
1482+
/// `containerTy = i16`
1483+
///
1484+
/// 3 _does not_ divide evenly 4 (= 16/4), hence the conditions are _not met_.
1485+
///
1486+
/// EXAMPLE 3:
1487+
/// `subByteVecTy = vector<3xi3>`, and
1488+
/// `containerTy = i16`
1489+
///
1490+
/// 16 _is not_ a multiple of 3, hence the conditions are _not met_.
14391491
///
14401492
/// NOTE: This method assumes that common conversion preconditions are met. In
1441-
/// particular, the element type of `dstType` is assumed to be a multi-byte
1442-
/// type (e.g. i8, i16, i32).
1493+
/// particular, `containerTy` is assumed to be a
1494+
/// multi-byte scalar type (e.g., i8, i16, i32).
14431495
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
1444-
VectorType subByteVecType,
1445-
VectorType dstType,
1496+
VectorType subByteVecTy,
1497+
Type containerTy,
14461498
Operation *op) {
1447-
if (!subByteVecType || !dstType)
1448-
return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
1449-
unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth();
1450-
unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
1499+
// TODO: This is validating the inputs rather than checking the conditions
1500+
// documented above. Replace with an assert.
1501+
if (!subByteVecTy)
1502+
return rewriter.notifyMatchFailure(op, "not a vector!");
14511503

1452-
if (dstElemBitwidth < 8)
1453-
return rewriter.notifyMatchFailure(
1454-
op, "the bitwidth of dstType must be greater than or equal to 8");
1455-
if (dstElemBitwidth % srcElemBitwidth != 0)
1456-
return rewriter.notifyMatchFailure(op, "unaligned cases are not supported");
1457-
if (srcElemBitwidth != 2 && srcElemBitwidth != 4)
1504+
// TODO: This is validating the inputs rather than checking the conditions
1505+
// documented above. Replace with an assert.
1506+
if (!containerTy.isIntOrFloat())
1507+
return rewriter.notifyMatchFailure(op, "not a scalar!");
1508+
1509+
unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
1510+
unsigned multiByteBits = containerTy.getIntOrFloatBitWidth();
1511+
1512+
// Enforced by the common pre-conditions.
1513+
assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
1514+
1515+
// TODO: Remove this condition - the assert above (and
1516+
// commonConversionPrecondtion) takes care of that.
1517+
if (multiByteBits < 8)
1518+
return rewriter.notifyMatchFailure(op, "not a multi-byte scalar type!");
1519+
1520+
// TODO: Add support other widths (when/if needed)
1521+
if (subByteBits != 2 && subByteBits != 4)
14581522
return rewriter.notifyMatchFailure(
1459-
op, "only src bitwidth of 2 or 4 is supported at this moment");
1523+
op, "only 2-bit and 4-bit sub-byte type is supported at this moment");
1524+
1525+
// Condition 1.
1526+
if (multiByteBits % subByteBits != 0)
1527+
return rewriter.notifyMatchFailure(op, "unalagined element types");
14601528

1461-
const int numSrcElemsPerByte = 8 / srcElemBitwidth;
1462-
if ((subByteVecType.getShape().back() % numSrcElemsPerByte) != 0)
1529+
// Condition 2.
1530+
if (!isSubByteVecFittable(subByteVecTy, containerTy))
14631531
return rewriter.notifyMatchFailure(
1464-
op, "the trailing dimension of the input vector of sub-bytes must be a "
1465-
"multiple of 8 / <sub-byte-width>");
1532+
op, "not possible to fit this sub-byte vector type into a vector of "
1533+
"the given multi-byte type");
14661534

14671535
return success();
14681536
}
@@ -1899,8 +1967,9 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
18991967
return failure();
19001968

19011969
// Check general alignment preconditions.
1902-
if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
1903-
conversionOp)))
1970+
Type containerType = rewriter.getI8Type();
1971+
if (failed(alignedConversionPrecondition(rewriter, srcVecType,
1972+
containerType, conversionOp)))
19041973
return failure();
19051974

19061975
// Perform the rewrite.
@@ -1964,8 +2033,9 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
19642033

19652034
// Check general alignment preconditions. We invert the src/dst type order
19662035
// to reuse the existing precondition logic.
1967-
if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
1968-
truncOp)))
2036+
Type containerType = rewriter.getI8Type();
2037+
if (failed(alignedConversionPrecondition(rewriter, dstVecType,
2038+
containerType, truncOp)))
19692039
return failure();
19702040

19712041
// Create a new iX -> i8 truncation op.

0 commit comments

Comments
 (0)