Skip to content

[mlir][ArmNeon] Adds Arm Neon SMMLA, UMMLA, and USMMLA Intrinsics #80511

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 6, 2024

Conversation

KoolJBlack
Copy link
Contributor

This adds the SMMLA, UMMLA, and USMMLA intrinsics to Neon dialect bringing it in line with the SVE dialect.

These ops enable matrix multiply-accumulate instructions with two e 2x8 matrix inputs of respective signage into a 2x2 32-bit integer accumulator. This is equivalent to performing an 8-way dot product per destination element.

Op details:
https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=mmla

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Some nits

@KoolJBlack KoolJBlack marked this pull request as ready for review February 3, 2024 01:09
@KoolJBlack KoolJBlack requested a review from dcaballe February 3, 2024 01:09
@llvmbot
Copy link
Member

llvmbot commented Feb 3, 2024

@llvm/pr-subscribers-mlir-neon
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Kojo Acquah (KoolJBlack)

Changes

This adds the SMMLA, UMMLA, and USMMLA intrinsics to Neon dialect bringing it in line with the SVE dialect.

These ops enable matrix multiply-accumulate instructions with two e 2x8 matrix inputs of respective signage into a 2x2 32-bit integer accumulator. This is equivalent to performing an 8-way dot product per destination element.

Op details:
https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=mmla


Full diff: https://github.com/llvm/llvm-project/pull/80511.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td (+93)
  • (modified) mlir/test/Dialect/ArmNeon/invalid.mlir (+60)
  • (modified) mlir/test/Dialect/ArmNeon/roundtrip.mlir (+34-1)
  • (modified) mlir/test/Target/LLVMIR/arm-neon.mlir (+43-1)
diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
index e298963f3d19f..c515a858ee8a1 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
+++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
@@ -120,6 +120,99 @@ def SdotOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"sdot",[1], [
     "$a `,` $b `,` $c attr-dict `:` type($b) `,` type($c) `to` type($res)";
   }
 
+def SmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"smmla",[1], [
+                Pure,
+                AllTypesMatch<["b", "c"]>,
+                AllTypesMatch<["a", "res"]>,
+              ]> {
+  let summary = "Matrix-matrix multiply and accumulate op";
+  let description = [{
+    SMMLA: Signed integer matrix multiply-accumulate.
+
+    Signed 8-bit integer matrix multiply-accumulate. This instruction multiplies
+    the 2x8 matrix of signed 8-bit integer values in the first source vector by
+    the 8x2 matrix of signed 8-bit integer values in the second source vector.
+    The resulting 2x2 32-bit integer matrix product is destructively added to
+    the 32-bit integer matrix accumulator in the destination vector. This is
+    equivalent to performing an 8-way dot product per destination element.
+
+    Source:
+    https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=smmla
+  }];
+  // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+  let arguments = (ins
+          VectorOfLengthAndType<[4], [I32]>:$a,
+          VectorOfLengthAndType<[16], [I8]>:$b,
+          VectorOfLengthAndType<[16], [I8]>:$c
+  );
+  let results = (outs VectorOfLengthAndType<[4], [I32]>:$res);
+  let assemblyFormat =
+    "$a `,` $b `,` $c attr-dict `:` type($b) `to` type($res)";
+}
+
+def UmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"ummla",[1], [
+                Pure,
+                AllTypesMatch<["b", "c"]>,
+                AllTypesMatch<["a", "res"]>,
+              ]> {
+  let summary = "Unsinged matrix-matrix multiply and accumulate op";
+  let description = [{
+    UMMLA: Signed integer matrix multiply-accumulate.
+
+    Unsigned 8-bit integer matrix multiply-accumulate. This instruction
+    multiplies the 2x8 matrix of unsigned 8-bit integer values in the first
+    source vector by the 8x2 matrix of unsigned 8-bit integer values in the
+    second source vector. The resulting 2x2 32-bit integer matrix product is
+    destructively added to the 32-bit integer matrix accumulator in the
+    destination vector. This is equivalent to performing an 8-way dot product
+    per destination element.
+
+    Source:
+    https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=ummla
+  }];
+  // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+  let arguments = (ins
+          VectorOfLengthAndType<[4], [I32]>:$a,
+          VectorOfLengthAndType<[16], [I8]>:$b,
+          VectorOfLengthAndType<[16], [I8]>:$c
+  );
+  let results = (outs VectorOfLengthAndType<[4], [I32]>:$res);
+  let assemblyFormat =
+    "$a `,` $b `,` $c attr-dict `:` type($b) `to` type($res)";
+}
+
+def UsmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"usmmla",[1], [
+                Pure,
+                AllTypesMatch<["b", "c"]>,
+                AllTypesMatch<["a", "res"]>,
+              ]> {
+  let summary = "Unsignged and signed matrix-matrix multiply and accumulate op";
+  let description = [{
+    USMMLA: Signed integer matrix multiply-accumulate.
+
+    Unsigned and signed 8-bit integer matrix multiply-accumulate. This
+    instruction multiplies the 2x8 matrix of unsigned 8-bit integer values in
+    the first source vector by the 8x2 matrix of signed 8-bit integer values in
+    the second source vector. The resulting 2x2 32-bit integer matrix product is
+    destructively added to the 32-bit integer matrix accumulator in the
+    destination vector. This is equivalent to performing an 8-way dot product
+     per destination element.
+
+
+    Source:
+    https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=usmmla
+  }];
+  // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+  let arguments = (ins
+          VectorOfLengthAndType<[4], [I32]>:$a,
+          VectorOfLengthAndType<[16], [I8]>:$b,
+          VectorOfLengthAndType<[16], [I8]>:$c
+  );
+  let results = (outs VectorOfLengthAndType<[4], [I32]>:$res);
+  let assemblyFormat =
+    "$a `,` $b `,` $c attr-dict `:` type($b) `to` type($res)";
+}
+
 class ArmNeon_2dOp<string mnemonic, list<Trait> traits = []>
     : Op</*dialect=*/ArmNeon_Dialect,
          /*opName=*/"2d." # mnemonic,
diff --git a/mlir/test/Dialect/ArmNeon/invalid.mlir b/mlir/test/Dialect/ArmNeon/invalid.mlir
index 62caf04160020..3ad763e7b4982 100644
--- a/mlir/test/Dialect/ArmNeon/invalid.mlir
+++ b/mlir/test/Dialect/ArmNeon/invalid.mlir
@@ -31,3 +31,63 @@ func.func @b_has_2_rows_but_a_has_length_4(%a : vector<4xi32>, %b : vector<2x4xi
     %0 = arm_neon.2d.sdot %a, %b, %b : vector<2x4xi8>, vector<2x4xi8> to vector<4xi32>
     return %0 : vector<4xi32>
 }
+
+// -----
+
+func.func @smmla_invalid_input_types(%a: vector<16xi4>,
+                                     %b: vector<16xi4>,
+                                     %c: vector<4xi32>) -> vector<4xi32> {
+  // expected-error@+1 {{op operand #1 must be vector of 8-bit signless integer values of length 16, but got 'vector<16xi4>'}}
+  %0 = arm_neon.intr.smmla %c, %a, %b : vector<16xi4> to vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// -----
+
+func.func @smmla_invalid_dimensions(%a: vector<32xi8>,
+                                    %b: vector<32xi8>,
+                                    %c: vector<8xi32>) -> vector<8xi32> {
+  // expected-error@+1 {{op operand #0 must be vector of 32-bit signless integer values of length 4, but got 'vector<8xi32>'}}
+  %0 = arm_neon.intr.smmla %c, %a, %b : vector<32xi8> to vector<8xi32>
+  return %0 : vector<8xi32>
+}
+
+// -----
+
+func.func @ummla_invalid_input_types(%a: vector<16xi4>,
+                                     %b: vector<16xi4>,
+                                     %c: vector<4xi32>) -> vector<4xi32> {
+  // expected-error@+1 {{op operand #1 must be vector of 8-bit signless integer values of length 16, but got 'vector<16xi4>'}}
+  %0 = arm_neon.intr.ummla %c, %a, %b : vector<16xi4> to vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// -----
+
+func.func @ummla_invalid_dimensions(%a: vector<32xi8>,
+                                    %b: vector<32xi8>,
+                                    %c: vector<8xi32>) -> vector<8xi32> {
+  // expected-error@+1 {{op operand #0 must be vector of 32-bit signless integer values of length 4, but got 'vector<8xi32>'}}
+  %0 = arm_neon.intr.ummla %c, %a, %b : vector<32xi8> to vector<8xi32>
+  return %0 : vector<8xi32>
+}
+
+// -----
+
+func.func @usmmla_invalid_input_types(%a: vector<16xi4>,
+                                      %b: vector<16xi4>,
+                                      %c: vector<4xi32>) -> vector<4xi32> {
+  // expected-error@+1 {{op operand #1 must be vector of 8-bit signless integer values of length 16, but got 'vector<16xi4>'}}
+  %0 = arm_neon.intr.usmmla %c, %a, %b : vector<16xi4> to vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// -----
+
+func.func @usmmla_invalid_dimensions(%a: vector<32xi8>,
+                                     %b: vector<32xi8>,
+                                     %c: vector<8xi32>) -> vector<8xi32> {
+  // expected-error@+1 {{op operand #0 must be vector of 32-bit signless integer values of length 4, but got 'vector<8xi32>'}}
+  %0 = arm_neon.intr.usmmla %c, %a, %b : vector<32xi8> to vector<8xi32>
+  return %0 : vector<8xi32>
+}
diff --git a/mlir/test/Dialect/ArmNeon/roundtrip.mlir b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
index 704bfe8c084a5..30afe325a482c 100644
--- a/mlir/test/Dialect/ArmNeon/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt -verify-diagnostics -split-input-file %s | mlir-opt | FileCheck %s
 
 // CHECK-LABEL: arm_neon_smull
 func.func @arm_neon_smull(%a: vector<8xi8>, %b: vector<8xi8>)
@@ -25,3 +25,36 @@ func.func @arm_neon_sdot(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<8xi8>)
   %0 = arm_neon.intr.sdot %a, %b, %c : vector<8xi8>, vector<8xi8> to vector<2xi32>
   return %0 : vector<2xi32>
 }
+
+// -----
+
+// CHECK-LABEL: arm_neon_smmla
+func.func @arm_neon_smmla(%a: vector<16xi8>,
+                          %b: vector<16xi8>,
+                          %c: vector<4xi32>) -> vector<4xi32> {
+  // CHECK: arm_neon.intr.smmla {{.*}}: vector<16xi8> to vector<4xi3
+  %0 = arm_neon.intr.smmla %c, %a, %b : vector<16xi8> to vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: arm_neon_ummla
+func.func @arm_neon_ummla(%a: vector<16xi8>,
+                          %b: vector<16xi8>,
+                          %c: vector<4xi32>) -> vector<4xi32> {
+  // CHECK: arm_neon.intr.ummla {{.*}}: vector<16xi8> to vector<4xi3
+  %0 = arm_neon.intr.ummla %c, %a, %b : vector<16xi8> to vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: arm_neon_usmmla
+func.func @arm_neon_usmmla(%a: vector<16xi8>,
+                            %b: vector<16xi8>,
+                            %c: vector<4xi32>) -> vector<4xi32> {
+  // CHECK: arm_neon.intr.usmmla {{.*}}: vector<16xi8> to vector<4xi3
+  %0 = arm_neon.intr.usmmla %c, %a, %b : vector<16xi8> to vector<4xi32>
+  return %0 : vector<4xi32>
+}
diff --git a/mlir/test/Target/LLVMIR/arm-neon.mlir b/mlir/test/Target/LLVMIR/arm-neon.mlir
index f4716fe58f203..e5b37ea3c8a5d 100644
--- a/mlir/test/Target/LLVMIR/arm-neon.mlir
+++ b/mlir/test/Target/LLVMIR/arm-neon.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file  %s | FileCheck %s
 
 // CHECK-LABEL: arm_neon_smull
 llvm.func @arm_neon_smull(%arg0: vector<8xi8>, %arg1: vector<8xi8>) -> !llvm.struct<(vector<8xi16>, vector<4xi32>, vector<2xi64>)> {
@@ -39,3 +39,45 @@ llvm.func @arm_neon_sdot_16_i8i8(%a: vector<4xi32>, %b: vector<16xi8>, %c: vecto
   %0 = arm_neon.intr.sdot %a, %b, %c : vector<16xi8>, vector<16xi8> to vector<4xi32>
   llvm.return %0 : vector<4xi32>
 }
+
+// -----
+
+// CHECK-LABEL: define <4 x i32> @arm_neon_smmla
+llvm.func @arm_neon_smmla(%arg0: vector<16xi8>,
+                          %arg1: vector<16xi8>,
+                          %arg2: vector<4xi32>)
+                         -> vector<4xi32> {
+  // CHECK: <4 x i32> @llvm.aarch64.neon.smmla.v4i32.v16i8(<4 x i32
+  %0 = "arm_neon.intr.smmla"(%arg2, %arg0, %arg1) :
+    (vector<4xi32>, vector<16xi8>, vector<16xi8>)
+        -> vector<4xi32>
+  llvm.return %0 : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: define <4 x i32> @arm_neon_ummla
+llvm.func @arm_neon_ummla(%arg0: vector<16xi8>,
+                          %arg1: vector<16xi8>,
+                          %arg2: vector<4xi32>)
+                         -> vector<4xi32> {
+  // CHECK: <4 x i32> @llvm.aarch64.neon.ummla.v4i32.v16i8(<4 x i32
+  %0 = "arm_neon.intr.ummla"(%arg2, %arg0, %arg1) :
+    (vector<4xi32>, vector<16xi8>, vector<16xi8>)
+        -> vector<4xi32>
+  llvm.return %0 : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: define <4 x i32> @arm_neon_usmmla
+llvm.func @arm_neon_usmmla(%arg0: vector<16xi8>,
+                          %arg1: vector<16xi8>,
+                          %arg2: vector<4xi32>)
+                         -> vector<4xi32> {
+  // CHECK: <4 x i32> @llvm.aarch64.neon.usmmla.v4i32.v16i8(<4 x i32
+  %0 = "arm_neon.intr.usmmla"(%arg2, %arg0, %arg1) :
+    (vector<4xi32>, vector<16xi8>, vector<16xi8>)
+        -> vector<4xi32>
+  llvm.return %0 : vector<4xi32>
+}

@KoolJBlack KoolJBlack changed the title Adds Arm Neon SMMLA, UMMLA, and USMMLA Intrinsics [mlir][ArmNeon] Adds Arm Neon SMMLA, UMMLA, and USMMLA Intrinsics Feb 3, 2024
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this! Please see my comments inline.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Please address Ben's comment before landing. Thanks again for implementing this 🙏🏻

@KoolJBlack KoolJBlack requested a review from MacDue February 6, 2024 17:27
Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM (other than one little thing)

@dcaballe dcaballe merged commit 16d890c into llvm:main Feb 6, 2024
@KoolJBlack KoolJBlack deleted the arm_neon_roundtrip branch February 6, 2024 22:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants