Skip to content

Commit 5ebdff1

Browse files
committed
[mlir][ArmSME] Provide descriptions and summaries for types
The auto-generated summaries are hard to read (and pretty unhelpful), and SME tile was: ``` vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values ``` ...and an SVE vector: ``` of ranks 1scalable vector of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 128-bit signless integer or 16-bit float or bfloat16 type or 32-bit float or 64-bit float values of length 16/8/4/2/1 ``` Note: The descriptions added here won't yet be shown on the MLIR docs (only the short summaries), but this should be easy to enable like it was for attribute descriptions in llvm#67009. A table of contents (TOC) is also added to the ArmSME docs page to make it easier to navigate.
1 parent 9ca6bf3 commit 5ebdff1

File tree

4 files changed

+90
-14
lines changed

4 files changed

+90
-14
lines changed

mlir/docs/Dialects/ArmSME.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# 'ArmSME' Dialect
22

3+
[TOC]
4+
35
Basic dialect to target Arm SME architectures This dialect contains the
46
definitions necessary to target Arm SME scalable matrix operations.
57

mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,45 @@ def ArmSME_Dialect : Dialect {
4242
// ArmSME type definitions
4343
//===----------------------------------------------------------------------===//
4444

45+
// FIXME: This allows types that are not SVE vectors, e.g. vector<[16]xi128>.
4546
def SVEVector : ScalableVectorOfRankAndLengthAndType<
46-
[1], [16, 8, 4, 2, 1], [I8, I16, I32, I64, I128, F16, BF16, F32, F64]>;
47+
[1], [16, 8, 4, 2, 1], [I8, I16, I32, I64, I128, F16, BF16, F32, F64]>
48+
{
49+
let summary = "a vector type that matches the size of a SVE vector";
50+
let description = [{
51+
Possible vector types:
52+
53+
Integer elements:
54+
55+
* `vector<[16]xi8>`
56+
* `vector<[8]xi16>`
57+
* `vector<[4]xi32>`
58+
* `vector<[2]xi64>`
59+
* `vector<[1]xi128>`
60+
61+
Floating point elements:
62+
63+
* `vector<[8]xf16>`
64+
* `vector<[8]xbf16>`
65+
* `vector<[4]xf32>`
66+
* `vector<[2]xf64>`
67+
}];
68+
}
4769

4870
def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
49-
[1], [16, 8, 4, 2, 1], [I1]>;
71+
[1], [16, 8, 4, 2, 1], [I1]>
72+
{
73+
let summary = "a vector type that matches the size of a SVE predicate";
74+
let description = [{
75+
Possible vector types:
76+
77+
* `vector<[16]xi1>`
78+
* `vector<[8]xi1>`
79+
* `vector<[4]xi1>`
80+
* `vector<[2]xi1>`
81+
* `vector<[1]xi1>`
82+
}];
83+
}
5084

5185

5286
#endif // ARMSME

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,47 @@ def nxnxv4f32 : SMETileType<F32, [4, 4 ], "vector<[4]x[4]xf32>">;
4343
def nxnxv2f64 : SMETileType<F64, [2, 2 ], "vector<[2]x[2]xf64>">;
4444

4545
def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
46-
nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64]>;
46+
nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64],
47+
"a vector type that fits into a SME tile">
48+
{
49+
let description = [{
50+
Possible vector types:
51+
52+
Integer elements:
53+
54+
* `vector<[16]x[16]xi8>`
55+
* `vector<[8]x[8]xi16>`
56+
* `vector<[4]x[4]xi32>`
57+
* `vector<[2]x[2]xi64>`
58+
* `vector<[1]x[1]xi128>`
59+
60+
Floating point elements:
61+
62+
* `vector<[8]x[8]xf16>`
63+
* `vector<[8]x[8]xbf16>`
64+
* `vector<[4]x[4]xf32>`
65+
* `vector<[2]x[2]xf64>`
66+
}];
67+
}
68+
69+
def TileID : AnyTypeOf<[I8, I16, I32, I64, I128],
70+
"an identifier of a virtual tile (of a size) within the ZA storage">
71+
{
72+
let description = [{
73+
The tile ID is an 8, 16, 32, 64, or 128-bit signless integer. The value of
74+
the integer indicates the tile to use, and the bit size indicates the size
75+
of tile. The number of tiles available and the element types of those depend
76+
on the size. This is summarised below:
77+
78+
| Tile ID Type | Possible Tile IDs | Tile Vector Types |
79+
|--------------|---------------------|-------------------------------------------------------------------------|
80+
| `i8` | 0 | `vector<[16]x[16]xi8>` |
81+
| `i16` | 0 and 1 | `vector<[8]x[8]xi16>`, `vector<[8]x[8]xf16>`, or `vector<[8]x[8]xbf16>` |
82+
| `i32` | 0 to 3 (inclusive) | `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>` |
83+
| `i64` | 0 to 7 (inclusive) | `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>` |
84+
| `i128` | 0 to 15 (inclusive) | `vector<[1]x[1]xi128>` |
85+
}];
86+
}
4787

4888
// A type constraint that verifies the bitwidth of the scalar integer returned
4989
// from 'arm_sme.get_tile_id' matches the element bitwidth of a "virtual tile".
@@ -145,7 +185,7 @@ def CastTileToVector : ArmSME_Op<"cast_tile_to_vector", [Pure, TileElementWidthM
145185
Canonicalization will look through `arm_sme.cast_tile_to_vector` and fold
146186
the cast away if it comes from a `arm_sme.cast_vector_to_tile`.
147187
}];
148-
let arguments = (ins AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id);
188+
let arguments = (ins TileID:$tile_id);
149189
let results = (outs SMETile:$vector);
150190
let assemblyFormat =
151191
"$tile_id attr-dict `:` type($tile_id) `to` type($vector)";
@@ -181,7 +221,7 @@ def CastVectorToTile : ArmSME_Op<"cast_vector_to_tile", [Pure, TileElementWidthM
181221
the cast away if it comes from a `arm_sme.cast_tile_to_vector`.
182222
}];
183223
let arguments = (ins SMETile:$vector);
184-
let results = (outs AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id);
224+
let results = (outs TileID:$tile_id);
185225
let assemblyFormat =
186226
"$vector attr-dict `:` type($vector) `to` type($tile_id)";
187227
let hasCanonicalizeMethod = 1;
@@ -217,7 +257,7 @@ def GetTileID : ArmSME_Op<"get_tile_id"> {
217257
```
218258
}];
219259

220-
let results = (outs AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id);
260+
let results = (outs TileID:$tile_id);
221261
let assemblyFormat = "attr-dict `:` type($tile_id)";
222262
}
223263

mlir/test/Dialect/ArmSME/invalid.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,39 +15,39 @@ func.func @arm_sme_cast_tile_to_vector__bad_tile_id_bitwidth(%tile_id : i8) -> v
1515
// -----
1616

1717
func.func @arm_sme_cast_tile_to_vector__bad_vector_type_rank_1(%tile_id : i8) -> vector<[16]xi8> {
18-
// expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]xi8>'}}
18+
// expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]xi8>'}}
1919
%0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]xi8>
2020
return %0 : vector<[16]xi8>
2121
}
2222

2323
// -----
2424

2525
func.func @arm_sme_cast_tile_to_vector__bad_vector_type_i4(%tile_id : i8) -> vector<[16]x[16]xi4> {
26-
// expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]x[16]xi4>'}}
26+
// expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]x[16]xi4>'}}
2727
%0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi4>
2828
return %0 : vector<[16]x[16]xi4>
2929
}
3030

3131
// -----
3232

3333
func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_0(%tile_id : i8) -> vector<16x[16]xi8> {
34-
// expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<16x[16]xi8>'}}
34+
// expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<16x[16]xi8>'}}
3535
%0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<16x[16]xi8>
3636
return %0 : vector<16x[16]xi8>
3737
}
3838

3939
// -----
4040

4141
func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_1(%tile_id : i8) -> vector<[16]x16xi8> {
42-
// expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]x16xi8>'}}
42+
// expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]x16xi8>'}}
4343
%0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x16xi8>
4444
return %0 : vector<[16]x16xi8>
4545
}
4646

4747
// -----
4848

4949
func.func @arm_sme_cast_tile_to_vector_bad_shape(%tile_id : i8) -> vector<[4]x[16]xi8> {
50-
// expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[4]x[16]xi8>'}}
50+
// expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[4]x[16]xi8>'}}
5151
%0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[4]x[16]xi8>
5252
return %0 : vector<[4]x[16]xi8>
5353
}
@@ -67,7 +67,7 @@ func.func @arm_sme_cast_vector_to_tile__bad_tile_id_bitwidth(%vector : vector<[1
6767
// -----
6868

6969
func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -> i8 {
70-
// expected-error@+1 {{op operand #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]xi8>'}}
70+
// expected-error@+1 {{op operand #0 must be a vector type that fits into a SME tile, but got 'vector<[16]xi8>'}}
7171
%0 = arm_sme.cast_vector_to_tile %vector : vector<[16]xi8> to i8
7272
return %0 : i8
7373
}
@@ -79,7 +79,7 @@ func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -
7979
// -----
8080

8181
func.func @arm_sme_get_tile_id__bad_type() -> i1 {
82-
// expected-error@+1 {{op result #0 must be 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 128-bit signless integer}}
82+
// expected-error@+1 {{op result #0 must be an identifier of a virtual tile (of a size) within the ZA storage}}
8383
%0 = arm_sme.get_tile_id : i1
8484
return %0 : i1
8585
}
@@ -172,7 +172,7 @@ func.func @arm_sme_load_tile_slice__bad_mask_type(%src : memref<?x?xi8>, %mask :
172172

173173
func.func @arm_sme_outerproduct__bad_result_type(%vecA: vector<[2]xi16>, %vecB: vector<[2]xi16>) -> vector<[2]x[2]xi16>
174174
{
175-
// expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[2]x[2]xi16>'}}
175+
// expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[2]x[2]xi16>'}}
176176
%0 = arm_sme.outerproduct %vecA, %vecB : vector<[2]xi16>, vector<[2]xi16>
177177
return %0 : vector<[2]x[2]xi16>
178178
}

0 commit comments

Comments
 (0)