Skip to content

Commit f2d824e

Browse files
bjacobMaheshRavishankar
authored andcommitted
Revert "[mlir][Vector] Add patterns for efficient unsigned i4 -> i8 conversion emulation (llvm#89131)"
This reverts commit 6dfaecf.
1 parent 04ce103 commit f2d824e

File tree

2 files changed

+9
-100
lines changed

2 files changed

+9
-100
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 8 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -880,38 +880,6 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
880880
return rewriter.create<vector::InterleaveOp>(loc, low, high);
881881
}
882882

883-
/// Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
884-
/// bitwise ops that take advantage of high-level information to avoid leaving
885-
/// LLVM to scramble with peephole optimizations.
886-
static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
887-
Value srcValue) {
888-
VectorType srcVecType = cast<VectorType>(srcValue.getType());
889-
assert(srcVecType.getElementType().isSignlessInteger(4) &&
890-
"Expected i4 type");
891-
892-
// 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
893-
SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
894-
constexpr int64_t i4Toi8BitwidthFactor = 2;
895-
i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
896-
auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
897-
Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
898-
899-
// 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
900-
// byte are placed in one vector and the high i4 elements in another vector.
901-
constexpr uint8_t lowBitsMask = 15; // Equivalent to [00001111] bit mask
902-
auto lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
903-
loc, DenseElementsAttr::get(i8VecType, lowBitsMask));
904-
Value low = rewriter.create<arith::AndIOp>(loc, i8VecType, i8Vector,
905-
lowBitsMaskValues);
906-
constexpr int8_t highBitsToShift = 4;
907-
auto highShiftValues = rewriter.create<arith::ConstantOp>(
908-
loc, DenseElementsAttr::get(i8VecType, highBitsToShift));
909-
Value high = rewriter.create<arith::ShRUIOp>(loc, i8Vector, highShiftValues);
910-
911-
// 3. Interleave low and high i8 elements.
912-
return rewriter.create<vector::InterleaveOp>(loc, low, high);
913-
}
914-
915883
/// Rewrite the i8 -> i4 truncation into a sequence of shuffles and bitwise ops
916884
/// that take advantage of high-level information to avoid leaving LLVM to
917885
/// scramble with peephole optimizations.
@@ -1080,10 +1048,9 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
10801048

10811049
/// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
10821050
/// bitwise ops that take advantage of high-level information to avoid leaving
1083-
/// LLVM to scramble with peephole optimizations. Templated to choose between
1084-
/// signed and unsigned conversions.
1051+
/// LLVM to scramble with peephole optimizations.
10851052
///
1086-
/// For example (signed):
1053+
/// For example:
10871054
/// arith.extsi %in : vector<8xi4> to vector<8xi32>
10881055
/// is rewriten as
10891056
/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
@@ -1102,17 +1069,8 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
11021069
/// %4 = vector.interleave %2, %3 : vector<4xi8>
11031070
/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
11041071
///
1105-
/// Example (unsigned):
1106-
/// arith.extui %in : vector<8xi4> to vector<8xi32>
1107-
/// is rewritten as
1108-
/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1109-
/// %1 = arith.andi %0, 15 : vector<4xi8>
1110-
/// %2 = arith.shrui %0, 4 : vector<4xi8>
1111-
/// %3 = vector.interleave %1, %2 : vector<4xi8>
1112-
/// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
1113-
///
1114-
template <typename ConversionOpType, bool isSigned>
1115-
struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
1072+
template <typename ConversionOpType>
1073+
struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
11161074
using OpRewritePattern<ConversionOpType>::OpRewritePattern;
11171075

11181076
LogicalResult matchAndRewrite(ConversionOpType conversionOp,
@@ -1121,7 +1079,6 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
11211079
Value srcValue = conversionOp.getIn();
11221080
auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
11231081
auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
1124-
11251082
if (failed(
11261083
commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
11271084
return failure();
@@ -1132,14 +1089,8 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
11321089
return failure();
11331090

11341091
// Perform the rewrite.
1135-
Value subByteExt;
1136-
if (isSigned) {
1137-
subByteExt =
1138-
rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
1139-
} else {
1140-
subByteExt =
1141-
rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
1142-
}
1092+
Value subByteExt =
1093+
rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
11431094

11441095
// Finalize the rewrite.
11451096
rewriter.replaceOpWithNewOp<ConversionOpType>(
@@ -1278,12 +1229,10 @@ void vector::populateVectorNarrowTypeRewritePatterns(
12781229

12791230
// Patterns for aligned cases. We set higher priority as they are expected to
12801231
// generate better performance for aligned cases.
1281-
patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp, /*isSigned=*/true>,
1282-
RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
1232+
patterns.add<RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
1233+
RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>,
12831234
RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
12841235
benefit.getBenefit() + 1);
1285-
patterns.add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>>(
1286-
patterns.getContext(), benefit.getBenefit() + 1);
12871236
}
12881237

12891238
void vector::populateVectorTransposeNarrowTypeRewritePatterns(

mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -324,47 +324,6 @@ func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
324324
return %0 : vector<16x8xi7>
325325
}
326326

327-
// CHECK-LABEL: func.func @aligned_extui(
328-
func.func @aligned_extui(%a: vector<8xi4>) -> vector<8xi32> {
329-
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
330-
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
331-
// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
332-
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
333-
// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
334-
// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
335-
// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
336-
// CHECK: %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
337-
%0 = arith.extui %a : vector<8xi4> to vector<8xi32>
338-
return %0 : vector<8xi32>
339-
}
340-
341-
// CHECK-LABEL: func.func @aligned_extui_2d(
342-
func.func @aligned_extui_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
343-
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
344-
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
345-
// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
346-
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[VAL_0]] : vector<8x32xi4> to vector<8x16xi8>
347-
// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8>
348-
// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
349-
// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
350-
// CHECK: %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
351-
%0 = arith.extui %a : vector<8x32xi4> to vector<8x32xi32>
352-
return %0 : vector<8x32xi32>
353-
}
354-
355-
// CHECK-LABEL: func.func @aligned_extui_base_case(
356-
func.func @aligned_extui_base_case(%a: vector<8xi4>) -> vector<8xi8> {
357-
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
358-
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
359-
// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
360-
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
361-
// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
362-
// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
363-
// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
364-
%0 = arith.extui %a : vector<8xi4> to vector<8xi8>
365-
return %0 : vector<8xi8>
366-
}
367-
368327
module attributes {transform.with_named_sequence} {
369328
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
370329
%f = transform.structured.match ops{["func.func"]} in %module_op
@@ -376,3 +335,4 @@ module attributes {transform.with_named_sequence} {
376335
transform.yield
377336
}
378337
}
338+

0 commit comments

Comments
 (0)