Skip to content

[mlir][ArmSME] Provide descriptions and summaries for types #70920

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 2, 2023

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Nov 1, 2023

The auto-generated summaries were hard to read (and pretty unhelpful), a SME tile was:

vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values

...and a SVE vector was:

of ranks 1scalable vector of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 128-bit signless integer or 16-bit float or bfloat16 type or 32-bit float or 64-bit float values of length 16/8/4/2/1

Note: The descriptions added here won't yet be shown on the MLIR docs (only the short summaries), but this should be easy to enable like it was for attribute descriptions in #67009.

A table of contents (TOC) is also added to the ArmSME docs page to make it easier to navigate.

@llvmbot
Copy link
Member

llvmbot commented Nov 1, 2023

@llvm/pr-subscribers-mlir-sme

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

Changes

The auto-generated summaries were hard to read (and pretty unhelpful), a SME tile was:

vector&lt;[16]x[16]xi8&gt; of 8-bit signless integer values or vector&lt;[8]x[8]xi16&gt; of 16-bit signless integer values or vector&lt;[4]x[4]xi32&gt; of 32-bit signless integer values or vector&lt;[2]x[2]xi64&gt; of 64-bit signless integer values or vector&lt;[1]x[1]xi128&gt; of 128-bit signless integer values or vector&lt;[8]x[8]xf16&gt; of 16-bit float values or vector&lt;[8]x[8]xbf16&gt; of bfloat16 type values or vector&lt;[4]x[4]xf32&gt; of 32-bit float values or vector&lt;[2]x[2]xf64&gt; of 64-bit float values

...and a SVE vector was:

of ranks 1scalable vector of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 128-bit signless integer or 16-bit float or bfloat16 type or 32-bit float or 64-bit float values of length 16/8/4/2/1

Note: The descriptions added here won't yet be shown on the MLIR docs (only the short summaries), but this should be easy to enable like it was for attribute descriptions in #67009.

A table of contents (TOC) is also added to the ArmSME docs page to make it easier to navigate.


Full diff: https://github.com/llvm/llvm-project/pull/70920.diff

4 Files Affected:

  • (modified) mlir/docs/Dialects/ArmSME.md (+2)
  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td (+36-2)
  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td (+44-4)
  • (modified) mlir/test/Dialect/ArmSME/invalid.mlir (+8-8)
diff --git a/mlir/docs/Dialects/ArmSME.md b/mlir/docs/Dialects/ArmSME.md
index ab7c9ffe7aa92f1..505b52938eacc05 100644
--- a/mlir/docs/Dialects/ArmSME.md
+++ b/mlir/docs/Dialects/ArmSME.md
@@ -1,5 +1,7 @@
 # 'ArmSME' Dialect
 
+[TOC]
+
 Basic dialect to target Arm SME architectures This dialect contains the
 definitions necessary to target Arm SME scalable matrix operations.
 
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 18b9bd7a107febf..835307d3d6b9786 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -42,11 +42,45 @@ def ArmSME_Dialect : Dialect {
 // ArmSME type definitions
 //===----------------------------------------------------------------------===//
 
+// FIXME: This allows types there are not SVE vectors, e.g. vector<[16]xi128>.
 def SVEVector : ScalableVectorOfRankAndLengthAndType<
-  [1], [16, 8, 4, 2, 1], [I8, I16, I32, I64, I128, F16, BF16, F32, F64]>;
+  [1], [16, 8, 4, 2, 1], [I8, I16, I32, I64, I128, F16, BF16, F32, F64]>
+{
+  let summary = "a vector type that matches the size of a SVE vector";
+  let description = [{
+    Possible vector types:
+
+    Integer elements:
+
+    * `vector<[16]xi8>`
+    * `vector<[8]xi16>`
+    * `vector<[4]xi32>`
+    * `vector<[2]xi64>`
+    * `vector<[1]xi128>`
+
+    Floating point elements:
+
+    * `vector<[8]xf16>`
+    * `vector<[8]xbf16>`
+    * `vector<[4]xf32>`
+    * `vector<[2]xf64>`
+  }];
+}
 
 def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
-  [1], [16, 8, 4, 2, 1], [I1]>;
+  [1], [16, 8, 4, 2, 1], [I1]>
+{
+  let summary = "a vector type that matches the size of a SVE predicate";
+  let description = [{
+    Possible vector types:
+
+    * `vector<[16]xi1>`
+    * `vector<[8]xi1>`
+    * `vector<[4]xi1>`
+    * `vector<[2]xi1>`
+    * `vector<[1]xi1>`
+  }];
+}
 
 
 #endif // ARMSME
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 37a2257a0015ce7..e57a8acd82de758 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -43,7 +43,47 @@ def nxnxv4f32  : SMETileType<F32,  [4,  4 ], "vector<[4]x[4]xf32>">;
 def nxnxv2f64  : SMETileType<F64,  [2,  2 ], "vector<[2]x[2]xf64>">;
 
 def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
-                         nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64]>;
+                         nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64],
+                        "a vector type that fits into a SME tile">
+{
+  let description = [{
+    Possible vector types:
+
+    Integer elements:
+
+    * `vector<[16]x[16]xi8>`
+    * `vector<[8]x[8]xi16>`
+    * `vector<[4]x[4]xi32>`
+    * `vector<[2]x[2]xi64>`
+    * `vector<[1]x[1]xi128>`
+
+    Floating point elements:
+
+    * `vector<[8]x[8]xf16>`
+    * `vector<[8]x[8]xbf16>`
+    * `vector<[4]x[4]xf32>`
+    * `vector<[2]x[2]xf64>`
+  }];
+}
+
+def TileID : AnyTypeOf<[I8, I16, I32, I64, I128],
+                       "an identifier of a virtual tile (of a size) within the ZA storage">
+{
+  let description = [{
+    The tile ID is an 8, 16, 32, 64, or 128-bit signless integer. The value of
+    the integer indicates the tile to use, and the bit size indicates the size
+    of tile. The number of tiles available and the element types of those depend
+    on the size. This is summarised below:
+
+    | Tile ID Type | Possible Tile IDs   | Tile Vector Types                                                       |
+    |--------------|---------------------|-------------------------------------------------------------------------|
+    | `i8`         | 0                   | `vector<[16]x[16]xi8>`                                                  |
+    | `i16`        | 0 and 1             | `vector<[8]x[8]xi16>`, `vector<[8]x[8]xf16>`, or `vector<[8]x[8]xbf16>` |
+    | `i32`        | 0 to 3 (inclusive)  | `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>`                          |
+    | `i64`        | 0 to 7 (inclusive)  | `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>`                          |
+    | `i128`       | 0 to 15 (inclusive) | `vector<[1]x[1]xi128>`                                                  |
+  }];
+}
 
 // A type constraint that verifies the bitwidth of the scalar integer returned
 // from 'arm_sme.get_tile_id' matches the element bitwidth of a "virtual tile".
@@ -145,7 +185,7 @@ def CastTileToVector : ArmSME_Op<"cast_tile_to_vector", [Pure, TileElementWidthM
     Canonicalization will look through `arm_sme.cast_tile_to_vector` and fold
     the cast away if it comes from a `arm_sme.cast_vector_to_tile`.
   }];
-  let arguments = (ins AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id);
+  let arguments = (ins TileID:$tile_id);
   let results = (outs SMETile:$vector);
   let assemblyFormat =
     "$tile_id attr-dict `:` type($tile_id) `to` type($vector)";
@@ -181,7 +221,7 @@ def CastVectorToTile : ArmSME_Op<"cast_vector_to_tile", [Pure, TileElementWidthM
     the cast away if it comes from a `arm_sme.cast_tile_to_vector`.
   }];
   let arguments = (ins SMETile:$vector);
-  let results = (outs AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id);
+  let results = (outs TileID:$tile_id);
   let assemblyFormat =
     "$vector attr-dict `:` type($vector) `to` type($tile_id)";
   let hasCanonicalizeMethod = 1;
@@ -217,7 +257,7 @@ def GetTileID : ArmSME_Op<"get_tile_id"> {
     ```
   }];
 
-  let results = (outs AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id);
+  let results = (outs TileID:$tile_id);
   let assemblyFormat = "attr-dict `:` type($tile_id)";
 }
 
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 1d6386bbf3828fa..666847dc60f51a5 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -15,7 +15,7 @@ func.func @arm_sme_cast_tile_to_vector__bad_tile_id_bitwidth(%tile_id : i8) -> v
 // -----
 
 func.func @arm_sme_cast_tile_to_vector__bad_vector_type_rank_1(%tile_id : i8) -> vector<[16]xi8> {
-  // expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]xi8>'}}
+  // expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]xi8>'}}
   %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]xi8>
   return %0 : vector<[16]xi8>
 }
@@ -23,7 +23,7 @@ func.func @arm_sme_cast_tile_to_vector__bad_vector_type_rank_1(%tile_id : i8) ->
 // -----
 
 func.func @arm_sme_cast_tile_to_vector__bad_vector_type_i4(%tile_id : i8) -> vector<[16]x[16]xi4> {
-  // expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]x[16]xi4>'}}
+  // expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]x[16]xi4>'}}
   %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi4>
   return %0 : vector<[16]x[16]xi4>
 }
@@ -31,7 +31,7 @@ func.func @arm_sme_cast_tile_to_vector__bad_vector_type_i4(%tile_id : i8) -> vec
 // -----
 
 func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_0(%tile_id : i8) -> vector<16x[16]xi8> {
-  // expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<16x[16]xi8>'}}
+  // expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<16x[16]xi8>'}}
   %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<16x[16]xi8>
   return %0 : vector<16x[16]xi8>
 }
@@ -39,7 +39,7 @@ func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_0(%tile
 // -----
 
 func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_1(%tile_id : i8) -> vector<[16]x16xi8> {
-  // expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]x16xi8>'}}
+  // expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]x16xi8>'}}
   %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x16xi8>
   return %0 : vector<[16]x16xi8>
 }
@@ -47,7 +47,7 @@ func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_1(%tile
 // -----
 
 func.func @arm_sme_cast_tile_to_vector_bad_shape(%tile_id : i8) -> vector<[4]x[16]xi8> {
-  // expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[4]x[16]xi8>'}}
+  // expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[4]x[16]xi8>'}}
   %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[4]x[16]xi8>
   return %0 : vector<[4]x[16]xi8>
 }
@@ -67,7 +67,7 @@ func.func @arm_sme_cast_vector_to_tile__bad_tile_id_bitwidth(%vector : vector<[1
 // -----
 
 func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -> i8 {
-  // expected-error@+1 {{op operand #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]xi8>'}}
+  // expected-error@+1 {{op operand #0 must be a vector type that fits into a SME tile, but got 'vector<[16]xi8>'}}
   %0 = arm_sme.cast_vector_to_tile %vector : vector<[16]xi8> to i8
   return %0 : i8
 }
@@ -79,7 +79,7 @@ func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -
 // -----
 
 func.func @arm_sme_get_tile_id__bad_type() -> i1 {
-  // expected-error@+1 {{op result #0 must be 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 128-bit signless integer}}
+  // expected-error@+1 {{op result #0 must be an identifier of a virtual tile (of a size) within the ZA storage}}
   %0 = arm_sme.get_tile_id : i1
   return %0 : i1
 }
@@ -172,7 +172,7 @@ func.func @arm_sme_load_tile_slice__bad_mask_type(%src : memref<?x?xi8>, %mask :
 
 func.func @arm_sme_outerproduct__bad_result_type(%vecA: vector<[2]xi16>, %vecB: vector<[2]xi16>) -> vector<[2]x[2]xi16>
 {
-  // expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[2]x[2]xi16>'}}
+  // expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[2]x[2]xi16>'}}
   %0 = arm_sme.outerproduct %vecA, %vecB : vector<[2]xi16>, vector<[2]xi16>
   return %0 : vector<[2]x[2]xi16>
 }

The auto-generated summaries are hard to read (and pretty unhelpful),
and SME tile was:

```
vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values
```

...and an SVE vector:

```
of ranks 1scalable vector of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 128-bit signless integer or 16-bit float or bfloat16 type or 32-bit float or 64-bit float values of length 16/8/4/2/1
```

Note: The descriptions added here won't yet be shown on the MLIR docs
(only the short summaries), but this should be easy to enable like
it was for attribute descriptions in llvm#67009.

A table of contents (TOC) is also added to the ArmSME docs page to make
it easier to navigate.
@MacDue MacDue force-pushed the arm_sme_docs_cleanup branch from 895e8ca to 5ebdff1 Compare November 1, 2023 12:03
Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM cheers

@MacDue MacDue merged commit f798bf8 into llvm:main Nov 2, 2023
@MacDue MacDue deleted the arm_sme_docs_cleanup branch November 2, 2023 09:04
MacDue added a commit to MacDue/llvm-project that referenced this pull request Nov 2, 2023
Follow on for some types missed in llvm#70920. This also replaces the
LDSTPredicate with SVEPredicate (as they are equivalent), and adds a
missing rank == 1 checks to the SVE vector types.

A FIXME is also added to point out an issue in the MOPVector type
constraint.
MacDue added a commit that referenced this pull request Nov 3, 2023
…#71057)

Follow on for some types missed in #70920. This also replaces the
LDSTPredicate with SVEPredicate (as they are equivalent), and adds a
missing rank == 1 checks to the SVE vector types.

A FIXME is also added to point out an issue in the MOPVector type
constraint.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants