Skip to content

Commit 32203d5

Browse files
[fixup] Misc changes
-- come commenting -- replace enable_if with a staic assert -- return reasons for match failures
1 parent f397467 commit 32203d5

File tree

1 file changed

+75
-34
lines changed

1 file changed

+75
-34
lines changed

mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp

Lines changed: 75 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
//
77
//===----------------------------------------------------------------------===//
88
//
9-
// This file implements lowering patterns from vector.contract to
10-
// SVE I8MM operations.
9+
// This file implements lowering patterns from vector.contract to operations
10+
// that map to instructions from the SVE FEAT_I8MM extension.
1111
//
12-
//===---
12+
//===----------------------------------------------------------------------===//
1313

1414
#include "mlir/Dialect/Arith/IR/Arith.h"
1515
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
@@ -33,10 +33,11 @@ namespace {
3333
// Check if the given value is a result of the operation `T` (which must be
3434
// sign- or zero- extend) from i8 to i32. Return the value before the extension.
3535
template <typename T>
36-
inline std::enable_if_t<(std::is_base_of_v<arith::ExtSIOp, T> ||
37-
std::is_base_of_v<arith::ExtUIOp, T>),
38-
std::optional<Value>>
39-
extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
36+
std::optional<Value> extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
37+
38+
static_assert(llvm::is_one_of<T, arith::ExtSIOp, arith::ExtUIOp>::value,
39+
"Must be instantiated with either sign- or zero- extension op");
40+
4041
auto extOp = dyn_cast_or_null<T>(v.getDefiningOp());
4142
if (!extOp)
4243
return {};
@@ -79,6 +80,37 @@ Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
7980
}
8081
}
8182

83+
// Lower a contraction operation that performs a matrix multiplication
84+
// of two 8-bit integer matrix tiles with logical dimensions <Mx8> and <8x[N]>
85+
// for the left-hand side and the right-hand side, respectively,
86+
// yielding a <Mx[N]> 32-bit integer result.
87+
//
88+
// The operands shapes are such that the operands can be evenly split into
89+
// sub-tiles with dimensions as expected by the targeted FEAT_I8MM instructions.
90+
// The intent is that M and N are chosen (by higher level transforms) in such a
91+
// way as to maximise register usage. The main use case we envision as of now is
92+
// MMT4D, thus the RHS operand is expected pre-transposed.
93+
//
94+
// The matrix multiplication is performed by unrolling the usual tiled matrix
95+
// multiplication algorithm using sub-tiles with dimensions <2x8> for the LHS,
96+
// <8x[2]> for the RHS, and <2x[2]> for the result and the input accumulator.
97+
//
98+
// One way to illustrate the operation is as follows:
99+
//
100+
// RHS<8x[N]>: <8x[2]> <8x[2]> ... <8x[2]>
101+
// +-----------------------------
102+
// LHS<Mx8>: <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
103+
// <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
104+
// ... | ... ... ... ...
105+
// <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
106+
//
107+
// The RHS operand is unpacked into N/2 values, each representing a sequence of
108+
// VSCALE number of sub-tiles with dimensions <8x2>.
109+
// The LHS operand is initially unpacked into M/2 values, each representing a
110+
// sub-tile with dimensions <2x8>, and then each such sub-tile is replicated
111+
// VSCALE times.
112+
// Multiplying thus replicated LHS sub-tile by the corresposponing RHS sub-tile
113+
// correctly computes an entire result sub-tile.
82114
class LowerContractionToSVEI8MMPattern
83115
: public OpRewritePattern<vector::ContractionOp> {
84116
public:
@@ -90,15 +122,11 @@ class LowerContractionToSVEI8MMPattern
90122
mlir::VectorType lhsType = op.getLhsType();
91123
mlir::VectorType rhsType = op.getRhsType();
92124

93-
// For now handle LHS<Mx8> and RHS<8x[N]> - these are the types we
94-
// eventually expect from MMT4D. M and N dimensions must be even and at
95-
// least 2.
96-
if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() ||
97-
rhsType.getRank() != 2)
98-
return failure();
99-
100-
if (lhsType.isScalable() || !rhsType.isScalable())
101-
return failure();
125+
// Check the operands have the expected shape. M and N dimensions must be
126+
// even and at least 2.
127+
if (lhsType.getRank() != 2 || rhsType.getRank() != 2 ||
128+
lhsType.isScalable() || !rhsType.isScalable())
129+
return rewriter.notifyMatchFailure(op, "non-matching operand shape");
102130

103131
// M, N, and K are the conventional names for matrix dimensions in the
104132
// context of matrix multiplication.
@@ -108,7 +136,7 @@ class LowerContractionToSVEI8MMPattern
108136

109137
if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 ||
110138
N % 2 != 0 || !rhsType.getScalableDims()[0])
111-
return failure();
139+
return rewriter.notifyMatchFailure(op, "non-matching operand shape");
112140

113141
// Check permutation maps. For now only accept
114142
// lhs: (d0, d1, d2) -> (d0, d2)
@@ -124,28 +152,31 @@ class LowerContractionToSVEI8MMPattern
124152
op.getIndexingMapsArray()[2] !=
125153
AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
126154
op.getContext()))
127-
return failure();
155+
return rewriter.notifyMatchFailure(op, "non-matching permutation maps");
128156

129157
// Check iterator types for matrix multiplication.
130158
auto itTypes = op.getIteratorTypesArray();
131159
if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
132160
itTypes[1] != vector::IteratorType::parallel ||
133161
itTypes[2] != vector::IteratorType::reduction)
134-
return failure();
162+
return rewriter.notifyMatchFailure(
163+
op, "iterator types do not correspond to matrix multiplication");
135164

136165
// Check the combining kind is addition.
137166
if (op.getKind() != vector::CombiningKind::ADD)
138-
return failure();
167+
return rewriter.notifyMatchFailure(op,
168+
"combining kind is not an addition");
139169

140170
// Check the output is a vector of i32 elements.
141-
auto outTy = dyn_cast<VectorType>(op.getType());
171+
auto outTy = dyn_cast<VectorType>(op.getResultType());
142172
if (!outTy || outTy.getElementType() != rewriter.getI32Type())
143-
return failure();
173+
return rewriter.notifyMatchFailure(op,
174+
"output type is not a vector of i32");
144175

145176
// Check inputs are sign-/zero- extensions from i8 to i32. Get the values
146177
// before the extension. All four signed/unsigned combinations for input
147178
// operands are supported, but they are lowered to different operations.
148-
// Determina which is the appropriate operation to lower to.
179+
// Determine which is the appropriate operation to lower to.
149180
MMLA mmlaOp = MMLA::Signed;
150181
auto maybeLhs = extractExtOperand<arith::ExtSIOp>(
151182
op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
@@ -155,7 +186,8 @@ class LowerContractionToSVEI8MMPattern
155186
op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
156187
}
157188
if (!maybeLhs)
158-
return failure();
189+
return rewriter.notifyMatchFailure(
190+
op, "LHS is not a sign- or zero- extended i8");
159191

160192
auto maybeRhs = extractExtOperand<arith::ExtSIOp>(
161193
op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
@@ -169,13 +201,16 @@ class LowerContractionToSVEI8MMPattern
169201
op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
170202
}
171203
if (!maybeRhs)
172-
return failure();
204+
return rewriter.notifyMatchFailure(
205+
op, "RHS is not a sign- or zero- extended i8");
173206

174207
// One-dimensional vector types for arm_sve.*mmla
175-
auto nxv16i8 = VectorType::get(16, rewriter.getI8Type(), {true});
176-
auto nxv4i32 = VectorType::get(4, rewriter.getI32Type(), {true});
208+
auto nxv16i8 = VectorType::get(/*shape=*/16, rewriter.getI8Type(),
209+
/*scalableDims=*/{true});
210+
auto nxv4i32 = VectorType::get(/*shape=*/4, rewriter.getI32Type(),
211+
/*scalableDims=*/{true});
177212

178-
// Extract LHS sub-tiles.
213+
// Extract LHS sub-tiles with logicall shape <2x8>.
179214
SmallVector<Value> lhsTile;
180215
for (int64_t i = 0; i < M; i += 2) {
181216
// Exract two consective rows of the LHS tile.
@@ -199,19 +234,25 @@ class LowerContractionToSVEI8MMPattern
199234
// "Flatten" the RHS tile from <[N]x8> to <[8*N]>.
200235
auto RHS = rewriter.create<vector::ShapeCastOp>(
201236
maybeRhs->getLoc(),
202-
VectorType::get(8 * N, rewriter.getI8Type(), {true}), *maybeRhs);
237+
VectorType::get(/*shape=*/8 * N, rewriter.getI8Type(),
238+
/*scalableDims=*/{true}),
239+
*maybeRhs);
203240

204-
// Extract the RHS sub-tiles.
241+
// Extract the RHS sub-tiles with logical shape <8x[2]>.
205242
SmallVector<Value> rhsTile;
206243
for (int64_t j = 0; j < N; j += 2)
207244
rhsTile.push_back(
208245
rewriter.create<vector::ScalableExtractOp>(loc, nxv16i8, RHS, j * 8));
209246

210247
// Handy types for packing/unpacking of the accumulator tile.
211-
auto accRowTy = VectorType::get(N, rewriter.getI32Type(), {true});
212-
auto accRowX2Ty = VectorType::get(2 * N, rewriter.getI32Type(), {true});
213-
auto accRow64Ty = VectorType::get(N / 2, rewriter.getI64Type(), {true});
214-
auto accRowX264Ty = VectorType::get(N, rewriter.getI64Type(), {true});
248+
auto accRowTy = VectorType::get(/*shape=*/N, rewriter.getI32Type(),
249+
/*scalableDims=*/{true});
250+
auto accRowX2Ty = VectorType::get(/*shape=*/2 * N, rewriter.getI32Type(),
251+
/*scalableDims=*/{true});
252+
auto accRow64Ty = VectorType::get(/*shape=*/N / 2, rewriter.getI64Type(),
253+
/*scalableDims=*/{true});
254+
auto accRowX264Ty = VectorType::get(/*shape=*/N, rewriter.getI64Type(),
255+
/*scalableDims=*/{true});
215256

216257
// Extract and pack the ACC sub-tiles.
217258
SmallVector<Value> accTile;

0 commit comments

Comments
 (0)