Skip to content

[mlir][ArmNeon] Adds Arm Neon SMMLA, UMMLA, and USMMLA Intrinsics #80511

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,99 @@ def SdotOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"sdot",[1], [
"$a `,` $b `,` $c attr-dict `:` type($b) `,` type($c) `to` type($res)";
}

def SmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"smmla",[1], [
Pure,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "res"]>,
]> {
let summary = "Matrix-matrix multiply and accumulate op";
let description = [{
SMMLA: Signed integer matrix multiply-accumulate.

Signed 8-bit integer matrix multiply-accumulate. This instruction multiplies
the 2x8 matrix of signed 8-bit integer values in the first source vector by
the 8x2 matrix of signed 8-bit integer values in the second source vector.
The resulting 2x2 32-bit integer matrix product is destructively added to
the 32-bit integer matrix accumulator in the destination vector. This is
equivalent to performing an 8-way dot product per destination element.

Source:
https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=smmla
}];
// Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
let arguments = (ins
VectorOfLengthAndType<[4], [I32]>:$acc,
VectorOfLengthAndType<[16], [I8]>:$src1,
VectorOfLengthAndType<[16], [I8]>:$src2
);
let results = (outs VectorOfLengthAndType<[4], [I32]>:$res);
let assemblyFormat =
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
}

def UmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"ummla",[1], [
Pure,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "res"]>,
]> {
let summary = "Unsinged matrix-matrix multiply and accumulate op";
let description = [{
UMMLA: Signed integer matrix multiply-accumulate.

Unsigned 8-bit integer matrix multiply-accumulate. This instruction
multiplies the 2x8 matrix of unsigned 8-bit integer values in the first
source vector by the 8x2 matrix of unsigned 8-bit integer values in the
second source vector. The resulting 2x2 32-bit integer matrix product is
destructively added to the 32-bit integer matrix accumulator in the
destination vector. This is equivalent to performing an 8-way dot product
per destination element.

Source:
https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=ummla
}];
// Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
let arguments = (ins
VectorOfLengthAndType<[4], [I32]>:$acc,
VectorOfLengthAndType<[16], [I8]>:$src1,
VectorOfLengthAndType<[16], [I8]>:$src2
);
let results = (outs VectorOfLengthAndType<[4], [I32]>:$res);
let assemblyFormat =
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
}

def UsmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"usmmla",[1], [
Pure,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "res"]>,
]> {
let summary = "Unsignged and signed matrix-matrix multiply and accumulate op";
let description = [{
USMMLA: Signed integer matrix multiply-accumulate.

Unsigned and signed 8-bit integer matrix multiply-accumulate. This
instruction multiplies the 2x8 matrix of unsigned 8-bit integer values in
the first source vector by the 8x2 matrix of signed 8-bit integer values in
the second source vector. The resulting 2x2 32-bit integer matrix product is
destructively added to the 32-bit integer matrix accumulator in the
destination vector. This is equivalent to performing an 8-way dot product
per destination element.


Source:
https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=usmmla
}];
// Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
let arguments = (ins
VectorOfLengthAndType<[4], [I32]>:$acc,
VectorOfLengthAndType<[16], [I8]>:$src1,
VectorOfLengthAndType<[16], [I8]>:$src2
);
let results = (outs VectorOfLengthAndType<[4], [I32]>:$res);
let assemblyFormat =
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
}

class ArmNeon_2dOp<string mnemonic, list<Trait> traits = []>
: Op</*dialect=*/ArmNeon_Dialect,
/*opName=*/"2d." # mnemonic,
Expand Down
60 changes: 60 additions & 0 deletions mlir/test/Dialect/ArmNeon/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,63 @@ func.func @b_has_2_rows_but_a_has_length_4(%a : vector<4xi32>, %b : vector<2x4xi
%0 = arm_neon.2d.sdot %a, %b, %b : vector<2x4xi8>, vector<2x4xi8> to vector<4xi32>
return %0 : vector<4xi32>
}

// -----

func.func @smmla_invalid_input_types(%a: vector<4xi32>,
%b: vector<16xi4>,
%c: vector<16xi4>) -> vector<4xi32> {
// expected-error@+1 {{op operand #1 must be vector of 8-bit signless integer values of length 16, but got 'vector<16xi4>'}}
%0 = arm_neon.intr.smmla %a, %b, %c : vector<16xi4> to vector<4xi32>
return %0 : vector<4xi32>
}

// -----

func.func @smmla_invalid_dimensions(%a: vector<8xi32>,
%b: vector<32xi8>,
%c: vector<32xi8>) -> vector<8xi32> {
// expected-error@+1 {{op operand #0 must be vector of 32-bit signless integer values of length 4, but got 'vector<8xi32>'}}
%0 = arm_neon.intr.smmla %a, %b, %c : vector<32xi8> to vector<8xi32>
return %0 : vector<8xi32>
}

// -----

func.func @ummla_invalid_input_types(%a: vector<4xi32>,
%b: vector<16xi4>,
%c: vector<16xi4>) -> vector<4xi32> {
// expected-error@+1 {{op operand #1 must be vector of 8-bit signless integer values of length 16, but got 'vector<16xi4>'}}
%0 = arm_neon.intr.ummla %a, %b, %c : vector<16xi4> to vector<4xi32>
return %0 : vector<4xi32>
}

// -----

func.func @ummla_invalid_dimensions(%a: vector<8xi32>,
%b: vector<32xi8>,
%c: vector<32xi8>) -> vector<8xi32> {
// expected-error@+1 {{op operand #0 must be vector of 32-bit signless integer values of length 4, but got 'vector<8xi32>'}}
%0 = arm_neon.intr.ummla %a, %b, %c : vector<32xi8> to vector<8xi32>
return %0 : vector<8xi32>
}

// -----

func.func @usmmla_invalid_input_types(%a: vector<4xi32>,
%b: vector<16xi4>,
%c: vector<16xi4>) -> vector<4xi32> {
// expected-error@+1 {{op operand #1 must be vector of 8-bit signless integer values of length 16, but got 'vector<16xi4>'}}
%0 = arm_neon.intr.usmmla %a, %b, %c : vector<16xi4> to vector<4xi32>
return %0 : vector<4xi32>
}

// -----

func.func @usmmla_invalid_dimensions(%a: vector<8xi32>,
%b: vector<32xi8>,
%c: vector<32xi8>) -> vector<8xi32> {
// expected-error@+1 {{op operand #0 must be vector of 32-bit signless integer values of length 4, but got 'vector<8xi32>'}}
%0 = arm_neon.intr.usmmla %a, %b, %c : vector<32xi8> to vector<8xi32>
return %0 : vector<8xi32>
}
37 changes: 36 additions & 1 deletion mlir/test/Dialect/ArmNeon/roundtrip.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
// RUN: mlir-opt -verify-diagnostics -split-input-file %s | mlir-opt | FileCheck %s

// CHECK-LABEL: arm_neon_smull
func.func @arm_neon_smull(%a: vector<8xi8>, %b: vector<8xi8>)
Expand All @@ -19,9 +19,44 @@ func.func @arm_neon_smull(%a: vector<8xi8>, %b: vector<8xi8>)
return %0, %1, %2 : vector<8xi16>, vector<4xi32>, vector<2xi64>
}

// -----

// CHECK-LABEL: arm_neon_sdot
func.func @arm_neon_sdot(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<8xi8>) -> vector<2xi32> {
// CHECK: arm_neon.intr.sdot {{.*}}: vector<8xi8>, vector<8xi8> to vector<2xi32>
%0 = arm_neon.intr.sdot %a, %b, %c : vector<8xi8>, vector<8xi8> to vector<2xi32>
return %0 : vector<2xi32>
}

// -----

// CHECK-LABEL: arm_neon_smmla
func.func @arm_neon_smmla(%a: vector<16xi8>,
%b: vector<16xi8>,
%c: vector<4xi32>) -> vector<4xi32> {
// CHECK: arm_neon.intr.smmla {{.*}}: vector<16xi8> to vector<4xi32>
%0 = arm_neon.intr.smmla %c, %a, %b : vector<16xi8> to vector<4xi32>
return %0 : vector<4xi32>
}

// -----

// CHECK-LABEL: arm_neon_ummla
func.func @arm_neon_ummla(%a: vector<16xi8>,
%b: vector<16xi8>,
%c: vector<4xi32>) -> vector<4xi32> {
// CHECK: arm_neon.intr.ummla {{.*}}: vector<16xi8> to vector<4xi32>
%0 = arm_neon.intr.ummla %c, %a, %b : vector<16xi8> to vector<4xi32>
return %0 : vector<4xi32>
}

// -----

// CHECK-LABEL: arm_neon_usmmla
func.func @arm_neon_usmmla(%a: vector<16xi8>,
%b: vector<16xi8>,
%c: vector<4xi32>) -> vector<4xi32> {
// CHECK: arm_neon.intr.usmmla {{.*}}: vector<16xi8> to vector<4xi32>
%0 = arm_neon.intr.usmmla %c, %a, %b : vector<16xi8> to vector<4xi32>
return %0 : vector<4xi32>
}
46 changes: 45 additions & 1 deletion mlir/test/Target/LLVMIR/arm-neon.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s

// CHECK-LABEL: arm_neon_smull
llvm.func @arm_neon_smull(%arg0: vector<8xi8>, %arg1: vector<8xi8>) -> !llvm.struct<(vector<8xi16>, vector<4xi32>, vector<2xi64>)> {
Expand All @@ -24,6 +24,8 @@ llvm.func @arm_neon_smull(%arg0: vector<8xi8>, %arg1: vector<8xi8>) -> !llvm.str
llvm.return %8 : !llvm.struct<(vector<8xi16>, vector<4xi32>, vector<2xi64>)>
}

// -----

// CHECK-LABEL: arm_neon_sdot_8_i8i8
llvm.func @arm_neon_sdot_8_i8i8(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<8xi8>) -> vector<2xi32> {
// CHECK: %[[V0:.*]] = call <2 x i32> @llvm.aarch64.neon.sdot.v2i32.v8i8(<2 x i32> %{{.*}}, <8 x i8> %{{.*}}, <8 x i8> %{{.*}})
Expand All @@ -32,10 +34,52 @@ llvm.func @arm_neon_sdot_8_i8i8(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<
llvm.return %0 : vector<2xi32>
}

// -----


// CHECK-LABEL: arm_neon_sdot_16_i8i8
llvm.func @arm_neon_sdot_16_i8i8(%a: vector<4xi32>, %b: vector<16xi8>, %c: vector<16xi8>) -> vector<4xi32> {
// CHECK: %[[V0:.*]] = call <4 x i32> @llvm.aarch64.neon.sdot.v4i32.v16i8(<4 x i32> %{{.*}}, <16 x i8> %{{.*}}, <16 x i8> %{{.*}})
// CHECK-NEXT: ret <4 x i32>
%0 = arm_neon.intr.sdot %a, %b, %c : vector<16xi8>, vector<16xi8> to vector<4xi32>
llvm.return %0 : vector<4xi32>
}

// -----

// CHECK-LABEL: arm_neon_smmla
llvm.func @arm_neon_smmla(%arg0: vector<16xi8>,
%arg1: vector<16xi8>,
%arg2: vector<4xi32>) -> vector<4xi32> {
// CHECK: <4 x i32> @llvm.aarch64.neon.smmla.v4i32.v16i8(<4 x i32
%0 = "arm_neon.intr.smmla"(%arg2, %arg0, %arg1) :
(vector<4xi32>, vector<16xi8>, vector<16xi8>)
-> vector<4xi32>
llvm.return %0 : vector<4xi32>
}

// -----

// CHECK-LABEL: arm_neon_ummla
llvm.func @arm_neon_ummla(%arg0: vector<16xi8>,
%arg1: vector<16xi8>,
%arg2: vector<4xi32>) -> vector<4xi32> {
// CHECK: <4 x i32> @llvm.aarch64.neon.ummla.v4i32.v16i8(<4 x i32
%0 = "arm_neon.intr.ummla"(%arg2, %arg0, %arg1) :
(vector<4xi32>, vector<16xi8>, vector<16xi8>)
-> vector<4xi32>
llvm.return %0 : vector<4xi32>
}

// -----

// CHECK-LABEL: arm_neon_usmmla
llvm.func @arm_neon_usmmla(%arg0: vector<16xi8>,
%arg1: vector<16xi8>,
%arg2: vector<4xi32>) -> vector<4xi32> {
// CHECK: <4 x i32> @llvm.aarch64.neon.usmmla.v4i32.v16i8(<4 x i32
%0 = "arm_neon.intr.usmmla"(%arg2, %arg0, %arg1) :
(vector<4xi32>, vector<16xi8>, vector<16xi8>)
-> vector<4xi32>
llvm.return %0 : vector<4xi32>
}