-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][ArmNeon] Adds Arm Neon SMMLA, UMMLA, and USMMLA Intrinsics #80511
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Some nits
b186617
to
034d6ed
Compare
@llvm/pr-subscribers-mlir-neon @llvm/pr-subscribers-mlir Author: Kojo Acquah (KoolJBlack) ChangesThis adds the SMMLA, UMMLA, and USMMLA intrinsics to Neon dialect bringing it in line with the SVE dialect. These ops enable matrix multiply-accumulate instructions with two e 2x8 matrix inputs of respective signage into a 2x2 32-bit integer accumulator. This is equivalent to performing an 8-way dot product per destination element. Op details: Full diff: https://github.com/llvm/llvm-project/pull/80511.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
index e298963f3d19f..c515a858ee8a1 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
+++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
@@ -120,6 +120,99 @@ def SdotOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"sdot",[1], [
"$a `,` $b `,` $c attr-dict `:` type($b) `,` type($c) `to` type($res)";
}
+def SmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"smmla",[1], [
+ Pure,
+ AllTypesMatch<["b", "c"]>,
+ AllTypesMatch<["a", "res"]>,
+ ]> {
+ let summary = "Matrix-matrix multiply and accumulate op";
+ let description = [{
+ SMMLA: Signed integer matrix multiply-accumulate.
+
+ Signed 8-bit integer matrix multiply-accumulate. This instruction multiplies
+ the 2x8 matrix of signed 8-bit integer values in the first source vector by
+ the 8x2 matrix of signed 8-bit integer values in the second source vector.
+ The resulting 2x2 32-bit integer matrix product is destructively added to
+ the 32-bit integer matrix accumulator in the destination vector. This is
+ equivalent to performing an 8-way dot product per destination element.
+
+ Source:
+ https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=smmla
+ }];
+ // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+ let arguments = (ins
+ VectorOfLengthAndType<[4], [I32]>:$a,
+ VectorOfLengthAndType<[16], [I8]>:$b,
+ VectorOfLengthAndType<[16], [I8]>:$c
+ );
+ let results = (outs VectorOfLengthAndType<[4], [I32]>:$res);
+ let assemblyFormat =
+ "$a `,` $b `,` $c attr-dict `:` type($b) `to` type($res)";
+}
+
+def UmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"ummla",[1], [
+ Pure,
+ AllTypesMatch<["b", "c"]>,
+ AllTypesMatch<["a", "res"]>,
+ ]> {
+ let summary = "Unsinged matrix-matrix multiply and accumulate op";
+ let description = [{
+ UMMLA: Signed integer matrix multiply-accumulate.
+
+ Unsigned 8-bit integer matrix multiply-accumulate. This instruction
+ multiplies the 2x8 matrix of unsigned 8-bit integer values in the first
+ source vector by the 8x2 matrix of unsigned 8-bit integer values in the
+ second source vector. The resulting 2x2 32-bit integer matrix product is
+ destructively added to the 32-bit integer matrix accumulator in the
+ destination vector. This is equivalent to performing an 8-way dot product
+ per destination element.
+
+ Source:
+ https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=ummla
+ }];
+ // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+ let arguments = (ins
+ VectorOfLengthAndType<[4], [I32]>:$a,
+ VectorOfLengthAndType<[16], [I8]>:$b,
+ VectorOfLengthAndType<[16], [I8]>:$c
+ );
+ let results = (outs VectorOfLengthAndType<[4], [I32]>:$res);
+ let assemblyFormat =
+ "$a `,` $b `,` $c attr-dict `:` type($b) `to` type($res)";
+}
+
+def UsmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"usmmla",[1], [
+ Pure,
+ AllTypesMatch<["b", "c"]>,
+ AllTypesMatch<["a", "res"]>,
+ ]> {
+ let summary = "Unsignged and signed matrix-matrix multiply and accumulate op";
+ let description = [{
+ USMMLA: Signed integer matrix multiply-accumulate.
+
+ Unsigned and signed 8-bit integer matrix multiply-accumulate. This
+ instruction multiplies the 2x8 matrix of unsigned 8-bit integer values in
+ the first source vector by the 8x2 matrix of signed 8-bit integer values in
+ the second source vector. The resulting 2x2 32-bit integer matrix product is
+ destructively added to the 32-bit integer matrix accumulator in the
+ destination vector. This is equivalent to performing an 8-way dot product
+ per destination element.
+
+
+ Source:
+ https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=usmmla
+ }];
+ // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+ let arguments = (ins
+ VectorOfLengthAndType<[4], [I32]>:$a,
+ VectorOfLengthAndType<[16], [I8]>:$b,
+ VectorOfLengthAndType<[16], [I8]>:$c
+ );
+ let results = (outs VectorOfLengthAndType<[4], [I32]>:$res);
+ let assemblyFormat =
+ "$a `,` $b `,` $c attr-dict `:` type($b) `to` type($res)";
+}
+
class ArmNeon_2dOp<string mnemonic, list<Trait> traits = []>
: Op</*dialect=*/ArmNeon_Dialect,
/*opName=*/"2d." # mnemonic,
diff --git a/mlir/test/Dialect/ArmNeon/invalid.mlir b/mlir/test/Dialect/ArmNeon/invalid.mlir
index 62caf04160020..3ad763e7b4982 100644
--- a/mlir/test/Dialect/ArmNeon/invalid.mlir
+++ b/mlir/test/Dialect/ArmNeon/invalid.mlir
@@ -31,3 +31,63 @@ func.func @b_has_2_rows_but_a_has_length_4(%a : vector<4xi32>, %b : vector<2x4xi
%0 = arm_neon.2d.sdot %a, %b, %b : vector<2x4xi8>, vector<2x4xi8> to vector<4xi32>
return %0 : vector<4xi32>
}
+
+// -----
+
+func.func @smmla_invalid_input_types(%a: vector<16xi4>,
+ %b: vector<16xi4>,
+ %c: vector<4xi32>) -> vector<4xi32> {
+ // expected-error@+1 {{op operand #1 must be vector of 8-bit signless integer values of length 16, but got 'vector<16xi4>'}}
+ %0 = arm_neon.intr.smmla %c, %a, %b : vector<16xi4> to vector<4xi32>
+ return %0 : vector<4xi32>
+}
+
+// -----
+
+func.func @smmla_invalid_dimensions(%a: vector<32xi8>,
+ %b: vector<32xi8>,
+ %c: vector<8xi32>) -> vector<8xi32> {
+ // expected-error@+1 {{op operand #0 must be vector of 32-bit signless integer values of length 4, but got 'vector<8xi32>'}}
+ %0 = arm_neon.intr.smmla %c, %a, %b : vector<32xi8> to vector<8xi32>
+ return %0 : vector<8xi32>
+}
+
+// -----
+
+func.func @ummla_invalid_input_types(%a: vector<16xi4>,
+ %b: vector<16xi4>,
+ %c: vector<4xi32>) -> vector<4xi32> {
+ // expected-error@+1 {{op operand #1 must be vector of 8-bit signless integer values of length 16, but got 'vector<16xi4>'}}
+ %0 = arm_neon.intr.ummla %c, %a, %b : vector<16xi4> to vector<4xi32>
+ return %0 : vector<4xi32>
+}
+
+// -----
+
+func.func @ummla_invalid_dimensions(%a: vector<32xi8>,
+ %b: vector<32xi8>,
+ %c: vector<8xi32>) -> vector<8xi32> {
+ // expected-error@+1 {{op operand #0 must be vector of 32-bit signless integer values of length 4, but got 'vector<8xi32>'}}
+ %0 = arm_neon.intr.ummla %c, %a, %b : vector<32xi8> to vector<8xi32>
+ return %0 : vector<8xi32>
+}
+
+// -----
+
+func.func @usmmla_invalid_input_types(%a: vector<16xi4>,
+ %b: vector<16xi4>,
+ %c: vector<4xi32>) -> vector<4xi32> {
+ // expected-error@+1 {{op operand #1 must be vector of 8-bit signless integer values of length 16, but got 'vector<16xi4>'}}
+ %0 = arm_neon.intr.usmmla %c, %a, %b : vector<16xi4> to vector<4xi32>
+ return %0 : vector<4xi32>
+}
+
+// -----
+
+func.func @usmmla_invalid_dimensions(%a: vector<32xi8>,
+ %b: vector<32xi8>,
+ %c: vector<8xi32>) -> vector<8xi32> {
+ // expected-error@+1 {{op operand #0 must be vector of 32-bit signless integer values of length 4, but got 'vector<8xi32>'}}
+ %0 = arm_neon.intr.usmmla %c, %a, %b : vector<32xi8> to vector<8xi32>
+ return %0 : vector<8xi32>
+}
diff --git a/mlir/test/Dialect/ArmNeon/roundtrip.mlir b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
index 704bfe8c084a5..30afe325a482c 100644
--- a/mlir/test/Dialect/ArmNeon/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt -verify-diagnostics -split-input-file %s | mlir-opt | FileCheck %s
// CHECK-LABEL: arm_neon_smull
func.func @arm_neon_smull(%a: vector<8xi8>, %b: vector<8xi8>)
@@ -25,3 +25,36 @@ func.func @arm_neon_sdot(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<8xi8>)
%0 = arm_neon.intr.sdot %a, %b, %c : vector<8xi8>, vector<8xi8> to vector<2xi32>
return %0 : vector<2xi32>
}
+
+// -----
+
+// CHECK-LABEL: arm_neon_smmla
+func.func @arm_neon_smmla(%a: vector<16xi8>,
+ %b: vector<16xi8>,
+ %c: vector<4xi32>) -> vector<4xi32> {
+ // CHECK: arm_neon.intr.smmla {{.*}}: vector<16xi8> to vector<4xi3
+ %0 = arm_neon.intr.smmla %c, %a, %b : vector<16xi8> to vector<4xi32>
+ return %0 : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: arm_neon_ummla
+func.func @arm_neon_ummla(%a: vector<16xi8>,
+ %b: vector<16xi8>,
+ %c: vector<4xi32>) -> vector<4xi32> {
+ // CHECK: arm_neon.intr.ummla {{.*}}: vector<16xi8> to vector<4xi3
+ %0 = arm_neon.intr.ummla %c, %a, %b : vector<16xi8> to vector<4xi32>
+ return %0 : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: arm_neon_usmmla
+func.func @arm_neon_usmmla(%a: vector<16xi8>,
+ %b: vector<16xi8>,
+ %c: vector<4xi32>) -> vector<4xi32> {
+ // CHECK: arm_neon.intr.usmmla {{.*}}: vector<16xi8> to vector<4xi3
+ %0 = arm_neon.intr.usmmla %c, %a, %b : vector<16xi8> to vector<4xi32>
+ return %0 : vector<4xi32>
+}
diff --git a/mlir/test/Target/LLVMIR/arm-neon.mlir b/mlir/test/Target/LLVMIR/arm-neon.mlir
index f4716fe58f203..e5b37ea3c8a5d 100644
--- a/mlir/test/Target/LLVMIR/arm-neon.mlir
+++ b/mlir/test/Target/LLVMIR/arm-neon.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
// CHECK-LABEL: arm_neon_smull
llvm.func @arm_neon_smull(%arg0: vector<8xi8>, %arg1: vector<8xi8>) -> !llvm.struct<(vector<8xi16>, vector<4xi32>, vector<2xi64>)> {
@@ -39,3 +39,45 @@ llvm.func @arm_neon_sdot_16_i8i8(%a: vector<4xi32>, %b: vector<16xi8>, %c: vecto
%0 = arm_neon.intr.sdot %a, %b, %c : vector<16xi8>, vector<16xi8> to vector<4xi32>
llvm.return %0 : vector<4xi32>
}
+
+// -----
+
+// CHECK-LABEL: define <4 x i32> @arm_neon_smmla
+llvm.func @arm_neon_smmla(%arg0: vector<16xi8>,
+ %arg1: vector<16xi8>,
+ %arg2: vector<4xi32>)
+ -> vector<4xi32> {
+ // CHECK: <4 x i32> @llvm.aarch64.neon.smmla.v4i32.v16i8(<4 x i32
+ %0 = "arm_neon.intr.smmla"(%arg2, %arg0, %arg1) :
+ (vector<4xi32>, vector<16xi8>, vector<16xi8>)
+ -> vector<4xi32>
+ llvm.return %0 : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: define <4 x i32> @arm_neon_ummla
+llvm.func @arm_neon_ummla(%arg0: vector<16xi8>,
+ %arg1: vector<16xi8>,
+ %arg2: vector<4xi32>)
+ -> vector<4xi32> {
+ // CHECK: <4 x i32> @llvm.aarch64.neon.ummla.v4i32.v16i8(<4 x i32
+ %0 = "arm_neon.intr.ummla"(%arg2, %arg0, %arg1) :
+ (vector<4xi32>, vector<16xi8>, vector<16xi8>)
+ -> vector<4xi32>
+ llvm.return %0 : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: define <4 x i32> @arm_neon_usmmla
+llvm.func @arm_neon_usmmla(%arg0: vector<16xi8>,
+ %arg1: vector<16xi8>,
+ %arg2: vector<4xi32>)
+ -> vector<4xi32> {
+ // CHECK: <4 x i32> @llvm.aarch64.neon.usmmla.v4i32.v16i8(<4 x i32
+ %0 = "arm_neon.intr.usmmla"(%arg2, %arg0, %arg1) :
+ (vector<4xi32>, vector<16xi8>, vector<16xi8>)
+ -> vector<4xi32>
+ llvm.return %0 : vector<4xi32>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this! Please see my comments inline.
9feff49
to
97527f8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Please address Ben's comment before landing. Thanks again for implementing this 🙏🏻
7a079d3
to
4ce12cd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM (other than one little thing)
This adds the SMMLA, UMMLA, and USMMLA intrinsics to Neon dialect bringing it in line with the SVE dialect.
These ops enable matrix multiply-accumulate instructions with two e 2x8 matrix inputs of respective signage into a 2x2 32-bit integer accumulator. This is equivalent to performing an 8-way dot product per destination element.
Op details:
https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=mmla