Skip to content

[mlir][Vector] Add patterns for efficient unsigned i4 -> i8 conversion emulation #89131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 1, 2024

Conversation

KoolJBlack
Copy link
Contributor

This PR builds on #79494 with an additional path for efficient unsigned i4 ->i8 type extension for 1D/2D operations. This will impact any i4 -> i8/i16/i32/i64 unsigned extensions as well as sitofp i4 -> f8/f16/f32/f64.

@KoolJBlack KoolJBlack marked this pull request as ready for review April 17, 2024 20:04
@llvmbot
Copy link
Member

llvmbot commented Apr 17, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Kojo Acquah (KoolJBlack)

Changes

This PR builds on #79494 with an additional path for efficient unsigned i4 ->i8 type extension for 1D/2D operations. This will impact any i4 -> i8/i16/i32/i64 unsigned extensions as well as sitofp i4 -> f8/f16/f32/f64.


Full diff: https://github.com/llvm/llvm-project/pull/89131.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+78)
  • (modified) mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir (+44-1)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index dc6f126aae4c87..53c5fb4dbc1da2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -880,6 +880,38 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
   return rewriter.create<vector::InterleaveOp>(loc, low, high);
 }
 
+/// Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
+                                      Value srcValue) {
+  VectorType srcVecType = cast<VectorType>(srcValue.getType());
+  assert(srcVecType.getElementType().isSignlessInteger(4) &&
+         "Expected i4 type");
+
+  // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
+  SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
+  constexpr int64_t i4Toi8BitwidthFactor = 2;
+  i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
+  auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
+  Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+
+  // 2 Extend the i4 elements using shifts & masking. Low i4 elemens of each
+  //  byte are place in one vector and the high i4 elements in another vector.
+  constexpr unsigned char lowBitsMask = 15; // Equivalent to [0000IIII] bit mask
+  auto lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
+      loc, DenseElementsAttr::get(i8VecType, lowBitsMask));
+  Value low = rewriter.create<arith::AndIOp>(loc, i8Vector.getType(), i8Vector,
+                                             lowBitsMaskValues);
+  constexpr int8_t highBitsToShift = 4;
+  auto highShiftValues = rewriter.create<arith::ConstantOp>(
+      loc, DenseElementsAttr::get(i8VecType, highBitsToShift));
+  Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, highShiftValues);
+
+  // 3. Interleave low and high i8 elements.
+  return rewriter.create<vector::InterleaveOp>(loc, low, high);
+}
+
 /// Rewrite the i8 -> i4 truncation into a sequence of shuffles and bitwise ops
 /// that take advantage of high-level information to avoid leaving LLVM to
 /// scramble with peephole optimizations.
@@ -1099,6 +1131,50 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
   }
 };
 
+/// Rewrite the i4 -> i8 part of any unsigned conversion into a sequence of
+/// shuffles and bitwise ops that take advantage of high-level information to
+/// avoid leaving LLVM to scramble with peephole optimizations.
+///
+/// For example:
+///    arith.extui %in : vector<8xi4> to vector<8xi32>
+///      is rewritten as
+///        %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
+///        %1 = arith.andi %0, 15 : vector<4xi8>
+///        %2 = arith.shrsi %0, 4 : vector<4xi8>
+///        %3 = vector.interleave %1, %2 : vector<4xi8>
+///        %4 = arith.extsi %3 : vector<8xi8> to vector<8xi32>
+///
+template <typename ConversionOpType>
+struct RewriteAlignedSubByteIntUnsignedExt
+    : OpRewritePattern<ConversionOpType> {
+  using OpRewritePattern<ConversionOpType>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ConversionOpType conversionOp,
+                                PatternRewriter &rewriter) const override {
+    // Verify the preconditions.
+    Value srcValue = conversionOp.getIn();
+    auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
+    auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
+    if (failed(
+            commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
+      return failure();
+
+    // Check general alignment preconditions.
+    if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
+                                             conversionOp)))
+      return failure();
+
+    // Perform the rewrite.
+    Value subByteExt =
+        rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
+
+    // Finalize the rewrite.
+    rewriter.replaceOpWithNewOp<ConversionOpType>(
+        conversionOp, conversionOp.getType(), subByteExt);
+    return success();
+  }
+};
+
 /// Rewrite the i8 -> i4 part of any truncation into a sequence of shuffles and
 /// bitwise ops that take advantage of high-level information to avoid leaving
 /// LLVM to scramble with peephole optimizations.
@@ -1233,6 +1309,8 @@ void vector::populateVectorNarrowTypeRewritePatterns(
                RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>,
                RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
                                               benefit.getBenefit() + 1);
+  patterns.add<RewriteAlignedSubByteIntUnsignedExt<arith::ExtUIOp>>(
+      patterns.getContext(), benefit.getBenefit() + 1);
 }
 
 void vector::populateVectorTransposeNarrowTypeRewritePatterns(
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 8f0148119806c9..6d2b49889a3392 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -324,6 +324,50 @@ func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
   return %0 : vector<16x8xi7>
 }
 
+// CHECK-LABEL: func.func @aligned_extui(
+func.func @aligned_extui(%a: vector<8xi4>) -> vector<8xi32> {
+// CHECK-SAME:                             %[[VAL_0:.*]]: vector<8xi4>) -> vector<8xi32> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[VAL_2:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK:           %[[VAL_3:.*]] = vector.bitcast %[[VAL_0]] : vector<8xi4> to vector<4xi8>
+// CHECK:           %[[VAL_4:.*]] = arith.andi %[[VAL_3]], %[[VAL_2]] : vector<4xi8>
+// CHECK:           %[[VAL_5:.*]] = arith.shrsi %[[VAL_3]], %[[VAL_1]] : vector<4xi8>
+// CHECK:           %[[VAL_6:.*]] = vector.interleave %[[VAL_4]], %[[VAL_5]] : vector<4xi8>
+// CHECK:           %[[VAL_7:.*]] = arith.extui %[[VAL_6]] : vector<8xi8> to vector<8xi32>
+  %0 = arith.extui %a : vector<8xi4> to vector<8xi32>
+  return %0 : vector<8xi32>
+}
+
+
+// CHECK-LABEL: func.func @aligned_extui_2d(
+func.func @aligned_extui_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-SAME:                                %[[VAL_0:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK:           %[[VAL_2:.*]] = arith.constant dense<15> : vector<8x16xi8>
+// CHECK:           %[[VAL_3:.*]] = vector.bitcast %[[VAL_0]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK:           %[[VAL_4:.*]] = arith.andi %[[VAL_3]], %[[VAL_2]] : vector<8x16xi8>
+// CHECK:           %[[VAL_5:.*]] = arith.shrsi %[[VAL_3]], %[[VAL_1]] : vector<8x16xi8>
+// CHECK:           %[[VAL_6:.*]] = vector.interleave %[[VAL_4]], %[[VAL_5]] : vector<8x16xi8>
+// CHECK:           %[[VAL_7:.*]] = arith.extui %[[VAL_6]] : vector<8x32xi8> to vector<8x32xi32>
+// CHECK:           return %[[VAL_7]] : vector<8x32xi32>
+  %0 = arith.extui %a : vector<8x32xi4> to vector<8x32xi32>
+  return %0 : vector<8x32xi32>
+}
+
+
+// CHECK-LABEL: func.func @aligned_extui_base_case(
+func.func @aligned_extui_base_case(%a: vector<8xi4>) -> vector<8xi8> {
+// CHECK-SAME:                                       %[[VAL_0:.*]]: vector<8xi4>) -> vector<8xi8> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[VAL_2:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK:           %[[VAL_3:.*]] = vector.bitcast %[[VAL_0]] : vector<8xi4> to vector<4xi8>
+// CHECK:           %[[VAL_4:.*]] = arith.andi %[[VAL_3]], %[[VAL_2]] : vector<4xi8>
+// CHECK:           %[[VAL_5:.*]] = arith.shrsi %[[VAL_3]], %[[VAL_1]] : vector<4xi8>
+// CHECK:           %[[VAL_6:.*]] = vector.interleave %[[VAL_4]], %[[VAL_5]] : vector<4xi8>
+  %0 = arith.extui %a : vector<8xi4> to vector<8xi8>
+  return %0 : vector<8xi8>
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op
@@ -335,4 +379,3 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
-

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG! A few comments

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@KoolJBlack KoolJBlack force-pushed the i4_unsigned_emulation branch from 673182a to 1b83c81 Compare April 26, 2024 18:52
@KoolJBlack KoolJBlack merged commit 6dfaecf into llvm:main May 1, 2024
4 checks passed
@KoolJBlack KoolJBlack deleted the i4_unsigned_emulation branch May 1, 2024 16:35
bjacob added a commit to iree-org/llvm-project that referenced this pull request May 9, 2024
ScottTodd added a commit to iree-org/iree that referenced this pull request May 10, 2024
…800a3 (#17330)

* torch-mlir integrated at bce800a.
* llvm-project integrated at 2083e97e plus local changes:
* Reverted llvm/llvm-project#89131 locally:
while this change is good in its own right, the `vector.interleave` that
it generates (instead of `vector.shuffle`) are not handled by some GPU
codegen lowerings.
        * Filed #17346.
* Cherry-picked Bazel build fix:
llvm/llvm-project#91654
* Several e2e tests have been temporarily disabled, follow-up work is
needed to reenable them: #17344

---------

Co-authored-by: MaheshRavishankar <[email protected]>
Co-authored-by: Scott Todd <[email protected]>
MaheshRavishankar pushed a commit to iree-org/llvm-project that referenced this pull request May 10, 2024
MaheshRavishankar pushed a commit to iree-org/llvm-project that referenced this pull request May 10, 2024
bjacob added a commit that referenced this pull request May 13, 2024
…le in VectorToSPIRV (#91800)

Context: iree-org/iree#17346.

Test IREE integrate showing it's fixing the problem it's intended to
fix, i.e. it allows IREE to drop its local revert of
#89131:

iree-org/iree#17359

This is added to VectorToSPIRV because SPIRV doesn't currently handle
`vector.interleave` (see motivating context above).

This is limited to 1D, non-scalable vectors.
bjacob added a commit that referenced this pull request May 13, 2024
…le in VectorToSPIRV (#92012)

This is the second attempt at merging #91800, which bounced due to a
linker error apparently caused by an undeclared dependency.
`MLIRVectorToSPIRV` needed to depend on `MLIRVectorTransforms`. In fact
that was a preexisting issue already flagged by the tool in
https://discourse.llvm.org/t/ninja-can-now-check-for-missing-cmake-dependencies-on-generated-files/74344.

Context: iree-org/iree#17346.

Test IREE integrate showing it's fixing the problem it's intended to
fix, i.e. it allows IREE to drop its local revert of
#89131:

iree-org/iree#17359

This is added to VectorToSPIRV because SPIRV doesn't currently handle
`vector.interleave` (see motivating context above).

This is limited to 1D, non-scalable vectors.
bjacob added a commit to iree-org/iree that referenced this pull request May 14, 2024
This allows dropping our existing local-revert of
llvm/llvm-project#89131 and cherry-pick of
llvm/llvm-project#91654 which we had introduced
in the earlier integrate #17330.

This locally reverts llvm/llvm-project#90802
because it causes numerical errors, reported at
llvm/llvm-project#90802 (comment).
bangtianliu pushed a commit to bangtianliu/iree that referenced this pull request Jun 5, 2024
…800a3 (iree-org#17330)

* torch-mlir integrated at bce800a.
* llvm-project integrated at 2083e97e plus local changes:
* Reverted llvm/llvm-project#89131 locally:
while this change is good in its own right, the `vector.interleave` that
it generates (instead of `vector.shuffle`) are not handled by some GPU
codegen lowerings.
        * Filed iree-org#17346.
* Cherry-picked Bazel build fix:
llvm/llvm-project#91654
* Several e2e tests have been temporarily disabled, follow-up work is
needed to reenable them: iree-org#17344

---------

Co-authored-by: MaheshRavishankar <[email protected]>
Co-authored-by: Scott Todd <[email protected]>
bangtianliu pushed a commit to bangtianliu/iree that referenced this pull request Jun 5, 2024
This allows dropping our existing local-revert of
llvm/llvm-project#89131 and cherry-pick of
llvm/llvm-project#91654 which we had introduced
in the earlier integrate iree-org#17330.

This locally reverts llvm/llvm-project#90802
because it causes numerical errors, reported at
llvm/llvm-project#90802 (comment).
LLITCHEV pushed a commit to LLITCHEV/iree that referenced this pull request Jul 30, 2024
…800a3 (iree-org#17330)

* torch-mlir integrated at bce800a.
* llvm-project integrated at 2083e97e plus local changes:
* Reverted llvm/llvm-project#89131 locally:
while this change is good in its own right, the `vector.interleave` that
it generates (instead of `vector.shuffle`) are not handled by some GPU
codegen lowerings.
        * Filed iree-org#17346.
* Cherry-picked Bazel build fix:
llvm/llvm-project#91654
* Several e2e tests have been temporarily disabled, follow-up work is
needed to reenable them: iree-org#17344

---------

Co-authored-by: MaheshRavishankar <[email protected]>
Co-authored-by: Scott Todd <[email protected]>
Signed-off-by: Lubo Litchev <[email protected]>
LLITCHEV pushed a commit to LLITCHEV/iree that referenced this pull request Jul 30, 2024
This allows dropping our existing local-revert of
llvm/llvm-project#89131 and cherry-pick of
llvm/llvm-project#91654 which we had introduced
in the earlier integrate iree-org#17330.

This locally reverts llvm/llvm-project#90802
because it causes numerical errors, reported at
llvm/llvm-project#90802 (comment).

Signed-off-by: Lubo Litchev <[email protected]>
banach-space pushed a commit that referenced this pull request Nov 11, 2024
…115485)

This pr just adds the patterns from
#89131 for the arith::UIToFPOp.
Also does some slight renaming and moving of the tests for better
readability.
Groverkss pushed a commit to iree-org/llvm-project that referenced this pull request Nov 15, 2024
…lvm#115485)

This pr just adds the patterns from
llvm#89131 for the arith::UIToFPOp.
Also does some slight renaming and moving of the tests for better
readability.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants