@@ -880,6 +880,38 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
880
880
return rewriter.create <vector::InterleaveOp>(loc, low, high);
881
881
}
882
882
883
+ // / Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
884
+ // / bitwise ops that take advantage of high-level information to avoid leaving
885
+ // / LLVM to scramble with peephole optimizations.
886
+ static Value rewriteI4ToI8UnsignedExt (PatternRewriter &rewriter, Location loc,
887
+ Value srcValue) {
888
+ VectorType srcVecType = cast<VectorType>(srcValue.getType ());
889
+ assert (srcVecType.getElementType ().isSignlessInteger (4 ) &&
890
+ " Expected i4 type" );
891
+
892
+ // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
893
+ SmallVector<int64_t > i8VecShape = llvm::to_vector (srcVecType.getShape ());
894
+ constexpr int64_t i4Toi8BitwidthFactor = 2 ;
895
+ i8VecShape.back () = i8VecShape.back () / i4Toi8BitwidthFactor;
896
+ auto i8VecType = VectorType::get (i8VecShape, rewriter.getI8Type ());
897
+ Value i8Vector = rewriter.create <vector::BitCastOp>(loc, i8VecType, srcValue);
898
+
899
+ // 2 Extend the i4 elements using shifts & masking. Low i4 elemens of each
900
+ // byte are place in one vector and the high i4 elements in another vector.
901
+ constexpr unsigned char lowBitsMask = 15 ; // Equivalent to [0000IIII] bit mask
902
+ auto lowBitsMaskValues = rewriter.create <arith::ConstantOp>(
903
+ loc, DenseElementsAttr::get (i8VecType, lowBitsMask));
904
+ Value low = rewriter.create <arith::AndIOp>(loc, i8Vector.getType (), i8Vector,
905
+ lowBitsMaskValues);
906
+ constexpr int8_t highBitsToShift = 4 ;
907
+ auto highShiftValues = rewriter.create <arith::ConstantOp>(
908
+ loc, DenseElementsAttr::get (i8VecType, highBitsToShift));
909
+ Value high = rewriter.create <arith::ShRSIOp>(loc, i8Vector, highShiftValues);
910
+
911
+ // 3. Interleave low and high i8 elements.
912
+ return rewriter.create <vector::InterleaveOp>(loc, low, high);
913
+ }
914
+
883
915
// / Rewrite the i8 -> i4 truncation into a sequence of shuffles and bitwise ops
884
916
// / that take advantage of high-level information to avoid leaving LLVM to
885
917
// / scramble with peephole optimizations.
@@ -1099,6 +1131,50 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
1099
1131
}
1100
1132
};
1101
1133
1134
+ // / Rewrite the i4 -> i8 part of any unsigned conversion into a sequence of
1135
+ // / shuffles and bitwise ops that take advantage of high-level information to
1136
+ // / avoid leaving LLVM to scramble with peephole optimizations.
1137
+ // /
1138
+ // / For example:
1139
+ // / arith.extui %in : vector<8xi4> to vector<8xi32>
1140
+ // / is rewritten as
1141
+ // / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1142
+ // / %1 = arith.andi %0, 15 : vector<4xi8>
1143
+ // / %2 = arith.shrsi %0, 4 : vector<4xi8>
1144
+ // / %3 = vector.interleave %1, %2 : vector<4xi8>
1145
+ // / %4 = arith.extsi %3 : vector<8xi8> to vector<8xi32>
1146
+ // /
1147
+ template <typename ConversionOpType>
1148
+ struct RewriteAlignedSubByteIntUnsignedExt
1149
+ : OpRewritePattern<ConversionOpType> {
1150
+ using OpRewritePattern<ConversionOpType>::OpRewritePattern;
1151
+
1152
+ LogicalResult matchAndRewrite (ConversionOpType conversionOp,
1153
+ PatternRewriter &rewriter) const override {
1154
+ // Verify the preconditions.
1155
+ Value srcValue = conversionOp.getIn ();
1156
+ auto srcVecType = dyn_cast<VectorType>(srcValue.getType ());
1157
+ auto dstVecType = dyn_cast<VectorType>(conversionOp.getType ());
1158
+ if (failed (
1159
+ commonConversionPrecondition (rewriter, dstVecType, conversionOp)))
1160
+ return failure ();
1161
+
1162
+ // Check general alignment preconditions.
1163
+ if (failed (alignedConversionPrecondition (rewriter, srcVecType, dstVecType,
1164
+ conversionOp)))
1165
+ return failure ();
1166
+
1167
+ // Perform the rewrite.
1168
+ Value subByteExt =
1169
+ rewriteI4ToI8UnsignedExt (rewriter, conversionOp.getLoc (), srcValue);
1170
+
1171
+ // Finalize the rewrite.
1172
+ rewriter.replaceOpWithNewOp <ConversionOpType>(
1173
+ conversionOp, conversionOp.getType (), subByteExt);
1174
+ return success ();
1175
+ }
1176
+ };
1177
+
1102
1178
// / Rewrite the i8 -> i4 part of any truncation into a sequence of shuffles and
1103
1179
// / bitwise ops that take advantage of high-level information to avoid leaving
1104
1180
// / LLVM to scramble with peephole optimizations.
@@ -1233,6 +1309,8 @@ void vector::populateVectorNarrowTypeRewritePatterns(
1233
1309
RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>,
1234
1310
RewriteAlignedSubByteIntTrunc>(patterns.getContext (),
1235
1311
benefit.getBenefit () + 1 );
1312
+ patterns.add <RewriteAlignedSubByteIntUnsignedExt<arith::ExtUIOp>>(
1313
+ patterns.getContext (), benefit.getBenefit () + 1 );
1236
1314
}
1237
1315
1238
1316
void vector::populateVectorTransposeNarrowTypeRewritePatterns (
0 commit comments