Skip to content

Commit aa17f5e

Browse files
committed
fixup! [mlir][Vector] Update VectorEmulateNarrowType.cpp (4/N)
Address comments from Alan
1 parent 696d7d2 commit aa17f5e

File tree

1 file changed

+15
-22
lines changed

1 file changed

+15
-22
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,8 +1106,8 @@ struct ConvertVectorMaskedLoad final
11061106
/// * vector<4xi4> -> i8 - yes (N = 2)
11071107
/// * vector<3xi4> -> i8 - no (N would have to be 1.5)
11081108
/// * vector<3xi2> -> i16 - no (N would have to be 0.5)
1109-
static bool isSubByteVecFittable(VectorType subByteVecTy,
1110-
Type multiByteScalarTy) {
1109+
static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
1110+
Type multiByteScalarTy) {
11111111
assert((isa<IntegerType, FloatType>(multiByteScalarTy)) && "Not scalar!");
11121112

11131113
int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
@@ -1160,7 +1160,7 @@ struct ConvertVectorTransferRead final
11601160

11611161
// Note, per-element-alignment was already verified above.
11621162
bool isFullyAligned =
1163-
isSubByteVecFittable(op.getVectorType(), containerElemTy);
1163+
fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
11641164

11651165
auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
11661166
adaptor.getPadding());
@@ -1496,38 +1496,31 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
14961496
VectorType subByteVecTy,
14971497
Type containerTy,
14981498
Operation *op) {
1499+
assert(containerTy.isIntOrFloat() &&
1500+
"container element type is not a scalar");
1501+
14991502
// TODO: This is validating the inputs rather than checking the conditions
15001503
// documented above. Replace with an assert.
15011504
if (!subByteVecTy)
15021505
return rewriter.notifyMatchFailure(op, "not a vector!");
15031506

1504-
// TODO: This is validating the inputs rather than checking the conditions
1505-
// documented above. Replace with an assert.
1506-
if (!containerTy.isIntOrFloat())
1507-
return rewriter.notifyMatchFailure(op, "not a scalar!");
1508-
15091507
unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
15101508
unsigned multiByteBits = containerTy.getIntOrFloatBitWidth();
15111509

15121510
// Enforced by the common pre-conditions.
15131511
assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
15141512

1515-
// TODO: Remove this condition - the assert above (and
1516-
// commonConversionPrecondtion) takes care of that.
1517-
if (multiByteBits < 8)
1518-
return rewriter.notifyMatchFailure(op, "not a multi-byte scalar type!");
1519-
15201513
// TODO: Add support other widths (when/if needed)
15211514
if (subByteBits != 2 && subByteBits != 4)
15221515
return rewriter.notifyMatchFailure(
15231516
op, "only 2-bit and 4-bit sub-byte type is supported at this moment");
15241517

1525-
// Condition 1.
1518+
// Condition 1 ("per-element" alignment)
15261519
if (multiByteBits % subByteBits != 0)
15271520
return rewriter.notifyMatchFailure(op, "unalagined element types");
15281521

1529-
// Condition 2.
1530-
if (!isSubByteVecFittable(subByteVecTy, containerTy))
1522+
// Condition 2 ("full" alignment)
1523+
if (!fitsInMultiByteContainerTy(subByteVecTy, containerTy))
15311524
return rewriter.notifyMatchFailure(
15321525
op, "not possible to fit this sub-byte vector type into a vector of "
15331526
"the given multi-byte type");
@@ -1967,9 +1960,9 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
19671960
return failure();
19681961

19691962
// Check general alignment preconditions.
1970-
Type containerType = rewriter.getI8Type();
1971-
if (failed(alignedConversionPrecondition(rewriter, srcVecType,
1972-
containerType, conversionOp)))
1963+
if (failed(alignedConversionPrecondition(
1964+
rewriter, srcVecType,
1965+
/*containerTy=*/rewriter.getI8Type(), conversionOp)))
19731966
return failure();
19741967

19751968
// Perform the rewrite.
@@ -2033,9 +2026,9 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
20332026

20342027
// Check general alignment preconditions. We invert the src/dst type order
20352028
// to reuse the existing precondition logic.
2036-
Type containerType = rewriter.getI8Type();
2037-
if (failed(alignedConversionPrecondition(rewriter, dstVecType,
2038-
containerType, truncOp)))
2029+
if (failed(alignedConversionPrecondition(
2030+
rewriter, dstVecType,
2031+
/*containerTy=*/rewriter.getI8Type(), truncOp)))
20392032
return failure();
20402033

20412034
// Create a new iX -> i8 truncation op.

0 commit comments

Comments
 (0)