6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
8
//
9
- // This file implements lowering patterns from vector.contract to
10
- // SVE I8MM operations .
9
+ // This file implements lowering patterns from vector.contract to operations
10
+ // that map to instructions from the SVE FEAT_I8MM extension .
11
11
//
12
- // ===---
12
+ // ===----------------------------------------------------------------------===//
13
13
14
14
#include " mlir/Dialect/Arith/IR/Arith.h"
15
15
#include " mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
@@ -33,10 +33,11 @@ namespace {
33
33
// Check if the given value is a result of the operation `T` (which must be
34
34
// sign- or zero- extend) from i8 to i32. Return the value before the extension.
35
35
template <typename T>
36
- inline std::enable_if_t <(std::is_base_of_v<arith::ExtSIOp, T> ||
37
- std::is_base_of_v<arith::ExtUIOp, T>),
38
- std::optional<Value>>
39
- extractExtOperand (Value v, Type i8Ty, Type i32Ty) {
36
+ std::optional<Value> extractExtOperand (Value v, Type i8Ty, Type i32Ty) {
37
+
38
+ static_assert (llvm::is_one_of<T, arith::ExtSIOp, arith::ExtUIOp>::value,
39
+ " Must be instantiated with either sign- or zero- extension op" );
40
+
40
41
auto extOp = dyn_cast_or_null<T>(v.getDefiningOp ());
41
42
if (!extOp)
42
43
return {};
@@ -79,6 +80,37 @@ Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
79
80
}
80
81
}
81
82
83
+ // Lower a contraction operation that performs a matrix multiplication
84
+ // of two 8-bit integer matrix tiles with logical dimensions <Mx8> and <8x[N]>
85
+ // for the left-hand side and the right-hand side, respectively,
86
+ // yielding a <Mx[N]> 32-bit integer result.
87
+ //
88
+ // The operands shapes are such that the operands can be evenly split into
89
+ // sub-tiles with dimensions as expected by the targeted FEAT_I8MM instructions.
90
+ // The intent is that M and N are chosen (by higher level transforms) in such a
91
+ // way as to maximise register usage. The main use case we envision as of now is
92
+ // MMT4D, thus the RHS operand is expected pre-transposed.
93
+ //
94
+ // The matrix multiplication is performed by unrolling the usual tiled matrix
95
+ // multiplication algorithm using sub-tiles with dimensions <2x8> for the LHS,
96
+ // <8x[2]> for the RHS, and <2x[2]> for the result and the input accumulator.
97
+ //
98
+ // One way to illustrate the operation is as follows:
99
+ //
100
+ // RHS<8x[N]>: <8x[2]> <8x[2]> ... <8x[2]>
101
+ // +-----------------------------
102
+ // LHS<Mx8>: <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
103
+ // <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
104
+ // ... | ... ... ... ...
105
+ // <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
106
+ //
107
+ // The RHS operand is unpacked into N/2 values, each representing a sequence of
108
+ // VSCALE number of sub-tiles with dimensions <8x2>.
109
+ // The LHS operand is initially unpacked into M/2 values, each representing a
110
+ // sub-tile with dimensions <2x8>, and then each such sub-tile is replicated
111
+ // VSCALE times.
112
+ // Multiplying thus replicated LHS sub-tile by the corresposponing RHS sub-tile
113
+ // correctly computes an entire result sub-tile.
82
114
class LowerContractionToSVEI8MMPattern
83
115
: public OpRewritePattern<vector::ContractionOp> {
84
116
public:
@@ -90,15 +122,11 @@ class LowerContractionToSVEI8MMPattern
90
122
mlir::VectorType lhsType = op.getLhsType ();
91
123
mlir::VectorType rhsType = op.getRhsType ();
92
124
93
- // For now handle LHS<Mx8> and RHS<8x[N]> - these are the types we
94
- // eventually expect from MMT4D. M and N dimensions must be even and at
95
- // least 2.
96
- if (!lhsType.hasRank () || lhsType.getRank () != 2 || !rhsType.hasRank () ||
97
- rhsType.getRank () != 2 )
98
- return failure ();
99
-
100
- if (lhsType.isScalable () || !rhsType.isScalable ())
101
- return failure ();
125
+ // Check the operands have the expected shape. M and N dimensions must be
126
+ // even and at least 2.
127
+ if (lhsType.getRank () != 2 || rhsType.getRank () != 2 ||
128
+ lhsType.isScalable () || !rhsType.isScalable ())
129
+ return rewriter.notifyMatchFailure (op, " non-matching operand shape" );
102
130
103
131
// M, N, and K are the conventional names for matrix dimensions in the
104
132
// context of matrix multiplication.
@@ -108,7 +136,7 @@ class LowerContractionToSVEI8MMPattern
108
136
109
137
if (lhsType.getDimSize (1 ) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 ||
110
138
N % 2 != 0 || !rhsType.getScalableDims ()[0 ])
111
- return failure ( );
139
+ return rewriter. notifyMatchFailure (op, " non-matching operand shape " );
112
140
113
141
// Check permutation maps. For now only accept
114
142
// lhs: (d0, d1, d2) -> (d0, d2)
@@ -124,28 +152,31 @@ class LowerContractionToSVEI8MMPattern
124
152
op.getIndexingMapsArray ()[2 ] !=
125
153
AffineMap::getMultiDimMapWithTargets (3 , ArrayRef{0u , 1u },
126
154
op.getContext ()))
127
- return failure ( );
155
+ return rewriter. notifyMatchFailure (op, " non-matching permutation maps " );
128
156
129
157
// Check iterator types for matrix multiplication.
130
158
auto itTypes = op.getIteratorTypesArray ();
131
159
if (itTypes.size () != 3 || itTypes[0 ] != vector::IteratorType::parallel ||
132
160
itTypes[1 ] != vector::IteratorType::parallel ||
133
161
itTypes[2 ] != vector::IteratorType::reduction)
134
- return failure ();
162
+ return rewriter.notifyMatchFailure (
163
+ op, " iterator types do not correspond to matrix multiplication" );
135
164
136
165
// Check the combining kind is addition.
137
166
if (op.getKind () != vector::CombiningKind::ADD)
138
- return failure ();
167
+ return rewriter.notifyMatchFailure (op,
168
+ " combining kind is not an addition" );
139
169
140
170
// Check the output is a vector of i32 elements.
141
- auto outTy = dyn_cast<VectorType>(op.getType ());
171
+ auto outTy = dyn_cast<VectorType>(op.getResultType ());
142
172
if (!outTy || outTy.getElementType () != rewriter.getI32Type ())
143
- return failure ();
173
+ return rewriter.notifyMatchFailure (op,
174
+ " output type is not a vector of i32" );
144
175
145
176
// Check inputs are sign-/zero- extensions from i8 to i32. Get the values
146
177
// before the extension. All four signed/unsigned combinations for input
147
178
// operands are supported, but they are lowered to different operations.
148
- // Determina which is the appropriate operation to lower to.
179
+ // Determine which is the appropriate operation to lower to.
149
180
MMLA mmlaOp = MMLA::Signed;
150
181
auto maybeLhs = extractExtOperand<arith::ExtSIOp>(
151
182
op.getLhs (), rewriter.getI8Type (), rewriter.getI32Type ());
@@ -155,7 +186,8 @@ class LowerContractionToSVEI8MMPattern
155
186
op.getLhs (), rewriter.getI8Type (), rewriter.getI32Type ());
156
187
}
157
188
if (!maybeLhs)
158
- return failure ();
189
+ return rewriter.notifyMatchFailure (
190
+ op, " LHS is not a sign- or zero- extended i8" );
159
191
160
192
auto maybeRhs = extractExtOperand<arith::ExtSIOp>(
161
193
op.getRhs (), rewriter.getI8Type (), rewriter.getI32Type ());
@@ -169,13 +201,16 @@ class LowerContractionToSVEI8MMPattern
169
201
op.getRhs (), rewriter.getI8Type (), rewriter.getI32Type ());
170
202
}
171
203
if (!maybeRhs)
172
- return failure ();
204
+ return rewriter.notifyMatchFailure (
205
+ op, " RHS is not a sign- or zero- extended i8" );
173
206
174
207
// One-dimensional vector types for arm_sve.*mmla
175
- auto nxv16i8 = VectorType::get (16 , rewriter.getI8Type (), {true });
176
- auto nxv4i32 = VectorType::get (4 , rewriter.getI32Type (), {true });
208
+ auto nxv16i8 = VectorType::get (/* shape=*/ 16 , rewriter.getI8Type (),
209
+ /* scalableDims=*/ {true });
210
+ auto nxv4i32 = VectorType::get (/* shape=*/ 4 , rewriter.getI32Type (),
211
+ /* scalableDims=*/ {true });
177
212
178
- // Extract LHS sub-tiles.
213
+ // Extract LHS sub-tiles with logicall shape <2x8> .
179
214
SmallVector<Value> lhsTile;
180
215
for (int64_t i = 0 ; i < M; i += 2 ) {
181
216
// Exract two consective rows of the LHS tile.
@@ -199,19 +234,25 @@ class LowerContractionToSVEI8MMPattern
199
234
// "Flatten" the RHS tile from <[N]x8> to <[8*N]>.
200
235
auto RHS = rewriter.create <vector::ShapeCastOp>(
201
236
maybeRhs->getLoc (),
202
- VectorType::get (8 * N, rewriter.getI8Type (), {true }), *maybeRhs);
237
+ VectorType::get (/* shape=*/ 8 * N, rewriter.getI8Type (),
238
+ /* scalableDims=*/ {true }),
239
+ *maybeRhs);
203
240
204
- // Extract the RHS sub-tiles.
241
+ // Extract the RHS sub-tiles with logical shape <8x[2]> .
205
242
SmallVector<Value> rhsTile;
206
243
for (int64_t j = 0 ; j < N; j += 2 )
207
244
rhsTile.push_back (
208
245
rewriter.create <vector::ScalableExtractOp>(loc, nxv16i8, RHS, j * 8 ));
209
246
210
247
// Handy types for packing/unpacking of the accumulator tile.
211
- auto accRowTy = VectorType::get (N, rewriter.getI32Type (), {true });
212
- auto accRowX2Ty = VectorType::get (2 * N, rewriter.getI32Type (), {true });
213
- auto accRow64Ty = VectorType::get (N / 2 , rewriter.getI64Type (), {true });
214
- auto accRowX264Ty = VectorType::get (N, rewriter.getI64Type (), {true });
248
+ auto accRowTy = VectorType::get (/* shape=*/ N, rewriter.getI32Type (),
249
+ /* scalableDims=*/ {true });
250
+ auto accRowX2Ty = VectorType::get (/* shape=*/ 2 * N, rewriter.getI32Type (),
251
+ /* scalableDims=*/ {true });
252
+ auto accRow64Ty = VectorType::get (/* shape=*/ N / 2 , rewriter.getI64Type (),
253
+ /* scalableDims=*/ {true });
254
+ auto accRowX264Ty = VectorType::get (/* shape=*/ N, rewriter.getI64Type (),
255
+ /* scalableDims=*/ {true });
215
256
216
257
// Extract and pack the ACC sub-tiles.
217
258
SmallVector<Value> accTile;
0 commit comments