@@ -519,7 +519,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
519
519
520
520
auto origElements = valueToStore.getType ().getNumElements ();
521
521
// Note, per-element-alignment was already verified above.
522
- bool isFullyAligned = origElements % emulatedPerContainerElem == 0 ;
522
+ bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0 ;
523
523
524
524
auto stridedMetadata =
525
525
rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
@@ -535,8 +535,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
535
535
getAsOpFoldResult (adaptor.getIndices ()));
536
536
537
537
std::optional<int64_t > foldedNumFrontPadElems =
538
- isFullyAligned ? 0
539
- : getConstantIntValue (linearizedInfo.intraDataOffset );
538
+ isDivisibleInSize ? 0
539
+ : getConstantIntValue (linearizedInfo.intraDataOffset );
540
540
541
541
if (!foldedNumFrontPadElems) {
542
542
return rewriter.notifyMatchFailure (
@@ -554,7 +554,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
554
554
// need unaligned emulation because the store address is aligned and the
555
555
// source is a whole byte.
556
556
bool emulationRequiresPartialStores =
557
- !isFullyAligned || *foldedNumFrontPadElems != 0 ;
557
+ !isDivisibleInSize || *foldedNumFrontPadElems != 0 ;
558
558
if (!emulationRequiresPartialStores) {
559
559
// Basic case: storing full bytes.
560
560
auto numElements = origElements / emulatedPerContainerElem;
@@ -881,7 +881,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
881
881
882
882
auto origElements = op.getVectorType ().getNumElements ();
883
883
// Note, per-element-alignment was already verified above.
884
- bool isFullyAligned = origElements % emulatedPerContainerElem == 0 ;
884
+ bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0 ;
885
885
886
886
auto stridedMetadata =
887
887
rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
@@ -897,8 +897,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
897
897
getAsOpFoldResult (adaptor.getIndices ()));
898
898
899
899
std::optional<int64_t > foldedIntraVectorOffset =
900
- isFullyAligned ? 0
901
- : getConstantIntValue (linearizedInfo.intraDataOffset );
900
+ isDivisibleInSize ? 0
901
+ : getConstantIntValue (linearizedInfo.intraDataOffset );
902
902
903
903
// Always load enough elements which can cover the original elements.
904
904
int64_t maxintraDataOffset =
@@ -915,7 +915,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
915
915
result = dynamicallyExtractSubVector (
916
916
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
917
917
linearizedInfo.intraDataOffset , origElements);
918
- } else if (!isFullyAligned ) {
918
+ } else if (!isDivisibleInSize ) {
919
919
result = staticallyExtractSubvector (
920
920
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
921
921
}
@@ -1002,7 +1002,7 @@ struct ConvertVectorMaskedLoad final
1002
1002
auto origType = op.getVectorType ();
1003
1003
auto origElements = origType.getNumElements ();
1004
1004
// Note, per-element-alignment was already verified above.
1005
- bool isFullyAligned = origElements % emulatedPerContainerElem == 0 ;
1005
+ bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0 ;
1006
1006
1007
1007
auto stridedMetadata =
1008
1008
rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
@@ -1017,8 +1017,8 @@ struct ConvertVectorMaskedLoad final
1017
1017
getAsOpFoldResult (adaptor.getIndices ()));
1018
1018
1019
1019
std::optional<int64_t > foldedIntraVectorOffset =
1020
- isFullyAligned ? 0
1021
- : getConstantIntValue (linearizedInfo.intraDataOffset );
1020
+ isDivisibleInSize ? 0
1021
+ : getConstantIntValue (linearizedInfo.intraDataOffset );
1022
1022
1023
1023
int64_t maxIntraDataOffset =
1024
1024
foldedIntraVectorOffset.value_or (emulatedPerContainerElem - 1 );
@@ -1042,7 +1042,7 @@ struct ConvertVectorMaskedLoad final
1042
1042
passthru = dynamicallyInsertSubVector (
1043
1043
rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset ,
1044
1044
origElements);
1045
- } else if (!isFullyAligned ) {
1045
+ } else if (!isDivisibleInSize ) {
1046
1046
passthru = staticallyInsertSubvector (rewriter, loc, passthru, emptyVector,
1047
1047
*foldedIntraVectorOffset);
1048
1048
}
@@ -1070,7 +1070,7 @@ struct ConvertVectorMaskedLoad final
1070
1070
mask = dynamicallyInsertSubVector (rewriter, loc, mask, emptyMask,
1071
1071
linearizedInfo.intraDataOffset ,
1072
1072
origElements);
1073
- } else if (!isFullyAligned ) {
1073
+ } else if (!isDivisibleInSize ) {
1074
1074
mask = staticallyInsertSubvector (rewriter, loc, op.getMask (), emptyMask,
1075
1075
*foldedIntraVectorOffset);
1076
1076
}
@@ -1081,7 +1081,7 @@ struct ConvertVectorMaskedLoad final
1081
1081
result = dynamicallyExtractSubVector (
1082
1082
rewriter, loc, result, op.getPassThru (),
1083
1083
linearizedInfo.intraDataOffset , origElements);
1084
- } else if (!isFullyAligned ) {
1084
+ } else if (!isDivisibleInSize ) {
1085
1085
result = staticallyExtractSubvector (
1086
1086
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1087
1087
}
@@ -1091,6 +1091,38 @@ struct ConvertVectorMaskedLoad final
1091
1091
}
1092
1092
};
1093
1093
1094
+ // / Check whether `subByteVecTy` fits wthin a vector of `multiByteScalarTy`
1095
+ // /
1096
+ // / "Fitting" means that `subByteVecTy` (a vector of sub-byte elements, e.g.
1097
+ // / vector<4xi4>), can fit within N scalar elements of type `multiByteScalarTy`
1098
+ // / (a multi-byte scalar, e.g. i16), where N is some integer.
1099
+ // /
1100
+ // / Put differently, this method checks whether this would be valid:
1101
+ // /
1102
+ // / vector.bitcast subByteVecTy into vector<N x multiByteScalarTy>
1103
+ // /
1104
+ // / EXAMPLES:
1105
+ // / * vector<4xi4> -> i16 - yes (N = 1)
1106
+ // / * vector<4xi4> -> i8 - yes (N = 2)
1107
+ // / * vector<3xi4> -> i8 - no (N would have to be 1.5)
1108
+ // / * vector<3xi2> -> i16 - no (N would have to be 0.5)
1109
+ static bool fitsInMultiByteContainerTy (VectorType subByteVecTy,
1110
+ Type multiByteScalarTy) {
1111
+ assert ((isa<IntegerType, FloatType>(multiByteScalarTy)) && " Not scalar!" );
1112
+
1113
+ int subByteBits = subByteVecTy.getElementType ().getIntOrFloatBitWidth ();
1114
+ int multiByteBits = multiByteScalarTy.getIntOrFloatBitWidth ();
1115
+
1116
+ assert (subByteBits < 8 && " Not a sub-byte scalar type!" );
1117
+ assert (multiByteBits % 8 == 0 && " Not a multi-byte scalar type!" );
1118
+ assert (multiByteBits % subByteBits == 0 && " Unalagined element types!" );
1119
+
1120
+ int elemsPerMultiByte = multiByteBits / subByteBits;
1121
+
1122
+ // TODO: This is a bit too restrictive for vectors rank > 1.
1123
+ return subByteVecTy.getShape ().back () % elemsPerMultiByte == 0 ;
1124
+ }
1125
+
1094
1126
// ===----------------------------------------------------------------------===//
1095
1127
// ConvertVectorTransferRead
1096
1128
// ===----------------------------------------------------------------------===//
@@ -1127,7 +1159,8 @@ struct ConvertVectorTransferRead final
1127
1159
auto origElements = op.getVectorType ().getNumElements ();
1128
1160
1129
1161
// Note, per-element-alignment was already verified above.
1130
- bool isFullyAligned = origElements % emulatedPerContainerElem == 0 ;
1162
+ bool isDivisibleInSize =
1163
+ fitsInMultiByteContainerTy (op.getVectorType (), containerElemTy);
1131
1164
1132
1165
auto newPadding = rewriter.create <arith::ExtUIOp>(loc, containerElemTy,
1133
1166
adaptor.getPadding ());
@@ -1146,8 +1179,8 @@ struct ConvertVectorTransferRead final
1146
1179
getAsOpFoldResult (adaptor.getIndices ()));
1147
1180
1148
1181
std::optional<int64_t > foldedIntraVectorOffset =
1149
- isFullyAligned ? 0
1150
- : getConstantIntValue (linearizedInfo.intraDataOffset );
1182
+ isDivisibleInSize ? 0
1183
+ : getConstantIntValue (linearizedInfo.intraDataOffset );
1151
1184
1152
1185
int64_t maxIntraDataOffset =
1153
1186
foldedIntraVectorOffset.value_or (emulatedPerContainerElem - 1 );
@@ -1171,7 +1204,7 @@ struct ConvertVectorTransferRead final
1171
1204
result = dynamicallyExtractSubVector (rewriter, loc, bitCast, zeros,
1172
1205
linearizedInfo.intraDataOffset ,
1173
1206
origElements);
1174
- } else if (!isFullyAligned ) {
1207
+ } else if (!isDivisibleInSize ) {
1175
1208
result = staticallyExtractSubvector (
1176
1209
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1177
1210
}
@@ -1428,41 +1461,69 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
1428
1461
return commonConversionPrecondition (rewriter, preconditionType, op);
1429
1462
}
1430
1463
1431
- // / Verify that `subByteVecType` and `dstType` are aligned. Alignment
1432
- // / means that:
1433
- // / 1. The `dstType` element type is a multiple of the
1434
- // / `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8
1435
- // / is not supported). Let this multiple be `N`.
1436
- // / 2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a
1437
- // / multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is
1438
- // / not supported).
1464
+ // / Verify that `subByteVecTy` (vector) and `containerTy` (scalar) are aligned.
1465
+ // /
1466
+ // / Alignment means that `subByteVecTy` can be packed into a vector of
1467
+ // / `containerTy` elements. More specifically:
1468
+ // / 1. The bit-width of `containerTy` is a multiple of the
1469
+ // / bit-width of `subByteVecTy` elements. For example, for `i4` and `i16`
1470
+ // / this multiple is 4.
1471
+ // / 2. The multiple from 1. above divides evenly the number of the (trailing)
1472
+ // / elements in `subByteVecTy`.
1473
+ // /
1474
+ // / EXAMPLE 1:
1475
+ // / `subByteVecTy = vector<2xi4>`, and
1476
+ // / `containerTy = i16`
1477
+ // /
1478
+ // / 2 divides evenly 4 ( = 16 / 4), hence both conditions are _met_.
1479
+ // /
1480
+ // / EXAMPLE 2:
1481
+ // / `subByteVecTy = vector<3xi4>`, and
1482
+ // / `containerTy = i16`
1483
+ // /
1484
+ // / 3 _does not_ divide evenly 4 (= 16/4), hence the conditions are _not met_.
1485
+ // /
1486
+ // / EXAMPLE 3:
1487
+ // / `subByteVecTy = vector<3xi3>`, and
1488
+ // / `containerTy = i16`
1489
+ // /
1490
+ // / 16 _is not_ a multiple of 3, hence the conditions are _not met_.
1439
1491
// /
1440
1492
// / NOTE: This method assumes that common conversion preconditions are met. In
1441
- // / particular, the element type of `dstType ` is assumed to be a multi-byte
1442
- // / type (e.g. i8, i16, i32).
1493
+ // / particular, `containerTy ` is assumed to be a
1494
+ // / multi-byte scalar type (e.g., i8, i16, i32).
1443
1495
static LogicalResult alignedConversionPrecondition (PatternRewriter &rewriter,
1444
- VectorType subByteVecType ,
1445
- VectorType dstType ,
1496
+ VectorType subByteVecTy ,
1497
+ Type containerTy ,
1446
1498
Operation *op) {
1447
- if (!subByteVecType || !dstType)
1448
- return rewriter.notifyMatchFailure (op, " Not a supported aligned case" );
1449
- unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth ();
1450
- unsigned dstElemBitwidth = dstType.getElementTypeBitWidth ();
1499
+ assert (containerTy.isIntOrFloat () &&
1500
+ " container element type is not a scalar" );
1451
1501
1452
- if (dstElemBitwidth < 8 )
1453
- return rewriter.notifyMatchFailure (
1454
- op, " the bitwidth of dstType must be greater than or equal to 8" );
1455
- if (dstElemBitwidth % srcElemBitwidth != 0 )
1456
- return rewriter.notifyMatchFailure (op, " unaligned cases are not supported" );
1457
- if (srcElemBitwidth != 2 && srcElemBitwidth != 4 )
1502
+ // TODO: This is validating the inputs rather than checking the conditions
1503
+ // documented above. Replace with an assert.
1504
+ if (!subByteVecTy)
1505
+ return rewriter.notifyMatchFailure (op, " not a vector!" );
1506
+
1507
+ unsigned subByteBits = subByteVecTy.getElementTypeBitWidth ();
1508
+ unsigned containerBits = containerTy.getIntOrFloatBitWidth ();
1509
+
1510
+ // Enforced by the common pre-conditions.
1511
+ assert (containerBits % 8 == 0 && " Not a multi-byte scalar type!" );
1512
+
1513
+ // TODO: Add support other widths (when/if needed)
1514
+ if (subByteBits != 2 && subByteBits != 4 )
1458
1515
return rewriter.notifyMatchFailure (
1459
- op, " only src bitwidth of 2 or 4 is supported at this moment" );
1516
+ op, " only 2-bit and 4-bit sub-byte type is supported at this moment" );
1517
+
1518
+ // Condition 1 ("per-element" alignment)
1519
+ if (containerBits % subByteBits != 0 )
1520
+ return rewriter.notifyMatchFailure (op, " unalagined element types" );
1460
1521
1461
- const int numSrcElemsPerByte = 8 / srcElemBitwidth;
1462
- if ((subByteVecType. getShape (). back () % numSrcElemsPerByte) != 0 )
1522
+ // Condition 2 ("full" alignment)
1523
+ if (! fitsInMultiByteContainerTy (subByteVecTy, containerTy) )
1463
1524
return rewriter.notifyMatchFailure (
1464
- op, " the trailing dimension of the input vector of sub-bytes must be a "
1465
- " multiple of 8 / <sub -byte-width> " );
1525
+ op, " not possible to fit this sub-byte vector type into a vector of "
1526
+ " the given multi -byte type " );
1466
1527
1467
1528
return success ();
1468
1529
}
@@ -1899,8 +1960,9 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
1899
1960
return failure ();
1900
1961
1901
1962
// Check general alignment preconditions.
1902
- if (failed (alignedConversionPrecondition (rewriter, srcVecType, dstVecType,
1903
- conversionOp)))
1963
+ if (failed (alignedConversionPrecondition (
1964
+ rewriter, srcVecType,
1965
+ /* containerTy=*/ rewriter.getI8Type (), conversionOp)))
1904
1966
return failure ();
1905
1967
1906
1968
// Perform the rewrite.
@@ -1964,8 +2026,9 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
1964
2026
1965
2027
// Check general alignment preconditions. We invert the src/dst type order
1966
2028
// to reuse the existing precondition logic.
1967
- if (failed (alignedConversionPrecondition (rewriter, dstVecType, srcVecType,
1968
- truncOp)))
2029
+ if (failed (alignedConversionPrecondition (
2030
+ rewriter, dstVecType,
2031
+ /* containerTy=*/ rewriter.getI8Type (), truncOp)))
1969
2032
return failure ();
1970
2033
1971
2034
// Create a new iX -> i8 truncation op.
0 commit comments