-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[MLIR] Add apply_patterns.vector.arm_sve.lower_contraction TD Op #140572
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: users/momchil-velikov/vector-contract-i8mm
Are you sure you want to change the base?
[MLIR] Add apply_patterns.vector.arm_sve.lower_contraction TD Op #140572
Conversation
@llvm/pr-subscribers-mlir-vector Author: Momchil Velikov (momchil-velikov) ChangesPatch is 73.30 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140572.diff 12 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h
new file mode 100644
index 0000000000000..7f22cd1fe6435
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h
@@ -0,0 +1,31 @@
+//===- ArmSVEVectorTransformOps.h - Vector transform ops --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H
+#define MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H
+
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+
+//===----------------------------------------------------------------------===//
+// ArmSVE Vector Transform Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h.inc"
+
+namespace mlir {
+class DialectRegistry;
+
+namespace arm_sve {
+void registerTransformDialectExtension(DialectRegistry ®istry);
+
+} // namespace arm_sve
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td
new file mode 100644
index 0000000000000..81b59340f3b0d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td
@@ -0,0 +1,26 @@
+//===- ArmSVEVectorTransformOps.td - Arm SVE transform ops--*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef ARMSVE_VECTOR_TRANSFORM_OPS
+#define ARMSVE_VECTOR_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+
+def ApplyArmSVELowerContractionPatternsOp
+ : Op<Transform_Dialect, "apply_patterns.vector.arm_sve.lower_contraction",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector contraction-like operations should be lowered to
+ finer-grained vector primitives using the ArmSVE dialect.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
+#endif // ARMSVE_VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..ce8d8fea7f188
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS ArmSVEVectorTransformOps.td)
+mlir_tablegen(ArmSVEVectorTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(ArmSVEVectorTransformOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRArmSVEVectorTransformOpsIncGen)
+
+add_mlir_doc(ArmSVEVectorTransformOps ArmSVEVectorTransformOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 37e4904cb48ed..767c7099accbb 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -34,6 +34,7 @@
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
+#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
@@ -106,6 +107,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) {
transform::registerLoopExtension(registry);
transform::registerPDLExtension(registry);
vector::registerTransformDialectExtension(registry);
+ arm_sve::registerTransformDialectExtension(registry);
// Translation extensions need to be registered by calling
// `registerAllToLLVMIRTranslations` (see All.h).
diff --git a/mlir/lib/Dialect/ArmSVE/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/lib/Dialect/ArmSVE/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
new file mode 100644
index 0000000000000..b2ca4fc1eaa8c
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
@@ -0,0 +1,54 @@
+//===- ArmSVEVectorTransformOps.cpp - Implementation transform ops -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
+
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Apply...PatternsOp
+//===----------------------------------------------------------------------===//
+
+void transform::ApplyArmSVELowerContractionPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ mlir::populateLowerContractionToSVEI8MMPatternPatterns(patterns);
+}
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class ArmSVEVectorTransformDialectExtension
+ : public transform::TransformDialectExtension<
+ ArmSVEVectorTransformDialectExtension> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ ArmSVEVectorTransformDialectExtension)
+
+ ArmSVEVectorTransformDialectExtension() {
+ declareGeneratedDialect<arm_sve::ArmSVEDialect>();
+ registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp.inc"
+ >();
+ }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp.inc"
+
+void mlir::arm_sve::registerTransformDialectExtension(
+ DialectRegistry ®istry) {
+ registry.addExtensions<ArmSVEVectorTransformDialectExtension>();
+}
diff --git a/mlir/lib/Dialect/ArmSVE/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..8771826e08913
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/TransformOps/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_dialect_library(MLIRArmSVEVectorTransformOps
+ ArmSVEVectorTransformOps.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSVE/TransformOps
+
+ DEPENDS
+ MLIRArmSVEVectorTransformOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRVectorDialect
+ MLIRTransformDialect
+ MLIRArmSVEDialect
+ MLIRArmSVETransforms
+ )
+
\ No newline at end of file
diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
index af0cb37e2d249..3991038761e8d 100644
--- a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
+++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' | FileCheck %s
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
#attrs = {
indexing_maps = [
@@ -12,77 +12,82 @@
// CHECK-LABEL: @test_vector_contract_to_smmla
+// CHECK-SAME: %[[LHS:arg0]]: vector<4x8xi8>
+// CHECK-SAME: %[[RHS:arg1]]: vector<[4]x8xi8>
+// CHECK-SAME: %[[ACC:arg2]]: vector<4x[4]xi32>
+
+// CHECK: [[P0:[0-9]+]] = ub.poison : vector<[8]xi32>
+// CHECK-NEXT: [[P1:[0-9]+]] = ub.poison : vector<4x[4]xi32>
+// CHECK-NEXT: [[P2:[0-9]+]] = ub.poison : vector<[16]xi8>
+
// Extract LHS rows 0 and 1, concatenate, turn into scalable vector
-// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
-// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8>
+// CHECK: %[[T3:[0-9]+]] = vector.extract %[[LHS]][0] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T4:[0-9]+]] = vector.extract %[[LHS]][1] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T5:[0-9]+]] = vector.shuffle %[[T3]], %[[T4]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>, vector<8xi8>
+// CHECK-NEXT: %[[T6:[0-9]+]] = vector.scalable.insert %[[T5]], %[[P2]][0] : vector<16xi8> into vector<[16]xi8>
-// Replicate across the entire length of the scalabale vector
-// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+// Replicate across the entire length of the scalable vector
+// CHECK-NEXT: %[[LHS_0:[0-9]+]] = arm_sve.dupq_lane %[[T6]][0] : vector<[16]xi8>
-// Same for LHS rows 2 and 4
-// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
-// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8>
-// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+// Same for LHS rows 2 and 3
+// CHECK-NEXT: %[[T8:[0-9]+]] = vector.extract %[[LHS]][2] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T9:[0-9]+]] = vector.extract %[[LHS]][3] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T10:[0-9]+]] = vector.shuffle %[[T8]], %[[T9]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>, vector<8xi8>
+// CHECK-NEXT: %[[T11:[0-9]+]] = vector.scalable.insert %[[T10]], %[[P2]][0] : vector<16xi8> into vector<[16]xi8>
+// CHECK-NEXT: %[[LHS_1:[0-9]+]] = arm_sve.dupq_lane %[[T11]][0] : vector<[16]xi8>
// Extract sub-tiles from the RHS
-// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8>
-// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8>
-// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8>
+// CHECK-NEXT: %[[T13:[0-9]+]] = vector.shape_cast %[[RHS]] : vector<[4]x8xi8> to vector<[32]xi8>
+// CHECK-NEXT: %[[RHS_0:[0-9]+]] = vector.scalable.extract %[[T13]][0] : vector<[16]xi8> from vector<[32]xi8>
+// CHECK-NEXT: %[[RHS_1:[0-9]+]] = vector.scalable.extract %[[T13]][16] : vector<[16]xi8> from vector<[32]xi8>
// Extract accumulator rows 0 and 1 and pack (into "registers")
-// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extractvalue %[[T3]][1] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T21:[0-9]+]] = llvm.bitcast %[[T19]] : vector<[4]xi32> to vector<[2]xi64>
-// CHECK-NEXT: %[[T22:[0-9]+]] = llvm.bitcast %[[T20]] : vector<[4]xi32> to vector<[2]xi64>
-// CHECK-NEXT: %[[T23:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T21]], %[[T22]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64>
-// CHECK-NEXT: %[[T24:[0-9]+]] = llvm.bitcast %[[T23]] : vector<[4]xi64> to vector<[8]xi32>
-// CHECK-NEXT: %[[T25:[0-9]+]] = llvm.intr.vector.extract %[[T24]][0] : vector<[4]xi32> from vector<[8]xi32>
-// CHECK-NEXT: %[[T26:[0-9]+]] = llvm.intr.vector.extract %[[T24]][4] : vector<[4]xi32> from vector<[8]xi32>
-
-// Same for accumulator rows 2 and 3.
-// CHECK-NEXT: %[[T27:[0-9]+]] = llvm.extractvalue %[[T3]][2] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T28:[0-9]+]] = llvm.extractvalue %[[T3]][3] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T29:[0-9]+]] = llvm.bitcast %[[T27]] : vector<[4]xi32> to vector<[2]xi64>
-// CHECK-NEXT: %[[T30:[0-9]+]] = llvm.bitcast %[[T28]] : vector<[4]xi32> to vector<[2]xi64>
-// CHECK-NEXT: %[[T31:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T29]], %[[T30]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64>
-// CHECK-NEXT: %[[T32:[0-9]+]] = llvm.bitcast %[[T31]] : vector<[4]xi64> to vector<[8]xi32>
-// CHECK-NEXT: %[[T33:[0-9]+]] = llvm.intr.vector.extract %[[T32]][0] : vector<[4]xi32> from vector<[8]xi32>
-// CHECK-NEXT: %[[T34:[0-9]+]] = llvm.intr.vector.extract %[[T32]][4] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[T16:[0-9]+]] = vector.extract %[[ACC]][0] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T17:[0-9]+]] = vector.extract %[[ACC]][1] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T18:[0-9]+]] = vector.bitcast %[[T16]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T19:[0-9]+]] = vector.bitcast %[[T17]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T20:[0-9]+]] = vector.interleave %[[T18]], %[[T19]] : vector<[2]xi64> -> vector<[4]xi64>
+// CHECK-NEXT: %[[T21:[0-9]+]] = vector.bitcast %[[T20]] : vector<[4]xi64> to vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_0:[0-9]+]] = vector.scalable.extract %[[T21]][0] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_1:[0-9]+]] = vector.scalable.extract %[[T21]][4] : vector<[4]xi32> from vector<[8]xi32>
+
+// Same for accumulator rows 2 and 3
+// CHECK-NEXT: %[[T24:[0-9]+]] = vector.extract %[[ACC]][2] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T25:[0-9]+]] = vector.extract %[[ACC]][3] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T26:[0-9]+]] = vector.bitcast %[[T24]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T27:[0-9]+]] = vector.bitcast %[[T25]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T28:[0-9]+]] = vector.interleave %[[T26]], %[[T27]] : vector<[2]xi64> -> vector<[4]xi64>
+// CHECK-NEXT: %[[T29:[0-9]+]] = vector.bitcast %[[T28]] : vector<[4]xi64> to vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_2:[0-9]+]] = vector.scalable.extract %[[T29]][0] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_3:[0-9]+]] = vector.scalable.extract %[[T29]][4] : vector<[4]xi32> from vector<[8]xi32>
// Do the sub-tile matrix multiplications
-// CHECK-NEXT: %[[T35:[0-9]+]] = "arm_sve.intr.smmla"(%[[T25]], %[[T10]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-// CHECK-NEXT: %[[T36:[0-9]+]] = "arm_sve.intr.smmla"(%[[T26]], %[[T10]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-// CHECK-NEXT: %[[T37:[0-9]+]] = "arm_sve.intr.smmla"(%[[T33]], %[[T15]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-// CHECK-NEXT: %[[T38:[0-9]+]] = "arm_sve.intr.smmla"(%[[T34]], %[[T15]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-
-// Unpack (from "registers") and insert in the output result rows 0 and 1
-// CHECK-NEXT: %[[T39:[0-9]+]] = llvm.intr.vector.insert %[[T35]], %[[T2:[0-9]+]][0] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T40:[0-9]+]] = llvm.intr.vector.insert %[[T36]], %[[T39]][4] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T41:[0-9]+]] = llvm.bitcast %[[T40]] : vector<[8]xi32> to vector<[4]xi64>
-// CHECK-NEXT: %[[T42:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T41]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T43:[0-9]+]] = llvm.extractvalue %[[T42]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T44:[0-9]+]] = llvm.extractvalue %[[T42]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T45:[0-9]+]] = llvm.bitcast %[[T43]] : vector<[2]xi64> to vector<[4]xi32>
-// CHECK-NEXT: %[[T46:[0-9]+]] = llvm.bitcast %[[T44]] : vector<[2]xi64> to vector<[4]xi32>
-// CHECK-NEXT: %[[T47:[0-9]+]] = llvm.insertvalue %[[T45]], %[[T5:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T48:[0-9]+]] = llvm.insertvalue %[[T46]], %[[T47]][1] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[PACK_RES_00:[0-9]+]] = arm_sve.smmla %[[ACC_0]], %[[LHS_0]], %[[RHS_0]] : vector<[16]xi8> to vector<[4]xi32>
+// CHECK-NEXT: %[[PACK_RES_01:[0-9]+]] = arm_sve.smmla %[[ACC_1]], %[[LHS_0]], %[[RHS_1]] : vector<[16]xi8> to vector<[4]xi32>
+// CHECK-NEXT: %[[PACK_RES_10:[0-9]+]] = arm_sve.smmla %[[ACC_2]], %[[LHS_1]], %[[RHS_0]] : vector<[16]xi8> to vector<[4]xi32>
+// CHECK-NEXT: %[[PACK_RES_11:[0-9]+]] = arm_sve.smmla %[[ACC_3]], %[[LHS_1]], %[[RHS_1]] : vector<[16]xi8> to vector<[4]xi32>
+
+// Unpack (from "registers") and insert in the output result rows 0 and 1
+// CHECK-NEXT: %[[T36:[0-9]+]] = vector.scalable.insert %[[PACK_RES_00]], %[[P0]][0] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T37:[0-9]+]] = vector.scalable.insert %[[PACK_RES_01]], %[[T36]][4] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T38:[0-9]+]] = vector.bitcast %[[T37]] : vector<[8]xi32> to vector<[4]xi64>
+// CHECK-NEXT: %res1, %res2 = vector.deinterleave %[[T38]] : vector<[4]xi64> -> vector<[2]xi64>
+// CHECK-NEXT: %[[UNPACK_RES_0:[0-9]+]] = vector.bitcast %res1 : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[UNPACK_RES_1:[0-9]+]] = vector.bitcast %res2 : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[TMP_OUT_0:[0-9]+]] = vector.insert %[[UNPACK_RES_0]], %[[P1]] [0] : vector<[4]xi32> into vector<4x[4]xi32>
+// CHECK-NEXT: %[[TMP_OUT_1:[0-9]+]] = vector.insert %[[UNPACK_RES_1]], %[[TMP_OUT_0]] [1] : vector<[4]xi32> into vector<4x[4]xi32>
// Same for result rows 2 and 3
-// CHECK-NEXT: %[[T49:[0-9]+]] = llvm.intr.vector.insert %[[T37]], %[[T2]][0] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T50:[0-9]+]] = llvm.intr.vector.insert %[[T38]], %[[T49]][4] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T51:[0-9]+]] = llvm.bitcast %[[T50]] : vector<[8]xi32> to vector<[4]xi64>
-// CHECK-NEXT: %[[T52:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T51]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T53:[0-9]+]] = llvm.extractvalue %[[T52]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T54:[0-9]+]] = llvm.extractvalue %[[T52]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T55:[0-9]+]] = llvm.bitcast %[[T53]] : vector<[2]xi64> to vector<[4]xi32>
-// CHECK-NEXT: %[[T56:[0-9]+]] = llvm.bitcast %[[T54]] : vector<[2]xi64> to vector<[4]xi32>
-// CHECK-NEXT: %[[T57:[0-9]+]] = llvm.insertvalue %[[T55]], %[[T48]][2] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T58:[0-9]+]] = llvm.insertvalue %[[T56]], %[[T57]][3] : !llvm.array<4 x vector<[4]xi32>>
-
+// CHECK-NEXT: %[[T43...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Momchil Velikov (momchil-velikov) ChangesPatch is 73.30 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140572.diff 12 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h
new file mode 100644
index 0000000000000..7f22cd1fe6435
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h
@@ -0,0 +1,31 @@
+//===- ArmSVEVectorTransformOps.h - Vector transform ops --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H
+#define MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H
+
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+
+//===----------------------------------------------------------------------===//
+// ArmSVE Vector Transform Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h.inc"
+
+namespace mlir {
+class DialectRegistry;
+
+namespace arm_sve {
+void registerTransformDialectExtension(DialectRegistry ®istry);
+
+} // namespace arm_sve
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td
new file mode 100644
index 0000000000000..81b59340f3b0d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td
@@ -0,0 +1,26 @@
+//===- ArmSVEVectorTransformOps.td - Arm SVE transform ops--*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef ARMSVE_VECTOR_TRANSFORM_OPS
+#define ARMSVE_VECTOR_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+
+def ApplyArmSVELowerContractionPatternsOp
+ : Op<Transform_Dialect, "apply_patterns.vector.arm_sve.lower_contraction",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector contraction-like operations should be lowered to
+ finer-grained vector primitives using the ArmSVE dialect.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
+#endif // ARMSVE_VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..ce8d8fea7f188
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS ArmSVEVectorTransformOps.td)
+mlir_tablegen(ArmSVEVectorTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(ArmSVEVectorTransformOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRArmSVEVectorTransformOpsIncGen)
+
+add_mlir_doc(ArmSVEVectorTransformOps ArmSVEVectorTransformOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 37e4904cb48ed..767c7099accbb 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -34,6 +34,7 @@
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
+#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
@@ -106,6 +107,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) {
transform::registerLoopExtension(registry);
transform::registerPDLExtension(registry);
vector::registerTransformDialectExtension(registry);
+ arm_sve::registerTransformDialectExtension(registry);
// Translation extensions need to be registered by calling
// `registerAllToLLVMIRTranslations` (see All.h).
diff --git a/mlir/lib/Dialect/ArmSVE/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/lib/Dialect/ArmSVE/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
new file mode 100644
index 0000000000000..b2ca4fc1eaa8c
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
@@ -0,0 +1,54 @@
+//===- ArmSVEVectorTransformOps.cpp - Implementation transform ops -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
+
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Apply...PatternsOp
+//===----------------------------------------------------------------------===//
+
+void transform::ApplyArmSVELowerContractionPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ mlir::populateLowerContractionToSVEI8MMPatternPatterns(patterns);
+}
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class ArmSVEVectorTransformDialectExtension
+ : public transform::TransformDialectExtension<
+ ArmSVEVectorTransformDialectExtension> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ ArmSVEVectorTransformDialectExtension)
+
+ ArmSVEVectorTransformDialectExtension() {
+ declareGeneratedDialect<arm_sve::ArmSVEDialect>();
+ registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp.inc"
+ >();
+ }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp.inc"
+
+void mlir::arm_sve::registerTransformDialectExtension(
+ DialectRegistry ®istry) {
+ registry.addExtensions<ArmSVEVectorTransformDialectExtension>();
+}
diff --git a/mlir/lib/Dialect/ArmSVE/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..8771826e08913
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/TransformOps/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_dialect_library(MLIRArmSVEVectorTransformOps
+ ArmSVEVectorTransformOps.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSVE/TransformOps
+
+ DEPENDS
+ MLIRArmSVEVectorTransformOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRVectorDialect
+ MLIRTransformDialect
+ MLIRArmSVEDialect
+ MLIRArmSVETransforms
+ )
+
\ No newline at end of file
diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
index af0cb37e2d249..3991038761e8d 100644
--- a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
+++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' | FileCheck %s
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
#attrs = {
indexing_maps = [
@@ -12,77 +12,82 @@
// CHECK-LABEL: @test_vector_contract_to_smmla
+// CHECK-SAME: %[[LHS:arg0]]: vector<4x8xi8>
+// CHECK-SAME: %[[RHS:arg1]]: vector<[4]x8xi8>
+// CHECK-SAME: %[[ACC:arg2]]: vector<4x[4]xi32>
+
+// CHECK: [[P0:[0-9]+]] = ub.poison : vector<[8]xi32>
+// CHECK-NEXT: [[P1:[0-9]+]] = ub.poison : vector<4x[4]xi32>
+// CHECK-NEXT: [[P2:[0-9]+]] = ub.poison : vector<[16]xi8>
+
// Extract LHS rows 0 and 1, concatenate, turn into scalable vector
-// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
-// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8>
+// CHECK: %[[T3:[0-9]+]] = vector.extract %[[LHS]][0] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T4:[0-9]+]] = vector.extract %[[LHS]][1] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T5:[0-9]+]] = vector.shuffle %[[T3]], %[[T4]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>, vector<8xi8>
+// CHECK-NEXT: %[[T6:[0-9]+]] = vector.scalable.insert %[[T5]], %[[P2]][0] : vector<16xi8> into vector<[16]xi8>
-// Replicate across the entire length of the scalabale vector
-// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+// Replicate across the entire length of the scalable vector
+// CHECK-NEXT: %[[LHS_0:[0-9]+]] = arm_sve.dupq_lane %[[T6]][0] : vector<[16]xi8>
-// Same for LHS rows 2 and 4
-// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
-// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8>
-// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+// Same for LHS rows 2 and 3
+// CHECK-NEXT: %[[T8:[0-9]+]] = vector.extract %[[LHS]][2] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T9:[0-9]+]] = vector.extract %[[LHS]][3] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T10:[0-9]+]] = vector.shuffle %[[T8]], %[[T9]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>, vector<8xi8>
+// CHECK-NEXT: %[[T11:[0-9]+]] = vector.scalable.insert %[[T10]], %[[P2]][0] : vector<16xi8> into vector<[16]xi8>
+// CHECK-NEXT: %[[LHS_1:[0-9]+]] = arm_sve.dupq_lane %[[T11]][0] : vector<[16]xi8>
// Extract sub-tiles from the RHS
-// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8>
-// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8>
-// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8>
+// CHECK-NEXT: %[[T13:[0-9]+]] = vector.shape_cast %[[RHS]] : vector<[4]x8xi8> to vector<[32]xi8>
+// CHECK-NEXT: %[[RHS_0:[0-9]+]] = vector.scalable.extract %[[T13]][0] : vector<[16]xi8> from vector<[32]xi8>
+// CHECK-NEXT: %[[RHS_1:[0-9]+]] = vector.scalable.extract %[[T13]][16] : vector<[16]xi8> from vector<[32]xi8>
// Extract accumulator rows 0 and 1 and pack (into "registers")
-// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extractvalue %[[T3]][1] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T21:[0-9]+]] = llvm.bitcast %[[T19]] : vector<[4]xi32> to vector<[2]xi64>
-// CHECK-NEXT: %[[T22:[0-9]+]] = llvm.bitcast %[[T20]] : vector<[4]xi32> to vector<[2]xi64>
-// CHECK-NEXT: %[[T23:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T21]], %[[T22]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64>
-// CHECK-NEXT: %[[T24:[0-9]+]] = llvm.bitcast %[[T23]] : vector<[4]xi64> to vector<[8]xi32>
-// CHECK-NEXT: %[[T25:[0-9]+]] = llvm.intr.vector.extract %[[T24]][0] : vector<[4]xi32> from vector<[8]xi32>
-// CHECK-NEXT: %[[T26:[0-9]+]] = llvm.intr.vector.extract %[[T24]][4] : vector<[4]xi32> from vector<[8]xi32>
-
-// Same for accumulator rows 2 and 3.
-// CHECK-NEXT: %[[T27:[0-9]+]] = llvm.extractvalue %[[T3]][2] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T28:[0-9]+]] = llvm.extractvalue %[[T3]][3] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T29:[0-9]+]] = llvm.bitcast %[[T27]] : vector<[4]xi32> to vector<[2]xi64>
-// CHECK-NEXT: %[[T30:[0-9]+]] = llvm.bitcast %[[T28]] : vector<[4]xi32> to vector<[2]xi64>
-// CHECK-NEXT: %[[T31:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T29]], %[[T30]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64>
-// CHECK-NEXT: %[[T32:[0-9]+]] = llvm.bitcast %[[T31]] : vector<[4]xi64> to vector<[8]xi32>
-// CHECK-NEXT: %[[T33:[0-9]+]] = llvm.intr.vector.extract %[[T32]][0] : vector<[4]xi32> from vector<[8]xi32>
-// CHECK-NEXT: %[[T34:[0-9]+]] = llvm.intr.vector.extract %[[T32]][4] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[T16:[0-9]+]] = vector.extract %[[ACC]][0] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T17:[0-9]+]] = vector.extract %[[ACC]][1] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T18:[0-9]+]] = vector.bitcast %[[T16]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T19:[0-9]+]] = vector.bitcast %[[T17]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T20:[0-9]+]] = vector.interleave %[[T18]], %[[T19]] : vector<[2]xi64> -> vector<[4]xi64>
+// CHECK-NEXT: %[[T21:[0-9]+]] = vector.bitcast %[[T20]] : vector<[4]xi64> to vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_0:[0-9]+]] = vector.scalable.extract %[[T21]][0] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_1:[0-9]+]] = vector.scalable.extract %[[T21]][4] : vector<[4]xi32> from vector<[8]xi32>
+
+// Same for accumulator rows 2 and 3
+// CHECK-NEXT: %[[T24:[0-9]+]] = vector.extract %[[ACC]][2] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T25:[0-9]+]] = vector.extract %[[ACC]][3] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T26:[0-9]+]] = vector.bitcast %[[T24]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T27:[0-9]+]] = vector.bitcast %[[T25]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T28:[0-9]+]] = vector.interleave %[[T26]], %[[T27]] : vector<[2]xi64> -> vector<[4]xi64>
+// CHECK-NEXT: %[[T29:[0-9]+]] = vector.bitcast %[[T28]] : vector<[4]xi64> to vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_2:[0-9]+]] = vector.scalable.extract %[[T29]][0] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_3:[0-9]+]] = vector.scalable.extract %[[T29]][4] : vector<[4]xi32> from vector<[8]xi32>
// Do the sub-tile matrix multiplications
-// CHECK-NEXT: %[[T35:[0-9]+]] = "arm_sve.intr.smmla"(%[[T25]], %[[T10]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-// CHECK-NEXT: %[[T36:[0-9]+]] = "arm_sve.intr.smmla"(%[[T26]], %[[T10]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-// CHECK-NEXT: %[[T37:[0-9]+]] = "arm_sve.intr.smmla"(%[[T33]], %[[T15]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-// CHECK-NEXT: %[[T38:[0-9]+]] = "arm_sve.intr.smmla"(%[[T34]], %[[T15]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-
-// Unpack (from "registers") and insert in the output result rows 0 and 1
-// CHECK-NEXT: %[[T39:[0-9]+]] = llvm.intr.vector.insert %[[T35]], %[[T2:[0-9]+]][0] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T40:[0-9]+]] = llvm.intr.vector.insert %[[T36]], %[[T39]][4] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T41:[0-9]+]] = llvm.bitcast %[[T40]] : vector<[8]xi32> to vector<[4]xi64>
-// CHECK-NEXT: %[[T42:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T41]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T43:[0-9]+]] = llvm.extractvalue %[[T42]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T44:[0-9]+]] = llvm.extractvalue %[[T42]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T45:[0-9]+]] = llvm.bitcast %[[T43]] : vector<[2]xi64> to vector<[4]xi32>
-// CHECK-NEXT: %[[T46:[0-9]+]] = llvm.bitcast %[[T44]] : vector<[2]xi64> to vector<[4]xi32>
-// CHECK-NEXT: %[[T47:[0-9]+]] = llvm.insertvalue %[[T45]], %[[T5:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T48:[0-9]+]] = llvm.insertvalue %[[T46]], %[[T47]][1] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[PACK_RES_00:[0-9]+]] = arm_sve.smmla %[[ACC_0]], %[[LHS_0]], %[[RHS_0]] : vector<[16]xi8> to vector<[4]xi32>
+// CHECK-NEXT: %[[PACK_RES_01:[0-9]+]] = arm_sve.smmla %[[ACC_1]], %[[LHS_0]], %[[RHS_1]] : vector<[16]xi8> to vector<[4]xi32>
+// CHECK-NEXT: %[[PACK_RES_10:[0-9]+]] = arm_sve.smmla %[[ACC_2]], %[[LHS_1]], %[[RHS_0]] : vector<[16]xi8> to vector<[4]xi32>
+// CHECK-NEXT: %[[PACK_RES_11:[0-9]+]] = arm_sve.smmla %[[ACC_3]], %[[LHS_1]], %[[RHS_1]] : vector<[16]xi8> to vector<[4]xi32>
+
+// Unpack (from "registers") and insert in the output result rows 0 and 1
+// CHECK-NEXT: %[[T36:[0-9]+]] = vector.scalable.insert %[[PACK_RES_00]], %[[P0]][0] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T37:[0-9]+]] = vector.scalable.insert %[[PACK_RES_01]], %[[T36]][4] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T38:[0-9]+]] = vector.bitcast %[[T37]] : vector<[8]xi32> to vector<[4]xi64>
+// CHECK-NEXT: %res1, %res2 = vector.deinterleave %[[T38]] : vector<[4]xi64> -> vector<[2]xi64>
+// CHECK-NEXT: %[[UNPACK_RES_0:[0-9]+]] = vector.bitcast %res1 : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[UNPACK_RES_1:[0-9]+]] = vector.bitcast %res2 : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[TMP_OUT_0:[0-9]+]] = vector.insert %[[UNPACK_RES_0]], %[[P1]] [0] : vector<[4]xi32> into vector<4x[4]xi32>
+// CHECK-NEXT: %[[TMP_OUT_1:[0-9]+]] = vector.insert %[[UNPACK_RES_1]], %[[TMP_OUT_0]] [1] : vector<[4]xi32> into vector<4x[4]xi32>
// Same for result rows 2 and 3
-// CHECK-NEXT: %[[T49:[0-9]+]] = llvm.intr.vector.insert %[[T37]], %[[T2]][0] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T50:[0-9]+]] = llvm.intr.vector.insert %[[T38]], %[[T49]][4] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T51:[0-9]+]] = llvm.bitcast %[[T50]] : vector<[8]xi32> to vector<[4]xi64>
-// CHECK-NEXT: %[[T52:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T51]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T53:[0-9]+]] = llvm.extractvalue %[[T52]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T54:[0-9]+]] = llvm.extractvalue %[[T52]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T55:[0-9]+]] = llvm.bitcast %[[T53]] : vector<[2]xi64> to vector<[4]xi32>
-// CHECK-NEXT: %[[T56:[0-9]+]] = llvm.bitcast %[[T54]] : vector<[2]xi64> to vector<[4]xi32>
-// CHECK-NEXT: %[[T57:[0-9]+]] = llvm.insertvalue %[[T55]], %[[T48]][2] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T58:[0-9]+]] = llvm.insertvalue %[[T56]], %[[T57]][3] : !llvm.array<4 x vector<[4]xi32>>
-
+// CHECK-NEXT: %[[T43...
[truncated]
|
@llvm/pr-subscribers-mlir-sve Author: Momchil Velikov (momchil-velikov) ChangesPatch is 73.30 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140572.diff 12 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h
new file mode 100644
index 0000000000000..7f22cd1fe6435
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h
@@ -0,0 +1,31 @@
+//===- ArmSVEVectorTransformOps.h - Vector transform ops --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H
+#define MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H
+
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+
+//===----------------------------------------------------------------------===//
+// ArmSVE Vector Transform Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h.inc"
+
+namespace mlir {
+class DialectRegistry;
+
+namespace arm_sve {
+void registerTransformDialectExtension(DialectRegistry ®istry);
+
+} // namespace arm_sve
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td
new file mode 100644
index 0000000000000..81b59340f3b0d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td
@@ -0,0 +1,26 @@
+//===- ArmSVEVectorTransformOps.td - Arm SVE transform ops--*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef ARMSVE_VECTOR_TRANSFORM_OPS
+#define ARMSVE_VECTOR_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+
+def ApplyArmSVELowerContractionPatternsOp
+ : Op<Transform_Dialect, "apply_patterns.vector.arm_sve.lower_contraction",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector contraction-like operations should be lowered to
+ finer-grained vector primitives using the ArmSVE dialect.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
+#endif // ARMSVE_VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..ce8d8fea7f188
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS ArmSVEVectorTransformOps.td)
+mlir_tablegen(ArmSVEVectorTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(ArmSVEVectorTransformOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRArmSVEVectorTransformOpsIncGen)
+
+add_mlir_doc(ArmSVEVectorTransformOps ArmSVEVectorTransformOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 37e4904cb48ed..767c7099accbb 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -34,6 +34,7 @@
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
+#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
@@ -106,6 +107,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) {
transform::registerLoopExtension(registry);
transform::registerPDLExtension(registry);
vector::registerTransformDialectExtension(registry);
+ arm_sve::registerTransformDialectExtension(registry);
// Translation extensions need to be registered by calling
// `registerAllToLLVMIRTranslations` (see All.h).
diff --git a/mlir/lib/Dialect/ArmSVE/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/lib/Dialect/ArmSVE/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
new file mode 100644
index 0000000000000..b2ca4fc1eaa8c
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
@@ -0,0 +1,54 @@
+//===- ArmSVEVectorTransformOps.cpp - Implementation transform ops -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
+
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Apply...PatternsOp
+//===----------------------------------------------------------------------===//
+
+void transform::ApplyArmSVELowerContractionPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ mlir::populateLowerContractionToSVEI8MMPatternPatterns(patterns);
+}
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class ArmSVEVectorTransformDialectExtension
+ : public transform::TransformDialectExtension<
+ ArmSVEVectorTransformDialectExtension> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ ArmSVEVectorTransformDialectExtension)
+
+ ArmSVEVectorTransformDialectExtension() {
+ declareGeneratedDialect<arm_sve::ArmSVEDialect>();
+ registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp.inc"
+ >();
+ }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp.inc"
+
+void mlir::arm_sve::registerTransformDialectExtension(
+ DialectRegistry ®istry) {
+ registry.addExtensions<ArmSVEVectorTransformDialectExtension>();
+}
diff --git a/mlir/lib/Dialect/ArmSVE/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..8771826e08913
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/TransformOps/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_dialect_library(MLIRArmSVEVectorTransformOps
+ ArmSVEVectorTransformOps.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSVE/TransformOps
+
+ DEPENDS
+ MLIRArmSVEVectorTransformOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRVectorDialect
+ MLIRTransformDialect
+ MLIRArmSVEDialect
+ MLIRArmSVETransforms
+ )
+
\ No newline at end of file
diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
index af0cb37e2d249..3991038761e8d 100644
--- a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
+++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' | FileCheck %s
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
#attrs = {
indexing_maps = [
@@ -12,77 +12,82 @@
// CHECK-LABEL: @test_vector_contract_to_smmla
+// CHECK-SAME: %[[LHS:arg0]]: vector<4x8xi8>
+// CHECK-SAME: %[[RHS:arg1]]: vector<[4]x8xi8>
+// CHECK-SAME: %[[ACC:arg2]]: vector<4x[4]xi32>
+
+// CHECK: [[P0:[0-9]+]] = ub.poison : vector<[8]xi32>
+// CHECK-NEXT: [[P1:[0-9]+]] = ub.poison : vector<4x[4]xi32>
+// CHECK-NEXT: [[P2:[0-9]+]] = ub.poison : vector<[16]xi8>
+
// Extract LHS rows 0 and 1, concatenate, turn into scalable vector
-// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
-// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8>
+// CHECK: %[[T3:[0-9]+]] = vector.extract %[[LHS]][0] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T4:[0-9]+]] = vector.extract %[[LHS]][1] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T5:[0-9]+]] = vector.shuffle %[[T3]], %[[T4]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>, vector<8xi8>
+// CHECK-NEXT: %[[T6:[0-9]+]] = vector.scalable.insert %[[T5]], %[[P2]][0] : vector<16xi8> into vector<[16]xi8>
-// Replicate across the entire length of the scalabale vector
-// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+// Replicate across the entire length of the scalable vector
+// CHECK-NEXT: %[[LHS_0:[0-9]+]] = arm_sve.dupq_lane %[[T6]][0] : vector<[16]xi8>
-// Same for LHS rows 2 and 4
-// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
-// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8>
-// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+// Same for LHS rows 2 and 3
+// CHECK-NEXT: %[[T8:[0-9]+]] = vector.extract %[[LHS]][2] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T9:[0-9]+]] = vector.extract %[[LHS]][3] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T10:[0-9]+]] = vector.shuffle %[[T8]], %[[T9]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>, vector<8xi8>
+// CHECK-NEXT: %[[T11:[0-9]+]] = vector.scalable.insert %[[T10]], %[[P2]][0] : vector<16xi8> into vector<[16]xi8>
+// CHECK-NEXT: %[[LHS_1:[0-9]+]] = arm_sve.dupq_lane %[[T11]][0] : vector<[16]xi8>
// Extract sub-tiles from the RHS
-// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8>
-// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8>
-// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8>
+// CHECK-NEXT: %[[T13:[0-9]+]] = vector.shape_cast %[[RHS]] : vector<[4]x8xi8> to vector<[32]xi8>
+// CHECK-NEXT: %[[RHS_0:[0-9]+]] = vector.scalable.extract %[[T13]][0] : vector<[16]xi8> from vector<[32]xi8>
+// CHECK-NEXT: %[[RHS_1:[0-9]+]] = vector.scalable.extract %[[T13]][16] : vector<[16]xi8> from vector<[32]xi8>
// Extract accumulator rows 0 and 1 and pack (into "registers")
-// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extractvalue %[[T3]][1] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T21:[0-9]+]] = llvm.bitcast %[[T19]] : vector<[4]xi32> to vector<[2]xi64>
-// CHECK-NEXT: %[[T22:[0-9]+]] = llvm.bitcast %[[T20]] : vector<[4]xi32> to vector<[2]xi64>
-// CHECK-NEXT: %[[T23:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T21]], %[[T22]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64>
-// CHECK-NEXT: %[[T24:[0-9]+]] = llvm.bitcast %[[T23]] : vector<[4]xi64> to vector<[8]xi32>
-// CHECK-NEXT: %[[T25:[0-9]+]] = llvm.intr.vector.extract %[[T24]][0] : vector<[4]xi32> from vector<[8]xi32>
-// CHECK-NEXT: %[[T26:[0-9]+]] = llvm.intr.vector.extract %[[T24]][4] : vector<[4]xi32> from vector<[8]xi32>
-
-// Same for accumulator rows 2 and 3.
-// CHECK-NEXT: %[[T27:[0-9]+]] = llvm.extractvalue %[[T3]][2] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T28:[0-9]+]] = llvm.extractvalue %[[T3]][3] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T29:[0-9]+]] = llvm.bitcast %[[T27]] : vector<[4]xi32> to vector<[2]xi64>
-// CHECK-NEXT: %[[T30:[0-9]+]] = llvm.bitcast %[[T28]] : vector<[4]xi32> to vector<[2]xi64>
-// CHECK-NEXT: %[[T31:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T29]], %[[T30]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64>
-// CHECK-NEXT: %[[T32:[0-9]+]] = llvm.bitcast %[[T31]] : vector<[4]xi64> to vector<[8]xi32>
-// CHECK-NEXT: %[[T33:[0-9]+]] = llvm.intr.vector.extract %[[T32]][0] : vector<[4]xi32> from vector<[8]xi32>
-// CHECK-NEXT: %[[T34:[0-9]+]] = llvm.intr.vector.extract %[[T32]][4] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[T16:[0-9]+]] = vector.extract %[[ACC]][0] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T17:[0-9]+]] = vector.extract %[[ACC]][1] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T18:[0-9]+]] = vector.bitcast %[[T16]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T19:[0-9]+]] = vector.bitcast %[[T17]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T20:[0-9]+]] = vector.interleave %[[T18]], %[[T19]] : vector<[2]xi64> -> vector<[4]xi64>
+// CHECK-NEXT: %[[T21:[0-9]+]] = vector.bitcast %[[T20]] : vector<[4]xi64> to vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_0:[0-9]+]] = vector.scalable.extract %[[T21]][0] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_1:[0-9]+]] = vector.scalable.extract %[[T21]][4] : vector<[4]xi32> from vector<[8]xi32>
+
+// Same for accumulator rows 2 and 3
+// CHECK-NEXT: %[[T24:[0-9]+]] = vector.extract %[[ACC]][2] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T25:[0-9]+]] = vector.extract %[[ACC]][3] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T26:[0-9]+]] = vector.bitcast %[[T24]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T27:[0-9]+]] = vector.bitcast %[[T25]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T28:[0-9]+]] = vector.interleave %[[T26]], %[[T27]] : vector<[2]xi64> -> vector<[4]xi64>
+// CHECK-NEXT: %[[T29:[0-9]+]] = vector.bitcast %[[T28]] : vector<[4]xi64> to vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_2:[0-9]+]] = vector.scalable.extract %[[T29]][0] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_3:[0-9]+]] = vector.scalable.extract %[[T29]][4] : vector<[4]xi32> from vector<[8]xi32>
// Do the sub-tile matrix multiplications
-// CHECK-NEXT: %[[T35:[0-9]+]] = "arm_sve.intr.smmla"(%[[T25]], %[[T10]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-// CHECK-NEXT: %[[T36:[0-9]+]] = "arm_sve.intr.smmla"(%[[T26]], %[[T10]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-// CHECK-NEXT: %[[T37:[0-9]+]] = "arm_sve.intr.smmla"(%[[T33]], %[[T15]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-// CHECK-NEXT: %[[T38:[0-9]+]] = "arm_sve.intr.smmla"(%[[T34]], %[[T15]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-
-// Unpack (from "registers") and insert in the output result rows 0 and 1
-// CHECK-NEXT: %[[T39:[0-9]+]] = llvm.intr.vector.insert %[[T35]], %[[T2:[0-9]+]][0] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T40:[0-9]+]] = llvm.intr.vector.insert %[[T36]], %[[T39]][4] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T41:[0-9]+]] = llvm.bitcast %[[T40]] : vector<[8]xi32> to vector<[4]xi64>
-// CHECK-NEXT: %[[T42:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T41]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T43:[0-9]+]] = llvm.extractvalue %[[T42]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T44:[0-9]+]] = llvm.extractvalue %[[T42]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T45:[0-9]+]] = llvm.bitcast %[[T43]] : vector<[2]xi64> to vector<[4]xi32>
-// CHECK-NEXT: %[[T46:[0-9]+]] = llvm.bitcast %[[T44]] : vector<[2]xi64> to vector<[4]xi32>
-// CHECK-NEXT: %[[T47:[0-9]+]] = llvm.insertvalue %[[T45]], %[[T5:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T48:[0-9]+]] = llvm.insertvalue %[[T46]], %[[T47]][1] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[PACK_RES_00:[0-9]+]] = arm_sve.smmla %[[ACC_0]], %[[LHS_0]], %[[RHS_0]] : vector<[16]xi8> to vector<[4]xi32>
+// CHECK-NEXT: %[[PACK_RES_01:[0-9]+]] = arm_sve.smmla %[[ACC_1]], %[[LHS_0]], %[[RHS_1]] : vector<[16]xi8> to vector<[4]xi32>
+// CHECK-NEXT: %[[PACK_RES_10:[0-9]+]] = arm_sve.smmla %[[ACC_2]], %[[LHS_1]], %[[RHS_0]] : vector<[16]xi8> to vector<[4]xi32>
+// CHECK-NEXT: %[[PACK_RES_11:[0-9]+]] = arm_sve.smmla %[[ACC_3]], %[[LHS_1]], %[[RHS_1]] : vector<[16]xi8> to vector<[4]xi32>
+
+// Unpack (from "registers") and insert in the output result rows 0 and 1
+// CHECK-NEXT: %[[T36:[0-9]+]] = vector.scalable.insert %[[PACK_RES_00]], %[[P0]][0] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T37:[0-9]+]] = vector.scalable.insert %[[PACK_RES_01]], %[[T36]][4] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T38:[0-9]+]] = vector.bitcast %[[T37]] : vector<[8]xi32> to vector<[4]xi64>
+// CHECK-NEXT: %res1, %res2 = vector.deinterleave %[[T38]] : vector<[4]xi64> -> vector<[2]xi64>
+// CHECK-NEXT: %[[UNPACK_RES_0:[0-9]+]] = vector.bitcast %res1 : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[UNPACK_RES_1:[0-9]+]] = vector.bitcast %res2 : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[TMP_OUT_0:[0-9]+]] = vector.insert %[[UNPACK_RES_0]], %[[P1]] [0] : vector<[4]xi32> into vector<4x[4]xi32>
+// CHECK-NEXT: %[[TMP_OUT_1:[0-9]+]] = vector.insert %[[UNPACK_RES_1]], %[[TMP_OUT_0]] [1] : vector<[4]xi32> into vector<4x[4]xi32>
// Same for result rows 2 and 3
-// CHECK-NEXT: %[[T49:[0-9]+]] = llvm.intr.vector.insert %[[T37]], %[[T2]][0] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T50:[0-9]+]] = llvm.intr.vector.insert %[[T38]], %[[T49]][4] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T51:[0-9]+]] = llvm.bitcast %[[T50]] : vector<[8]xi32> to vector<[4]xi64>
-// CHECK-NEXT: %[[T52:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T51]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T53:[0-9]+]] = llvm.extractvalue %[[T52]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T54:[0-9]+]] = llvm.extractvalue %[[T52]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T55:[0-9]+]] = llvm.bitcast %[[T53]] : vector<[2]xi64> to vector<[4]xi32>
-// CHECK-NEXT: %[[T56:[0-9]+]] = llvm.bitcast %[[T54]] : vector<[2]xi64> to vector<[4]xi32>
-// CHECK-NEXT: %[[T57:[0-9]+]] = llvm.insertvalue %[[T55]], %[[T48]][2] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T58:[0-9]+]] = llvm.insertvalue %[[T56]], %[[T57]][3] : !llvm.array<4 x vector<[4]xi32>>
-
+// CHECK-NEXT: %[[T43...
[truncated]
|
|
||
transform.yield | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: trailing newline please (may require configuring your editor appropriately).
251f93e
to
4ab7765
Compare
65feafd
to
36108bb
Compare
No description provided.