Skip to content

Commit 4ce12cd

Browse files
committed
introduce NeonVectorOfLength type
1 parent 97527f8 commit 4ce12cd

File tree

3 files changed

+27
-19
lines changed

3 files changed

+27
-19
lines changed

mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,15 @@ def ArmNeon_Dialect : Dialect {
3030
// to the LLVMDialect (ops or types).
3131
}
3232

33+
//===----------------------------------------------------------------------===//
34+
// ArmNeon type definition
35+
//===----------------------------------------------------------------------===//
36+
37+
class NeonVectorOfLength<int length, Type elementType> : ShapedContainerType<
38+
[elementType], And<[IsVectorOfShape<[length]>, IsFixedVectorTypePred]>,
39+
"Neon vector of " # length # "x" # elementType.summary,
40+
"::mlir::VectorType">;
41+
3342
//===----------------------------------------------------------------------===//
3443
// ArmNeon op definitions
3544
//===----------------------------------------------------------------------===//
@@ -141,11 +150,11 @@ def SmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"smmla",[1], [
141150
}];
142151
// Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
143152
let arguments = (ins
144-
VectorOfLengthAndType<[4], [I32]>:$acc,
145-
VectorOfLengthAndType<[16], [I8]>:$src1,
146-
VectorOfLengthAndType<[16], [I8]>:$src2
153+
NeonVectorOfLength<4, I32>:$acc,
154+
NeonVectorOfLength<16, I8>:$src1,
155+
NeonVectorOfLength<16, I8>:$src2
147156
);
148-
let results = (outs VectorOfLengthAndType<[4], [I32]>:$res);
157+
let results = (outs NeonVectorOfLength<4, I32>:$res);
149158
let assemblyFormat =
150159
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
151160
}
@@ -172,11 +181,11 @@ def UmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"ummla",[1], [
172181
}];
173182
// Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
174183
let arguments = (ins
175-
VectorOfLengthAndType<[4], [I32]>:$acc,
176-
VectorOfLengthAndType<[16], [I8]>:$src1,
177-
VectorOfLengthAndType<[16], [I8]>:$src2
184+
NeonVectorOfLength<4, I32>:$acc,
185+
NeonVectorOfLength<16, I8>:$src1,
186+
NeonVectorOfLength<16, I8>:$src2
178187
);
179-
let results = (outs VectorOfLengthAndType<[4], [I32]>:$res);
188+
let results = (outs NeonVectorOfLength<4, I32>:$res);
180189
let assemblyFormat =
181190
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
182191
}
@@ -204,11 +213,11 @@ def UsmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"usmmla",[1], [
204213
}];
205214
// Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
206215
let arguments = (ins
207-
VectorOfLengthAndType<[4], [I32]>:$acc,
208-
VectorOfLengthAndType<[16], [I8]>:$src1,
209-
VectorOfLengthAndType<[16], [I8]>:$src2
216+
NeonVectorOfLength<4, I32>:$acc,
217+
NeonVectorOfLength<16, I8>:$src1,
218+
NeonVectorOfLength<16, I8>:$src2
210219
);
211-
let results = (outs VectorOfLengthAndType<[4], [I32]>:$res);
220+
let results = (outs NeonVectorOfLength<4, I32>:$res);
212221
let assemblyFormat =
213222
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
214223
}

mlir/test/Dialect/ArmNeon/invalid.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func.func @b_has_2_rows_but_a_has_length_4(%a : vector<4xi32>, %b : vector<2x4xi
3737
func.func @smmla_invalid_input_types(%a: vector<4xi32>,
3838
%b: vector<16xi4>,
3939
%c: vector<16xi4>) -> vector<4xi32> {
40-
// expected-error@+1 {{op operand #1 must be vector of 8-bit signless integer values of length 16, but got 'vector<16xi4>'}}
40+
// expected-error@+1 {{op operand #1 must be Neon vector of 16x8-bit signless integer of 8-bit signless integer values, but got 'vector<16xi4>}}
4141
%0 = arm_neon.intr.smmla %a, %b, %c : vector<16xi4> to vector<4xi32>
4242
return %0 : vector<4xi32>
4343
}
@@ -47,7 +47,7 @@ func.func @smmla_invalid_input_types(%a: vector<4xi32>,
4747
func.func @smmla_invalid_dimensions(%a: vector<8xi32>,
4848
%b: vector<32xi8>,
4949
%c: vector<32xi8>) -> vector<8xi32> {
50-
// expected-error@+1 {{op operand #0 must be vector of 32-bit signless integer values of length 4, but got 'vector<8xi32>'}}
50+
// expected-error@+1 {{op operand #0 must be Neon vector of 4x32-bit signless integer of 32-bit signless integer values, but got 'vector<8xi32>'}}
5151
%0 = arm_neon.intr.smmla %a, %b, %c : vector<32xi8> to vector<8xi32>
5252
return %0 : vector<8xi32>
5353
}
@@ -57,7 +57,7 @@ func.func @smmla_invalid_dimensions(%a: vector<8xi32>,
5757
func.func @ummla_invalid_input_types(%a: vector<4xi32>,
5858
%b: vector<16xi4>,
5959
%c: vector<16xi4>) -> vector<4xi32> {
60-
// expected-error@+1 {{op operand #1 must be vector of 8-bit signless integer values of length 16, but got 'vector<16xi4>'}}
60+
// expected-error@+1 {{op operand #1 must be Neon vector of 16x8-bit signless integer of 8-bit signless integer values, but got 'vector<16xi4>'}}
6161
%0 = arm_neon.intr.ummla %a, %b, %c : vector<16xi4> to vector<4xi32>
6262
return %0 : vector<4xi32>
6363
}
@@ -67,7 +67,7 @@ func.func @ummla_invalid_input_types(%a: vector<4xi32>,
6767
func.func @ummla_invalid_dimensions(%a: vector<8xi32>,
6868
%b: vector<32xi8>,
6969
%c: vector<32xi8>) -> vector<8xi32> {
70-
// expected-error@+1 {{op operand #0 must be vector of 32-bit signless integer values of length 4, but got 'vector<8xi32>'}}
70+
// expected-error@+1 {{op operand #0 must be Neon vector of 4x32-bit signless integer of 32-bit signless integer values, but got 'vector<8xi32>}}
7171
%0 = arm_neon.intr.ummla %a, %b, %c : vector<32xi8> to vector<8xi32>
7272
return %0 : vector<8xi32>
7373
}
@@ -77,7 +77,7 @@ func.func @ummla_invalid_dimensions(%a: vector<8xi32>,
7777
func.func @usmmla_invalid_input_types(%a: vector<4xi32>,
7878
%b: vector<16xi4>,
7979
%c: vector<16xi4>) -> vector<4xi32> {
80-
// expected-error@+1 {{op operand #1 must be vector of 8-bit signless integer values of length 16, but got 'vector<16xi4>'}}
80+
// expected-error@+1 {{op operand #1 must be Neon vector of 16x8-bit signless integer of 8-bit signless integer values, but got 'vector<16xi4}}
8181
%0 = arm_neon.intr.usmmla %a, %b, %c : vector<16xi4> to vector<4xi32>
8282
return %0 : vector<4xi32>
8383
}
@@ -87,7 +87,7 @@ func.func @usmmla_invalid_input_types(%a: vector<4xi32>,
8787
func.func @usmmla_invalid_dimensions(%a: vector<8xi32>,
8888
%b: vector<32xi8>,
8989
%c: vector<32xi8>) -> vector<8xi32> {
90-
// expected-error@+1 {{op operand #0 must be vector of 32-bit signless integer values of length 4, but got 'vector<8xi32>'}}
90+
// expected-error@+1 {{op operand #0 must be Neon vector of 4x32-bit signless integer of 32-bit signless integer values, but got 'vector<8xi32>'}}
9191
%0 = arm_neon.intr.usmmla %a, %b, %c : vector<32xi8> to vector<8xi32>
9292
return %0 : vector<8xi32>
9393
}

mlir/test/Target/LLVMIR/arm-neon.mlir

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ llvm.func @arm_neon_sdot_8_i8i8(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<
3636

3737
// -----
3838

39-
4039
// CHECK-LABEL: arm_neon_sdot_16_i8i8
4140
llvm.func @arm_neon_sdot_16_i8i8(%a: vector<4xi32>, %b: vector<16xi8>, %c: vector<16xi8>) -> vector<4xi32> {
4241
// CHECK: %[[V0:.*]] = call <4 x i32> @llvm.aarch64.neon.sdot.v4i32.v16i8(<4 x i32> %{{.*}}, <16 x i8> %{{.*}}, <16 x i8> %{{.*}})

0 commit comments

Comments
 (0)