@@ -880,38 +880,6 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
880
880
return rewriter.create <vector::InterleaveOp>(loc, low, high);
881
881
}
882
882
883
- // / Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
884
- // / bitwise ops that take advantage of high-level information to avoid leaving
885
- // / LLVM to scramble with peephole optimizations.
886
- static Value rewriteI4ToI8UnsignedExt (PatternRewriter &rewriter, Location loc,
887
- Value srcValue) {
888
- VectorType srcVecType = cast<VectorType>(srcValue.getType ());
889
- assert (srcVecType.getElementType ().isSignlessInteger (4 ) &&
890
- " Expected i4 type" );
891
-
892
- // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
893
- SmallVector<int64_t > i8VecShape = llvm::to_vector (srcVecType.getShape ());
894
- constexpr int64_t i4Toi8BitwidthFactor = 2 ;
895
- i8VecShape.back () = i8VecShape.back () / i4Toi8BitwidthFactor;
896
- auto i8VecType = VectorType::get (i8VecShape, rewriter.getI8Type ());
897
- Value i8Vector = rewriter.create <vector::BitCastOp>(loc, i8VecType, srcValue);
898
-
899
- // 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
900
- // byte are placed in one vector and the high i4 elements in another vector.
901
- constexpr uint8_t lowBitsMask = 15 ; // Equivalent to [00001111] bit mask
902
- auto lowBitsMaskValues = rewriter.create <arith::ConstantOp>(
903
- loc, DenseElementsAttr::get (i8VecType, lowBitsMask));
904
- Value low = rewriter.create <arith::AndIOp>(loc, i8VecType, i8Vector,
905
- lowBitsMaskValues);
906
- constexpr int8_t highBitsToShift = 4 ;
907
- auto highShiftValues = rewriter.create <arith::ConstantOp>(
908
- loc, DenseElementsAttr::get (i8VecType, highBitsToShift));
909
- Value high = rewriter.create <arith::ShRUIOp>(loc, i8Vector, highShiftValues);
910
-
911
- // 3. Interleave low and high i8 elements.
912
- return rewriter.create <vector::InterleaveOp>(loc, low, high);
913
- }
914
-
915
883
// / Rewrite the i8 -> i4 truncation into a sequence of shuffles and bitwise ops
916
884
// / that take advantage of high-level information to avoid leaving LLVM to
917
885
// / scramble with peephole optimizations.
@@ -1080,10 +1048,9 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
1080
1048
1081
1049
// / Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
1082
1050
// / bitwise ops that take advantage of high-level information to avoid leaving
1083
- // / LLVM to scramble with peephole optimizations. Templated to choose between
1084
- // / signed and unsigned conversions.
1051
+ // / LLVM to scramble with peephole optimizations.
1085
1052
// /
1086
- // / For example (signed) :
1053
+ // / For example:
1087
1054
// / arith.extsi %in : vector<8xi4> to vector<8xi32>
1088
1055
// / is rewriten as
1089
1056
// / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
@@ -1102,17 +1069,8 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
1102
1069
// / %4 = vector.interleave %2, %3 : vector<4xi8>
1103
1070
// / %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
1104
1071
// /
1105
- // / Example (unsigned):
1106
- // / arith.extui %in : vector<8xi4> to vector<8xi32>
1107
- // / is rewritten as
1108
- // / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1109
- // / %1 = arith.andi %0, 15 : vector<4xi8>
1110
- // / %2 = arith.shrui %0, 4 : vector<4xi8>
1111
- // / %3 = vector.interleave %1, %2 : vector<4xi8>
1112
- // / %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
1113
- // /
1114
- template <typename ConversionOpType, bool isSigned>
1115
- struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
1072
+ template <typename ConversionOpType>
1073
+ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
1116
1074
using OpRewritePattern<ConversionOpType>::OpRewritePattern;
1117
1075
1118
1076
LogicalResult matchAndRewrite (ConversionOpType conversionOp,
@@ -1121,7 +1079,6 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
1121
1079
Value srcValue = conversionOp.getIn ();
1122
1080
auto srcVecType = dyn_cast<VectorType>(srcValue.getType ());
1123
1081
auto dstVecType = dyn_cast<VectorType>(conversionOp.getType ());
1124
-
1125
1082
if (failed (
1126
1083
commonConversionPrecondition (rewriter, dstVecType, conversionOp)))
1127
1084
return failure ();
@@ -1132,14 +1089,8 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
1132
1089
return failure ();
1133
1090
1134
1091
// Perform the rewrite.
1135
- Value subByteExt;
1136
- if (isSigned) {
1137
- subByteExt =
1138
- rewriteI4ToI8SignedExt (rewriter, conversionOp.getLoc (), srcValue);
1139
- } else {
1140
- subByteExt =
1141
- rewriteI4ToI8UnsignedExt (rewriter, conversionOp.getLoc (), srcValue);
1142
- }
1092
+ Value subByteExt =
1093
+ rewriteI4ToI8SignedExt (rewriter, conversionOp.getLoc (), srcValue);
1143
1094
1144
1095
// Finalize the rewrite.
1145
1096
rewriter.replaceOpWithNewOp <ConversionOpType>(
@@ -1278,12 +1229,10 @@ void vector::populateVectorNarrowTypeRewritePatterns(
1278
1229
1279
1230
// Patterns for aligned cases. We set higher priority as they are expected to
1280
1231
// generate better performance for aligned cases.
1281
- patterns.add <RewriteAlignedSubByteIntExt <arith::ExtSIOp, /* isSigned= */ true >,
1282
- RewriteAlignedSubByteIntExt <arith::SIToFPOp, /* isSigned= */ true >,
1232
+ patterns.add <RewriteAlignedSubByteIntSignedExt <arith::ExtSIOp>,
1233
+ RewriteAlignedSubByteIntSignedExt <arith::SIToFPOp>,
1283
1234
RewriteAlignedSubByteIntTrunc>(patterns.getContext (),
1284
1235
benefit.getBenefit () + 1 );
1285
- patterns.add <RewriteAlignedSubByteIntExt<arith::ExtUIOp, /* isSigned=*/ false >>(
1286
- patterns.getContext (), benefit.getBenefit () + 1 );
1287
1236
}
1288
1237
1289
1238
void vector::populateVectorTransposeNarrowTypeRewritePatterns (
0 commit comments