Skip to content

Commit b2b6eea

Browse files
committed
Implement LowerVectorToArmNeon
1 parent b2ca23a commit b2b6eea

File tree

7 files changed

+225
-13
lines changed

7 files changed

+225
-13
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- Transforms.h - ArmNeon Dialect Transformation Entrypoints -*- C++
2+
//-*-===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#ifndef MLIR_DIALECT_ARMNEON_TRANSFORMS_H
11+
#define MLIR_DIALECT_ARMNEON_TRANSFORMS_H
12+
13+
namespace mlir {
14+
15+
namespace arm_neon {
16+
void populateLowerVectorToArmNeonPatterns(RewritePatternSet &patterns);
17+
} // namespace arm_neon
18+
19+
} // namespace mlir
20+
21+
#endif // MLIR_DIALECT_ARMNEON_TRANSFORMS_H

mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
3434
MLIRVectorToLLVM
3535

3636
MLIRArmNeonDialect
37+
MLIRArmNeonTransforms
3738
MLIRArmSMEDialect
3839
MLIRArmSMETransforms
3940
MLIRArmSVEDialect
Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,2 @@
1-
add_mlir_dialect_library(MLIRArmNeonDialect
2-
IR/ArmNeonDialect.cpp
3-
4-
ADDITIONAL_HEADER_DIRS
5-
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmNeon
6-
7-
DEPENDS
8-
MLIRArmNeonIncGen
9-
10-
LINK_LIBS PUBLIC
11-
MLIRIR
12-
MLIRSideEffectInterfaces
13-
)
1+
add_subdirectory(IR)
2+
add_subdirectory(Transforms)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
add_mlir_dialect_library(MLIRArmNeonDialect
2+
ArmNeonDialect.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmNeon
6+
7+
DEPENDS
8+
MLIRArmNeonIncGen
9+
10+
LINK_LIBS PUBLIC
11+
MLIRIR
12+
MLIRSideEffectInterfaces
13+
)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
add_mlir_dialect_library(MLIRArmNeonTransforms
2+
LowerVectorToArmNeon.cpp
3+
4+
DEPENDS
5+
MLIRArmNeonIncGen
6+
7+
LINK_LIBS PUBLIC
8+
MLIRArmNeonDialect
9+
MLIRFuncDialect
10+
MLIRVectorDialect
11+
MLIRIR
12+
MLIRLLVMCommonConversion
13+
MLIRLLVMDialect
14+
)
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
//===- LowerVectorToArmNeon.cpp - Lower 'arm_neon.intr.smmla' ops
2+
//-----------===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// This file implements lowering patterns from vector.contract to
11+
// arm_neon.intr.smmla
12+
//
13+
//===---
14+
15+
#include "mlir/Dialect/Arith/IR/Arith.h"
16+
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
17+
#include "mlir/Dialect/ArmNeon/Transforms.h"
18+
#include "mlir/Dialect/Func/IR/FuncOps.h"
19+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
21+
#include "mlir/IR/PatternMatch.h"
22+
#include "mlir/Support/LogicalResult.h"
23+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24+
25+
#define DEBUG_TYPE "arm-neon-vector-lowering"
26+
27+
using namespace mlir;
28+
using namespace mlir::arm_neon;
29+
30+
namespace {
31+
32+
// Return the shaped type with new element type.
33+
static Type matchContainerType(Type element, Type container) {
34+
if (auto shapedTy = dyn_cast<ShapedType>(container))
35+
return shapedTy.clone(element);
36+
37+
return element;
38+
}
39+
40+
// Lowering from vector::contractOp directly to the arm neon
41+
// intrinsic.
42+
class LowerVectorToArmNeonPattern
43+
: public OpRewritePattern<vector::ContractionOp> {
44+
public:
45+
using OpRewritePattern::OpRewritePattern;
46+
LogicalResult matchAndRewrite(vector::ContractionOp op,
47+
PatternRewriter &rewriter) const override {
48+
Location loc = op.getLoc();
49+
Value lhs = op.getLhs();
50+
Value rhs = op.getRhs();
51+
Value res = op.getAcc();
52+
53+
// Check index maps represent M N K and aren't transposed.
54+
auto indexingMaps = op.getIndexingMapsArray();
55+
if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
56+
return affineMap.isPermutation() || affineMap.getNumDims() != 3 ||
57+
affineMap.getNumResults() != 2;
58+
})) {
59+
return failure();
60+
}
61+
62+
// Check iterator types for contract
63+
auto iteratorTypes = op.getIteratorTypesArray();
64+
if (iteratorTypes.size() != 3 ||
65+
iteratorTypes[0] != vector::IteratorType::parallel ||
66+
iteratorTypes[1] != vector::IteratorType::parallel ||
67+
iteratorTypes[2] != vector::IteratorType::reduction) {
68+
return failure();
69+
}
70+
71+
// Check the tile size by mapping the dimensions of the contract
72+
// -- Tile size: [2, 2, 8]
73+
// Infer tile sizes from operands. Check required tile size
74+
// Note: RHS is not transposed
75+
mlir::VectorType lhsType = op.getLhsType();
76+
mlir::VectorType rhsType = op.getRhsType();
77+
auto dimM = lhsType.getDimSize(0);
78+
auto dimN = rhsType.getDimSize(0);
79+
auto dimK = lhsType.getDimSize(1);
80+
if (rhsType.getDimSize(1) != dimK || dimM != 2 || dimN != 2 || dimK != 8) {
81+
return failure();
82+
}
83+
84+
// Check two extsi inputs Rhs Lhs
85+
arith::ExtSIOp origLhsExtOp;
86+
arith::ExtSIOp origRhsExtOp;
87+
if (!(origLhsExtOp =
88+
dyn_cast_or_null<arith::ExtSIOp>(lhs.getDefiningOp())) ||
89+
!(origRhsExtOp =
90+
dyn_cast_or_null<arith::ExtSIOp>(rhs.getDefiningOp()))) {
91+
return failure();
92+
}
93+
94+
arith::ExtSIOp extsiLhs;
95+
arith::ExtSIOp extsiRhs;
96+
// Match any iX to i32 for X<8 then turn into an i8 output. Feed into
97+
// following neon instruction. Check inputs for extsi are <=i8
98+
if (auto lhsExtInType =
99+
origLhsExtOp.getIn().getType().dyn_cast<mlir::VectorType>()) {
100+
if (lhsExtInType.getElementTypeBitWidth() <= 8) {
101+
// Target lhs type with i8. This is likely redundant
102+
Type targetLhsExtTy =
103+
matchContainerType(rewriter.getI8Type(), lhsExtInType);
104+
extsiLhs = rewriter.create<arith::ExtSIOp>(loc, targetLhsExtTy,
105+
origLhsExtOp.getIn());
106+
}
107+
}
108+
if (auto rhsExtInType =
109+
origRhsExtOp.getIn().getType().dyn_cast<mlir::VectorType>()) {
110+
if (rhsExtInType.getElementTypeBitWidth() <= 8) {
111+
// Target rhs type with i8
112+
Type targetRhsExtTy =
113+
matchContainerType(rewriter.getI8Type(), rhsExtInType);
114+
extsiRhs = rewriter.create<arith::ExtSIOp>(loc, targetRhsExtTy,
115+
origRhsExtOp.getIn());
116+
}
117+
}
118+
119+
if (!extsiLhs || !extsiRhs) {
120+
return failure();
121+
}
122+
123+
// Collapse to 1D vectors required by smmla intrinsic
124+
auto collapsedInputType = VectorType::get(
125+
{16}, extsiLhs.getType().cast<ShapedType>().getElementType());
126+
auto collapsedOutputType =
127+
VectorType::get({4}, res.getType().cast<ShapedType>().getElementType());
128+
auto collapsedLhs = rewriter.create<vector::ShapeCastOp>(
129+
extsiLhs.getLoc(), collapsedInputType, extsiLhs);
130+
auto collapsedRhs = rewriter.create<vector::ShapeCastOp>(
131+
extsiRhs.getLoc(), collapsedInputType, extsiRhs);
132+
auto collapsedRes = rewriter.create<vector::ShapeCastOp>(
133+
res.getLoc(), collapsedOutputType, res);
134+
135+
// Replace the contract with a neon op
136+
auto smmlaOp = rewriter.create<arm_neon::SmmlaOp>(
137+
op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
138+
collapsedRhs);
139+
140+
// Reshape output back to 2D
141+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
142+
smmlaOp);
143+
return success();
144+
}
145+
};
146+
147+
} // namespace
148+
149+
void mlir::arm_neon::populateLowerVectorToArmNeonPatterns(
150+
RewritePatternSet &patterns) {
151+
MLIRContext *context = patterns.getContext();
152+
patterns.add<LowerVectorToArmNeonPattern>(context, /*benefit=*/1);
153+
}

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1929,6 +1929,27 @@ cc_library(
19291929
],
19301930
)
19311931

1932+
cc_library(
1933+
name = "ArmNeonTransforms",
1934+
srcs = ["lib/Dialect/ArmNeon/Transforms/LowerVectorToArmNeon.cpp"],
1935+
hdrs = ["include/mlir/Dialect/ArmNeon/Transforms.h"],
1936+
includes = ["include"],
1937+
deps = [
1938+
":ArithDialect",
1939+
":ArmNeonIncGen",
1940+
":ArmNeonDialect",
1941+
":FuncDialect",
1942+
":IR",
1943+
":LLVMDialect",
1944+
":SideEffectInterfaces",
1945+
":Support",
1946+
":VectorDialect",
1947+
":Transforms",
1948+
"//llvm:Core",
1949+
"//llvm:Support",
1950+
],
1951+
)
1952+
19321953
gentbl_cc_library(
19331954
name = "ArmNeonConversionIncGen",
19341955
tbl_outs = [

0 commit comments

Comments
 (0)