Skip to content

Commit f798bf8

Browse files
authored
[mlir][ArmSME] Provide descriptions and summaries for types (#70920)
The auto-generated summaries were hard to read (and pretty unhelpful), a SME tile was: ``` vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values ``` ...and a SVE vector was: ``` of ranks 1scalable vector of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 128-bit signless integer or 16-bit float or bfloat16 type or 32-bit float or 64-bit float values of length 16/8/4/2/1 ``` Note: The descriptions added here won't yet be shown on the MLIR docs (only the short summaries), but this should be easy to enable like it was for attribute descriptions in #67009. A table of contents (TOC) is also added to the ArmSME docs page to make it easier to navigate.
1 parent b3523d7 commit f798bf8

File tree

4 files changed

+90
-14
lines changed

4 files changed

+90
-14
lines changed

mlir/docs/Dialects/ArmSME.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# 'ArmSME' Dialect
22

3+
[TOC]
4+
35
Basic dialect to target Arm SME architectures This dialect contains the
46
definitions necessary to target Arm SME scalable matrix operations.
57

mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,45 @@ def ArmSME_Dialect : Dialect {
4242
// ArmSME type definitions
4343
//===----------------------------------------------------------------------===//
4444

45+
// FIXME: This allows types that are not SVE vectors, e.g. vector<[16]xi128>.
4546
def SVEVector : ScalableVectorOfRankAndLengthAndType<
46-
[1], [16, 8, 4, 2, 1], [I8, I16, I32, I64, I128, F16, BF16, F32, F64]>;
47+
[1], [16, 8, 4, 2, 1], [I8, I16, I32, I64, I128, F16, BF16, F32, F64]>
48+
{
49+
let summary = "a vector type that matches the size of a SVE vector";
50+
let description = [{
51+
Possible vector types:
52+
53+
Integer elements:
54+
55+
* `vector<[16]xi8>`
56+
* `vector<[8]xi16>`
57+
* `vector<[4]xi32>`
58+
* `vector<[2]xi64>`
59+
* `vector<[1]xi128>`
60+
61+
Floating point elements:
62+
63+
* `vector<[8]xf16>`
64+
* `vector<[8]xbf16>`
65+
* `vector<[4]xf32>`
66+
* `vector<[2]xf64>`
67+
}];
68+
}
4769

4870
def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
49-
[1], [16, 8, 4, 2, 1], [I1]>;
71+
[1], [16, 8, 4, 2, 1], [I1]>
72+
{
73+
let summary = "a vector type that matches the size of a SVE predicate";
74+
let description = [{
75+
Possible vector types:
76+
77+
* `vector<[16]xi1>`
78+
* `vector<[8]xi1>`
79+
* `vector<[4]xi1>`
80+
* `vector<[2]xi1>`
81+
* `vector<[1]xi1>`
82+
}];
83+
}
5084

5185

5286
#endif // ARMSME

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,47 @@ def nxnxv4f32 : SMETileType<F32, [4, 4 ], "vector<[4]x[4]xf32>">;
4343
def nxnxv2f64 : SMETileType<F64, [2, 2 ], "vector<[2]x[2]xf64>">;
4444

4545
def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
46-
nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64]>;
46+
nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64],
47+
"a vector type that fits into a SME tile">
48+
{
49+
let description = [{
50+
Possible vector types:
51+
52+
Integer elements:
53+
54+
* `vector<[16]x[16]xi8>`
55+
* `vector<[8]x[8]xi16>`
56+
* `vector<[4]x[4]xi32>`
57+
* `vector<[2]x[2]xi64>`
58+
* `vector<[1]x[1]xi128>`
59+
60+
Floating point elements:
61+
62+
* `vector<[8]x[8]xf16>`
63+
* `vector<[8]x[8]xbf16>`
64+
* `vector<[4]x[4]xf32>`
65+
* `vector<[2]x[2]xf64>`
66+
}];
67+
}
68+
69+
def TileID : AnyTypeOf<[I8, I16, I32, I64, I128],
70+
"an identifier of a virtual tile (of a size) within the ZA storage">
71+
{
72+
let description = [{
73+
The tile ID is an 8, 16, 32, 64, or 128-bit signless integer. The value of
74+
the integer indicates the tile to use, and the bit size indicates the size
75+
of tile. The number of tiles available and the element types of those depend
76+
on the size. This is summarised below:
77+
78+
| Tile ID Type | Possible Tile IDs | Tile Vector Types |
79+
|--------------|---------------------|-------------------------------------------------------------------------|
80+
| `i8` | 0 | `vector<[16]x[16]xi8>` |
81+
| `i16` | 0 and 1 | `vector<[8]x[8]xi16>`, `vector<[8]x[8]xf16>`, or `vector<[8]x[8]xbf16>` |
82+
| `i32` | 0 to 3 (inclusive) | `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>` |
83+
| `i64` | 0 to 7 (inclusive) | `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>` |
84+
| `i128` | 0 to 15 (inclusive) | `vector<[1]x[1]xi128>` |
85+
}];
86+
}
4787

4888
// A type constraint that verifies the bitwidth of the scalar integer returned
4989
// from 'arm_sme.get_tile_id' matches the element bitwidth of a "virtual tile".
@@ -160,7 +200,7 @@ def CastTileToVector : ArmSME_Op<"cast_tile_to_vector", [Pure, TileElementWidthM
160200
Canonicalization will look through `arm_sme.cast_tile_to_vector` and fold
161201
the cast away if it comes from a `arm_sme.cast_vector_to_tile`.
162202
}];
163-
let arguments = (ins AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id);
203+
let arguments = (ins TileID:$tile_id);
164204
let results = (outs SMETile:$vector);
165205
let assemblyFormat =
166206
"$tile_id attr-dict `:` type($tile_id) `to` type($vector)";
@@ -196,7 +236,7 @@ def CastVectorToTile : ArmSME_Op<"cast_vector_to_tile", [Pure, TileElementWidthM
196236
the cast away if it comes from a `arm_sme.cast_tile_to_vector`.
197237
}];
198238
let arguments = (ins SMETile:$vector);
199-
let results = (outs AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id);
239+
let results = (outs TileID:$tile_id);
200240
let assemblyFormat =
201241
"$vector attr-dict `:` type($vector) `to` type($tile_id)";
202242
let hasCanonicalizeMethod = 1;
@@ -232,7 +272,7 @@ def GetTileID : ArmSME_Op<"get_tile_id"> {
232272
```
233273
}];
234274

235-
let results = (outs AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id);
275+
let results = (outs TileID:$tile_id);
236276
let assemblyFormat = "attr-dict `:` type($tile_id)";
237277
}
238278

mlir/test/Dialect/ArmSME/invalid.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,39 +15,39 @@ func.func @arm_sme_cast_tile_to_vector__bad_tile_id_bitwidth(%tile_id : i8) -> v
1515
// -----
1616

1717
func.func @arm_sme_cast_tile_to_vector__bad_vector_type_rank_1(%tile_id : i8) -> vector<[16]xi8> {
18-
// expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]xi8>'}}
18+
// expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]xi8>'}}
1919
%0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]xi8>
2020
return %0 : vector<[16]xi8>
2121
}
2222

2323
// -----
2424

2525
func.func @arm_sme_cast_tile_to_vector__bad_vector_type_i4(%tile_id : i8) -> vector<[16]x[16]xi4> {
26-
// expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]x[16]xi4>'}}
26+
// expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]x[16]xi4>'}}
2727
%0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi4>
2828
return %0 : vector<[16]x[16]xi4>
2929
}
3030

3131
// -----
3232

3333
func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_0(%tile_id : i8) -> vector<16x[16]xi8> {
34-
// expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<16x[16]xi8>'}}
34+
// expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<16x[16]xi8>'}}
3535
%0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<16x[16]xi8>
3636
return %0 : vector<16x[16]xi8>
3737
}
3838

3939
// -----
4040

4141
func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_1(%tile_id : i8) -> vector<[16]x16xi8> {
42-
// expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]x16xi8>'}}
42+
// expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]x16xi8>'}}
4343
%0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x16xi8>
4444
return %0 : vector<[16]x16xi8>
4545
}
4646

4747
// -----
4848

4949
func.func @arm_sme_cast_tile_to_vector_bad_shape(%tile_id : i8) -> vector<[4]x[16]xi8> {
50-
// expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[4]x[16]xi8>'}}
50+
// expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[4]x[16]xi8>'}}
5151
%0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[4]x[16]xi8>
5252
return %0 : vector<[4]x[16]xi8>
5353
}
@@ -67,7 +67,7 @@ func.func @arm_sme_cast_vector_to_tile__bad_tile_id_bitwidth(%vector : vector<[1
6767
// -----
6868

6969
func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -> i8 {
70-
// expected-error@+1 {{op operand #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]xi8>'}}
70+
// expected-error@+1 {{op operand #0 must be a vector type that fits into a SME tile, but got 'vector<[16]xi8>'}}
7171
%0 = arm_sme.cast_vector_to_tile %vector : vector<[16]xi8> to i8
7272
return %0 : i8
7373
}
@@ -79,7 +79,7 @@ func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -
7979
// -----
8080

8181
func.func @arm_sme_get_tile_id__bad_type() -> i1 {
82-
// expected-error@+1 {{op result #0 must be 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 128-bit signless integer}}
82+
// expected-error@+1 {{op result #0 must be an identifier of a virtual tile (of a size) within the ZA storage}}
8383
%0 = arm_sme.get_tile_id : i1
8484
return %0 : i1
8585
}
@@ -200,7 +200,7 @@ func.func @arm_sme_store_tile_slice__bad_mask_type(%tile : vector<[16]x[16]xi8>,
200200

201201
func.func @arm_sme_outerproduct__bad_result_type(%vecA: vector<[2]xi16>, %vecB: vector<[2]xi16>) -> vector<[2]x[2]xi16>
202202
{
203-
// expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[2]x[2]xi16>'}}
203+
// expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[2]x[2]xi16>'}}
204204
%0 = arm_sme.outerproduct %vecA, %vecB : vector<[2]xi16>, vector<[2]xi16>
205205
return %0 : vector<[2]x[2]xi16>
206206
}

0 commit comments

Comments
 (0)