Skip to content

Commit 49f06d9

Browse files
[fixup] Handle implicit sign-extend of LHS and RHS
1 parent 20e2729 commit 49f06d9

File tree

3 files changed

+119
-19
lines changed

3 files changed

+119
-19
lines changed

mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,30 @@ using namespace mlir;
3030
using namespace mlir::arm_sve;
3131

3232
namespace {
33-
// Check if the given value is a result of the operation `T` (which must be
34-
// sign- or zero- extend) from i8 to i32. Return the value before the extension.
33+
// Get the LHS or RHS side operand of a vector contract. Handle two cases
34+
// * if the operand is a sign- or zero- extend operation of type `T` from i8
35+
// to i32, return the value before the extension, otherwise
36+
// * if the operand is of i8 type and the operation is sign-extend, return the
37+
// operand itself.
38+
//
39+
// This way we handle both explicit sign- or zero- extension or implicit
40+
// sign-extension.
3541
template <typename T>
36-
std::optional<Value> extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
42+
std::optional<Value> getExtOperand(Value v, Type i8Ty, Type i32Ty) {
3743

3844
static_assert(llvm::is_one_of<T, arith::ExtSIOp, arith::ExtUIOp>::value,
3945
"Must be instantiated with either sign- or zero- extension op");
4046

4147
auto extOp = dyn_cast_or_null<T>(v.getDefiningOp());
42-
if (!extOp)
48+
if (!extOp) {
49+
if constexpr (std::is_same<T, arith::ExtSIOp>::value) {
50+
auto vTy = cast<VectorType>(v.getType());
51+
if (vTy.getElementType() != i8Ty)
52+
return {};
53+
return v;
54+
}
4355
return {};
56+
}
4457

4558
auto inOp = extOp.getIn();
4659
auto inTy = dyn_cast<VectorType>(inOp.getType());
@@ -178,26 +191,26 @@ class LowerContractionToSVEI8MMPattern
178191
// operands are supported, but they are lowered to different operations.
179192
// Determine which is the appropriate operation to lower to.
180193
MMLA mmlaOp = MMLA::Signed;
181-
auto maybeLhs = extractExtOperand<arith::ExtSIOp>(
194+
auto maybeLhs = getExtOperand<arith::ExtSIOp>(
182195
op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
183196
if (!maybeLhs) {
184197
mmlaOp = MMLA::Unsigned;
185-
maybeLhs = extractExtOperand<arith::ExtUIOp>(
198+
maybeLhs = getExtOperand<arith::ExtUIOp>(
186199
op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
187200
}
188201
if (!maybeLhs)
189202
return rewriter.notifyMatchFailure(
190203
op, "LHS is not a sign- or zero- extended i8");
191204

192-
auto maybeRhs = extractExtOperand<arith::ExtSIOp>(
205+
auto maybeRhs = getExtOperand<arith::ExtSIOp>(
193206
op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
194207
if (maybeRhs) {
195208
if (mmlaOp == MMLA::Unsigned)
196209
mmlaOp = MMLA::Mixed;
197210
} else {
198211
if (mmlaOp == MMLA::Signed)
199212
mmlaOp = MMLA::MixedSwapped;
200-
maybeRhs = extractExtOperand<arith::ExtUIOp>(
213+
maybeRhs = getExtOperand<arith::ExtUIOp>(
201214
op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
202215
}
203216
if (!maybeRhs)

mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir

Lines changed: 96 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1-
// RUN: mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' --split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' | FileCheck %s
22

3-
#packed_maps = [
4-
affine_map<(d0, d1, d2) -> (d0, d2)>,
5-
affine_map<(d0, d1, d2) -> (d1, d2)>,
6-
affine_map<(d0, d1, d2) -> (d0, d1)>
7-
]
3+
#attrs = {
4+
indexing_maps = [
5+
affine_map<(d0, d1, d2) -> (d0, d2)>,
6+
affine_map<(d0, d1, d2) -> (d1, d2)>,
7+
affine_map<(d0, d1, d2) -> (d0, d1)>
8+
],
9+
iterator_types = ["parallel", "parallel", "reduction"],
10+
kind = #vector.kind<add>
11+
}
812

913
// CHECK-LABEL: @test_vector_contract_to_smmla
1014

@@ -85,10 +89,93 @@ func.func @test_vector_contract_to_smmla(%lhs: vector<4x8xi8>,
8589

8690
%0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
8791
%1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
88-
%2 = vector.contract {indexing_maps = #packed_maps,
89-
iterator_types = ["parallel", "parallel", "reduction"],
90-
kind = #vector.kind<add>} %0, %1, %acc
92+
%2 = vector.contract #attrs %0, %1, %acc
9193
: vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
9294

9395
return %2 : vector<4x[4]xi32>
9496
}
97+
98+
// CHECK-LABEL: @test_vector_contract_to_smmla_implicit_sext
99+
100+
// Extract LHS rows 0 and 1, concatenate, turn into scalable vector
101+
// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>>
102+
// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>>
103+
// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
104+
// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8>
105+
106+
// Replicate across the entire length of the scalabale vector
107+
// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
108+
109+
// Same for LHS rows 2 and 4
110+
// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>>
111+
// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>>
112+
// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
113+
// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8>
114+
// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
115+
116+
// Extract sub-tiles from the RHS
117+
// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8>
118+
// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8>
119+
// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8>
120+
121+
// Extract accumulator rows 0 and 1 and pack (into "registers")
122+
// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
123+
// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extractvalue %[[T3]][1] : !llvm.array<4 x vector<[4]xi32>>
124+
// CHECK-NEXT: %[[T21:[0-9]+]] = llvm.bitcast %[[T19]] : vector<[4]xi32> to vector<[2]xi64>
125+
// CHECK-NEXT: %[[T22:[0-9]+]] = llvm.bitcast %[[T20]] : vector<[4]xi32> to vector<[2]xi64>
126+
// CHECK-NEXT: %[[T23:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T21]], %[[T22]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64>
127+
// CHECK-NEXT: %[[T24:[0-9]+]] = llvm.bitcast %[[T23]] : vector<[4]xi64> to vector<[8]xi32>
128+
// CHECK-NEXT: %[[T25:[0-9]+]] = llvm.intr.vector.extract %[[T24]][0] : vector<[4]xi32> from vector<[8]xi32>
129+
// CHECK-NEXT: %[[T26:[0-9]+]] = llvm.intr.vector.extract %[[T24]][4] : vector<[4]xi32> from vector<[8]xi32>
130+
131+
// Same for accumulator rows 2 and 3.
132+
// CHECK-NEXT: %[[T27:[0-9]+]] = llvm.extractvalue %[[T3]][2] : !llvm.array<4 x vector<[4]xi32>>
133+
// CHECK-NEXT: %[[T28:[0-9]+]] = llvm.extractvalue %[[T3]][3] : !llvm.array<4 x vector<[4]xi32>>
134+
// CHECK-NEXT: %[[T29:[0-9]+]] = llvm.bitcast %[[T27]] : vector<[4]xi32> to vector<[2]xi64>
135+
// CHECK-NEXT: %[[T30:[0-9]+]] = llvm.bitcast %[[T28]] : vector<[4]xi32> to vector<[2]xi64>
136+
// CHECK-NEXT: %[[T31:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T29]], %[[T30]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64>
137+
// CHECK-NEXT: %[[T32:[0-9]+]] = llvm.bitcast %[[T31]] : vector<[4]xi64> to vector<[8]xi32>
138+
// CHECK-NEXT: %[[T33:[0-9]+]] = llvm.intr.vector.extract %[[T32]][0] : vector<[4]xi32> from vector<[8]xi32>
139+
// CHECK-NEXT: %[[T34:[0-9]+]] = llvm.intr.vector.extract %[[T32]][4] : vector<[4]xi32> from vector<[8]xi32>
140+
141+
// Do the sub-tile matrix multiplications
142+
// CHECK-NEXT: %[[T35:[0-9]+]] = "arm_sve.intr.smmla"(%[[T25]], %[[T10]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
143+
// CHECK-NEXT: %[[T36:[0-9]+]] = "arm_sve.intr.smmla"(%[[T26]], %[[T10]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
144+
// CHECK-NEXT: %[[T37:[0-9]+]] = "arm_sve.intr.smmla"(%[[T33]], %[[T15]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
145+
// CHECK-NEXT: %[[T38:[0-9]+]] = "arm_sve.intr.smmla"(%[[T34]], %[[T15]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
146+
147+
// Unpack (from "registers") and insert in the output result rows 0 and 1
148+
// CHECK-NEXT: %[[T39:[0-9]+]] = llvm.intr.vector.insert %[[T35]], %[[T2:[0-9]+]][0] : vector<[4]xi32> into vector<[8]xi32>
149+
// CHECK-NEXT: %[[T40:[0-9]+]] = llvm.intr.vector.insert %[[T36]], %[[T39]][4] : vector<[4]xi32> into vector<[8]xi32>
150+
// CHECK-NEXT: %[[T41:[0-9]+]] = llvm.bitcast %[[T40]] : vector<[8]xi32> to vector<[4]xi64>
151+
// CHECK-NEXT: %[[T42:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T41]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
152+
// CHECK-NEXT: %[[T43:[0-9]+]] = llvm.extractvalue %[[T42]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
153+
// CHECK-NEXT: %[[T44:[0-9]+]] = llvm.extractvalue %[[T42]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
154+
// CHECK-NEXT: %[[T45:[0-9]+]] = llvm.bitcast %[[T43]] : vector<[2]xi64> to vector<[4]xi32>
155+
// CHECK-NEXT: %[[T46:[0-9]+]] = llvm.bitcast %[[T44]] : vector<[2]xi64> to vector<[4]xi32>
156+
// CHECK-NEXT: %[[T47:[0-9]+]] = llvm.insertvalue %[[T45]], %[[T5:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
157+
// CHECK-NEXT: %[[T48:[0-9]+]] = llvm.insertvalue %[[T46]], %[[T47]][1] : !llvm.array<4 x vector<[4]xi32>>
158+
159+
// Same for result rows 2 and 3
160+
// CHECK-NEXT: %[[T49:[0-9]+]] = llvm.intr.vector.insert %[[T37]], %[[T2]][0] : vector<[4]xi32> into vector<[8]xi32>
161+
// CHECK-NEXT: %[[T50:[0-9]+]] = llvm.intr.vector.insert %[[T38]], %[[T49]][4] : vector<[4]xi32> into vector<[8]xi32>
162+
// CHECK-NEXT: %[[T51:[0-9]+]] = llvm.bitcast %[[T50]] : vector<[8]xi32> to vector<[4]xi64>
163+
// CHECK-NEXT: %[[T52:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T51]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
164+
// CHECK-NEXT: %[[T53:[0-9]+]] = llvm.extractvalue %[[T52]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
165+
// CHECK-NEXT: %[[T54:[0-9]+]] = llvm.extractvalue %[[T52]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
166+
// CHECK-NEXT: %[[T55:[0-9]+]] = llvm.bitcast %[[T53]] : vector<[2]xi64> to vector<[4]xi32>
167+
// CHECK-NEXT: %[[T56:[0-9]+]] = llvm.bitcast %[[T54]] : vector<[2]xi64> to vector<[4]xi32>
168+
// CHECK-NEXT: %[[T57:[0-9]+]] = llvm.insertvalue %[[T55]], %[[T48]][2] : !llvm.array<4 x vector<[4]xi32>>
169+
// CHECK-NEXT: %[[T58:[0-9]+]] = llvm.insertvalue %[[T56]], %[[T57]][3] : !llvm.array<4 x vector<[4]xi32>>
170+
171+
// Test a variant where the sign-extension of the operands is
172+
// implicit. The output is identical to the one of the previous test.
173+
func.func @test_vector_contract_to_smmla_implicit_sext(%lhs: vector<4x8xi8>,
174+
%rhs: vector<[4]x8xi8>,
175+
%acc: vector<4x[4]xi32>) -> vector<4x[4]xi32> {
176+
177+
%0 = vector.contract #attrs %lhs, %rhs, %acc
178+
: vector<4x8xi8>, vector<[4]x8xi8> into vector<4x[4]xi32>
179+
180+
return %0 : vector<4x[4]xi32>
181+
}

mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func.func @main() {
4141
%acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32>
4242
%acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32>
4343
%acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
44-
44+
4545
vector.print str "ACC:\n"
4646
%acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
4747
%acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
@@ -91,7 +91,7 @@ func.func @main() {
9191
vector.print %rhs1 : vector<[16]xi8>
9292

9393
%rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
94-
94+
9595
// Matrix multiplication
9696
%0 = arith.extui %lhs : vector<4x8xi8> to vector<4x8xi32>
9797
%1 = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>

0 commit comments

Comments
 (0)