Skip to content

Commit cb6ff74

Browse files
authored
[mlir][ArmNeon] Implements LowerVectorToArmNeon Pattern for SMMLA (llvm#81895)
This patch adds a the `LowerVectorToArmNeonPattern` patterns to the ArmNeon. This pattern inspects `vector.contract` ops that can be 1-1 mapped to an `arm.neon.smmla` intrinsic. The contract ops must be separated into tiles who's inputs must fit that of a single smmla op (`2x8xi32` inputs and `2x2xi32` output). The `vector.contract` inputs must be sign extended from narrow types (<=i8) to be converted. If all conditions are met, an smmla op is inserted with additional `vector.shape_casts` to handle linearizing the input and output dimension.
1 parent e93489c commit cb6ff74

File tree

13 files changed

+338
-13
lines changed

13 files changed

+338
-13
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- Transforms.h - ArmNeon Transformation Entrypoints --------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_ARMNEON_TRANSFORMS_H
10+
#define MLIR_DIALECT_ARMNEON_TRANSFORMS_H
11+
12+
namespace mlir {
13+
14+
namespace arm_neon {
15+
void populateLowerContractionToSMMLAPatternPatterns(
16+
RewritePatternSet &patterns);
17+
} // namespace arm_neon
18+
19+
} // namespace mlir
20+
21+
#endif // MLIR_DIALECT_ARMNEON_TRANSFORMS_H

mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
3434
MLIRVectorToLLVM
3535

3636
MLIRArmNeonDialect
37+
MLIRArmNeonTransforms
3738
MLIRArmSMEDialect
3839
MLIRArmSMETransforms
3940
MLIRArmSVEDialect
Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,2 @@
1-
add_mlir_dialect_library(MLIRArmNeonDialect
2-
IR/ArmNeonDialect.cpp
3-
4-
ADDITIONAL_HEADER_DIRS
5-
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmNeon
6-
7-
DEPENDS
8-
MLIRArmNeonIncGen
9-
10-
LINK_LIBS PUBLIC
11-
MLIRIR
12-
MLIRSideEffectInterfaces
13-
)
1+
add_subdirectory(IR)
2+
add_subdirectory(Transforms)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
add_mlir_dialect_library(MLIRArmNeonDialect
2+
ArmNeonDialect.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmNeon
6+
7+
DEPENDS
8+
MLIRArmNeonIncGen
9+
10+
LINK_LIBS PUBLIC
11+
MLIRIR
12+
MLIRSideEffectInterfaces
13+
)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
add_mlir_dialect_library(MLIRArmNeonTransforms
2+
LowerContractionToSMMLAPattern.cpp
3+
4+
DEPENDS
5+
MLIRArmNeonIncGen
6+
7+
LINK_LIBS PUBLIC
8+
MLIRArmNeonDialect
9+
MLIRFuncDialect
10+
MLIRVectorDialect
11+
MLIRIR
12+
MLIRLLVMCommonConversion
13+
MLIRLLVMDialect
14+
)
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements lowering patterns from vector.contract to
10+
// arm_neon.intr.smmla
11+
//
12+
//===---
13+
14+
#include "mlir/Dialect/Arith/IR/Arith.h"
15+
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
16+
#include "mlir/Dialect/ArmNeon/Transforms.h"
17+
#include "mlir/Dialect/Func/IR/FuncOps.h"
18+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
20+
#include "mlir/IR/PatternMatch.h"
21+
#include "mlir/Support/LogicalResult.h"
22+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23+
24+
#define DEBUG_TYPE "lower-contract-to-arm-neon"
25+
26+
using namespace mlir;
27+
using namespace mlir::arm_neon;
28+
29+
namespace {
30+
31+
/// Return the shaped type with new element type.
32+
static Type matchContainerType(Type element, Type container) {
33+
if (auto shapedTy = dyn_cast<ShapedType>(container)) {
34+
return shapedTy.clone(element);
35+
}
36+
return element;
37+
}
38+
39+
/// Lowering from a single vector::contractOp directly to the arm neon smmla
40+
/// intrinsic. The shapes of the contract and intrinsic must match.
41+
class LowerContractionToSMMLAPattern
42+
: public OpRewritePattern<vector::ContractionOp> {
43+
public:
44+
using OpRewritePattern::OpRewritePattern;
45+
LogicalResult matchAndRewrite(vector::ContractionOp op,
46+
PatternRewriter &rewriter) const override {
47+
Location loc = op.getLoc();
48+
Value lhs = op.getLhs();
49+
Value rhs = op.getRhs();
50+
Value res = op.getAcc();
51+
52+
// Check index maps that represent M N K in contract.
53+
auto indexingMaps = op.getIndexingMapsArray();
54+
if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
55+
return affineMap.isPermutation() || affineMap.getNumDims() != 3 ||
56+
affineMap.getNumResults() != 2;
57+
})) {
58+
return failure();
59+
}
60+
61+
// Check iterator types for contract.
62+
auto iteratorTypes = op.getIteratorTypesArray();
63+
if (iteratorTypes.size() != 3 ||
64+
iteratorTypes[0] != vector::IteratorType::parallel ||
65+
iteratorTypes[1] != vector::IteratorType::parallel ||
66+
iteratorTypes[2] != vector::IteratorType::reduction) {
67+
return failure();
68+
}
69+
70+
// Check the tile size by mapping the dimensions of the contract.
71+
mlir::VectorType lhsType = op.getLhsType();
72+
mlir::VectorType rhsType = op.getRhsType();
73+
auto dimM = lhsType.getDimSize(0);
74+
auto dimN = rhsType.getDimSize(0);
75+
auto dimK = lhsType.getDimSize(1);
76+
if (rhsType.getDimSize(1) != dimK || dimM != 2 || dimN != 2 || dimK != 8) {
77+
return failure();
78+
}
79+
80+
// Check two extsi inputs Rhs Lhs for contract.
81+
arith::ExtSIOp origLhsExtOp =
82+
dyn_cast_or_null<arith::ExtSIOp>(lhs.getDefiningOp());
83+
arith::ExtSIOp origRhsExtOp =
84+
dyn_cast_or_null<arith::ExtSIOp>(rhs.getDefiningOp());
85+
if (!origLhsExtOp || !origRhsExtOp) {
86+
return failure();
87+
}
88+
89+
// Match any iX to i32 for X<8 then turn into an i8 output. Feed into
90+
// following neon instruction. Check inputs for extsi are <=i8
91+
Value extsiLhs;
92+
Value extsiRhs;
93+
if (auto lhsExtInType =
94+
origLhsExtOp.getIn().getType().dyn_cast<mlir::VectorType>()) {
95+
if (lhsExtInType.getElementTypeBitWidth() <= 8) {
96+
Type targetLhsExtTy =
97+
matchContainerType(rewriter.getI8Type(), lhsExtInType);
98+
extsiLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
99+
origLhsExtOp.getIn());
100+
}
101+
}
102+
if (auto rhsExtInType =
103+
origRhsExtOp.getIn().getType().dyn_cast<mlir::VectorType>()) {
104+
if (rhsExtInType.getElementTypeBitWidth() <= 8) {
105+
Type targetRhsExtTy =
106+
matchContainerType(rewriter.getI8Type(), rhsExtInType);
107+
extsiRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
108+
origRhsExtOp.getIn());
109+
}
110+
}
111+
112+
if (!extsiLhs || !extsiRhs) {
113+
return failure();
114+
}
115+
116+
// Collapse to 1D vectors required by smmla intrinsic
117+
auto collapsedInputType = VectorType::get(
118+
{16}, extsiLhs.getType().cast<ShapedType>().getElementType());
119+
auto collapsedOutputType =
120+
VectorType::get({4}, res.getType().cast<ShapedType>().getElementType());
121+
auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
122+
extsiLhs.getLoc(), collapsedInputType, extsiLhs);
123+
auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
124+
extsiRhs.getLoc(), collapsedInputType, extsiRhs);
125+
auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
126+
res.getLoc(), collapsedOutputType, res);
127+
128+
// Replace the contract with a neon op
129+
auto smmlaOp = rewriter.createOrFold<arm_neon::SmmlaOp>(
130+
op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
131+
collapsedRhs);
132+
133+
// Reshape output back to 2D
134+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
135+
smmlaOp);
136+
return success();
137+
}
138+
};
139+
140+
} // namespace
141+
142+
void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns(
143+
RewritePatternSet &patterns) {
144+
MLIRContext *context = patterns.getContext();
145+
patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/1);
146+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// RUN: mlir-opt -test-lower-to-arm-neon -verify-diagnostics -split-input-file %s | FileCheck %s
2+
3+
// CHECK-LABEL: test_lower_vector_arm_neon_mixed_types
4+
// CHECK-SAME: %[[A0:.*]]: vector<2x8xi8>, %[[A1:.*]]: vector<2x8xi4>, %[[A2:.*]]: vector<2x2xi32>
5+
// CHECK-DAG: %[[D0:.*]] = arith.extsi %[[A1]] : vector<2x8xi4> to vector<2x8xi8>
6+
// CHECK-DAG: %[[D1:.*]] = vector.shape_cast %[[A0]] : vector<2x8xi8> to vector<16xi8>
7+
// CHECK-DAG: %[[D2:.*]] = vector.shape_cast %[[D0]] : vector<2x8xi8> to vector<16xi8>
8+
// CHECK-DAG: %[[D3:.*]] = vector.shape_cast %[[A2]] : vector<2x2xi32> to vector<4xi32>
9+
// CHECK-DAG: %[[D4:.*]] = arm_neon.intr.smmla %[[D3]], %[[D1]], %[[D2]] : vector<16xi8> to vector<4xi32>
10+
// CHECK-DAG: %[[D5:.*]] = vector.shape_cast %[[D4]] : vector<4xi32> to vector<2x2xi32>
11+
func.func @test_lower_vector_arm_neon_mixed_types(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi4>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
12+
%lhs_extsi = arith.extsi %lhs : vector<2x8xi8> to vector<2x8xi32>
13+
%rhs_extsi = arith.extsi %rhs : vector<2x8xi4> to vector<2x8xi32>
14+
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
15+
return %res : vector<2x2xi32>
16+
}
17+
18+
// -----
19+
20+
// CHECK-LABEL: test_lower_vector_arm_neon_same_types
21+
// CHECK-SAME: %[[A0:.*]]: vector<2x8xi8>, %[[A1:.*]]: vector<2x8xi8>, %[[A2:.*]]: vector<2x2xi32>
22+
// CHECK-DAG: %[[D0:.*]] = vector.shape_cast %[[A0]] : vector<2x8xi8> to vector<16xi8>
23+
// CHECK-DAG: %[[D1:.*]] = vector.shape_cast %[[A1]] : vector<2x8xi8> to vector<16xi8>
24+
// CHECK-DAG: %[[D2:.*]] = vector.shape_cast %[[A2]] : vector<2x2xi32> to vector<4xi32>
25+
// CHECK-DAG: %[[D3:.*]] = arm_neon.intr.smmla %[[D2]], %[[D0]], %[[D1]] : vector<16xi8> to vector<4xi32>
26+
// CHECK-DAG: %[[D4:.*]] = vector.shape_cast %[[D3]] : vector<4xi32> to vector<2x2xi32>
27+
func.func @test_lower_vector_arm_neon_same_types(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
28+
%lhs_extsi = arith.extsi %lhs : vector<2x8xi8> to vector<2x8xi32>
29+
%rhs_extsi = arith.extsi %rhs : vector<2x8xi8> to vector<2x8xi32>
30+
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
31+
return %res : vector<2x2xi32>
32+
}
33+
34+
// -----
35+
36+
// CHECK-LABEL: test_lower_vector_arm_neon_without_extsi
37+
// CHECK-SAME: %[[A0:.*]]: vector<2x8xi32>, %[[A1:.*]]: vector<2x8xi32>, %[[A2:.*]]: vector<2x2xi32>
38+
// CHECK-DAG: %[[D0:.*]] = vector.contract
39+
func.func @test_lower_vector_arm_neon_without_extsi(%lhs: vector<2x8xi32>, %rhs: vector<2x8xi32>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
40+
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
41+
return %res : vector<2x2xi32>
42+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Exclude tests from libMLIR.so
2+
add_mlir_library(MLIRArmNeonTestPasses
3+
TestLowerToArmNeon.cpp
4+
5+
EXCLUDE_FROM_LIBMLIR
6+
7+
LINK_LIBS PUBLIC
8+
MLIRArmNeonDialect
9+
MLIRArmNeonTransforms
10+
MLIRIR
11+
MLIRPass
12+
MLIRTransforms
13+
)
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
//===- TestLowerToArmNeon.cpp - Test lowering to ArmNeon as a sink pass -===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass for testing the lowering to ArmNeon as a
10+
// generally usable sink pass.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
15+
#include "mlir/Dialect/ArmNeon/Transforms.h"
16+
#include "mlir/Dialect/Func/IR/FuncOps.h"
17+
#include "mlir/IR/PatternMatch.h"
18+
#include "mlir/Pass/Pass.h"
19+
#include "mlir/Pass/PassManager.h"
20+
#include "mlir/Support/LogicalResult.h"
21+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22+
23+
#define PASS_NAME "test-lower-to-arm-neon"
24+
25+
using namespace mlir;
26+
using namespace mlir::arm_neon;
27+
28+
namespace {
29+
struct TestLowerToArmNeon
30+
: public PassWrapper<TestLowerToArmNeon, OperationPass<func::FuncOp>> {
31+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLowerToArmNeon)
32+
33+
StringRef getArgument() const final { return PASS_NAME; }
34+
StringRef getDescription() const final { return "Tests lower to arm Neon."; }
35+
TestLowerToArmNeon() = default;
36+
TestLowerToArmNeon(const TestLowerToArmNeon &pass) = default;
37+
38+
void getDependentDialects(DialectRegistry &registry) const override {
39+
registry.insert<arm_neon::ArmNeonDialect>();
40+
}
41+
42+
void runOnOperation() override;
43+
};
44+
45+
} // namespace
46+
47+
void TestLowerToArmNeon::runOnOperation() {
48+
MLIRContext *context = &getContext();
49+
RewritePatternSet patterns(context);
50+
populateLowerContractionToSMMLAPatternPatterns(patterns);
51+
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
52+
return signalPassFailure();
53+
}
54+
55+
namespace mlir {
56+
namespace test {
57+
58+
void registerTestLowerToArmNeon() { PassRegistration<TestLowerToArmNeon>(); }
59+
60+
} // namespace test
61+
} // namespace mlir

mlir/test/lib/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_subdirectory(Affine)
22
add_subdirectory(Arith)
3+
add_subdirectory(ArmNeon)
34
add_subdirectory(ArmSME)
45
add_subdirectory(Bufferization)
56
add_subdirectory(ControlFlow)

mlir/tools/mlir-opt/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ if(MLIR_INCLUDE_TESTS)
1717
MLIRTestFuncToLLVM
1818
MLIRAffineTransformsTestPasses
1919
MLIRArithTestPasses
20+
MLIRArmNeonTestPasses
2021
MLIRArmSMETestPasses
2122
MLIRBufferizationTestPasses
2223
MLIRControlFlowTestPasses

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ void registerTestLoopFusion();
111111
void registerTestCFGLoopInfoPass();
112112
void registerTestLoopMappingPass();
113113
void registerTestLoopUnrollingPass();
114+
void registerTestLowerToArmNeon();
114115
void registerTestLowerToArmSME();
115116
void registerTestLowerToLLVM();
116117
void registerTestMakeIsolatedFromAbovePass();
@@ -237,6 +238,7 @@ void registerTestPasses() {
237238
mlir::test::registerTestCFGLoopInfoPass();
238239
mlir::test::registerTestLoopMappingPass();
239240
mlir::test::registerTestLoopUnrollingPass();
241+
mlir::test::registerTestLowerToArmNeon();
240242
mlir::test::registerTestLowerToArmSME();
241243
mlir::test::registerTestLowerToLLVM();
242244
mlir::test::registerTestMakeIsolatedFromAbovePass();

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1931,6 +1931,27 @@ cc_library(
19311931
],
19321932
)
19331933

1934+
cc_library(
1935+
name = "ArmNeonTransforms",
1936+
srcs = ["lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp"],
1937+
hdrs = ["include/mlir/Dialect/ArmNeon/Transforms.h"],
1938+
includes = ["include"],
1939+
deps = [
1940+
":ArithDialect",
1941+
":ArmNeonIncGen",
1942+
":ArmNeonDialect",
1943+
":FuncDialect",
1944+
":IR",
1945+
":LLVMDialect",
1946+
":SideEffectInterfaces",
1947+
":Support",
1948+
":VectorDialect",
1949+
":Transforms",
1950+
"//llvm:Core",
1951+
"//llvm:Support",
1952+
],
1953+
)
1954+
19341955
gentbl_cc_library(
19351956
name = "ArmNeonConversionIncGen",
19361957
tbl_outs = [

0 commit comments

Comments
 (0)