-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][ArmSME] Provide descriptions and summaries for types #70920
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-sme @llvm/pr-subscribers-mlir Author: Benjamin Maxwell (MacDue) ChangesThe auto-generated summaries were hard to read (and pretty unhelpful), a SME tile was:
...and a SVE vector was:
Note: The descriptions added here won't yet be shown on the MLIR docs (only the short summaries), but this should be easy to enable like it was for attribute descriptions in #67009. A table of contents (TOC) is also added to the ArmSME docs page to make it easier to navigate. Full diff: https://github.com/llvm/llvm-project/pull/70920.diff 4 Files Affected:
diff --git a/mlir/docs/Dialects/ArmSME.md b/mlir/docs/Dialects/ArmSME.md
index ab7c9ffe7aa92f1..505b52938eacc05 100644
--- a/mlir/docs/Dialects/ArmSME.md
+++ b/mlir/docs/Dialects/ArmSME.md
@@ -1,5 +1,7 @@
# 'ArmSME' Dialect
+[TOC]
+
Basic dialect to target Arm SME architectures This dialect contains the
definitions necessary to target Arm SME scalable matrix operations.
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 18b9bd7a107febf..835307d3d6b9786 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -42,11 +42,45 @@ def ArmSME_Dialect : Dialect {
// ArmSME type definitions
//===----------------------------------------------------------------------===//
+// FIXME: This allows types there are not SVE vectors, e.g. vector<[16]xi128>.
def SVEVector : ScalableVectorOfRankAndLengthAndType<
- [1], [16, 8, 4, 2, 1], [I8, I16, I32, I64, I128, F16, BF16, F32, F64]>;
+ [1], [16, 8, 4, 2, 1], [I8, I16, I32, I64, I128, F16, BF16, F32, F64]>
+{
+ let summary = "a vector type that matches the size of a SVE vector";
+ let description = [{
+ Possible vector types:
+
+ Integer elements:
+
+ * `vector<[16]xi8>`
+ * `vector<[8]xi16>`
+ * `vector<[4]xi32>`
+ * `vector<[2]xi64>`
+ * `vector<[1]xi128>`
+
+ Floating point elements:
+
+ * `vector<[8]xf16>`
+ * `vector<[8]xbf16>`
+ * `vector<[4]xf32>`
+ * `vector<[2]xf64>`
+ }];
+}
def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
- [1], [16, 8, 4, 2, 1], [I1]>;
+ [1], [16, 8, 4, 2, 1], [I1]>
+{
+ let summary = "a vector type that matches the size of a SVE predicate";
+ let description = [{
+ Possible vector types:
+
+ * `vector<[16]xi1>`
+ * `vector<[8]xi1>`
+ * `vector<[4]xi1>`
+ * `vector<[2]xi1>`
+ * `vector<[1]xi1>`
+ }];
+}
#endif // ARMSME
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 37a2257a0015ce7..e57a8acd82de758 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -43,7 +43,47 @@ def nxnxv4f32 : SMETileType<F32, [4, 4 ], "vector<[4]x[4]xf32>">;
def nxnxv2f64 : SMETileType<F64, [2, 2 ], "vector<[2]x[2]xf64>">;
def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
- nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64]>;
+ nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64],
+ "a vector type that fits into a SME tile">
+{
+ let description = [{
+ Possible vector types:
+
+ Integer elements:
+
+ * `vector<[16]x[16]xi8>`
+ * `vector<[8]x[8]xi16>`
+ * `vector<[4]x[4]xi32>`
+ * `vector<[2]x[2]xi64>`
+ * `vector<[1]x[1]xi128>`
+
+ Floating point elements:
+
+ * `vector<[8]x[8]xf16>`
+ * `vector<[8]x[8]xbf16>`
+ * `vector<[4]x[4]xf32>`
+ * `vector<[2]x[2]xf64>`
+ }];
+}
+
+def TileID : AnyTypeOf<[I8, I16, I32, I64, I128],
+ "an identifier of a virtual tile (of a size) within the ZA storage">
+{
+ let description = [{
+ The tile ID is an 8, 16, 32, 64, or 128-bit signless integer. The value of
+ the integer indicates the tile to use, and the bit size indicates the size
+ of tile. The number of tiles available and the element types of those depend
+ on the size. This is summarised below:
+
+ | Tile ID Type | Possible Tile IDs | Tile Vector Types |
+ |--------------|---------------------|-------------------------------------------------------------------------|
+ | `i8` | 0 | `vector<[16]x[16]xi8>` |
+ | `i16` | 0 and 1 | `vector<[8]x[8]xi16>`, `vector<[8]x[8]xf16>`, or `vector<[8]x[8]xbf16>` |
+ | `i32` | 0 to 3 (inclusive) | `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>` |
+ | `i64` | 0 to 7 (inclusive) | `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>` |
+ | `i128` | 0 to 15 (inclusive) | `vector<[1]x[1]xi128>` |
+ }];
+}
// A type constraint that verifies the bitwidth of the scalar integer returned
// from 'arm_sme.get_tile_id' matches the element bitwidth of a "virtual tile".
@@ -145,7 +185,7 @@ def CastTileToVector : ArmSME_Op<"cast_tile_to_vector", [Pure, TileElementWidthM
Canonicalization will look through `arm_sme.cast_tile_to_vector` and fold
the cast away if it comes from a `arm_sme.cast_vector_to_tile`.
}];
- let arguments = (ins AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id);
+ let arguments = (ins TileID:$tile_id);
let results = (outs SMETile:$vector);
let assemblyFormat =
"$tile_id attr-dict `:` type($tile_id) `to` type($vector)";
@@ -181,7 +221,7 @@ def CastVectorToTile : ArmSME_Op<"cast_vector_to_tile", [Pure, TileElementWidthM
the cast away if it comes from a `arm_sme.cast_tile_to_vector`.
}];
let arguments = (ins SMETile:$vector);
- let results = (outs AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id);
+ let results = (outs TileID:$tile_id);
let assemblyFormat =
"$vector attr-dict `:` type($vector) `to` type($tile_id)";
let hasCanonicalizeMethod = 1;
@@ -217,7 +257,7 @@ def GetTileID : ArmSME_Op<"get_tile_id"> {
```
}];
- let results = (outs AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id);
+ let results = (outs TileID:$tile_id);
let assemblyFormat = "attr-dict `:` type($tile_id)";
}
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 1d6386bbf3828fa..666847dc60f51a5 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -15,7 +15,7 @@ func.func @arm_sme_cast_tile_to_vector__bad_tile_id_bitwidth(%tile_id : i8) -> v
// -----
func.func @arm_sme_cast_tile_to_vector__bad_vector_type_rank_1(%tile_id : i8) -> vector<[16]xi8> {
- // expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]xi8>'}}
+ // expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]xi8>'}}
%0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]xi8>
return %0 : vector<[16]xi8>
}
@@ -23,7 +23,7 @@ func.func @arm_sme_cast_tile_to_vector__bad_vector_type_rank_1(%tile_id : i8) ->
// -----
func.func @arm_sme_cast_tile_to_vector__bad_vector_type_i4(%tile_id : i8) -> vector<[16]x[16]xi4> {
- // expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]x[16]xi4>'}}
+ // expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]x[16]xi4>'}}
%0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi4>
return %0 : vector<[16]x[16]xi4>
}
@@ -31,7 +31,7 @@ func.func @arm_sme_cast_tile_to_vector__bad_vector_type_i4(%tile_id : i8) -> vec
// -----
func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_0(%tile_id : i8) -> vector<16x[16]xi8> {
- // expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<16x[16]xi8>'}}
+ // expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<16x[16]xi8>'}}
%0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<16x[16]xi8>
return %0 : vector<16x[16]xi8>
}
@@ -39,7 +39,7 @@ func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_0(%tile
// -----
func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_1(%tile_id : i8) -> vector<[16]x16xi8> {
- // expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]x16xi8>'}}
+ // expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]x16xi8>'}}
%0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x16xi8>
return %0 : vector<[16]x16xi8>
}
@@ -47,7 +47,7 @@ func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_1(%tile
// -----
func.func @arm_sme_cast_tile_to_vector_bad_shape(%tile_id : i8) -> vector<[4]x[16]xi8> {
- // expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[4]x[16]xi8>'}}
+ // expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[4]x[16]xi8>'}}
%0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[4]x[16]xi8>
return %0 : vector<[4]x[16]xi8>
}
@@ -67,7 +67,7 @@ func.func @arm_sme_cast_vector_to_tile__bad_tile_id_bitwidth(%vector : vector<[1
// -----
func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -> i8 {
- // expected-error@+1 {{op operand #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]xi8>'}}
+ // expected-error@+1 {{op operand #0 must be a vector type that fits into a SME tile, but got 'vector<[16]xi8>'}}
%0 = arm_sme.cast_vector_to_tile %vector : vector<[16]xi8> to i8
return %0 : i8
}
@@ -79,7 +79,7 @@ func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -
// -----
func.func @arm_sme_get_tile_id__bad_type() -> i1 {
- // expected-error@+1 {{op result #0 must be 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 128-bit signless integer}}
+ // expected-error@+1 {{op result #0 must be an identifier of a virtual tile (of a size) within the ZA storage}}
%0 = arm_sme.get_tile_id : i1
return %0 : i1
}
@@ -172,7 +172,7 @@ func.func @arm_sme_load_tile_slice__bad_mask_type(%src : memref<?x?xi8>, %mask :
func.func @arm_sme_outerproduct__bad_result_type(%vecA: vector<[2]xi16>, %vecB: vector<[2]xi16>) -> vector<[2]x[2]xi16>
{
- // expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[2]x[2]xi16>'}}
+ // expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[2]x[2]xi16>'}}
%0 = arm_sme.outerproduct %vecA, %vecB : vector<[2]xi16>, vector<[2]xi16>
return %0 : vector<[2]x[2]xi16>
}
|
The auto-generated summaries are hard to read (and pretty unhelpful), and SME tile was: ``` vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values ``` ...and an SVE vector: ``` of ranks 1scalable vector of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 128-bit signless integer or 16-bit float or bfloat16 type or 32-bit float or 64-bit float values of length 16/8/4/2/1 ``` Note: The descriptions added here won't yet be shown on the MLIR docs (only the short summaries), but this should be easy to enable like it was for attribute descriptions in llvm#67009. A table of contents (TOC) is also added to the ArmSME docs page to make it easier to navigate.
895e8ca
to
5ebdff1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM cheers
Follow on for some types missed in llvm#70920. This also replaces the LDSTPredicate with SVEPredicate (as they are equivalent), and adds a missing rank == 1 checks to the SVE vector types. A FIXME is also added to point out an issue in the MOPVector type constraint.
The auto-generated summaries were hard to read (and pretty unhelpful), a SME tile was:
...and a SVE vector was:
Note: The descriptions added here won't yet be shown on the MLIR docs (only the short summaries), but this should be easy to enable like it was for attribute descriptions in #67009.
A table of contents (TOC) is also added to the ArmSME docs page to make it easier to navigate.