Skip to content

Commit 97527f8

Browse files
committed
banach-space comments
1 parent 034d6ed commit 97527f8

File tree

4 files changed

+52
-48
lines changed

4 files changed

+52
-48
lines changed

mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ def SdotOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"sdot",[1], [
122122

123123
def SmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"smmla",[1], [
124124
Pure,
125-
AllTypesMatch<["b", "c"]>,
126-
AllTypesMatch<["a", "res"]>,
125+
AllTypesMatch<["src1", "src2"]>,
126+
AllTypesMatch<["acc", "res"]>,
127127
]> {
128128
let summary = "Matrix-matrix multiply and accumulate op";
129129
let description = [{
@@ -141,19 +141,19 @@ def SmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"smmla",[1], [
141141
}];
142142
// Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
143143
let arguments = (ins
144-
VectorOfLengthAndType<[4], [I32]>:$a,
145-
VectorOfLengthAndType<[16], [I8]>:$b,
146-
VectorOfLengthAndType<[16], [I8]>:$c
144+
VectorOfLengthAndType<[4], [I32]>:$acc,
145+
VectorOfLengthAndType<[16], [I8]>:$src1,
146+
VectorOfLengthAndType<[16], [I8]>:$src2
147147
);
148148
let results = (outs VectorOfLengthAndType<[4], [I32]>:$res);
149149
let assemblyFormat =
150-
"$a `,` $b `,` $c attr-dict `:` type($b) `to` type($res)";
150+
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
151151
}
152152

153153
def UmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"ummla",[1], [
154154
Pure,
155-
AllTypesMatch<["b", "c"]>,
156-
AllTypesMatch<["a", "res"]>,
155+
AllTypesMatch<["src1", "src2"]>,
156+
AllTypesMatch<["acc", "res"]>,
157157
]> {
158158
let summary = "Unsinged matrix-matrix multiply and accumulate op";
159159
let description = [{
@@ -172,19 +172,19 @@ def UmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"ummla",[1], [
172172
}];
173173
// Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
174174
let arguments = (ins
175-
VectorOfLengthAndType<[4], [I32]>:$a,
176-
VectorOfLengthAndType<[16], [I8]>:$b,
177-
VectorOfLengthAndType<[16], [I8]>:$c
175+
VectorOfLengthAndType<[4], [I32]>:$acc,
176+
VectorOfLengthAndType<[16], [I8]>:$src1,
177+
VectorOfLengthAndType<[16], [I8]>:$src2
178178
);
179179
let results = (outs VectorOfLengthAndType<[4], [I32]>:$res);
180180
let assemblyFormat =
181-
"$a `,` $b `,` $c attr-dict `:` type($b) `to` type($res)";
181+
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
182182
}
183183

184184
def UsmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"usmmla",[1], [
185185
Pure,
186-
AllTypesMatch<["b", "c"]>,
187-
AllTypesMatch<["a", "res"]>,
186+
AllTypesMatch<["src1", "src2"]>,
187+
AllTypesMatch<["acc", "res"]>,
188188
]> {
189189
let summary = "Unsignged and signed matrix-matrix multiply and accumulate op";
190190
let description = [{
@@ -204,13 +204,13 @@ def UsmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"usmmla",[1], [
204204
}];
205205
// Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
206206
let arguments = (ins
207-
VectorOfLengthAndType<[4], [I32]>:$a,
208-
VectorOfLengthAndType<[16], [I8]>:$b,
209-
VectorOfLengthAndType<[16], [I8]>:$c
207+
VectorOfLengthAndType<[4], [I32]>:$acc,
208+
VectorOfLengthAndType<[16], [I8]>:$src1,
209+
VectorOfLengthAndType<[16], [I8]>:$src2
210210
);
211211
let results = (outs VectorOfLengthAndType<[4], [I32]>:$res);
212212
let assemblyFormat =
213-
"$a `,` $b `,` $c attr-dict `:` type($b) `to` type($res)";
213+
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
214214
}
215215

216216
class ArmNeon_2dOp<string mnemonic, list<Trait> traits = []>

mlir/test/Dialect/ArmNeon/invalid.mlir

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,60 +34,60 @@ func.func @b_has_2_rows_but_a_has_length_4(%a : vector<4xi32>, %b : vector<2x4xi
3434

3535
// -----
3636

37-
func.func @smmla_invalid_input_types(%a: vector<16xi4>,
37+
func.func @smmla_invalid_input_types(%a: vector<4xi32>,
3838
%b: vector<16xi4>,
39-
%c: vector<4xi32>) -> vector<4xi32> {
39+
%c: vector<16xi4>) -> vector<4xi32> {
4040
// expected-error@+1 {{op operand #1 must be vector of 8-bit signless integer values of length 16, but got 'vector<16xi4>'}}
41-
%0 = arm_neon.intr.smmla %c, %a, %b : vector<16xi4> to vector<4xi32>
41+
%0 = arm_neon.intr.smmla %a, %b, %c : vector<16xi4> to vector<4xi32>
4242
return %0 : vector<4xi32>
4343
}
4444

4545
// -----
4646

47-
func.func @smmla_invalid_dimensions(%a: vector<32xi8>,
47+
func.func @smmla_invalid_dimensions(%a: vector<8xi32>,
4848
%b: vector<32xi8>,
49-
%c: vector<8xi32>) -> vector<8xi32> {
49+
%c: vector<32xi8>) -> vector<8xi32> {
5050
// expected-error@+1 {{op operand #0 must be vector of 32-bit signless integer values of length 4, but got 'vector<8xi32>'}}
51-
%0 = arm_neon.intr.smmla %c, %a, %b : vector<32xi8> to vector<8xi32>
51+
%0 = arm_neon.intr.smmla %a, %b, %c : vector<32xi8> to vector<8xi32>
5252
return %0 : vector<8xi32>
5353
}
5454

5555
// -----
5656

57-
func.func @ummla_invalid_input_types(%a: vector<16xi4>,
57+
func.func @ummla_invalid_input_types(%a: vector<4xi32>,
5858
%b: vector<16xi4>,
59-
%c: vector<4xi32>) -> vector<4xi32> {
59+
%c: vector<16xi4>) -> vector<4xi32> {
6060
// expected-error@+1 {{op operand #1 must be vector of 8-bit signless integer values of length 16, but got 'vector<16xi4>'}}
61-
%0 = arm_neon.intr.ummla %c, %a, %b : vector<16xi4> to vector<4xi32>
61+
%0 = arm_neon.intr.ummla %a, %b, %c : vector<16xi4> to vector<4xi32>
6262
return %0 : vector<4xi32>
6363
}
6464

6565
// -----
6666

67-
func.func @ummla_invalid_dimensions(%a: vector<32xi8>,
67+
func.func @ummla_invalid_dimensions(%a: vector<8xi32>,
6868
%b: vector<32xi8>,
69-
%c: vector<8xi32>) -> vector<8xi32> {
69+
%c: vector<32xi8>) -> vector<8xi32> {
7070
// expected-error@+1 {{op operand #0 must be vector of 32-bit signless integer values of length 4, but got 'vector<8xi32>'}}
71-
%0 = arm_neon.intr.ummla %c, %a, %b : vector<32xi8> to vector<8xi32>
71+
%0 = arm_neon.intr.ummla %a, %b, %c : vector<32xi8> to vector<8xi32>
7272
return %0 : vector<8xi32>
7373
}
7474

7575
// -----
7676

77-
func.func @usmmla_invalid_input_types(%a: vector<16xi4>,
77+
func.func @usmmla_invalid_input_types(%a: vector<4xi32>,
7878
%b: vector<16xi4>,
79-
%c: vector<4xi32>) -> vector<4xi32> {
79+
%c: vector<16xi4>) -> vector<4xi32> {
8080
// expected-error@+1 {{op operand #1 must be vector of 8-bit signless integer values of length 16, but got 'vector<16xi4>'}}
81-
%0 = arm_neon.intr.usmmla %c, %a, %b : vector<16xi4> to vector<4xi32>
81+
%0 = arm_neon.intr.usmmla %a, %b, %c : vector<16xi4> to vector<4xi32>
8282
return %0 : vector<4xi32>
8383
}
8484

8585
// -----
8686

87-
func.func @usmmla_invalid_dimensions(%a: vector<32xi8>,
87+
func.func @usmmla_invalid_dimensions(%a: vector<8xi32>,
8888
%b: vector<32xi8>,
89-
%c: vector<8xi32>) -> vector<8xi32> {
89+
%c: vector<32xi8>) -> vector<8xi32> {
9090
// expected-error@+1 {{op operand #0 must be vector of 32-bit signless integer values of length 4, but got 'vector<8xi32>'}}
91-
%0 = arm_neon.intr.usmmla %c, %a, %b : vector<32xi8> to vector<8xi32>
91+
%0 = arm_neon.intr.usmmla %a, %b, %c : vector<32xi8> to vector<8xi32>
9292
return %0 : vector<8xi32>
9393
}

mlir/test/Dialect/ArmNeon/roundtrip.mlir

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ func.func @arm_neon_smull(%a: vector<8xi8>, %b: vector<8xi8>)
1919
return %0, %1, %2 : vector<8xi16>, vector<4xi32>, vector<2xi64>
2020
}
2121

22+
// -----
23+
2224
// CHECK-LABEL: arm_neon_sdot
2325
func.func @arm_neon_sdot(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<8xi8>) -> vector<2xi32> {
2426
// CHECK: arm_neon.intr.sdot {{.*}}: vector<8xi8>, vector<8xi8> to vector<2xi32>
@@ -32,7 +34,7 @@ func.func @arm_neon_sdot(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<8xi8>)
3234
func.func @arm_neon_smmla(%a: vector<16xi8>,
3335
%b: vector<16xi8>,
3436
%c: vector<4xi32>) -> vector<4xi32> {
35-
// CHECK: arm_neon.intr.smmla {{.*}}: vector<16xi8> to vector<4xi3
37+
// CHECK: arm_neon.intr.smmla {{.*}}: vector<16xi8> to vector<4xi32>
3638
%0 = arm_neon.intr.smmla %c, %a, %b : vector<16xi8> to vector<4xi32>
3739
return %0 : vector<4xi32>
3840
}
@@ -43,7 +45,7 @@ func.func @arm_neon_smmla(%a: vector<16xi8>,
4345
func.func @arm_neon_ummla(%a: vector<16xi8>,
4446
%b: vector<16xi8>,
4547
%c: vector<4xi32>) -> vector<4xi32> {
46-
// CHECK: arm_neon.intr.ummla {{.*}}: vector<16xi8> to vector<4xi3
48+
// CHECK: arm_neon.intr.ummla {{.*}}: vector<16xi8> to vector<4xi32>
4749
%0 = arm_neon.intr.ummla %c, %a, %b : vector<16xi8> to vector<4xi32>
4850
return %0 : vector<4xi32>
4951
}
@@ -54,7 +56,7 @@ func.func @arm_neon_ummla(%a: vector<16xi8>,
5456
func.func @arm_neon_usmmla(%a: vector<16xi8>,
5557
%b: vector<16xi8>,
5658
%c: vector<4xi32>) -> vector<4xi32> {
57-
// CHECK: arm_neon.intr.usmmla {{.*}}: vector<16xi8> to vector<4xi3
59+
// CHECK: arm_neon.intr.usmmla {{.*}}: vector<16xi8> to vector<4xi32>
5860
%0 = arm_neon.intr.usmmla %c, %a, %b : vector<16xi8> to vector<4xi32>
5961
return %0 : vector<4xi32>
6062
}

mlir/test/Target/LLVMIR/arm-neon.mlir

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ llvm.func @arm_neon_smull(%arg0: vector<8xi8>, %arg1: vector<8xi8>) -> !llvm.str
2424
llvm.return %8 : !llvm.struct<(vector<8xi16>, vector<4xi32>, vector<2xi64>)>
2525
}
2626

27+
// -----
28+
2729
// CHECK-LABEL: arm_neon_sdot_8_i8i8
2830
llvm.func @arm_neon_sdot_8_i8i8(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<8xi8>) -> vector<2xi32> {
2931
// CHECK: %[[V0:.*]] = call <2 x i32> @llvm.aarch64.neon.sdot.v2i32.v8i8(<2 x i32> %{{.*}}, <8 x i8> %{{.*}}, <8 x i8> %{{.*}})
@@ -32,6 +34,9 @@ llvm.func @arm_neon_sdot_8_i8i8(%a: vector<2xi32>, %b: vector<8xi8>, %c: vector<
3234
llvm.return %0 : vector<2xi32>
3335
}
3436

37+
// -----
38+
39+
3540
// CHECK-LABEL: arm_neon_sdot_16_i8i8
3641
llvm.func @arm_neon_sdot_16_i8i8(%a: vector<4xi32>, %b: vector<16xi8>, %c: vector<16xi8>) -> vector<4xi32> {
3742
// CHECK: %[[V0:.*]] = call <4 x i32> @llvm.aarch64.neon.sdot.v4i32.v16i8(<4 x i32> %{{.*}}, <16 x i8> %{{.*}}, <16 x i8> %{{.*}})
@@ -42,11 +47,10 @@ llvm.func @arm_neon_sdot_16_i8i8(%a: vector<4xi32>, %b: vector<16xi8>, %c: vecto
4247

4348
// -----
4449

45-
// CHECK-LABEL: define <4 x i32> @arm_neon_smmla
50+
// CHECK-LABEL: arm_neon_smmla
4651
llvm.func @arm_neon_smmla(%arg0: vector<16xi8>,
4752
%arg1: vector<16xi8>,
48-
%arg2: vector<4xi32>)
49-
-> vector<4xi32> {
53+
%arg2: vector<4xi32>) -> vector<4xi32> {
5054
// CHECK: <4 x i32> @llvm.aarch64.neon.smmla.v4i32.v16i8(<4 x i32
5155
%0 = "arm_neon.intr.smmla"(%arg2, %arg0, %arg1) :
5256
(vector<4xi32>, vector<16xi8>, vector<16xi8>)
@@ -56,11 +60,10 @@ llvm.func @arm_neon_smmla(%arg0: vector<16xi8>,
5660

5761
// -----
5862

59-
// CHECK-LABEL: define <4 x i32> @arm_neon_ummla
63+
// CHECK-LABEL: arm_neon_ummla
6064
llvm.func @arm_neon_ummla(%arg0: vector<16xi8>,
6165
%arg1: vector<16xi8>,
62-
%arg2: vector<4xi32>)
63-
-> vector<4xi32> {
66+
%arg2: vector<4xi32>) -> vector<4xi32> {
6467
// CHECK: <4 x i32> @llvm.aarch64.neon.ummla.v4i32.v16i8(<4 x i32
6568
%0 = "arm_neon.intr.ummla"(%arg2, %arg0, %arg1) :
6669
(vector<4xi32>, vector<16xi8>, vector<16xi8>)
@@ -70,11 +73,10 @@ llvm.func @arm_neon_ummla(%arg0: vector<16xi8>,
7073

7174
// -----
7275

73-
// CHECK-LABEL: define <4 x i32> @arm_neon_usmmla
76+
// CHECK-LABEL: arm_neon_usmmla
7477
llvm.func @arm_neon_usmmla(%arg0: vector<16xi8>,
7578
%arg1: vector<16xi8>,
76-
%arg2: vector<4xi32>)
77-
-> vector<4xi32> {
79+
%arg2: vector<4xi32>) -> vector<4xi32> {
7880
// CHECK: <4 x i32> @llvm.aarch64.neon.usmmla.v4i32.v16i8(<4 x i32
7981
%0 = "arm_neon.intr.usmmla"(%arg2, %arg0, %arg1) :
8082
(vector<4xi32>, vector<16xi8>, vector<16xi8>)

0 commit comments

Comments
 (0)