Skip to content

Commit b1c60ab

Browse files
committed
Implement LowerVectorToArmNeon
1 parent 351f94d commit b1c60ab

File tree

8 files changed

+206
-13
lines changed

8 files changed

+206
-13
lines changed

mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ add_mlir_doc(ArmNeon ArmNeon Dialects/ -gen-dialect-doc -dialect=arm_neon)
44
set(LLVM_TARGET_DEFINITIONS ArmNeon.td)
55
mlir_tablegen(ArmNeonConversions.inc -gen-llvmir-conversions)
66
add_public_tablegen_target(MLIRArmNeonConversionsIncGen)
7+
add_subdirectory(Transforms)

mlir/include/mlir/Dialect/ArmNeon/Transforms/CMakeLists.txt

Whitespace-only changes.
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- Transforms.h - ArmNeon Dialect Transformation Entrypoints -*- C++
2+
//-*-===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#ifndef MLIR_DIALECT_ARMNEON_TRANSFORMS_H
11+
#define MLIR_DIALECT_ARMNEON_TRANSFORMS_H
12+
13+
namespace mlir {
14+
15+
namespace arm_neon {
16+
void populateLowerVectorToArmNeonPatterns(RewritePatternSet &patterns);
17+
} // namespace arm_neon
18+
19+
} // namespace mlir
20+
21+
#endif // MLIR_DIALECT_ARMNEON_TRANSFORMS_H

mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
3434
MLIRVectorToLLVM
3535

3636
MLIRArmNeonDialect
37+
MLIRArmNeonTransforms
3738
MLIRArmSMEDialect
3839
MLIRArmSMETransforms
3940
MLIRArmSVEDialect
Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,2 @@
1-
add_mlir_dialect_library(MLIRArmNeonDialect
2-
IR/ArmNeonDialect.cpp
3-
4-
ADDITIONAL_HEADER_DIRS
5-
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmNeon
6-
7-
DEPENDS
8-
MLIRArmNeonIncGen
9-
10-
LINK_LIBS PUBLIC
11-
MLIRIR
12-
MLIRSideEffectInterfaces
13-
)
1+
add_subdirectory(IR)
2+
add_subdirectory(Transforms)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
add_mlir_dialect_library(MLIRArmNeonDialect
2+
ArmNeonDialect.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmNeon
6+
7+
DEPENDS
8+
MLIRArmNeonIncGen
9+
10+
LINK_LIBS PUBLIC
11+
MLIRIR
12+
MLIRSideEffectInterfaces
13+
)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
add_mlir_dialect_library(MLIRArmNeonTransforms
2+
LowerVectorToArmNeon.cpp
3+
4+
DEPENDS
5+
MLIRArmNeonIncGen
6+
7+
LINK_LIBS PUBLIC
8+
MLIRArmNeonDialect
9+
MLIRFuncDialect
10+
MLIRVectorDialect
11+
MLIRIR
12+
MLIRLLVMCommonConversion
13+
MLIRLLVMDialect
14+
)
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
//===- LowerVectorToArmNeon.cpp - Lower 'arm_neon.intr.smmla' ops
2+
//-----------===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// This file implements lowering patterns from vector.contract to
11+
// arm_neon.intr.smmla
12+
//
13+
//===---
14+
15+
#include "mlir/Dialect/Arith/IR/Arith.h"
16+
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
17+
#include "mlir/Dialect/ArmNeon/Transforms/Transforms.h"
18+
#include "mlir/Dialect/Func/IR/FuncOps.h"
19+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
21+
#include "mlir/IR/PatternMatch.h"
22+
#include "mlir/Support/LogicalResult.h"
23+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24+
25+
#define DEBUG_TYPE "arm-neon-vector-lowering"
26+
27+
using namespace mlir;
28+
using namespace mlir::arm_neon;
29+
30+
namespace {
31+
32+
// Return the shaped type with new element type.
33+
static Type matchContainerType(Type element, Type container) {
34+
if (auto shapedTy = dyn_cast<ShapedType>(container))
35+
return shapedTy.clone(element);
36+
37+
return element;
38+
}
39+
40+
// Lowering from vector::contractOp directly to the arm neon
41+
// intrinsic.
42+
class LowerVectorToArmNeonPattern
43+
: public OpRewritePattern<vector::ContractionOp> {
44+
public:
45+
using OpRewritePattern::OpRewritePattern;
46+
LogicalResult matchAndRewrite(vector::ContractionOp op,
47+
PatternRewriter &rewriter) const override {
48+
Location loc = op.getLoc();
49+
Value lhs = op.getLhs();
50+
Value rhs = op.getRhs();
51+
Value res = op.getAcc();
52+
53+
// Check index maps represent M N K and aren't transposed.
54+
auto indexingMaps = op.getIndexingMapsArray();
55+
if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
56+
return affineMap.isPermutation() || affineMap.getNumDims() != 3 ||
57+
affineMap.getNumResults() != 2;
58+
})) {
59+
llvm::dbgs() << "The affine check failed! \n";
60+
return failure();
61+
}
62+
63+
// Check iterator types for contract
64+
auto iteratorTypes = op.getIteratorTypesArray();
65+
if (iteratorTypes.size() != 3 ||
66+
iteratorTypes[0] != vector::IteratorType::parallel ||
67+
iteratorTypes[1] != vector::IteratorType::parallel ||
68+
iteratorTypes[2] != vector::IteratorType::reduction) {
69+
return failure();
70+
}
71+
72+
// Check the tile size by mapping the dimensions of the contract
73+
// -- Tile size: [2, 2, 8]
74+
// Infer tile sizes from operands. Check required tile size
75+
// Note: RHS is not transposed
76+
mlir::VectorType lhsType = op.getLhsType();
77+
mlir::VectorType rhsType = op.getRhsType();
78+
auto dimM = lhsType.getDimSize(0);
79+
auto dimN = rhsType.getDimSize(0);
80+
auto dimK = lhsType.getDimSize(1);
81+
if (rhsType.getDimSize(1) != dimK || dimM != 2 || dimN != 2 || dimK != 8) {
82+
return failure();
83+
}
84+
85+
// Check two extsi inputs Rhs Lhs
86+
arith::ExtSIOp origLhsExtOp;
87+
arith::ExtSIOp origRhsExtOp;
88+
if (!(origLhsExtOp =
89+
dyn_cast_or_null<arith::ExtSIOp>(lhs.getDefiningOp())) ||
90+
!(origRhsExtOp =
91+
dyn_cast_or_null<arith::ExtSIOp>(rhs.getDefiningOp()))) {
92+
return failure();
93+
}
94+
95+
arith::ExtSIOp extsiLhs;
96+
arith::ExtSIOp extsiRhs;
97+
// Match any iX to i32 for X<8 then turn into an i8 output. Feed into
98+
// following neon instruction. Check inputs for extsi are <=i8
99+
if (auto lhsExtInType =
100+
origLhsExtOp.getIn().getType().dyn_cast<mlir::VectorType>()) {
101+
if (lhsExtInType.getElementTypeBitWidth() <= 8) {
102+
// Target lhs type with i8. This is likely redundant
103+
Type targetLhsExtTy =
104+
matchContainerType(rewriter.getI8Type(), lhsExtInType);
105+
extsiLhs = rewriter.create<arith::ExtSIOp>(loc, targetLhsExtTy,
106+
origLhsExtOp.getIn());
107+
}
108+
}
109+
if (auto rhsExtInType =
110+
origRhsExtOp.getIn().getType().dyn_cast<mlir::VectorType>()) {
111+
if (rhsExtInType.getElementTypeBitWidth() <= 8) {
112+
// Target rhs type with i8
113+
Type targetRhsExtTy =
114+
matchContainerType(rewriter.getI8Type(), rhsExtInType);
115+
extsiRhs = rewriter.create<arith::ExtSIOp>(loc, targetRhsExtTy,
116+
origRhsExtOp.getIn());
117+
}
118+
}
119+
120+
if (!extsiLhs || !extsiRhs) {
121+
return failure();
122+
}
123+
124+
// Collapse to 1D vectors required by smmla intrinsic
125+
auto collapsedInputType = VectorType::get(
126+
{16}, extsiLhs.getType().cast<ShapedType>().getElementType());
127+
auto collapsedOutputType =
128+
VectorType::get({4}, res.getType().cast<ShapedType>().getElementType());
129+
auto collapsedLhs = rewriter.create<vector::ShapeCastOp>(
130+
extsiLhs.getLoc(), collapsedInputType, extsiLhs);
131+
auto collapsedRhs = rewriter.create<vector::ShapeCastOp>(
132+
extsiRhs.getLoc(), collapsedInputType, extsiRhs);
133+
auto collapsedRes = rewriter.create<vector::ShapeCastOp>(
134+
res.getLoc(), collapsedOutputType, res);
135+
136+
// Replace the contract with a neon op
137+
auto smmlaOp = rewriter.create<arm_neon::SmmlaOp>(
138+
op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
139+
collapsedRhs);
140+
141+
// Reshape output back to 2D
142+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
143+
smmlaOp);
144+
return success();
145+
}
146+
};
147+
148+
} // namespace
149+
150+
void mlir::arm_neon::populateLowerVectorToArmNeonPatterns(
151+
RewritePatternSet &patterns) {
152+
MLIRContext *context = patterns.getContext();
153+
patterns.add<LowerVectorToArmNeonPattern>(context, /*benefit=*/1);
154+
}

0 commit comments

Comments
 (0)