@@ -896,17 +896,17 @@ static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
896
896
auto i8VecType = VectorType::get (i8VecShape, rewriter.getI8Type ());
897
897
Value i8Vector = rewriter.create <vector::BitCastOp>(loc, i8VecType, srcValue);
898
898
899
- // 2 Extend the i4 elements using shifts & masking. Low i4 elemens of each
900
- // byte are place in one vector and the high i4 elements in another vector.
901
- constexpr unsigned char lowBitsMask = 15 ; // Equivalent to [0000IIII ] bit mask
899
+ // 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
900
+ // byte are placed in one vector and the high i4 elements in another vector.
901
+ constexpr uint8_t lowBitsMask = 15 ; // Equivalent to [00001111 ] bit mask
902
902
auto lowBitsMaskValues = rewriter.create <arith::ConstantOp>(
903
903
loc, DenseElementsAttr::get (i8VecType, lowBitsMask));
904
- Value low = rewriter.create <arith::AndIOp>(loc, i8Vector. getType () , i8Vector,
904
+ Value low = rewriter.create <arith::AndIOp>(loc, i8VecType , i8Vector,
905
905
lowBitsMaskValues);
906
906
constexpr int8_t highBitsToShift = 4 ;
907
907
auto highShiftValues = rewriter.create <arith::ConstantOp>(
908
908
loc, DenseElementsAttr::get (i8VecType, highBitsToShift));
909
- Value high = rewriter.create <arith::ShRSIOp >(loc, i8Vector, highShiftValues);
909
+ Value high = rewriter.create <arith::ShRUIOp >(loc, i8Vector, highShiftValues);
910
910
911
911
// 3. Interleave low and high i8 elements.
912
912
return rewriter.create <vector::InterleaveOp>(loc, low, high);
@@ -1080,9 +1080,10 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
1080
1080
1081
1081
// / Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
1082
1082
// / bitwise ops that take advantage of high-level information to avoid leaving
1083
- // / LLVM to scramble with peephole optimizations.
1083
+ // / LLVM to scramble with peephole optimizations. Templated to choose between
1084
+ // / signed and unsigned conversions.
1084
1085
// /
1085
- // / For example:
1086
+ // / For example (signed) :
1086
1087
// / arith.extsi %in : vector<8xi4> to vector<8xi32>
1087
1088
// / is rewriten as
1088
1089
// / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
@@ -1101,60 +1102,25 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
1101
1102
// / %4 = vector.interleave %2, %3 : vector<4xi8>
1102
1103
// / %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
1103
1104
// /
1104
- template <typename ConversionOpType>
1105
- struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
1106
- using OpRewritePattern<ConversionOpType>::OpRewritePattern;
1107
-
1108
- LogicalResult matchAndRewrite (ConversionOpType conversionOp,
1109
- PatternRewriter &rewriter) const override {
1110
- // Verify the preconditions.
1111
- Value srcValue = conversionOp.getIn ();
1112
- auto srcVecType = dyn_cast<VectorType>(srcValue.getType ());
1113
- auto dstVecType = dyn_cast<VectorType>(conversionOp.getType ());
1114
- if (failed (
1115
- commonConversionPrecondition (rewriter, dstVecType, conversionOp)))
1116
- return failure ();
1117
-
1118
- // Check general alignment preconditions.
1119
- if (failed (alignedConversionPrecondition (rewriter, srcVecType, dstVecType,
1120
- conversionOp)))
1121
- return failure ();
1122
-
1123
- // Perform the rewrite.
1124
- Value subByteExt =
1125
- rewriteI4ToI8SignedExt (rewriter, conversionOp.getLoc (), srcValue);
1126
-
1127
- // Finalize the rewrite.
1128
- rewriter.replaceOpWithNewOp <ConversionOpType>(
1129
- conversionOp, conversionOp.getType (), subByteExt);
1130
- return success ();
1131
- }
1132
- };
1133
-
1134
- // / Rewrite the i4 -> i8 part of any unsigned conversion into a sequence of
1135
- // / shuffles and bitwise ops that take advantage of high-level information to
1136
- // / avoid leaving LLVM to scramble with peephole optimizations.
1137
- // /
1138
- // / For example:
1105
+ // / Example (unsigned):
1139
1106
// / arith.extui %in : vector<8xi4> to vector<8xi32>
1140
1107
// / is rewritten as
1141
1108
// / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1142
1109
// / %1 = arith.andi %0, 15 : vector<4xi8>
1143
- // / %2 = arith.shrsi %0, 4 : vector<4xi8>
1110
+ // / %2 = arith.shrui %0, 4 : vector<4xi8>
1144
1111
// / %3 = vector.interleave %1, %2 : vector<4xi8>
1145
- // / %4 = arith.extsi %3 : vector<8xi8> to vector<8xi32>
1112
+ // / %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
1146
1113
// /
1147
- template <typename ConversionOpType>
1148
- struct RewriteAlignedSubByteIntUnsignedExt
1149
- : OpRewritePattern<ConversionOpType> {
1114
+ template <typename ConversionOpType, bool isSigned>
1115
+ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
1150
1116
using OpRewritePattern<ConversionOpType>::OpRewritePattern;
1151
1117
1152
1118
LogicalResult matchAndRewrite (ConversionOpType conversionOp,
1153
1119
PatternRewriter &rewriter) const override {
1154
1120
// Verify the preconditions.
1155
1121
Value srcValue = conversionOp.getIn ();
1156
- auto srcVecType = dyn_cast <VectorType>(srcValue.getType ());
1157
- auto dstVecType = dyn_cast <VectorType>(conversionOp.getType ());
1122
+ auto srcVecType = cast <VectorType>(srcValue.getType ());
1123
+ auto dstVecType = cast <VectorType>(conversionOp.getType ());
1158
1124
if (failed (
1159
1125
commonConversionPrecondition (rewriter, dstVecType, conversionOp)))
1160
1126
return failure ();
@@ -1165,8 +1131,14 @@ struct RewriteAlignedSubByteIntUnsignedExt
1165
1131
return failure ();
1166
1132
1167
1133
// Perform the rewrite.
1168
- Value subByteExt =
1169
- rewriteI4ToI8UnsignedExt (rewriter, conversionOp.getLoc (), srcValue);
1134
+ Value subByteExt;
1135
+ if (isSigned) {
1136
+ subByteExt =
1137
+ rewriteI4ToI8SignedExt (rewriter, conversionOp.getLoc (), srcValue);
1138
+ } else {
1139
+ subByteExt =
1140
+ rewriteI4ToI8UnsignedExt (rewriter, conversionOp.getLoc (), srcValue);
1141
+ }
1170
1142
1171
1143
// Finalize the rewrite.
1172
1144
rewriter.replaceOpWithNewOp <ConversionOpType>(
@@ -1305,11 +1277,11 @@ void vector::populateVectorNarrowTypeRewritePatterns(
1305
1277
1306
1278
// Patterns for aligned cases. We set higher priority as they are expected to
1307
1279
// generate better performance for aligned cases.
1308
- patterns.add <RewriteAlignedSubByteIntSignedExt <arith::ExtSIOp>,
1309
- RewriteAlignedSubByteIntSignedExt <arith::SIToFPOp>,
1280
+ patterns.add <RewriteAlignedSubByteIntExt <arith::ExtSIOp, /* isSigned= */ true >,
1281
+ RewriteAlignedSubByteIntExt <arith::SIToFPOp, /* isSigned= */ true >,
1310
1282
RewriteAlignedSubByteIntTrunc>(patterns.getContext (),
1311
1283
benefit.getBenefit () + 1 );
1312
- patterns.add <RewriteAlignedSubByteIntUnsignedExt <arith::ExtUIOp>>(
1284
+ patterns.add <RewriteAlignedSubByteIntExt <arith::ExtUIOp, /* isSigned= */ false >>(
1313
1285
patterns.getContext (), benefit.getBenefit () + 1 );
1314
1286
}
1315
1287
0 commit comments