Skip to content

Commit 16d890c

Browse files
authored
[mlir][ArmNeon] Adds Arm Neon SMMLA, UMMLA, and USMMLA Intrinsics (#80511)
This adds the SMMLA, UMMLA, and USMMLA intrinsics to Neon dialect bringing it in line with the SVE dialect. These ops enable matrix multiply-accumulate instructions with two e 2x8 matrix inputs of respective signage into a 2x2 32-bit integer accumulator. This is equivalent to performing an 8-way dot product per destination element. Op details: https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=mmla
1 parent e976385 commit 16d890c

File tree

4 files changed

+242
-2
lines changed

4 files changed

+242
-2
lines changed

mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,15 @@ def ArmNeon_Dialect : Dialect {
3030
// to the LLVMDialect (ops or types).
3131
}
3232

33+
//===----------------------------------------------------------------------===//
34+
// ArmNeon type definition
35+
//===----------------------------------------------------------------------===//
36+
37+
class NeonVectorOfLength<int length, Type elementType> : ShapedContainerType<
38+
[elementType], And<[IsVectorOfShape<[length]>, IsFixedVectorTypePred]>,
39+
"a vector with length " # length,
40+
"::mlir::VectorType">;
41+
3342
//===----------------------------------------------------------------------===//
3443
// ArmNeon op definitions
3544
//===----------------------------------------------------------------------===//
@@ -120,6 +129,99 @@ def SdotOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"sdot",[1], [
120129
"$a `,` $b `,` $c attr-dict `:` type($b) `,` type($c) `to` type($res)";
121130
}
122131

132+
def SmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"smmla",[1], [
133+
Pure,
134+
AllTypesMatch<["src1", "src2"]>,
135+
AllTypesMatch<["acc", "res"]>,
136+
]> {
137+
let summary = "Matrix-matrix multiply and accumulate op";
138+
let description = [{
139+
SMMLA: Signed integer matrix multiply-accumulate.
140+
141+
Signed 8-bit integer matrix multiply-accumulate. This instruction multiplies
142+
the 2x8 matrix of signed 8-bit integer values in the first source vector by
143+
the 8x2 matrix of signed 8-bit integer values in the second source vector.
144+
The resulting 2x2 32-bit integer matrix product is destructively added to
145+
the 32-bit integer matrix accumulator in the destination vector. This is
146+
equivalent to performing an 8-way dot product per destination element.
147+
148+
Source:
149+
https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=smmla
150+
}];
151+
// Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
152+
let arguments = (ins
153+
NeonVectorOfLength<4, I32>:$acc,
154+
NeonVectorOfLength<16, I8>:$src1,
155+
NeonVectorOfLength<16, I8>:$src2
156+
);
157+
let results = (outs NeonVectorOfLength<4, I32>:$res);
158+
let assemblyFormat =
159+
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
160+
}
161+
162+
def UmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"ummla",[1], [
163+
Pure,
164+
AllTypesMatch<["src1", "src2"]>,
165+
AllTypesMatch<["acc", "res"]>,
166+
]> {
167+
let summary = "Unsinged matrix-matrix multiply and accumulate op";
168+
let description = [{
169+
UMMLA: Signed integer matrix multiply-accumulate.
170+
171+
Unsigned 8-bit integer matrix multiply-accumulate. This instruction
172+
multiplies the 2x8 matrix of unsigned 8-bit integer values in the first
173+
source vector by the 8x2 matrix of unsigned 8-bit integer values in the
174+
second source vector. The resulting 2x2 32-bit integer matrix product is
175+
destructively added to the 32-bit integer matrix accumulator in the
176+
destination vector. This is equivalent to performing an 8-way dot product
177+
per destination element.
178+
179+
Source:
180+
https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=ummla
181+
}];
182+
// Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
183+
let arguments = (ins
184+
NeonVectorOfLength<4, I32>:$acc,
185+
NeonVectorOfLength<16, I8>:$src1,
186+
NeonVectorOfLength<16, I8>:$src2
187+
);
188+
let results = (outs NeonVectorOfLength<4, I32>:$res);
189+
let assemblyFormat =
190+
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
191+
}
192+
193+
def UsmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"usmmla",[1], [
194+
Pure,
195+
AllTypesMatch<["src1", "src2"]>,
196+
AllTypesMatch<["acc", "res"]>,
197+
]> {
198+
let summary = "Unsignged and signed matrix-matrix multiply and accumulate op";
199+
let description = [{
200+
USMMLA: Signed integer matrix multiply-accumulate.
201+
202+
Unsigned and signed 8-bit integer matrix multiply-accumulate. This
203+
instruction multiplies the 2x8 matrix of unsigned 8-bit integer values in
204+
the first source vector by the 8x2 matrix of signed 8-bit integer values in
205+
the second source vector. The resulting 2x2 32-bit integer matrix product is
206+
destructively added to the 32-bit integer matrix accumulator in the
207+
destination vector. This is equivalent to performing an 8-way dot product
208+
per destination element.
209+
210+
211+
Source:
212+
https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=usmmla
213+
}];
214+
// Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
215+
let arguments = (ins
216+
NeonVectorOfLength<4, I32>:$acc,
217+
NeonVectorOfLength<16, I8>:$src1,
218+
NeonVectorOfLength<16, I8>:$src2
219+
);
220+
let results = (outs NeonVectorOfLength<4, I32>:$res);
221+
let assemblyFormat =
222+
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
223+
}
224+
123225
class ArmNeon_2dOp<string mnemonic, list<Trait> traits = []>
124226
: Op</*dialect=*/ArmNeon_Dialect,
125227
/*opName=*/"2d." # mnemonic,

mlir/test/Dialect/ArmNeon/invalid.mlir

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,63 @@ func.func @b_has_2_rows_but_a_has_length_4(%a : vector<4xi32>, %b : vector<2x4xi
3131
%0 = arm_neon.2d.sdot %a, %b, %b : vector<2x4xi8>, vector<2x4xi8> to vector<4xi32>
3232
return %0 : vector<4xi32>
3333
}
34+
35+
// -----
36+
37+
func.func @smmla_invalid_input_types(%a: vector<4xi32>,
38+
%b: vector<16xi4>,
39+
%c: vector<16xi4>) -> vector<4xi32> {
40+
// expected-error@+1 {{op operand #1 must be a vector with length 16 of 8-bit signless integer values, but got 'vector<16xi4>'}}
41+
%0 = arm_neon.intr.smmla %a, %b, %c : vector<16xi4> to vector<4xi32>
42+
return %0 : vector<4xi32>
43+
}
44+
45+
// -----
46+
47+
func.func @smmla_invalid_dimensions(%a: vector<8xi32>,
48+
%b: vector<32xi8>,
49+
%c: vector<32xi8>) -> vector<8xi32> {
50+
// expected-error@+1 {{op operand #0 must be a vector with length 4 of 32-bit signless integer values, but got 'vector<8xi32>'}}
51+
%0 = arm_neon.intr.smmla %a, %b, %c : vector<32xi8> to vector<8xi32>
52+
return %0 : vector<8xi32>
53+
}
54+
55+
// -----
56+
57+
func.func @ummla_invalid_input_types(%a: vector<4xi32>,
58+
%b: vector<16xi4>,
59+
%c: vector<16xi4>) -> vector<4xi32> {
60+
// expected-error@+1 {{op operand #1 must be a vector with length 16 of 8-bit signless integer values, but got 'vector<16xi4>'}}
61+
%0 = arm_neon.intr.ummla %a, %b, %c : vector<16xi4> to vector<4xi32>
62+
return %0 : vector<4xi32>
63+
}
64+
65+
// -----
66+
67+
func.func @ummla_invalid_dimensions(%a: vector<8xi32>,
68+
%b: vector<32xi8>,
69+
%c: vector<32xi8>) -> vector<8xi32> {
70+
// expected-error@+1 {{op operand #0 must be a vector with length 4 of 32-bit signless integer values, but got 'vector<8xi32>'}}
71+
%0 = arm_neon.intr.ummla %a, %b, %c : vector<32xi8> to vector<8xi32>
72+
return %0 : vector<8xi32>
73+
}
74+
75+
// -----
76+
77+
func.func @usmmla_invalid_input_types(%a: vector<4xi32>,
78+
%b: vector<16xi4>,
79+
%c: vector<16xi4>) -> vector<4xi32> {
80+
// expected-error@+1 {{op operand #1 must be a vector with length 16 of 8-bit signless integer values, but got 'vector<16xi4>'}}
81+
%0 = arm_neon.intr.usmmla %a, %b, %c : vector<16xi4> to vector<4xi32>
82+
return %0 : vector<4xi32>
83+
}
84+
85+
// -----
86+
87+
func.func @usmmla_invalid_dimensions(%a: vector<8xi32>,
88+
%b: vector<32xi8>,
89+
%c: vector<32xi8>) -> vector<8xi32> {
90+
// expected-error@+1 {{op operand #0 must be a vector with length 4 of 32-bit signless integer values, but got 'vector<8xi32>'}}
91+
%0 = arm_neon.intr.usmmla %a, %b, %c : vector<32xi8> to vector<8xi32>
92+
return %0 : vector<8xi32>
93+
}

mlir/test/Dialect/ArmNeon/roundtrip.mlir

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
1+
// RUN: mlir-opt -verify-diagnostics -split-input-file %s | mlir-opt | FileCheck %s
22

33
// CHECK-LABEL: arm_neon_smull
44
func.func @arm_neon_smull(%a: vector<8xi8>, %b: vector<8xi8>)
@@ -19,9 +19,44 @@ func.func @arm_neon_smull(%a: vector<8xi8>, %b: vector<8xi8>)
1919
return %0, %1, %2 : vector<8xi16>, vector<4xi32>, vector<2xi64>
2020
}
2121

22+
// -----
23+
2224
// CHECK-LABEL: arm_neon_sdot
2325
func.func @arm_neon_sdot(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<8xi8>) -> vector<2xi32> {
2426
// CHECK: arm_neon.intr.sdot {{.*}}: vector<8xi8>, vector<8xi8> to vector<2xi32>
2527
%0 = arm_neon.intr.sdot %a, %b, %c : vector<8xi8>, vector<8xi8> to vector<2xi32>
2628
return %0 : vector<2xi32>
2729
}
30+
31+
// -----
32+
33+
// CHECK-LABEL: arm_neon_smmla
34+
func.func @arm_neon_smmla(%a: vector<16xi8>,
35+
%b: vector<16xi8>,
36+
%c: vector<4xi32>) -> vector<4xi32> {
37+
// CHECK: arm_neon.intr.smmla {{.*}}: vector<16xi8> to vector<4xi32>
38+
%0 = arm_neon.intr.smmla %c, %a, %b : vector<16xi8> to vector<4xi32>
39+
return %0 : vector<4xi32>
40+
}
41+
42+
// -----
43+
44+
// CHECK-LABEL: arm_neon_ummla
45+
func.func @arm_neon_ummla(%a: vector<16xi8>,
46+
%b: vector<16xi8>,
47+
%c: vector<4xi32>) -> vector<4xi32> {
48+
// CHECK: arm_neon.intr.ummla {{.*}}: vector<16xi8> to vector<4xi32>
49+
%0 = arm_neon.intr.ummla %c, %a, %b : vector<16xi8> to vector<4xi32>
50+
return %0 : vector<4xi32>
51+
}
52+
53+
// -----
54+
55+
// CHECK-LABEL: arm_neon_usmmla
56+
func.func @arm_neon_usmmla(%a: vector<16xi8>,
57+
%b: vector<16xi8>,
58+
%c: vector<4xi32>) -> vector<4xi32> {
59+
// CHECK: arm_neon.intr.usmmla {{.*}}: vector<16xi8> to vector<4xi32>
60+
%0 = arm_neon.intr.usmmla %c, %a, %b : vector<16xi8> to vector<4xi32>
61+
return %0 : vector<4xi32>
62+
}

mlir/test/Target/LLVMIR/arm-neon.mlir

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
1+
// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
22

33
// CHECK-LABEL: arm_neon_smull
44
llvm.func @arm_neon_smull(%arg0: vector<8xi8>, %arg1: vector<8xi8>) -> !llvm.struct<(vector<8xi16>, vector<4xi32>, vector<2xi64>)> {
@@ -24,6 +24,8 @@ llvm.func @arm_neon_smull(%arg0: vector<8xi8>, %arg1: vector<8xi8>) -> !llvm.str
2424
llvm.return %8 : !llvm.struct<(vector<8xi16>, vector<4xi32>, vector<2xi64>)>
2525
}
2626

27+
// -----
28+
2729
// CHECK-LABEL: arm_neon_sdot_8_i8i8
2830
llvm.func @arm_neon_sdot_8_i8i8(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<8xi8>) -> vector<2xi32> {
2931
// CHECK: %[[V0:.*]] = call <2 x i32> @llvm.aarch64.neon.sdot.v2i32.v8i8(<2 x i32> %{{.*}}, <8 x i8> %{{.*}}, <8 x i8> %{{.*}})
@@ -32,10 +34,51 @@ llvm.func @arm_neon_sdot_8_i8i8(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<
3234
llvm.return %0 : vector<2xi32>
3335
}
3436

37+
// -----
38+
3539
// CHECK-LABEL: arm_neon_sdot_16_i8i8
3640
llvm.func @arm_neon_sdot_16_i8i8(%a: vector<4xi32>, %b: vector<16xi8>, %c: vector<16xi8>) -> vector<4xi32> {
3741
// CHECK: %[[V0:.*]] = call <4 x i32> @llvm.aarch64.neon.sdot.v4i32.v16i8(<4 x i32> %{{.*}}, <16 x i8> %{{.*}}, <16 x i8> %{{.*}})
3842
// CHECK-NEXT: ret <4 x i32>
3943
%0 = arm_neon.intr.sdot %a, %b, %c : vector<16xi8>, vector<16xi8> to vector<4xi32>
4044
llvm.return %0 : vector<4xi32>
4145
}
46+
47+
// -----
48+
49+
// CHECK-LABEL: arm_neon_smmla
50+
llvm.func @arm_neon_smmla(%arg0: vector<16xi8>,
51+
%arg1: vector<16xi8>,
52+
%arg2: vector<4xi32>) -> vector<4xi32> {
53+
// CHECK: <4 x i32> @llvm.aarch64.neon.smmla.v4i32.v16i8(<4 x i32
54+
%0 = "arm_neon.intr.smmla"(%arg2, %arg0, %arg1) :
55+
(vector<4xi32>, vector<16xi8>, vector<16xi8>)
56+
-> vector<4xi32>
57+
llvm.return %0 : vector<4xi32>
58+
}
59+
60+
// -----
61+
62+
// CHECK-LABEL: arm_neon_ummla
63+
llvm.func @arm_neon_ummla(%arg0: vector<16xi8>,
64+
%arg1: vector<16xi8>,
65+
%arg2: vector<4xi32>) -> vector<4xi32> {
66+
// CHECK: <4 x i32> @llvm.aarch64.neon.ummla.v4i32.v16i8(<4 x i32
67+
%0 = "arm_neon.intr.ummla"(%arg2, %arg0, %arg1) :
68+
(vector<4xi32>, vector<16xi8>, vector<16xi8>)
69+
-> vector<4xi32>
70+
llvm.return %0 : vector<4xi32>
71+
}
72+
73+
// -----
74+
75+
// CHECK-LABEL: arm_neon_usmmla
76+
llvm.func @arm_neon_usmmla(%arg0: vector<16xi8>,
77+
%arg1: vector<16xi8>,
78+
%arg2: vector<4xi32>) -> vector<4xi32> {
79+
// CHECK: <4 x i32> @llvm.aarch64.neon.usmmla.v4i32.v16i8(<4 x i32
80+
%0 = "arm_neon.intr.usmmla"(%arg2, %arg0, %arg1) :
81+
(vector<4xi32>, vector<16xi8>, vector<16xi8>)
82+
-> vector<4xi32>
83+
llvm.return %0 : vector<4xi32>
84+
}

0 commit comments

Comments
 (0)