Skip to content

[mlir][ArmNeon] Adds Arm Neon SMMLA, UMMLA, and USMMLA Intrinsics #80511

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ def ArmNeon_Dialect : Dialect {
// to the LLVMDialect (ops or types).
}

//===----------------------------------------------------------------------===//
// ArmNeon type definition
//===----------------------------------------------------------------------===//

class NeonVectorOfLength<int length, Type elementType> : ShapedContainerType<
[elementType], And<[IsVectorOfShape<[length]>, IsFixedVectorTypePred]>,
"a vector with length " # length,
"::mlir::VectorType">;

//===----------------------------------------------------------------------===//
// ArmNeon op definitions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -120,6 +129,99 @@ def SdotOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"sdot",[1], [
"$a `,` $b `,` $c attr-dict `:` type($b) `,` type($c) `to` type($res)";
}

def SmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"smmla",[1], [
Pure,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "res"]>,
]> {
let summary = "Matrix-matrix multiply and accumulate op";
let description = [{
SMMLA: Signed integer matrix multiply-accumulate.

Signed 8-bit integer matrix multiply-accumulate. This instruction multiplies
the 2x8 matrix of signed 8-bit integer values in the first source vector by
the 8x2 matrix of signed 8-bit integer values in the second source vector.
The resulting 2x2 32-bit integer matrix product is destructively added to
the 32-bit integer matrix accumulator in the destination vector. This is
equivalent to performing an 8-way dot product per destination element.

Source:
https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=smmla
}];
// Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
let arguments = (ins
NeonVectorOfLength<4, I32>:$acc,
NeonVectorOfLength<16, I8>:$src1,
NeonVectorOfLength<16, I8>:$src2
);
let results = (outs NeonVectorOfLength<4, I32>:$res);
let assemblyFormat =
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
}

def UmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"ummla",[1], [
Pure,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "res"]>,
]> {
let summary = "Unsinged matrix-matrix multiply and accumulate op";
let description = [{
UMMLA: Signed integer matrix multiply-accumulate.

Unsigned 8-bit integer matrix multiply-accumulate. This instruction
multiplies the 2x8 matrix of unsigned 8-bit integer values in the first
source vector by the 8x2 matrix of unsigned 8-bit integer values in the
second source vector. The resulting 2x2 32-bit integer matrix product is
destructively added to the 32-bit integer matrix accumulator in the
destination vector. This is equivalent to performing an 8-way dot product
per destination element.

Source:
https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=ummla
}];
// Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
let arguments = (ins
NeonVectorOfLength<4, I32>:$acc,
NeonVectorOfLength<16, I8>:$src1,
NeonVectorOfLength<16, I8>:$src2
);
let results = (outs NeonVectorOfLength<4, I32>:$res);
let assemblyFormat =
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
}

def UsmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"usmmla",[1], [
Pure,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "res"]>,
]> {
let summary = "Unsignged and signed matrix-matrix multiply and accumulate op";
let description = [{
USMMLA: Signed integer matrix multiply-accumulate.

Unsigned and signed 8-bit integer matrix multiply-accumulate. This
instruction multiplies the 2x8 matrix of unsigned 8-bit integer values in
the first source vector by the 8x2 matrix of signed 8-bit integer values in
the second source vector. The resulting 2x2 32-bit integer matrix product is
destructively added to the 32-bit integer matrix accumulator in the
destination vector. This is equivalent to performing an 8-way dot product
per destination element.


Source:
https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=usmmla
}];
// Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
let arguments = (ins
NeonVectorOfLength<4, I32>:$acc,
NeonVectorOfLength<16, I8>:$src1,
NeonVectorOfLength<16, I8>:$src2
);
let results = (outs NeonVectorOfLength<4, I32>:$res);
let assemblyFormat =
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
}

class ArmNeon_2dOp<string mnemonic, list<Trait> traits = []>
: Op</*dialect=*/ArmNeon_Dialect,
/*opName=*/"2d." # mnemonic,
Expand Down
60 changes: 60 additions & 0 deletions mlir/test/Dialect/ArmNeon/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,63 @@ func.func @b_has_2_rows_but_a_has_length_4(%a : vector<4xi32>, %b : vector<2x4xi
%0 = arm_neon.2d.sdot %a, %b, %b : vector<2x4xi8>, vector<2x4xi8> to vector<4xi32>
return %0 : vector<4xi32>
}

// -----

func.func @smmla_invalid_input_types(%a: vector<4xi32>,
%b: vector<16xi4>,
%c: vector<16xi4>) -> vector<4xi32> {
// expected-error@+1 {{op operand #1 must be a vector with length 16 of 8-bit signless integer values, but got 'vector<16xi4>'}}
%0 = arm_neon.intr.smmla %a, %b, %c : vector<16xi4> to vector<4xi32>
return %0 : vector<4xi32>
}

// -----

func.func @smmla_invalid_dimensions(%a: vector<8xi32>,
%b: vector<32xi8>,
%c: vector<32xi8>) -> vector<8xi32> {
// expected-error@+1 {{op operand #0 must be a vector with length 4 of 32-bit signless integer values, but got 'vector<8xi32>'}}
%0 = arm_neon.intr.smmla %a, %b, %c : vector<32xi8> to vector<8xi32>
return %0 : vector<8xi32>
}

// -----

func.func @ummla_invalid_input_types(%a: vector<4xi32>,
%b: vector<16xi4>,
%c: vector<16xi4>) -> vector<4xi32> {
// expected-error@+1 {{op operand #1 must be a vector with length 16 of 8-bit signless integer values, but got 'vector<16xi4>'}}
%0 = arm_neon.intr.ummla %a, %b, %c : vector<16xi4> to vector<4xi32>
return %0 : vector<4xi32>
}

// -----

func.func @ummla_invalid_dimensions(%a: vector<8xi32>,
%b: vector<32xi8>,
%c: vector<32xi8>) -> vector<8xi32> {
// expected-error@+1 {{op operand #0 must be a vector with length 4 of 32-bit signless integer values, but got 'vector<8xi32>'}}
%0 = arm_neon.intr.ummla %a, %b, %c : vector<32xi8> to vector<8xi32>
return %0 : vector<8xi32>
}

// -----

func.func @usmmla_invalid_input_types(%a: vector<4xi32>,
%b: vector<16xi4>,
%c: vector<16xi4>) -> vector<4xi32> {
// expected-error@+1 {{op operand #1 must be a vector with length 16 of 8-bit signless integer values, but got 'vector<16xi4>'}}
%0 = arm_neon.intr.usmmla %a, %b, %c : vector<16xi4> to vector<4xi32>
return %0 : vector<4xi32>
}

// -----

func.func @usmmla_invalid_dimensions(%a: vector<8xi32>,
%b: vector<32xi8>,
%c: vector<32xi8>) -> vector<8xi32> {
// expected-error@+1 {{op operand #0 must be a vector with length 4 of 32-bit signless integer values, but got 'vector<8xi32>'}}
%0 = arm_neon.intr.usmmla %a, %b, %c : vector<32xi8> to vector<8xi32>
return %0 : vector<8xi32>
}
37 changes: 36 additions & 1 deletion mlir/test/Dialect/ArmNeon/roundtrip.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
// RUN: mlir-opt -verify-diagnostics -split-input-file %s | mlir-opt | FileCheck %s

// CHECK-LABEL: arm_neon_smull
func.func @arm_neon_smull(%a: vector<8xi8>, %b: vector<8xi8>)
Expand All @@ -19,9 +19,44 @@ func.func @arm_neon_smull(%a: vector<8xi8>, %b: vector<8xi8>)
return %0, %1, %2 : vector<8xi16>, vector<4xi32>, vector<2xi64>
}

// -----

// CHECK-LABEL: arm_neon_sdot
func.func @arm_neon_sdot(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<8xi8>) -> vector<2xi32> {
// CHECK: arm_neon.intr.sdot {{.*}}: vector<8xi8>, vector<8xi8> to vector<2xi32>
%0 = arm_neon.intr.sdot %a, %b, %c : vector<8xi8>, vector<8xi8> to vector<2xi32>
return %0 : vector<2xi32>
}

// -----

// CHECK-LABEL: arm_neon_smmla
func.func @arm_neon_smmla(%a: vector<16xi8>,
%b: vector<16xi8>,
%c: vector<4xi32>) -> vector<4xi32> {
// CHECK: arm_neon.intr.smmla {{.*}}: vector<16xi8> to vector<4xi32>
%0 = arm_neon.intr.smmla %c, %a, %b : vector<16xi8> to vector<4xi32>
return %0 : vector<4xi32>
}

// -----

// CHECK-LABEL: arm_neon_ummla
func.func @arm_neon_ummla(%a: vector<16xi8>,
%b: vector<16xi8>,
%c: vector<4xi32>) -> vector<4xi32> {
// CHECK: arm_neon.intr.ummla {{.*}}: vector<16xi8> to vector<4xi32>
%0 = arm_neon.intr.ummla %c, %a, %b : vector<16xi8> to vector<4xi32>
return %0 : vector<4xi32>
}

// -----

// CHECK-LABEL: arm_neon_usmmla
func.func @arm_neon_usmmla(%a: vector<16xi8>,
%b: vector<16xi8>,
%c: vector<4xi32>) -> vector<4xi32> {
// CHECK: arm_neon.intr.usmmla {{.*}}: vector<16xi8> to vector<4xi32>
%0 = arm_neon.intr.usmmla %c, %a, %b : vector<16xi8> to vector<4xi32>
return %0 : vector<4xi32>
}
45 changes: 44 additions & 1 deletion mlir/test/Target/LLVMIR/arm-neon.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s

// CHECK-LABEL: arm_neon_smull
llvm.func @arm_neon_smull(%arg0: vector<8xi8>, %arg1: vector<8xi8>) -> !llvm.struct<(vector<8xi16>, vector<4xi32>, vector<2xi64>)> {
Expand All @@ -24,6 +24,8 @@ llvm.func @arm_neon_smull(%arg0: vector<8xi8>, %arg1: vector<8xi8>) -> !llvm.str
llvm.return %8 : !llvm.struct<(vector<8xi16>, vector<4xi32>, vector<2xi64>)>
}

// -----

// CHECK-LABEL: arm_neon_sdot_8_i8i8
llvm.func @arm_neon_sdot_8_i8i8(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<8xi8>) -> vector<2xi32> {
// CHECK: %[[V0:.*]] = call <2 x i32> @llvm.aarch64.neon.sdot.v2i32.v8i8(<2 x i32> %{{.*}}, <8 x i8> %{{.*}}, <8 x i8> %{{.*}})
Expand All @@ -32,10 +34,51 @@ llvm.func @arm_neon_sdot_8_i8i8(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<
llvm.return %0 : vector<2xi32>
}

// -----

// CHECK-LABEL: arm_neon_sdot_16_i8i8
llvm.func @arm_neon_sdot_16_i8i8(%a: vector<4xi32>, %b: vector<16xi8>, %c: vector<16xi8>) -> vector<4xi32> {
// CHECK: %[[V0:.*]] = call <4 x i32> @llvm.aarch64.neon.sdot.v4i32.v16i8(<4 x i32> %{{.*}}, <16 x i8> %{{.*}}, <16 x i8> %{{.*}})
// CHECK-NEXT: ret <4 x i32>
%0 = arm_neon.intr.sdot %a, %b, %c : vector<16xi8>, vector<16xi8> to vector<4xi32>
llvm.return %0 : vector<4xi32>
}

// -----

// CHECK-LABEL: arm_neon_smmla
llvm.func @arm_neon_smmla(%arg0: vector<16xi8>,
%arg1: vector<16xi8>,
%arg2: vector<4xi32>) -> vector<4xi32> {
// CHECK: <4 x i32> @llvm.aarch64.neon.smmla.v4i32.v16i8(<4 x i32
%0 = "arm_neon.intr.smmla"(%arg2, %arg0, %arg1) :
(vector<4xi32>, vector<16xi8>, vector<16xi8>)
-> vector<4xi32>
llvm.return %0 : vector<4xi32>
}

// -----

// CHECK-LABEL: arm_neon_ummla
llvm.func @arm_neon_ummla(%arg0: vector<16xi8>,
%arg1: vector<16xi8>,
%arg2: vector<4xi32>) -> vector<4xi32> {
// CHECK: <4 x i32> @llvm.aarch64.neon.ummla.v4i32.v16i8(<4 x i32
%0 = "arm_neon.intr.ummla"(%arg2, %arg0, %arg1) :
(vector<4xi32>, vector<16xi8>, vector<16xi8>)
-> vector<4xi32>
llvm.return %0 : vector<4xi32>
}

// -----

// CHECK-LABEL: arm_neon_usmmla
llvm.func @arm_neon_usmmla(%arg0: vector<16xi8>,
%arg1: vector<16xi8>,
%arg2: vector<4xi32>) -> vector<4xi32> {
// CHECK: <4 x i32> @llvm.aarch64.neon.usmmla.v4i32.v16i8(<4 x i32
%0 = "arm_neon.intr.usmmla"(%arg2, %arg0, %arg1) :
(vector<4xi32>, vector<16xi8>, vector<16xi8>)
-> vector<4xi32>
llvm.return %0 : vector<4xi32>
}