Skip to content

[MLIR] Add apply_patterns.arm_sve.vector_contract_to_i8mm TD Op #140572

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
add_subdirectory(TransformOps)
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//===- ArmSVEVectorTransformOps.h - Vector transform ops --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H
#define MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H

#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpImplementation.h"

//===----------------------------------------------------------------------===//
// ArmSVE Vector Transform Operations
//===----------------------------------------------------------------------===//

#define GET_OP_CLASSES
#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h.inc"

namespace mlir {
class DialectRegistry;

namespace arm_sve {
void registerTransformDialectExtension(DialectRegistry &registry);

} // namespace arm_sve
} // namespace mlir

#endif // MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//===- ArmSVEVectorTransformOps.td - Arm SVE transform ops--*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef ARMSVE_VECTOR_TRANSFORM_OPS
#define ARMSVE_VECTOR_TRANSFORM_OPS

include "mlir/Dialect/Transform/IR/TransformAttrs.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"

def ApplyArmSVELowerContractionPatternsOp
: Op<Transform_Dialect, "apply_patterns.arm_sve.vector_contract_to_i8mm",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Indicates that vector contraction-like operations should be lowered to
finer-grained vector primitives using the ArmSVE dialect.
}];

let assemblyFormat = "attr-dict";
}

#endif // ARMSVE_VECTOR_TRANSFORM_OPS
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
set(LLVM_TARGET_DEFINITIONS ArmSVEVectorTransformOps.td)
mlir_tablegen(ArmSVEVectorTransformOps.h.inc -gen-op-decls)
mlir_tablegen(ArmSVEVectorTransformOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRArmSVEVectorTransformOpsIncGen)

add_mlir_doc(ArmSVEVectorTransformOps ArmSVEVectorTransformOps Dialects/ -gen-op-doc)
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllExtensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
Expand Down Expand Up @@ -106,6 +107,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
transform::registerPDLExtension(registry);
vector::registerTransformDialectExtension(registry);
arm_neon::registerTransformDialectExtension(registry);
arm_sve::registerTransformDialectExtension(registry);

// Translation extensions need to be registered by calling
// `registerAllToLLVMIRTranslations` (see All.h).
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/ArmSVE/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
add_subdirectory(TransformOps)
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
//===- ArmSVEVectorTransformOps.cpp - Implementation transform ops -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"

#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"

using namespace mlir;

//===----------------------------------------------------------------------===//
// Apply...PatternsOp
//===----------------------------------------------------------------------===//

void transform::ApplyArmSVELowerContractionPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
mlir::populateLowerContractionToSVEI8MMPatternPatterns(patterns);
}

//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//

namespace {
class ArmSVEVectorTransformDialectExtension
: public transform::TransformDialectExtension<
ArmSVEVectorTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
ArmSVEVectorTransformDialectExtension)

ArmSVEVectorTransformDialectExtension() {
declareGeneratedDialect<arm_sve::ArmSVEDialect>();
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp.inc"
>();
}
};
} // namespace

#define GET_OP_CLASSES
#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp.inc"

void mlir::arm_sve::registerTransformDialectExtension(
DialectRegistry &registry) {
registry.addExtensions<ArmSVEVectorTransformDialectExtension>();
}
18 changes: 18 additions & 0 deletions mlir/lib/Dialect/ArmSVE/TransformOps/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
add_mlir_dialect_library(MLIRArmSVEVectorTransformOps
ArmSVEVectorTransformOps.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSVE/TransformOps

DEPENDS
MLIRArmSVEVectorTransformOpsIncGen

LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMCommonConversion
MLIRLLVMDialect
MLIRVectorDialect
MLIRTransformDialect
MLIRArmSVEDialect
MLIRArmSVETransforms
)
Loading
Loading