Skip to content

Commit 251f93e

Browse files
[MLIR] Add apply_patterns.vector.arm_sve.lower_contraction TD Op
1 parent 65feafd commit 251f93e

File tree

12 files changed

+512
-290
lines changed

12 files changed

+512
-290
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
add_subdirectory(IR)
22
add_subdirectory(Transforms)
3+
add_subdirectory(TransformOps)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//===- ArmSVEVectorTransformOps.h - Vector transform ops --------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H
10+
#define MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H
11+
12+
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
13+
#include "mlir/IR/OpImplementation.h"
14+
15+
//===----------------------------------------------------------------------===//
16+
// ArmSVE Vector Transform Operations
17+
//===----------------------------------------------------------------------===//
18+
19+
#define GET_OP_CLASSES
20+
#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h.inc"
21+
22+
namespace mlir {
23+
class DialectRegistry;
24+
25+
namespace arm_sve {
26+
void registerTransformDialectExtension(DialectRegistry &registry);
27+
28+
} // namespace arm_sve
29+
} // namespace mlir
30+
31+
#endif // MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//===- ArmSVEVectorTransformOps.td - Arm SVE transform ops--*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#ifndef ARMSVE_VECTOR_TRANSFORM_OPS
9+
#define ARMSVE_VECTOR_TRANSFORM_OPS
10+
11+
include "mlir/Dialect/Transform/IR/TransformAttrs.td"
12+
include "mlir/Dialect/Transform/IR/TransformDialect.td"
13+
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
14+
15+
def ApplyArmSVELowerContractionPatternsOp
16+
: Op<Transform_Dialect, "apply_patterns.vector.arm_sve.lower_contraction",
17+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
18+
let description = [{
19+
Indicates that vector contraction-like operations should be lowered to
20+
finer-grained vector primitives using the ArmSVE dialect.
21+
}];
22+
23+
let assemblyFormat = "attr-dict";
24+
}
25+
26+
#endif // ARMSVE_VECTOR_TRANSFORM_OPS
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
set(LLVM_TARGET_DEFINITIONS ArmSVEVectorTransformOps.td)
2+
mlir_tablegen(ArmSVEVectorTransformOps.h.inc -gen-op-decls)
3+
mlir_tablegen(ArmSVEVectorTransformOps.cpp.inc -gen-op-defs)
4+
add_public_tablegen_target(MLIRArmSVEVectorTransformOpsIncGen)
5+
6+
add_mlir_doc(ArmSVEVectorTransformOps ArmSVEVectorTransformOps Dialects/ -gen-op-doc)

mlir/include/mlir/InitAllExtensions.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
3535
#include "mlir/Dialect/AMX/Transforms.h"
3636
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
37+
#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
3738
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
3839
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
3940
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
@@ -106,6 +107,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
106107
transform::registerLoopExtension(registry);
107108
transform::registerPDLExtension(registry);
108109
vector::registerTransformDialectExtension(registry);
110+
arm_sve::registerTransformDialectExtension(registry);
109111

110112
// Translation extensions need to be registered by calling
111113
// `registerAllToLLVMIRTranslations` (see All.h).
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
add_subdirectory(IR)
22
add_subdirectory(Transforms)
3+
add_subdirectory(TransformOps)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
//===- ArmSVEVectorTransformOps.cpp - Implementation transform ops -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
10+
11+
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
12+
#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
13+
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
14+
15+
using namespace mlir;
16+
17+
//===----------------------------------------------------------------------===//
18+
// Apply...PatternsOp
19+
//===----------------------------------------------------------------------===//
20+
21+
void transform::ApplyArmSVELowerContractionPatternsOp::populatePatterns(
22+
RewritePatternSet &patterns) {
23+
mlir::populateLowerContractionToSVEI8MMPatternPatterns(patterns);
24+
}
25+
26+
//===----------------------------------------------------------------------===//
27+
// Transform op registration
28+
//===----------------------------------------------------------------------===//
29+
30+
namespace {
31+
class ArmSVEVectorTransformDialectExtension
32+
: public transform::TransformDialectExtension<
33+
ArmSVEVectorTransformDialectExtension> {
34+
public:
35+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
36+
ArmSVEVectorTransformDialectExtension)
37+
38+
ArmSVEVectorTransformDialectExtension() {
39+
declareGeneratedDialect<arm_sve::ArmSVEDialect>();
40+
registerTransformOps<
41+
#define GET_OP_LIST
42+
#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp.inc"
43+
>();
44+
}
45+
};
46+
} // namespace
47+
48+
#define GET_OP_CLASSES
49+
#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp.inc"
50+
51+
void mlir::arm_sve::registerTransformDialectExtension(
52+
DialectRegistry &registry) {
53+
registry.addExtensions<ArmSVEVectorTransformDialectExtension>();
54+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
add_mlir_dialect_library(MLIRArmSVEVectorTransformOps
2+
ArmSVEVectorTransformOps.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSVE/TransformOps
6+
7+
DEPENDS
8+
MLIRArmSVEVectorTransformOpsIncGen
9+
10+
LINK_LIBS PUBLIC
11+
MLIRIR
12+
MLIRLLVMCommonConversion
13+
MLIRLLVMDialect
14+
MLIRVectorDialect
15+
MLIRTransformDialect
16+
MLIRArmSVEDialect
17+
MLIRArmSVETransforms
18+
)
19+

0 commit comments

Comments
 (0)