-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][Vector] Add patterns for efficient unsigned i4 -> i8 conversion emulation #89131
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Kojo Acquah (KoolJBlack) ChangesThis PR builds on #79494 with an additional path for efficient unsigned Full diff: https://github.com/llvm/llvm-project/pull/89131.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index dc6f126aae4c87..53c5fb4dbc1da2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -880,6 +880,38 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
return rewriter.create<vector::InterleaveOp>(loc, low, high);
}
+/// Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
+ Value srcValue) {
+ VectorType srcVecType = cast<VectorType>(srcValue.getType());
+ assert(srcVecType.getElementType().isSignlessInteger(4) &&
+ "Expected i4 type");
+
+ // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
+ SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
+ constexpr int64_t i4Toi8BitwidthFactor = 2;
+ i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
+ auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
+ Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+
+ // 2 Extend the i4 elements using shifts & masking. Low i4 elemens of each
+ // byte are place in one vector and the high i4 elements in another vector.
+ constexpr unsigned char lowBitsMask = 15; // Equivalent to [0000IIII] bit mask
+ auto lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(i8VecType, lowBitsMask));
+ Value low = rewriter.create<arith::AndIOp>(loc, i8Vector.getType(), i8Vector,
+ lowBitsMaskValues);
+ constexpr int8_t highBitsToShift = 4;
+ auto highShiftValues = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(i8VecType, highBitsToShift));
+ Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, highShiftValues);
+
+ // 3. Interleave low and high i8 elements.
+ return rewriter.create<vector::InterleaveOp>(loc, low, high);
+}
+
/// Rewrite the i8 -> i4 truncation into a sequence of shuffles and bitwise ops
/// that take advantage of high-level information to avoid leaving LLVM to
/// scramble with peephole optimizations.
@@ -1099,6 +1131,50 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
}
};
+/// Rewrite the i4 -> i8 part of any unsigned conversion into a sequence of
+/// shuffles and bitwise ops that take advantage of high-level information to
+/// avoid leaving LLVM to scramble with peephole optimizations.
+///
+/// For example:
+/// arith.extui %in : vector<8xi4> to vector<8xi32>
+/// is rewritten as
+/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
+/// %1 = arith.andi %0, 15 : vector<4xi8>
+/// %2 = arith.shrsi %0, 4 : vector<4xi8>
+/// %3 = vector.interleave %1, %2 : vector<4xi8>
+/// %4 = arith.extsi %3 : vector<8xi8> to vector<8xi32>
+///
+template <typename ConversionOpType>
+struct RewriteAlignedSubByteIntUnsignedExt
+ : OpRewritePattern<ConversionOpType> {
+ using OpRewritePattern<ConversionOpType>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ConversionOpType conversionOp,
+ PatternRewriter &rewriter) const override {
+ // Verify the preconditions.
+ Value srcValue = conversionOp.getIn();
+ auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
+ auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
+ if (failed(
+ commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
+ return failure();
+
+ // Check general alignment preconditions.
+ if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
+ conversionOp)))
+ return failure();
+
+ // Perform the rewrite.
+ Value subByteExt =
+ rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
+
+ // Finalize the rewrite.
+ rewriter.replaceOpWithNewOp<ConversionOpType>(
+ conversionOp, conversionOp.getType(), subByteExt);
+ return success();
+ }
+};
+
/// Rewrite the i8 -> i4 part of any truncation into a sequence of shuffles and
/// bitwise ops that take advantage of high-level information to avoid leaving
/// LLVM to scramble with peephole optimizations.
@@ -1233,6 +1309,8 @@ void vector::populateVectorNarrowTypeRewritePatterns(
RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>,
RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
benefit.getBenefit() + 1);
+ patterns.add<RewriteAlignedSubByteIntUnsignedExt<arith::ExtUIOp>>(
+ patterns.getContext(), benefit.getBenefit() + 1);
}
void vector::populateVectorTransposeNarrowTypeRewritePatterns(
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 8f0148119806c9..6d2b49889a3392 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -324,6 +324,50 @@ func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
return %0 : vector<16x8xi7>
}
+// CHECK-LABEL: func.func @aligned_extui(
+func.func @aligned_extui(%a: vector<8xi4>) -> vector<8xi32> {
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8xi4>) -> vector<8xi32> {
+// CHECK: %[[VAL_1:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[VAL_2:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK: %[[VAL_3:.*]] = vector.bitcast %[[VAL_0]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[VAL_4:.*]] = arith.andi %[[VAL_3]], %[[VAL_2]] : vector<4xi8>
+// CHECK: %[[VAL_5:.*]] = arith.shrsi %[[VAL_3]], %[[VAL_1]] : vector<4xi8>
+// CHECK: %[[VAL_6:.*]] = vector.interleave %[[VAL_4]], %[[VAL_5]] : vector<4xi8>
+// CHECK: %[[VAL_7:.*]] = arith.extui %[[VAL_6]] : vector<8xi8> to vector<8xi32>
+ %0 = arith.extui %a : vector<8xi4> to vector<8xi32>
+ return %0 : vector<8xi32>
+}
+
+
+// CHECK-LABEL: func.func @aligned_extui_2d(
+func.func @aligned_extui_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK: %[[VAL_1:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK: %[[VAL_2:.*]] = arith.constant dense<15> : vector<8x16xi8>
+// CHECK: %[[VAL_3:.*]] = vector.bitcast %[[VAL_0]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK: %[[VAL_4:.*]] = arith.andi %[[VAL_3]], %[[VAL_2]] : vector<8x16xi8>
+// CHECK: %[[VAL_5:.*]] = arith.shrsi %[[VAL_3]], %[[VAL_1]] : vector<8x16xi8>
+// CHECK: %[[VAL_6:.*]] = vector.interleave %[[VAL_4]], %[[VAL_5]] : vector<8x16xi8>
+// CHECK: %[[VAL_7:.*]] = arith.extui %[[VAL_6]] : vector<8x32xi8> to vector<8x32xi32>
+// CHECK: return %[[VAL_7]] : vector<8x32xi32>
+ %0 = arith.extui %a : vector<8x32xi4> to vector<8x32xi32>
+ return %0 : vector<8x32xi32>
+}
+
+
+// CHECK-LABEL: func.func @aligned_extui_base_case(
+func.func @aligned_extui_base_case(%a: vector<8xi4>) -> vector<8xi8> {
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8xi4>) -> vector<8xi8> {
+// CHECK: %[[VAL_1:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[VAL_2:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK: %[[VAL_3:.*]] = vector.bitcast %[[VAL_0]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[VAL_4:.*]] = arith.andi %[[VAL_3]], %[[VAL_2]] : vector<4xi8>
+// CHECK: %[[VAL_5:.*]] = arith.shrsi %[[VAL_3]], %[[VAL_1]] : vector<4xi8>
+// CHECK: %[[VAL_6:.*]] = vector.interleave %[[VAL_4]], %[[VAL_5]] : vector<4xi8>
+ %0 = arith.extui %a : vector<8xi4> to vector<8xi8>
+ return %0 : vector<8xi8>
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
@@ -335,4 +379,3 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
-
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG! A few comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
673182a
to
1b83c81
Compare
…onversion emulation (llvm#89131)" This reverts commit 6dfaecf.
…800a3 (#17330) * torch-mlir integrated at bce800a. * llvm-project integrated at 2083e97e plus local changes: * Reverted llvm/llvm-project#89131 locally: while this change is good in its own right, the `vector.interleave` that it generates (instead of `vector.shuffle`) are not handled by some GPU codegen lowerings. * Filed #17346. * Cherry-picked Bazel build fix: llvm/llvm-project#91654 * Several e2e tests have been temporarily disabled, follow-up work is needed to reenable them: #17344 --------- Co-authored-by: MaheshRavishankar <[email protected]> Co-authored-by: Scott Todd <[email protected]>
…onversion emulation (llvm#89131)" This reverts commit 6dfaecf.
…onversion emulation (llvm#89131)" This reverts commit 6dfaecf.
…le in VectorToSPIRV (#91800) Context: iree-org/iree#17346. Test IREE integrate showing it's fixing the problem it's intended to fix, i.e. it allows IREE to drop its local revert of #89131: iree-org/iree#17359 This is added to VectorToSPIRV because SPIRV doesn't currently handle `vector.interleave` (see motivating context above). This is limited to 1D, non-scalable vectors.
…le in VectorToSPIRV (#92012) This is the second attempt at merging #91800, which bounced due to a linker error apparently caused by an undeclared dependency. `MLIRVectorToSPIRV` needed to depend on `MLIRVectorTransforms`. In fact that was a preexisting issue already flagged by the tool in https://discourse.llvm.org/t/ninja-can-now-check-for-missing-cmake-dependencies-on-generated-files/74344. Context: iree-org/iree#17346. Test IREE integrate showing it's fixing the problem it's intended to fix, i.e. it allows IREE to drop its local revert of #89131: iree-org/iree#17359 This is added to VectorToSPIRV because SPIRV doesn't currently handle `vector.interleave` (see motivating context above). This is limited to 1D, non-scalable vectors.
This allows dropping our existing local-revert of llvm/llvm-project#89131 and cherry-pick of llvm/llvm-project#91654 which we had introduced in the earlier integrate #17330. This locally reverts llvm/llvm-project#90802 because it causes numerical errors, reported at llvm/llvm-project#90802 (comment).
…800a3 (iree-org#17330) * torch-mlir integrated at bce800a. * llvm-project integrated at 2083e97e plus local changes: * Reverted llvm/llvm-project#89131 locally: while this change is good in its own right, the `vector.interleave` that it generates (instead of `vector.shuffle`) are not handled by some GPU codegen lowerings. * Filed iree-org#17346. * Cherry-picked Bazel build fix: llvm/llvm-project#91654 * Several e2e tests have been temporarily disabled, follow-up work is needed to reenable them: iree-org#17344 --------- Co-authored-by: MaheshRavishankar <[email protected]> Co-authored-by: Scott Todd <[email protected]>
This allows dropping our existing local-revert of llvm/llvm-project#89131 and cherry-pick of llvm/llvm-project#91654 which we had introduced in the earlier integrate iree-org#17330. This locally reverts llvm/llvm-project#90802 because it causes numerical errors, reported at llvm/llvm-project#90802 (comment).
…800a3 (iree-org#17330) * torch-mlir integrated at bce800a. * llvm-project integrated at 2083e97e plus local changes: * Reverted llvm/llvm-project#89131 locally: while this change is good in its own right, the `vector.interleave` that it generates (instead of `vector.shuffle`) are not handled by some GPU codegen lowerings. * Filed iree-org#17346. * Cherry-picked Bazel build fix: llvm/llvm-project#91654 * Several e2e tests have been temporarily disabled, follow-up work is needed to reenable them: iree-org#17344 --------- Co-authored-by: MaheshRavishankar <[email protected]> Co-authored-by: Scott Todd <[email protected]> Signed-off-by: Lubo Litchev <[email protected]>
This allows dropping our existing local-revert of llvm/llvm-project#89131 and cherry-pick of llvm/llvm-project#91654 which we had introduced in the earlier integrate iree-org#17330. This locally reverts llvm/llvm-project#90802 because it causes numerical errors, reported at llvm/llvm-project#90802 (comment). Signed-off-by: Lubo Litchev <[email protected]>
…lvm#115485) This pr just adds the patterns from llvm#89131 for the arith::UIToFPOp. Also does some slight renaming and moving of the tests for better readability.
This PR builds on #79494 with an additional path for efficient unsigned
i4 ->i8
type extension for 1D/2D operations. This will impact any i4 -> i8/i16/i32/i64 unsigned extensions as well as sitofp i4 -> f8/f16/f32/f64.