-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[MLIR][ArmSVE] Add initial lowering of vector.contract to SVE *MMLA
instructions
#135636
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir-neon Author: Momchil Velikov (momchil-velikov) ChangesSupersedes #135359 Patch is 77.36 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/135636.diff 16 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bbba495e613b2..930d8b44abca0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1406,6 +1406,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of ArmSVE dialect while lowering the vector "
"dialect.">,
+ Option<"armI8MM", "enable-arm-i8mm",
+ "bool", /*default=*/"false",
+ "Enables the use of Arm FEAT_I8MM instructions while lowering "
+ "the vector dialect.">,
Option<"x86Vector", "enable-x86vector",
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
index 8665c8224cc45..232e2be29e574 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
@@ -20,6 +20,9 @@ class RewritePatternSet;
void populateArmSVELegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
+void populateLowerContractionToSVEI8MMPatternPatterns(
+ RewritePatternSet &patterns);
+
/// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
/// intrinsics.
void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 330474a718e30..8e2620029c354 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
MLIRVectorToLLVM
MLIRArmNeonDialect
+ MLIRArmNeonTransforms
MLIRArmSVEDialect
MLIRArmSVETransforms
MLIRAMXDialect
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 7082b92c95d1d..1e6c8122b1d0e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorStepLoweringPatterns(patterns);
populateVectorRankReducingFMAPattern(patterns);
populateVectorGatherLoweringPatterns(patterns);
+ if (armI8MM) {
+ if (armNeon)
+ arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
+ if (armSVE)
+ populateLowerContractionToSVEI8MMPatternPatterns(patterns);
+ }
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 2a1271dfd6bdf..e807b233aa7aa 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -56,6 +56,9 @@ class LowerContractionToSMMLAPattern
// Avoid 0-D vectors and 1-D rhs:
if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
return failure();
+ // Avoid scalable vectors.
+ if (lhsType.isScalable() || rhsType.isScalable())
+ return failure();
auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
auto dimN = rhsType.getDimSize(0);
auto dimK = rhsType.getDimSize(1);
@@ -238,5 +241,5 @@ class LowerContractionToSMMLAPattern
void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns(
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
- patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/1);
+ patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/2);
}
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
index a70c489a51fea..65f98b44b1b69 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRArmSVETransforms
LegalizeForLLVMExport.cpp
LegalizeVectorStorage.cpp
+ LowerContractionToSVEI8MMPattern.cpp
DEPENDS
MLIRArmSVEConversionsIncGen
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
new file mode 100644
index 0000000000000..c0620c71440bc
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Check if the given value is a result of the operation `T` (which must be
+// sign- or zero- extend) from i8 to i32. Return the value before the extension.
+template <typename T>
+inline std::enable_if_t<(std::is_base_of_v<arith::ExtSIOp, T> ||
+ std::is_base_of_v<arith::ExtUIOp, T>),
+ std::optional<Value>>
+extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
+ auto extOp = dyn_cast_or_null<T>(v.getDefiningOp());
+ if (!extOp)
+ return {};
+
+ auto inOp = extOp.getIn();
+ auto inTy = dyn_cast<VectorType>(inOp.getType());
+ if (!inTy || inTy.getElementType() != i8Ty)
+ return {};
+
+ auto outTy = dyn_cast<VectorType>(extOp.getType());
+ if (!outTy || outTy.getElementType() != i32Ty)
+ return {};
+
+ return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+ Signed, // smmla
+ Unsigned, // ummla
+ Mixed, // usmmla
+ MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix multply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+ mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
+ switch (op) {
+ case MMLA::Signed:
+ return rewriter.create<arm_sve::SmmlaOp>(loc, accType, acc, lhs, rhs);
+ case MMLA::Unsigned:
+ return rewriter.create<arm_sve::UmmlaOp>(loc, accType, acc, lhs, rhs);
+ case MMLA::Mixed:
+ return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, lhs, rhs);
+ case MMLA::MixedSwapped:
+ // The accumulator comes transposed and the result will be transposed
+ // later, so all we have to do here is swap the operands.
+ return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, rhs, lhs);
+ }
+}
+
+class LowerContractionToSVEI8MMPattern
+ : public OpRewritePattern<vector::ContractionOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override {
+
+ Location loc = op.getLoc();
+ mlir::VectorType lhsType = op.getLhsType();
+ mlir::VectorType rhsType = op.getRhsType();
+
+ // For now handle LHS<Mx8> and RHS<8x[N]> - these are the types we
+ // eventually expect from MMT4D. M and N dimensions must be even and at
+ // least 2.
+ if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() ||
+ rhsType.getRank() != 2)
+ return failure();
+
+ if (lhsType.isScalable() || !rhsType.isScalable())
+ return failure();
+
+ // M, N, and K are the conventional names for matrix dimensions in the
+ // context of matrix multiplication.
+ auto M = lhsType.getDimSize(0);
+ auto N = rhsType.getDimSize(0);
+ auto K = rhsType.getDimSize(1);
+
+ if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 ||
+ N % 2 != 0 || !rhsType.getScalableDims()[0])
+ return failure();
+
+ // Check permutation maps. For now only accept
+ // lhs: (d0, d1, d2) -> (d0, d2)
+ // rhs: (d0, d1, d2) -> (d1, d2)
+ // acc: (d0, d1, d2) -> (d0, d1)
+ // Note: RHS is transposed.
+ if (op.getIndexingMapsArray()[0] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
+ op.getContext()) ||
+ op.getIndexingMapsArray()[1] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
+ op.getContext()) ||
+ op.getIndexingMapsArray()[2] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
+ op.getContext()))
+ return failure();
+
+ // Check iterator types for matrix multiplication.
+ auto itTypes = op.getIteratorTypesArray();
+ if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
+ itTypes[1] != vector::IteratorType::parallel ||
+ itTypes[2] != vector::IteratorType::reduction)
+ return failure();
+
+ // Check the combining kind is addition.
+ if (op.getKind() != vector::CombiningKind::ADD)
+ return failure();
+
+ // Check the output is a vector of i32 elements.
+ auto outTy = dyn_cast<VectorType>(op.getType());
+ if (!outTy || outTy.getElementType() != rewriter.getI32Type())
+ return failure();
+
+ // Check inputs are sign-/zero- extensions from i8 to i32. Get the values
+ // before the extension. All four signed/unsigned combinations for input
+ // operands are supported, but they are lowered to different operations.
+ // Determina which is the appropriate operation to lower to.
+ MMLA mmlaOp = MMLA::Signed;
+ auto maybeLhs = extractExtOperand<arith::ExtSIOp>(
+ op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ if (!maybeLhs) {
+ mmlaOp = MMLA::Unsigned;
+ maybeLhs = extractExtOperand<arith::ExtUIOp>(
+ op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ }
+ if (!maybeLhs)
+ return failure();
+
+ auto maybeRhs = extractExtOperand<arith::ExtSIOp>(
+ op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ if (maybeRhs) {
+ if (mmlaOp == MMLA::Unsigned)
+ mmlaOp = MMLA::Mixed;
+ } else {
+ if (mmlaOp == MMLA::Signed)
+ mmlaOp = MMLA::MixedSwapped;
+ maybeRhs = extractExtOperand<arith::ExtUIOp>(
+ op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ }
+ if (!maybeRhs)
+ return failure();
+
+ // One-dimensional vector types for arm_sve.*mmla
+ auto nxv16i8 = VectorType::get(16, rewriter.getI8Type(), {true});
+ auto nxv4i32 = VectorType::get(4, rewriter.getI32Type(), {true});
+
+ // Extract LHS sub-tiles.
+ SmallVector<Value> lhsTile;
+ for (int64_t i = 0; i < M; i += 2) {
+ // Exract two consective rows of the LHS tile.
+ auto r0 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs,
+ ArrayRef<int64_t>{i});
+ auto r1 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs,
+ ArrayRef<int64_t>{i + 1});
+ // Concatenate to obtain a 16 x i8 flattened sub-tile.
+ auto t = rewriter.create<vector::ShuffleOp>(
+ loc, r0, r1,
+ llvm::ArrayRef<int64_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
+ 14, 15});
+ // Turn it into a scalable vector.
+ auto s = rewriter.create<vector::ScalableInsertOp>(
+ loc, t, rewriter.create<ub::PoisonOp>(loc, nxv16i8), 0);
+ // Replicate the sub-tile VSCALE times to fill the entire vector.
+ auto r = rewriter.create<arm_sve::DupQLaneOp>(loc, s, 0);
+ lhsTile.push_back(r);
+ }
+
+ // "Flatten" the RHS tile from <[N]x8> to <[8*N]>.
+ auto RHS = rewriter.create<vector::ShapeCastOp>(
+ maybeRhs->getLoc(),
+ VectorType::get(8 * N, rewriter.getI8Type(), {true}), *maybeRhs);
+
+ // Extract the RHS sub-tiles.
+ SmallVector<Value> rhsTile;
+ for (int64_t j = 0; j < N; j += 2)
+ rhsTile.push_back(
+ rewriter.create<vector::ScalableExtractOp>(loc, nxv16i8, RHS, j * 8));
+
+ // Handy types for packing/unpacking of the accumulator tile.
+ auto accRowTy = VectorType::get(N, rewriter.getI32Type(), {true});
+ auto accRowX2Ty = VectorType::get(2 * N, rewriter.getI32Type(), {true});
+ auto accRow64Ty = VectorType::get(N / 2, rewriter.getI64Type(), {true});
+ auto accRowX264Ty = VectorType::get(N, rewriter.getI64Type(), {true});
+
+ // Extract and pack the ACC sub-tiles.
+ SmallVector<Value> accTile;
+ for (int64_t i = 0; i < M; i += 2) {
+ // Extract two consecutive rows of the accumulator tile.
+ auto r0 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
+ ArrayRef<int64_t>{i});
+ auto r1 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
+ ArrayRef<int64_t>{i + 1});
+ Value accTileVec;
+ if (mmlaOp == MMLA::MixedSwapped) {
+ // We need to swap the positions of the LHS and RHS (since we don't have
+ // a signed * unsigned operation), but then each individual 2x2 tile of
+ // the acumulator and (later) the result need to be transposed.
+ accTileVec = rewriter.create<vector::InterleaveOp>(loc, r0, r1);
+ } else {
+ // Bitcast them to 64-bit elements, so subsequent
+ // interleave/deinterleave work on pairs of 32-bit numbers.
+ auto r0_i64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r0);
+ auto r1_i64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r1);
+
+ // Interleave the rows, effectively flattening each 2x2 tile into 4
+ // consecutive elements.
+ auto intr_i64 =
+ rewriter.create<vector::InterleaveOp>(loc, r0_i64, r1_i64);
+
+ // Bitcast back to 32-bit elements.
+ accTileVec =
+ rewriter.create<vector::BitCastOp>(loc, accRowX2Ty, intr_i64);
+ }
+ // Extract ACC sub-tiles.
+ for (int64_t j = 0; j < N; j += 2)
+ accTile.push_back(rewriter.create<vector::ScalableExtractOp>(
+ loc, nxv4i32, accTileVec, j * 2));
+ }
+
+ // Emit sub-tile matrix multiplications.
+ SmallVector<Value> outTile;
+ for (int64_t i = 0; i < M / 2; ++i)
+ for (int64_t j = 0; j < N / 2; ++j) {
+ Value mmla = createMMLA(rewriter, mmlaOp, loc, nxv4i32,
+ accTile[i * N / 2 + j], lhsTile[i], rhsTile[j]);
+ outTile.push_back(mmla);
+ }
+
+ // Unpack the OUT sub-tiles and insert into the result.
+ Value result = rewriter.create<ub::PoisonOp>(loc, op.getResultType());
+ for (int64_t i = 0; i < M / 2; ++i) {
+ // Collect a number of sub-tiles in a row.
+ Value row = rewriter.create<ub::PoisonOp>(loc, accRowX2Ty);
+ for (int64_t j = 0; j < N / 2; ++j)
+ row = rewriter.create<vector::ScalableInsertOp>(
+ loc, outTile[i * N / 2 + j], row, j * 4);
+
+ // Unpack the row to obtain two rows of the output. If we have the out
+ // sub-tiles transposed we obtain two consecutive output rows by
+ // separating even and odd elements, i.e. a simple deinterleave.
+ // Otherwise, the interleave is by pairs.
+ Value out0, out1;
+ if (mmlaOp == MMLA::MixedSwapped) {
+ auto tmp = rewriter.create<vector::DeinterleaveOp>(loc, row);
+ out0 = tmp.getRes1();
+ out1 = tmp.getRes2();
+ } else {
+ // Deinterleave by pairs.
+ auto row64 = rewriter.create<vector::BitCastOp>(loc, accRowX264Ty, row);
+ auto deintr64 = rewriter.create<vector::DeinterleaveOp>(loc, row64);
+
+ // Bitcast back into 32-bit elements and insert into the result.
+ out0 = rewriter.create<vector::BitCastOp>(loc, accRowTy,
+ deintr64.getRes1());
+ out1 = rewriter.create<vector::BitCastOp>(loc, accRowTy,
+ deintr64.getRes2());
+ }
+ result = rewriter.create<vector::InsertOp>(loc, out0, result, i * 2);
+ result = rewriter.create<vector::InsertOp>(loc, out1, result, i * 2 + 1);
+ }
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::populateLowerContractionToSVEI8MMPatternPatterns(
+ RewritePatternSet &patterns) {
+ MLIRContext *context = patterns.getContext();
+ patterns.add<LowerContractionToSVEI8MMPattern>(context, /*benefit=*/2);
+}
diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
new file mode 100644
index 0000000000000..2535ee9181c13
--- /dev/null
+++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
@@ -0,0 +1,94 @@
+// RUN: mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' --split-input-file | FileCheck %s
+
+#packed_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+// CHECK-LABEL: @test_vector_contract_to_smmla
+
+// Extract LHS rows 0 and 1, concatenate, turn into scalable vector
+// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
+// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8>
+
+// Replicate across the entire length of the scalabale vector
+// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+
+// Same for LHS rows 2 and 4
+// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
+// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8>
+// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+
+// Extract sub-tiles from the RHS
+// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8>
+// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8>
+// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8>
+
+// Extract accumulator rows 0 and 1 and pack (into "registers")
+// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extract...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Momchil Velikov (momchil-velikov) ChangesSupersedes #135359 Patch is 77.36 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/135636.diff 16 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bbba495e613b2..930d8b44abca0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1406,6 +1406,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of ArmSVE dialect while lowering the vector "
"dialect.">,
+ Option<"armI8MM", "enable-arm-i8mm",
+ "bool", /*default=*/"false",
+ "Enables the use of Arm FEAT_I8MM instructions while lowering "
+ "the vector dialect.">,
Option<"x86Vector", "enable-x86vector",
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
index 8665c8224cc45..232e2be29e574 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
@@ -20,6 +20,9 @@ class RewritePatternSet;
void populateArmSVELegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
+void populateLowerContractionToSVEI8MMPatternPatterns(
+ RewritePatternSet &patterns);
+
/// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
/// intrinsics.
void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 330474a718e30..8e2620029c354 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
MLIRVectorToLLVM
MLIRArmNeonDialect
+ MLIRArmNeonTransforms
MLIRArmSVEDialect
MLIRArmSVETransforms
MLIRAMXDialect
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 7082b92c95d1d..1e6c8122b1d0e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorStepLoweringPatterns(patterns);
populateVectorRankReducingFMAPattern(patterns);
populateVectorGatherLoweringPatterns(patterns);
+ if (armI8MM) {
+ if (armNeon)
+ arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
+ if (armSVE)
+ populateLowerContractionToSVEI8MMPatternPatterns(patterns);
+ }
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 2a1271dfd6bdf..e807b233aa7aa 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -56,6 +56,9 @@ class LowerContractionToSMMLAPattern
// Avoid 0-D vectors and 1-D rhs:
if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
return failure();
+ // Avoid scalable vectors.
+ if (lhsType.isScalable() || rhsType.isScalable())
+ return failure();
auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
auto dimN = rhsType.getDimSize(0);
auto dimK = rhsType.getDimSize(1);
@@ -238,5 +241,5 @@ class LowerContractionToSMMLAPattern
void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns(
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
- patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/1);
+ patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/2);
}
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
index a70c489a51fea..65f98b44b1b69 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRArmSVETransforms
LegalizeForLLVMExport.cpp
LegalizeVectorStorage.cpp
+ LowerContractionToSVEI8MMPattern.cpp
DEPENDS
MLIRArmSVEConversionsIncGen
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
new file mode 100644
index 0000000000000..c0620c71440bc
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Check if the given value is a result of the operation `T` (which must be
+// sign- or zero- extend) from i8 to i32. Return the value before the extension.
+template <typename T>
+inline std::enable_if_t<(std::is_base_of_v<arith::ExtSIOp, T> ||
+ std::is_base_of_v<arith::ExtUIOp, T>),
+ std::optional<Value>>
+extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
+ auto extOp = dyn_cast_or_null<T>(v.getDefiningOp());
+ if (!extOp)
+ return {};
+
+ auto inOp = extOp.getIn();
+ auto inTy = dyn_cast<VectorType>(inOp.getType());
+ if (!inTy || inTy.getElementType() != i8Ty)
+ return {};
+
+ auto outTy = dyn_cast<VectorType>(extOp.getType());
+ if (!outTy || outTy.getElementType() != i32Ty)
+ return {};
+
+ return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+ Signed, // smmla
+ Unsigned, // ummla
+ Mixed, // usmmla
+ MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix multply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+ mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
+ switch (op) {
+ case MMLA::Signed:
+ return rewriter.create<arm_sve::SmmlaOp>(loc, accType, acc, lhs, rhs);
+ case MMLA::Unsigned:
+ return rewriter.create<arm_sve::UmmlaOp>(loc, accType, acc, lhs, rhs);
+ case MMLA::Mixed:
+ return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, lhs, rhs);
+ case MMLA::MixedSwapped:
+ // The accumulator comes transposed and the result will be transposed
+ // later, so all we have to do here is swap the operands.
+ return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, rhs, lhs);
+ }
+}
+
+class LowerContractionToSVEI8MMPattern
+ : public OpRewritePattern<vector::ContractionOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override {
+
+ Location loc = op.getLoc();
+ mlir::VectorType lhsType = op.getLhsType();
+ mlir::VectorType rhsType = op.getRhsType();
+
+ // For now handle LHS<Mx8> and RHS<8x[N]> - these are the types we
+ // eventually expect from MMT4D. M and N dimensions must be even and at
+ // least 2.
+ if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() ||
+ rhsType.getRank() != 2)
+ return failure();
+
+ if (lhsType.isScalable() || !rhsType.isScalable())
+ return failure();
+
+ // M, N, and K are the conventional names for matrix dimensions in the
+ // context of matrix multiplication.
+ auto M = lhsType.getDimSize(0);
+ auto N = rhsType.getDimSize(0);
+ auto K = rhsType.getDimSize(1);
+
+ if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 ||
+ N % 2 != 0 || !rhsType.getScalableDims()[0])
+ return failure();
+
+ // Check permutation maps. For now only accept
+ // lhs: (d0, d1, d2) -> (d0, d2)
+ // rhs: (d0, d1, d2) -> (d1, d2)
+ // acc: (d0, d1, d2) -> (d0, d1)
+ // Note: RHS is transposed.
+ if (op.getIndexingMapsArray()[0] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
+ op.getContext()) ||
+ op.getIndexingMapsArray()[1] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
+ op.getContext()) ||
+ op.getIndexingMapsArray()[2] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
+ op.getContext()))
+ return failure();
+
+ // Check iterator types for matrix multiplication.
+ auto itTypes = op.getIteratorTypesArray();
+ if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
+ itTypes[1] != vector::IteratorType::parallel ||
+ itTypes[2] != vector::IteratorType::reduction)
+ return failure();
+
+ // Check the combining kind is addition.
+ if (op.getKind() != vector::CombiningKind::ADD)
+ return failure();
+
+ // Check the output is a vector of i32 elements.
+ auto outTy = dyn_cast<VectorType>(op.getType());
+ if (!outTy || outTy.getElementType() != rewriter.getI32Type())
+ return failure();
+
+ // Check inputs are sign-/zero- extensions from i8 to i32. Get the values
+ // before the extension. All four signed/unsigned combinations for input
+ // operands are supported, but they are lowered to different operations.
+ // Determina which is the appropriate operation to lower to.
+ MMLA mmlaOp = MMLA::Signed;
+ auto maybeLhs = extractExtOperand<arith::ExtSIOp>(
+ op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ if (!maybeLhs) {
+ mmlaOp = MMLA::Unsigned;
+ maybeLhs = extractExtOperand<arith::ExtUIOp>(
+ op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ }
+ if (!maybeLhs)
+ return failure();
+
+ auto maybeRhs = extractExtOperand<arith::ExtSIOp>(
+ op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ if (maybeRhs) {
+ if (mmlaOp == MMLA::Unsigned)
+ mmlaOp = MMLA::Mixed;
+ } else {
+ if (mmlaOp == MMLA::Signed)
+ mmlaOp = MMLA::MixedSwapped;
+ maybeRhs = extractExtOperand<arith::ExtUIOp>(
+ op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ }
+ if (!maybeRhs)
+ return failure();
+
+ // One-dimensional vector types for arm_sve.*mmla
+ auto nxv16i8 = VectorType::get(16, rewriter.getI8Type(), {true});
+ auto nxv4i32 = VectorType::get(4, rewriter.getI32Type(), {true});
+
+ // Extract LHS sub-tiles.
+ SmallVector<Value> lhsTile;
+ for (int64_t i = 0; i < M; i += 2) {
+ // Exract two consective rows of the LHS tile.
+ auto r0 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs,
+ ArrayRef<int64_t>{i});
+ auto r1 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs,
+ ArrayRef<int64_t>{i + 1});
+ // Concatenate to obtain a 16 x i8 flattened sub-tile.
+ auto t = rewriter.create<vector::ShuffleOp>(
+ loc, r0, r1,
+ llvm::ArrayRef<int64_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
+ 14, 15});
+ // Turn it into a scalable vector.
+ auto s = rewriter.create<vector::ScalableInsertOp>(
+ loc, t, rewriter.create<ub::PoisonOp>(loc, nxv16i8), 0);
+ // Replicate the sub-tile VSCALE times to fill the entire vector.
+ auto r = rewriter.create<arm_sve::DupQLaneOp>(loc, s, 0);
+ lhsTile.push_back(r);
+ }
+
+ // "Flatten" the RHS tile from <[N]x8> to <[8*N]>.
+ auto RHS = rewriter.create<vector::ShapeCastOp>(
+ maybeRhs->getLoc(),
+ VectorType::get(8 * N, rewriter.getI8Type(), {true}), *maybeRhs);
+
+ // Extract the RHS sub-tiles.
+ SmallVector<Value> rhsTile;
+ for (int64_t j = 0; j < N; j += 2)
+ rhsTile.push_back(
+ rewriter.create<vector::ScalableExtractOp>(loc, nxv16i8, RHS, j * 8));
+
+ // Handy types for packing/unpacking of the accumulator tile.
+ auto accRowTy = VectorType::get(N, rewriter.getI32Type(), {true});
+ auto accRowX2Ty = VectorType::get(2 * N, rewriter.getI32Type(), {true});
+ auto accRow64Ty = VectorType::get(N / 2, rewriter.getI64Type(), {true});
+ auto accRowX264Ty = VectorType::get(N, rewriter.getI64Type(), {true});
+
+ // Extract and pack the ACC sub-tiles.
+ SmallVector<Value> accTile;
+ for (int64_t i = 0; i < M; i += 2) {
+ // Extract two consecutive rows of the accumulator tile.
+ auto r0 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
+ ArrayRef<int64_t>{i});
+ auto r1 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
+ ArrayRef<int64_t>{i + 1});
+ Value accTileVec;
+ if (mmlaOp == MMLA::MixedSwapped) {
+ // We need to swap the positions of the LHS and RHS (since we don't have
+ // a signed * unsigned operation), but then each individual 2x2 tile of
+ // the acumulator and (later) the result need to be transposed.
+ accTileVec = rewriter.create<vector::InterleaveOp>(loc, r0, r1);
+ } else {
+ // Bitcast them to 64-bit elements, so subsequent
+ // interleave/deinterleave work on pairs of 32-bit numbers.
+ auto r0_i64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r0);
+ auto r1_i64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r1);
+
+ // Interleave the rows, effectively flattening each 2x2 tile into 4
+ // consecutive elements.
+ auto intr_i64 =
+ rewriter.create<vector::InterleaveOp>(loc, r0_i64, r1_i64);
+
+ // Bitcast back to 32-bit elements.
+ accTileVec =
+ rewriter.create<vector::BitCastOp>(loc, accRowX2Ty, intr_i64);
+ }
+ // Extract ACC sub-tiles.
+ for (int64_t j = 0; j < N; j += 2)
+ accTile.push_back(rewriter.create<vector::ScalableExtractOp>(
+ loc, nxv4i32, accTileVec, j * 2));
+ }
+
+ // Emit sub-tile matrix multiplications.
+ SmallVector<Value> outTile;
+ for (int64_t i = 0; i < M / 2; ++i)
+ for (int64_t j = 0; j < N / 2; ++j) {
+ Value mmla = createMMLA(rewriter, mmlaOp, loc, nxv4i32,
+ accTile[i * N / 2 + j], lhsTile[i], rhsTile[j]);
+ outTile.push_back(mmla);
+ }
+
+ // Unpack the OUT sub-tiles and insert into the result.
+ Value result = rewriter.create<ub::PoisonOp>(loc, op.getResultType());
+ for (int64_t i = 0; i < M / 2; ++i) {
+ // Collect a number of sub-tiles in a row.
+ Value row = rewriter.create<ub::PoisonOp>(loc, accRowX2Ty);
+ for (int64_t j = 0; j < N / 2; ++j)
+ row = rewriter.create<vector::ScalableInsertOp>(
+ loc, outTile[i * N / 2 + j], row, j * 4);
+
+ // Unpack the row to obtain two rows of the output. If we have the out
+ // sub-tiles transposed we obtain two consecutive output rows by
+ // separating even and odd elements, i.e. a simple deinterleave.
+ // Otherwise, the interleave is by pairs.
+ Value out0, out1;
+ if (mmlaOp == MMLA::MixedSwapped) {
+ auto tmp = rewriter.create<vector::DeinterleaveOp>(loc, row);
+ out0 = tmp.getRes1();
+ out1 = tmp.getRes2();
+ } else {
+ // Deinterleave by pairs.
+ auto row64 = rewriter.create<vector::BitCastOp>(loc, accRowX264Ty, row);
+ auto deintr64 = rewriter.create<vector::DeinterleaveOp>(loc, row64);
+
+ // Bitcast back into 32-bit elements and insert into the result.
+ out0 = rewriter.create<vector::BitCastOp>(loc, accRowTy,
+ deintr64.getRes1());
+ out1 = rewriter.create<vector::BitCastOp>(loc, accRowTy,
+ deintr64.getRes2());
+ }
+ result = rewriter.create<vector::InsertOp>(loc, out0, result, i * 2);
+ result = rewriter.create<vector::InsertOp>(loc, out1, result, i * 2 + 1);
+ }
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::populateLowerContractionToSVEI8MMPatternPatterns(
+ RewritePatternSet &patterns) {
+ MLIRContext *context = patterns.getContext();
+ patterns.add<LowerContractionToSVEI8MMPattern>(context, /*benefit=*/2);
+}
diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
new file mode 100644
index 0000000000000..2535ee9181c13
--- /dev/null
+++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
@@ -0,0 +1,94 @@
+// RUN: mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' --split-input-file | FileCheck %s
+
+#packed_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+// CHECK-LABEL: @test_vector_contract_to_smmla
+
+// Extract LHS rows 0 and 1, concatenate, turn into scalable vector
+// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
+// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8>
+
+// Replicate across the entire length of the scalabale vector
+// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+
+// Same for LHS rows 2 and 4
+// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
+// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8>
+// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+
+// Extract sub-tiles from the RHS
+// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8>
+// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8>
+// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8>
+
+// Extract accumulator rows 0 and 1 and pack (into "registers")
+// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extract...
[truncated]
|
@llvm/pr-subscribers-mlir-sve Author: Momchil Velikov (momchil-velikov) ChangesSupersedes #135359 Patch is 77.36 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/135636.diff 16 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bbba495e613b2..930d8b44abca0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1406,6 +1406,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of ArmSVE dialect while lowering the vector "
"dialect.">,
+ Option<"armI8MM", "enable-arm-i8mm",
+ "bool", /*default=*/"false",
+ "Enables the use of Arm FEAT_I8MM instructions while lowering "
+ "the vector dialect.">,
Option<"x86Vector", "enable-x86vector",
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
index 8665c8224cc45..232e2be29e574 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
@@ -20,6 +20,9 @@ class RewritePatternSet;
void populateArmSVELegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
+void populateLowerContractionToSVEI8MMPatternPatterns(
+ RewritePatternSet &patterns);
+
/// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
/// intrinsics.
void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 330474a718e30..8e2620029c354 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
MLIRVectorToLLVM
MLIRArmNeonDialect
+ MLIRArmNeonTransforms
MLIRArmSVEDialect
MLIRArmSVETransforms
MLIRAMXDialect
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 7082b92c95d1d..1e6c8122b1d0e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorStepLoweringPatterns(patterns);
populateVectorRankReducingFMAPattern(patterns);
populateVectorGatherLoweringPatterns(patterns);
+ if (armI8MM) {
+ if (armNeon)
+ arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
+ if (armSVE)
+ populateLowerContractionToSVEI8MMPatternPatterns(patterns);
+ }
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 2a1271dfd6bdf..e807b233aa7aa 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -56,6 +56,9 @@ class LowerContractionToSMMLAPattern
// Avoid 0-D vectors and 1-D rhs:
if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
return failure();
+ // Avoid scalable vectors.
+ if (lhsType.isScalable() || rhsType.isScalable())
+ return failure();
auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
auto dimN = rhsType.getDimSize(0);
auto dimK = rhsType.getDimSize(1);
@@ -238,5 +241,5 @@ class LowerContractionToSMMLAPattern
void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns(
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
- patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/1);
+ patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/2);
}
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
index a70c489a51fea..65f98b44b1b69 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRArmSVETransforms
LegalizeForLLVMExport.cpp
LegalizeVectorStorage.cpp
+ LowerContractionToSVEI8MMPattern.cpp
DEPENDS
MLIRArmSVEConversionsIncGen
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
new file mode 100644
index 0000000000000..c0620c71440bc
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Check if the given value is a result of the operation `T` (which must be
+// sign- or zero- extend) from i8 to i32. Return the value before the extension.
+template <typename T>
+inline std::enable_if_t<(std::is_base_of_v<arith::ExtSIOp, T> ||
+ std::is_base_of_v<arith::ExtUIOp, T>),
+ std::optional<Value>>
+extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
+ auto extOp = dyn_cast_or_null<T>(v.getDefiningOp());
+ if (!extOp)
+ return {};
+
+ auto inOp = extOp.getIn();
+ auto inTy = dyn_cast<VectorType>(inOp.getType());
+ if (!inTy || inTy.getElementType() != i8Ty)
+ return {};
+
+ auto outTy = dyn_cast<VectorType>(extOp.getType());
+ if (!outTy || outTy.getElementType() != i32Ty)
+ return {};
+
+ return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+ Signed, // smmla
+ Unsigned, // ummla
+ Mixed, // usmmla
+ MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix multply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+ mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
+ switch (op) {
+ case MMLA::Signed:
+ return rewriter.create<arm_sve::SmmlaOp>(loc, accType, acc, lhs, rhs);
+ case MMLA::Unsigned:
+ return rewriter.create<arm_sve::UmmlaOp>(loc, accType, acc, lhs, rhs);
+ case MMLA::Mixed:
+ return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, lhs, rhs);
+ case MMLA::MixedSwapped:
+ // The accumulator comes transposed and the result will be transposed
+ // later, so all we have to do here is swap the operands.
+ return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, rhs, lhs);
+ }
+}
+
+class LowerContractionToSVEI8MMPattern
+ : public OpRewritePattern<vector::ContractionOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override {
+
+ Location loc = op.getLoc();
+ mlir::VectorType lhsType = op.getLhsType();
+ mlir::VectorType rhsType = op.getRhsType();
+
+ // For now handle LHS<Mx8> and RHS<8x[N]> - these are the types we
+ // eventually expect from MMT4D. M and N dimensions must be even and at
+ // least 2.
+ if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() ||
+ rhsType.getRank() != 2)
+ return failure();
+
+ if (lhsType.isScalable() || !rhsType.isScalable())
+ return failure();
+
+ // M, N, and K are the conventional names for matrix dimensions in the
+ // context of matrix multiplication.
+ auto M = lhsType.getDimSize(0);
+ auto N = rhsType.getDimSize(0);
+ auto K = rhsType.getDimSize(1);
+
+ if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 ||
+ N % 2 != 0 || !rhsType.getScalableDims()[0])
+ return failure();
+
+ // Check permutation maps. For now only accept
+ // lhs: (d0, d1, d2) -> (d0, d2)
+ // rhs: (d0, d1, d2) -> (d1, d2)
+ // acc: (d0, d1, d2) -> (d0, d1)
+ // Note: RHS is transposed.
+ if (op.getIndexingMapsArray()[0] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
+ op.getContext()) ||
+ op.getIndexingMapsArray()[1] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
+ op.getContext()) ||
+ op.getIndexingMapsArray()[2] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
+ op.getContext()))
+ return failure();
+
+ // Check iterator types for matrix multiplication.
+ auto itTypes = op.getIteratorTypesArray();
+ if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
+ itTypes[1] != vector::IteratorType::parallel ||
+ itTypes[2] != vector::IteratorType::reduction)
+ return failure();
+
+ // Check the combining kind is addition.
+ if (op.getKind() != vector::CombiningKind::ADD)
+ return failure();
+
+ // Check the output is a vector of i32 elements.
+ auto outTy = dyn_cast<VectorType>(op.getType());
+ if (!outTy || outTy.getElementType() != rewriter.getI32Type())
+ return failure();
+
+ // Check inputs are sign-/zero- extensions from i8 to i32. Get the values
+ // before the extension. All four signed/unsigned combinations for input
+ // operands are supported, but they are lowered to different operations.
+ // Determina which is the appropriate operation to lower to.
+ MMLA mmlaOp = MMLA::Signed;
+ auto maybeLhs = extractExtOperand<arith::ExtSIOp>(
+ op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ if (!maybeLhs) {
+ mmlaOp = MMLA::Unsigned;
+ maybeLhs = extractExtOperand<arith::ExtUIOp>(
+ op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ }
+ if (!maybeLhs)
+ return failure();
+
+ auto maybeRhs = extractExtOperand<arith::ExtSIOp>(
+ op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ if (maybeRhs) {
+ if (mmlaOp == MMLA::Unsigned)
+ mmlaOp = MMLA::Mixed;
+ } else {
+ if (mmlaOp == MMLA::Signed)
+ mmlaOp = MMLA::MixedSwapped;
+ maybeRhs = extractExtOperand<arith::ExtUIOp>(
+ op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ }
+ if (!maybeRhs)
+ return failure();
+
+ // One-dimensional vector types for arm_sve.*mmla
+ auto nxv16i8 = VectorType::get(16, rewriter.getI8Type(), {true});
+ auto nxv4i32 = VectorType::get(4, rewriter.getI32Type(), {true});
+
+ // Extract LHS sub-tiles.
+ SmallVector<Value> lhsTile;
+ for (int64_t i = 0; i < M; i += 2) {
+ // Exract two consective rows of the LHS tile.
+ auto r0 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs,
+ ArrayRef<int64_t>{i});
+ auto r1 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs,
+ ArrayRef<int64_t>{i + 1});
+ // Concatenate to obtain a 16 x i8 flattened sub-tile.
+ auto t = rewriter.create<vector::ShuffleOp>(
+ loc, r0, r1,
+ llvm::ArrayRef<int64_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
+ 14, 15});
+ // Turn it into a scalable vector.
+ auto s = rewriter.create<vector::ScalableInsertOp>(
+ loc, t, rewriter.create<ub::PoisonOp>(loc, nxv16i8), 0);
+ // Replicate the sub-tile VSCALE times to fill the entire vector.
+ auto r = rewriter.create<arm_sve::DupQLaneOp>(loc, s, 0);
+ lhsTile.push_back(r);
+ }
+
+ // "Flatten" the RHS tile from <[N]x8> to <[8*N]>.
+ auto RHS = rewriter.create<vector::ShapeCastOp>(
+ maybeRhs->getLoc(),
+ VectorType::get(8 * N, rewriter.getI8Type(), {true}), *maybeRhs);
+
+ // Extract the RHS sub-tiles.
+ SmallVector<Value> rhsTile;
+ for (int64_t j = 0; j < N; j += 2)
+ rhsTile.push_back(
+ rewriter.create<vector::ScalableExtractOp>(loc, nxv16i8, RHS, j * 8));
+
+ // Handy types for packing/unpacking of the accumulator tile.
+ auto accRowTy = VectorType::get(N, rewriter.getI32Type(), {true});
+ auto accRowX2Ty = VectorType::get(2 * N, rewriter.getI32Type(), {true});
+ auto accRow64Ty = VectorType::get(N / 2, rewriter.getI64Type(), {true});
+ auto accRowX264Ty = VectorType::get(N, rewriter.getI64Type(), {true});
+
+ // Extract and pack the ACC sub-tiles.
+ SmallVector<Value> accTile;
+ for (int64_t i = 0; i < M; i += 2) {
+ // Extract two consecutive rows of the accumulator tile.
+ auto r0 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
+ ArrayRef<int64_t>{i});
+ auto r1 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
+ ArrayRef<int64_t>{i + 1});
+ Value accTileVec;
+ if (mmlaOp == MMLA::MixedSwapped) {
+ // We need to swap the positions of the LHS and RHS (since we don't have
+ // a signed * unsigned operation), but then each individual 2x2 tile of
+ // the acumulator and (later) the result need to be transposed.
+ accTileVec = rewriter.create<vector::InterleaveOp>(loc, r0, r1);
+ } else {
+ // Bitcast them to 64-bit elements, so subsequent
+ // interleave/deinterleave work on pairs of 32-bit numbers.
+ auto r0_i64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r0);
+ auto r1_i64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r1);
+
+ // Interleave the rows, effectively flattening each 2x2 tile into 4
+ // consecutive elements.
+ auto intr_i64 =
+ rewriter.create<vector::InterleaveOp>(loc, r0_i64, r1_i64);
+
+ // Bitcast back to 32-bit elements.
+ accTileVec =
+ rewriter.create<vector::BitCastOp>(loc, accRowX2Ty, intr_i64);
+ }
+ // Extract ACC sub-tiles.
+ for (int64_t j = 0; j < N; j += 2)
+ accTile.push_back(rewriter.create<vector::ScalableExtractOp>(
+ loc, nxv4i32, accTileVec, j * 2));
+ }
+
+ // Emit sub-tile matrix multiplications.
+ SmallVector<Value> outTile;
+ for (int64_t i = 0; i < M / 2; ++i)
+ for (int64_t j = 0; j < N / 2; ++j) {
+ Value mmla = createMMLA(rewriter, mmlaOp, loc, nxv4i32,
+ accTile[i * N / 2 + j], lhsTile[i], rhsTile[j]);
+ outTile.push_back(mmla);
+ }
+
+ // Unpack the OUT sub-tiles and insert into the result.
+ Value result = rewriter.create<ub::PoisonOp>(loc, op.getResultType());
+ for (int64_t i = 0; i < M / 2; ++i) {
+ // Collect a number of sub-tiles in a row.
+ Value row = rewriter.create<ub::PoisonOp>(loc, accRowX2Ty);
+ for (int64_t j = 0; j < N / 2; ++j)
+ row = rewriter.create<vector::ScalableInsertOp>(
+ loc, outTile[i * N / 2 + j], row, j * 4);
+
+ // Unpack the row to obtain two rows of the output. If we have the out
+ // sub-tiles transposed we obtain two consecutive output rows by
+ // separating even and odd elements, i.e. a simple deinterleave.
+ // Otherwise, the interleave is by pairs.
+ Value out0, out1;
+ if (mmlaOp == MMLA::MixedSwapped) {
+ auto tmp = rewriter.create<vector::DeinterleaveOp>(loc, row);
+ out0 = tmp.getRes1();
+ out1 = tmp.getRes2();
+ } else {
+ // Deinterleave by pairs.
+ auto row64 = rewriter.create<vector::BitCastOp>(loc, accRowX264Ty, row);
+ auto deintr64 = rewriter.create<vector::DeinterleaveOp>(loc, row64);
+
+ // Bitcast back into 32-bit elements and insert into the result.
+ out0 = rewriter.create<vector::BitCastOp>(loc, accRowTy,
+ deintr64.getRes1());
+ out1 = rewriter.create<vector::BitCastOp>(loc, accRowTy,
+ deintr64.getRes2());
+ }
+ result = rewriter.create<vector::InsertOp>(loc, out0, result, i * 2);
+ result = rewriter.create<vector::InsertOp>(loc, out1, result, i * 2 + 1);
+ }
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::populateLowerContractionToSVEI8MMPatternPatterns(
+ RewritePatternSet &patterns) {
+ MLIRContext *context = patterns.getContext();
+ patterns.add<LowerContractionToSVEI8MMPattern>(context, /*benefit=*/2);
+}
diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
new file mode 100644
index 0000000000000..2535ee9181c13
--- /dev/null
+++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
@@ -0,0 +1,94 @@
+// RUN: mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' --split-input-file | FileCheck %s
+
+#packed_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+// CHECK-LABEL: @test_vector_contract_to_smmla
+
+// Extract LHS rows 0 and 1, concatenate, turn into scalable vector
+// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
+// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8>
+
+// Replicate across the entire length of the scalabale vector
+// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+
+// Same for LHS rows 2 and 4
+// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
+// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8>
+// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+
+// Extract sub-tiles from the RHS
+// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8>
+// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8>
+// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8>
+
+// Extract accumulator rows 0 and 1 and pack (into "registers")
+// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extract...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This one is a bit longer, so I may need to wait till Thursday before I can review.
One high-level question - would sharing some code between NEON and SVE be possible?
2e61d3e
to
8e87a7f
Compare
71e2f13
to
5e91c2e
Compare
No, I can't see it happening and resulting in less, or simpler, or easier to maintain code. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Momchil - this is great!
I skimmed through the pattern logic, and it's very neatly written. It's actually quite easy to follow, despite the underlying logic being a bit convoluted - well done! I've left a few minor suggestions, but nothing major.
Also, it seems like we should be able to extend this fairly easily to support NEON as well. Worth thinking about 🙂
Now, overall this patch is quite large, and I’d suggest extracting the end-to-end / integration tests into a separate PR. Additionally, the remaining tests currently use --convert-vector-to-llvm=
, which lowers all the way to LLVM (i.e., it exercises a lot of patterns). Instead, I’d recommend testing LowerContractionToSVEI8MMPattern
in isolation and only verifying that the correct sequence of ArmSVE ops (plus some Vector ops) is generated - for example:
(...)
%33 = arm_sve.smmla %23, %7, %15 : vector<[16]xi8> to vector<[4]xi32>
%34 = arm_sve.smmla %24, %7, %16 : vector<[16]xi8> to vector<[4]xi32>
%35 = arm_sve.smmla %31, %13, %15 : vector<[16]xi8> to vector<[4]xi32>
%36 = arm_sve.smmla %32, %13, %16 : vector<[16]xi8> to vector<[4]xi32>
That way, we will:
- reduce noise in the test output (by focusing on a single pattern),
- simplify expected output (fewer ops to match),
- avoid re-testing functionality already covered elsewhere (e.g.,
arm_sve.smmla
→arm_sve.intr.smmla
lowering).
Btw, this is already looking great, and I know I’m asking for a bit of a rewrite (especially around the tests), but I really think it’ll help with long-term maintainability.
mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really excited to see this! I'll take a look in the next iteration. Thanks!
5e91c2e
to
5282373
Compare
8e87a7f
to
32203d5
Compare
5282373
to
e60ca5a
Compare
32203d5
to
c44b31e
Compare
e60ca5a
to
9eee3ad
Compare
-- come commenting -- replace enable_if with a staic assert -- return reasons for match failures
They'll come in a separate PR.
c44b31e
to
65feafd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few more questions/suggestions inline. In general, I would appreciate a few more comments with ASCII - these tend to be super helpful :) 🙏🏻
One other high-level comment, why not incorporate the test changes from #140572 in this patch? If 140572 implements the end state (that would be my preference) then lets use that. The TD op can still be upstreamed as a separate change.
All in all, this is already very polished and nicely crafted, well done!
@@ -238,5 +241,5 @@ class LowerContractionToSMMLAPattern | |||
void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns( | |||
RewritePatternSet &patterns) { | |||
MLIRContext *context = patterns.getContext(); | |||
patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/1); | |||
patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this change?
@@ -56,6 +56,9 @@ class LowerContractionToSMMLAPattern | |||
// Avoid 0-D vectors and 1-D rhs: | |||
if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2) | |||
return failure(); | |||
// Avoid scalable vectors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why? Could you add a note that there's a separate SVE path implement in the SVE dialect?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eh? This rewrite will not work with scalable vectors regardless of whether there's another rewrite or not.
Should be a separate PR, in principle, but it's hardly worth the faff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me rephrase.
The comment “Avoid scalable vectors.” explains what the code is doing, which is already clear from the code itself. What would be more helpful here is a brief note on why scalable vectors are being avoided. This may be obvious to those familiar with NEON and SVE, but adding a bit of context would make the intent clearer for others as well.
Relatedly, it might be worth mentioning that there's a separate implementation for NEON and SVE. Even within Arm, not everyone is aware that I8MM is available for both.
All I am asking for is a small addition to the comment.
using namespace mlir::arm_sve; | ||
|
||
namespace { | ||
// Get the LHS or RHS side operand of a vector contract. Handle two cases |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] To dissambiguate.
// Get the LHS or RHS side operand of a vector contract. Handle two cases | |
// Get the LHS or RHS side operand of a `vector.contract`. Handle two cases |
// This way we handle both explicit sign- or zero- extension or implicit | ||
// sign-extension. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the absence of explicit extension, why do we assume sign-extension? I am sure there's an obvious explanation, but I don't remember right now 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If operands and the result have types of different bitwidths, operands are |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, thanks for the reminder!
Could you add a reference to the vector.contract
docs? That basically explains "why" sign-extension is assumed. Also, that assumption is bound to evolve in the future and we should make it easy to find places where it informs the actual lowering (this might be the only example). Thanks!
// Check the operands have the expected shape. M and N dimensions must be | ||
// even and at least 2. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and at least 2
Note that this is checking that the rank is 2 rather than any of the dimensions being at least "2"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment applies for code further down.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets just move it closer to the code that it documents.
N % 2 != 0 || !rhsType.getScalableDims()[0]) | ||
return rewriter.notifyMatchFailure(op, "non-matching operand shape"); | ||
|
||
// Check permutation maps. For now only accept |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Add a note that this is a regular matmul with RHS transposed.
} | ||
|
||
// "Flatten" the RHS tile from <[N]x8> to <[8*N]>. | ||
auto RHS = rewriter.create<vector::ShapeCastOp>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like RHS
, but we should stick with the naming convention and use camelBack
. Same for M
, N
and K
:(
outTile.push_back(mmla); | ||
} | ||
|
||
// Unpack the OUT sub-tiles and insert into the result. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you mind adding some ASCII here?
auto accRowX264Ty = VectorType::get(/*shape=*/N, rewriter.getI64Type(), | ||
/*scalableDims=*/{true}); | ||
|
||
// Extract and pack the ACC sub-tiles. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ASCII pls :)
Supersedes #135359