Skip to content

Commit 673182a

Browse files
committed
unsigned emulation for i4
1 parent 652bcf6 commit 673182a

File tree

2 files changed

+122
-1
lines changed

2 files changed

+122
-1
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,38 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
880880
return rewriter.create<vector::InterleaveOp>(loc, low, high);
881881
}
882882

883+
/// Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
884+
/// bitwise ops that take advantage of high-level information to avoid leaving
885+
/// LLVM to scramble with peephole optimizations.
886+
static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
887+
Value srcValue) {
888+
VectorType srcVecType = cast<VectorType>(srcValue.getType());
889+
assert(srcVecType.getElementType().isSignlessInteger(4) &&
890+
"Expected i4 type");
891+
892+
// 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
893+
SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
894+
constexpr int64_t i4Toi8BitwidthFactor = 2;
895+
i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
896+
auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
897+
Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
898+
899+
// 2 Extend the i4 elements using shifts & masking. Low i4 elemens of each
900+
// byte are place in one vector and the high i4 elements in another vector.
901+
constexpr unsigned char lowBitsMask = 15; // Equivalent to [0000IIII] bit mask
902+
auto lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
903+
loc, DenseElementsAttr::get(i8VecType, lowBitsMask));
904+
Value low = rewriter.create<arith::AndIOp>(loc, i8Vector.getType(), i8Vector,
905+
lowBitsMaskValues);
906+
constexpr int8_t highBitsToShift = 4;
907+
auto highShiftValues = rewriter.create<arith::ConstantOp>(
908+
loc, DenseElementsAttr::get(i8VecType, highBitsToShift));
909+
Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, highShiftValues);
910+
911+
// 3. Interleave low and high i8 elements.
912+
return rewriter.create<vector::InterleaveOp>(loc, low, high);
913+
}
914+
883915
/// Rewrite the i8 -> i4 truncation into a sequence of shuffles and bitwise ops
884916
/// that take advantage of high-level information to avoid leaving LLVM to
885917
/// scramble with peephole optimizations.
@@ -1099,6 +1131,50 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
10991131
}
11001132
};
11011133

1134+
/// Rewrite the i4 -> i8 part of any unsigned conversion into a sequence of
1135+
/// shuffles and bitwise ops that take advantage of high-level information to
1136+
/// avoid leaving LLVM to scramble with peephole optimizations.
1137+
///
1138+
/// For example:
1139+
/// arith.extui %in : vector<8xi4> to vector<8xi32>
1140+
/// is rewritten as
1141+
/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1142+
/// %1 = arith.andi %0, 15 : vector<4xi8>
1143+
/// %2 = arith.shrsi %0, 4 : vector<4xi8>
1144+
/// %3 = vector.interleave %1, %2 : vector<4xi8>
1145+
/// %4 = arith.extsi %3 : vector<8xi8> to vector<8xi32>
1146+
///
1147+
template <typename ConversionOpType>
1148+
struct RewriteAlignedSubByteIntUnsignedExt
1149+
: OpRewritePattern<ConversionOpType> {
1150+
using OpRewritePattern<ConversionOpType>::OpRewritePattern;
1151+
1152+
LogicalResult matchAndRewrite(ConversionOpType conversionOp,
1153+
PatternRewriter &rewriter) const override {
1154+
// Verify the preconditions.
1155+
Value srcValue = conversionOp.getIn();
1156+
auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
1157+
auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
1158+
if (failed(
1159+
commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
1160+
return failure();
1161+
1162+
// Check general alignment preconditions.
1163+
if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
1164+
conversionOp)))
1165+
return failure();
1166+
1167+
// Perform the rewrite.
1168+
Value subByteExt =
1169+
rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
1170+
1171+
// Finalize the rewrite.
1172+
rewriter.replaceOpWithNewOp<ConversionOpType>(
1173+
conversionOp, conversionOp.getType(), subByteExt);
1174+
return success();
1175+
}
1176+
};
1177+
11021178
/// Rewrite the i8 -> i4 part of any truncation into a sequence of shuffles and
11031179
/// bitwise ops that take advantage of high-level information to avoid leaving
11041180
/// LLVM to scramble with peephole optimizations.
@@ -1233,6 +1309,8 @@ void vector::populateVectorNarrowTypeRewritePatterns(
12331309
RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>,
12341310
RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
12351311
benefit.getBenefit() + 1);
1312+
patterns.add<RewriteAlignedSubByteIntUnsignedExt<arith::ExtUIOp>>(
1313+
patterns.getContext(), benefit.getBenefit() + 1);
12361314
}
12371315

12381316
void vector::populateVectorTransposeNarrowTypeRewritePatterns(

mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,50 @@ func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
324324
return %0 : vector<16x8xi7>
325325
}
326326

327+
// CHECK-LABEL: func.func @aligned_extui(
328+
func.func @aligned_extui(%a: vector<8xi4>) -> vector<8xi32> {
329+
// CHECK-SAME: %[[VAL_0:.*]]: vector<8xi4>) -> vector<8xi32> {
330+
// CHECK: %[[VAL_1:.*]] = arith.constant dense<4> : vector<4xi8>
331+
// CHECK: %[[VAL_2:.*]] = arith.constant dense<15> : vector<4xi8>
332+
// CHECK: %[[VAL_3:.*]] = vector.bitcast %[[VAL_0]] : vector<8xi4> to vector<4xi8>
333+
// CHECK: %[[VAL_4:.*]] = arith.andi %[[VAL_3]], %[[VAL_2]] : vector<4xi8>
334+
// CHECK: %[[VAL_5:.*]] = arith.shrsi %[[VAL_3]], %[[VAL_1]] : vector<4xi8>
335+
// CHECK: %[[VAL_6:.*]] = vector.interleave %[[VAL_4]], %[[VAL_5]] : vector<4xi8>
336+
// CHECK: %[[VAL_7:.*]] = arith.extui %[[VAL_6]] : vector<8xi8> to vector<8xi32>
337+
%0 = arith.extui %a : vector<8xi4> to vector<8xi32>
338+
return %0 : vector<8xi32>
339+
}
340+
341+
342+
// CHECK-LABEL: func.func @aligned_extui_2d(
343+
func.func @aligned_extui_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
344+
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
345+
// CHECK: %[[VAL_1:.*]] = arith.constant dense<4> : vector<8x16xi8>
346+
// CHECK: %[[VAL_2:.*]] = arith.constant dense<15> : vector<8x16xi8>
347+
// CHECK: %[[VAL_3:.*]] = vector.bitcast %[[VAL_0]] : vector<8x32xi4> to vector<8x16xi8>
348+
// CHECK: %[[VAL_4:.*]] = arith.andi %[[VAL_3]], %[[VAL_2]] : vector<8x16xi8>
349+
// CHECK: %[[VAL_5:.*]] = arith.shrsi %[[VAL_3]], %[[VAL_1]] : vector<8x16xi8>
350+
// CHECK: %[[VAL_6:.*]] = vector.interleave %[[VAL_4]], %[[VAL_5]] : vector<8x16xi8>
351+
// CHECK: %[[VAL_7:.*]] = arith.extui %[[VAL_6]] : vector<8x32xi8> to vector<8x32xi32>
352+
// CHECK: return %[[VAL_7]] : vector<8x32xi32>
353+
%0 = arith.extui %a : vector<8x32xi4> to vector<8x32xi32>
354+
return %0 : vector<8x32xi32>
355+
}
356+
357+
358+
// CHECK-LABEL: func.func @aligned_extui_base_case(
359+
func.func @aligned_extui_base_case(%a: vector<8xi4>) -> vector<8xi8> {
360+
// CHECK-SAME: %[[VAL_0:.*]]: vector<8xi4>) -> vector<8xi8> {
361+
// CHECK: %[[VAL_1:.*]] = arith.constant dense<4> : vector<4xi8>
362+
// CHECK: %[[VAL_2:.*]] = arith.constant dense<15> : vector<4xi8>
363+
// CHECK: %[[VAL_3:.*]] = vector.bitcast %[[VAL_0]] : vector<8xi4> to vector<4xi8>
364+
// CHECK: %[[VAL_4:.*]] = arith.andi %[[VAL_3]], %[[VAL_2]] : vector<4xi8>
365+
// CHECK: %[[VAL_5:.*]] = arith.shrsi %[[VAL_3]], %[[VAL_1]] : vector<4xi8>
366+
// CHECK: %[[VAL_6:.*]] = vector.interleave %[[VAL_4]], %[[VAL_5]] : vector<4xi8>
367+
%0 = arith.extui %a : vector<8xi4> to vector<8xi8>
368+
return %0 : vector<8xi8>
369+
}
370+
327371
module attributes {transform.with_named_sequence} {
328372
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
329373
%f = transform.structured.match ops{["func.func"]} in %module_op
@@ -335,4 +379,3 @@ module attributes {transform.with_named_sequence} {
335379
transform.yield
336380
}
337381
}
338-

0 commit comments

Comments
 (0)