Skip to content

Commit 1b83c81

Browse files
committed
review comments
1 parent f879bab commit 1b83c81

File tree

2 files changed

+48
-79
lines changed

2 files changed

+48
-79
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 26 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -896,17 +896,17 @@ static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
896896
auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
897897
Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
898898

899-
// 2 Extend the i4 elements using shifts & masking. Low i4 elemens of each
900-
// byte are place in one vector and the high i4 elements in another vector.
901-
constexpr unsigned char lowBitsMask = 15; // Equivalent to [0000IIII] bit mask
899+
// 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
900+
// byte are placed in one vector and the high i4 elements in another vector.
901+
constexpr uint8_t lowBitsMask = 15; // Equivalent to [00001111] bit mask
902902
auto lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
903903
loc, DenseElementsAttr::get(i8VecType, lowBitsMask));
904-
Value low = rewriter.create<arith::AndIOp>(loc, i8Vector.getType(), i8Vector,
904+
Value low = rewriter.create<arith::AndIOp>(loc, i8VecType, i8Vector,
905905
lowBitsMaskValues);
906906
constexpr int8_t highBitsToShift = 4;
907907
auto highShiftValues = rewriter.create<arith::ConstantOp>(
908908
loc, DenseElementsAttr::get(i8VecType, highBitsToShift));
909-
Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, highShiftValues);
909+
Value high = rewriter.create<arith::ShRUIOp>(loc, i8Vector, highShiftValues);
910910

911911
// 3. Interleave low and high i8 elements.
912912
return rewriter.create<vector::InterleaveOp>(loc, low, high);
@@ -1080,9 +1080,10 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
10801080

10811081
/// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
10821082
/// bitwise ops that take advantage of high-level information to avoid leaving
1083-
/// LLVM to scramble with peephole optimizations.
1083+
/// LLVM to scramble with peephole optimizations. Templated to choose between
1084+
/// signed and unsigned conversions.
10841085
///
1085-
/// For example:
1086+
/// For example (signed):
10861087
/// arith.extsi %in : vector<8xi4> to vector<8xi32>
10871088
/// is rewriten as
10881089
/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
@@ -1101,60 +1102,25 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
11011102
/// %4 = vector.interleave %2, %3 : vector<4xi8>
11021103
/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
11031104
///
1104-
template <typename ConversionOpType>
1105-
struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
1106-
using OpRewritePattern<ConversionOpType>::OpRewritePattern;
1107-
1108-
LogicalResult matchAndRewrite(ConversionOpType conversionOp,
1109-
PatternRewriter &rewriter) const override {
1110-
// Verify the preconditions.
1111-
Value srcValue = conversionOp.getIn();
1112-
auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
1113-
auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
1114-
if (failed(
1115-
commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
1116-
return failure();
1117-
1118-
// Check general alignment preconditions.
1119-
if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
1120-
conversionOp)))
1121-
return failure();
1122-
1123-
// Perform the rewrite.
1124-
Value subByteExt =
1125-
rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
1126-
1127-
// Finalize the rewrite.
1128-
rewriter.replaceOpWithNewOp<ConversionOpType>(
1129-
conversionOp, conversionOp.getType(), subByteExt);
1130-
return success();
1131-
}
1132-
};
1133-
1134-
/// Rewrite the i4 -> i8 part of any unsigned conversion into a sequence of
1135-
/// shuffles and bitwise ops that take advantage of high-level information to
1136-
/// avoid leaving LLVM to scramble with peephole optimizations.
1137-
///
1138-
/// For example:
1105+
/// Example (unsigned):
11391106
/// arith.extui %in : vector<8xi4> to vector<8xi32>
11401107
/// is rewritten as
11411108
/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
11421109
/// %1 = arith.andi %0, 15 : vector<4xi8>
1143-
/// %2 = arith.shrsi %0, 4 : vector<4xi8>
1110+
/// %2 = arith.shrui %0, 4 : vector<4xi8>
11441111
/// %3 = vector.interleave %1, %2 : vector<4xi8>
1145-
/// %4 = arith.extsi %3 : vector<8xi8> to vector<8xi32>
1112+
/// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
11461113
///
1147-
template <typename ConversionOpType>
1148-
struct RewriteAlignedSubByteIntUnsignedExt
1149-
: OpRewritePattern<ConversionOpType> {
1114+
template <typename ConversionOpType, bool isSigned>
1115+
struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
11501116
using OpRewritePattern<ConversionOpType>::OpRewritePattern;
11511117

11521118
LogicalResult matchAndRewrite(ConversionOpType conversionOp,
11531119
PatternRewriter &rewriter) const override {
11541120
// Verify the preconditions.
11551121
Value srcValue = conversionOp.getIn();
1156-
auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
1157-
auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
1122+
auto srcVecType = cast<VectorType>(srcValue.getType());
1123+
auto dstVecType = cast<VectorType>(conversionOp.getType());
11581124
if (failed(
11591125
commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
11601126
return failure();
@@ -1165,8 +1131,14 @@ struct RewriteAlignedSubByteIntUnsignedExt
11651131
return failure();
11661132

11671133
// Perform the rewrite.
1168-
Value subByteExt =
1169-
rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
1134+
Value subByteExt;
1135+
if (isSigned) {
1136+
subByteExt =
1137+
rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
1138+
} else {
1139+
subByteExt =
1140+
rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
1141+
}
11701142

11711143
// Finalize the rewrite.
11721144
rewriter.replaceOpWithNewOp<ConversionOpType>(
@@ -1305,11 +1277,11 @@ void vector::populateVectorNarrowTypeRewritePatterns(
13051277

13061278
// Patterns for aligned cases. We set higher priority as they are expected to
13071279
// generate better performance for aligned cases.
1308-
patterns.add<RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
1309-
RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>,
1280+
patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp, /*isSigned=*/true>,
1281+
RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
13101282
RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
13111283
benefit.getBenefit() + 1);
1312-
patterns.add<RewriteAlignedSubByteIntUnsignedExt<arith::ExtUIOp>>(
1284+
patterns.add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>>(
13131285
patterns.getContext(), benefit.getBenefit() + 1);
13141286
}
13151287

mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -326,44 +326,41 @@ func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
326326

327327
// CHECK-LABEL: func.func @aligned_extui(
328328
func.func @aligned_extui(%a: vector<8xi4>) -> vector<8xi32> {
329-
// CHECK-SAME: %[[VAL_0:.*]]: vector<8xi4>) -> vector<8xi32> {
330-
// CHECK: %[[VAL_1:.*]] = arith.constant dense<4> : vector<4xi8>
331-
// CHECK: %[[VAL_2:.*]] = arith.constant dense<15> : vector<4xi8>
332-
// CHECK: %[[VAL_3:.*]] = vector.bitcast %[[VAL_0]] : vector<8xi4> to vector<4xi8>
333-
// CHECK: %[[VAL_4:.*]] = arith.andi %[[VAL_3]], %[[VAL_2]] : vector<4xi8>
334-
// CHECK: %[[VAL_5:.*]] = arith.shrsi %[[VAL_3]], %[[VAL_1]] : vector<4xi8>
335-
// CHECK: %[[VAL_6:.*]] = vector.interleave %[[VAL_4]], %[[VAL_5]] : vector<4xi8>
336-
// CHECK: %[[VAL_7:.*]] = arith.extui %[[VAL_6]] : vector<8xi8> to vector<8xi32>
329+
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
330+
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
331+
// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
332+
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
333+
// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
334+
// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
335+
// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
336+
// CHECK: %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
337337
%0 = arith.extui %a : vector<8xi4> to vector<8xi32>
338338
return %0 : vector<8xi32>
339339
}
340340

341-
342341
// CHECK-LABEL: func.func @aligned_extui_2d(
343342
func.func @aligned_extui_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
344343
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
345-
// CHECK: %[[VAL_1:.*]] = arith.constant dense<4> : vector<8x16xi8>
346-
// CHECK: %[[VAL_2:.*]] = arith.constant dense<15> : vector<8x16xi8>
347-
// CHECK: %[[VAL_3:.*]] = vector.bitcast %[[VAL_0]] : vector<8x32xi4> to vector<8x16xi8>
348-
// CHECK: %[[VAL_4:.*]] = arith.andi %[[VAL_3]], %[[VAL_2]] : vector<8x16xi8>
349-
// CHECK: %[[VAL_5:.*]] = arith.shrsi %[[VAL_3]], %[[VAL_1]] : vector<8x16xi8>
350-
// CHECK: %[[VAL_6:.*]] = vector.interleave %[[VAL_4]], %[[VAL_5]] : vector<8x16xi8>
351-
// CHECK: %[[VAL_7:.*]] = arith.extui %[[VAL_6]] : vector<8x32xi8> to vector<8x32xi32>
352-
// CHECK: return %[[VAL_7]] : vector<8x32xi32>
344+
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
345+
// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
346+
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[VAL_0]] : vector<8x32xi4> to vector<8x16xi8>
347+
// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8>
348+
// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
349+
// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
350+
// CHECK: %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
353351
%0 = arith.extui %a : vector<8x32xi4> to vector<8x32xi32>
354352
return %0 : vector<8x32xi32>
355353
}
356354

357-
358355
// CHECK-LABEL: func.func @aligned_extui_base_case(
359356
func.func @aligned_extui_base_case(%a: vector<8xi4>) -> vector<8xi8> {
360-
// CHECK-SAME: %[[VAL_0:.*]]: vector<8xi4>) -> vector<8xi8> {
361-
// CHECK: %[[VAL_1:.*]] = arith.constant dense<4> : vector<4xi8>
362-
// CHECK: %[[VAL_2:.*]] = arith.constant dense<15> : vector<4xi8>
363-
// CHECK: %[[VAL_3:.*]] = vector.bitcast %[[VAL_0]] : vector<8xi4> to vector<4xi8>
364-
// CHECK: %[[VAL_4:.*]] = arith.andi %[[VAL_3]], %[[VAL_2]] : vector<4xi8>
365-
// CHECK: %[[VAL_5:.*]] = arith.shrsi %[[VAL_3]], %[[VAL_1]] : vector<4xi8>
366-
// CHECK: %[[VAL_6:.*]] = vector.interleave %[[VAL_4]], %[[VAL_5]] : vector<4xi8>
357+
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
358+
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
359+
// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
360+
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
361+
// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
362+
// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
363+
// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
367364
%0 = arith.extui %a : vector<8xi4> to vector<8xi8>
368365
return %0 : vector<8xi8>
369366
}

0 commit comments

Comments
 (0)