Skip to content

[mlir][ArmSME] Provide descriptions and summaries for types #70920

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlir/docs/Dialects/ArmSME.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# 'ArmSME' Dialect

[TOC]

Basic dialect to target Arm SME architectures This dialect contains the
definitions necessary to target Arm SME scalable matrix operations.

Expand Down
38 changes: 36 additions & 2 deletions mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,45 @@ def ArmSME_Dialect : Dialect {
// ArmSME type definitions
//===----------------------------------------------------------------------===//

// FIXME: This allows types that are not SVE vectors, e.g. vector<[16]xi128>.
def SVEVector : ScalableVectorOfRankAndLengthAndType<
[1], [16, 8, 4, 2, 1], [I8, I16, I32, I64, I128, F16, BF16, F32, F64]>;
[1], [16, 8, 4, 2, 1], [I8, I16, I32, I64, I128, F16, BF16, F32, F64]>
{
let summary = "a vector type that matches the size of a SVE vector";
let description = [{
Possible vector types:

Integer elements:

* `vector<[16]xi8>`
* `vector<[8]xi16>`
* `vector<[4]xi32>`
* `vector<[2]xi64>`
* `vector<[1]xi128>`

Floating point elements:

* `vector<[8]xf16>`
* `vector<[8]xbf16>`
* `vector<[4]xf32>`
* `vector<[2]xf64>`
}];
}

def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
[1], [16, 8, 4, 2, 1], [I1]>;
[1], [16, 8, 4, 2, 1], [I1]>
{
let summary = "a vector type that matches the size of a SVE predicate";
let description = [{
Possible vector types:

* `vector<[16]xi1>`
* `vector<[8]xi1>`
* `vector<[4]xi1>`
* `vector<[2]xi1>`
* `vector<[1]xi1>`
}];
}


#endif // ARMSME
48 changes: 44 additions & 4 deletions mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,47 @@ def nxnxv4f32 : SMETileType<F32, [4, 4 ], "vector<[4]x[4]xf32>">;
def nxnxv2f64 : SMETileType<F64, [2, 2 ], "vector<[2]x[2]xf64>">;

def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64]>;
nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64],
"a vector type that fits into a SME tile">
{
let description = [{
Possible vector types:

Integer elements:

* `vector<[16]x[16]xi8>`
* `vector<[8]x[8]xi16>`
* `vector<[4]x[4]xi32>`
* `vector<[2]x[2]xi64>`
* `vector<[1]x[1]xi128>`

Floating point elements:

* `vector<[8]x[8]xf16>`
* `vector<[8]x[8]xbf16>`
* `vector<[4]x[4]xf32>`
* `vector<[2]x[2]xf64>`
}];
}

def TileID : AnyTypeOf<[I8, I16, I32, I64, I128],
"an identifier of a virtual tile (of a size) within the ZA storage">
{
let description = [{
The tile ID is an 8, 16, 32, 64, or 128-bit signless integer. The value of
the integer indicates the tile to use, and the bit size indicates the size
of tile. The number of tiles available and the element types of those depend
on the size. This is summarised below:

| Tile ID Type | Possible Tile IDs | Tile Vector Types |
|--------------|---------------------|-------------------------------------------------------------------------|
| `i8` | 0 | `vector<[16]x[16]xi8>` |
| `i16` | 0 and 1 | `vector<[8]x[8]xi16>`, `vector<[8]x[8]xf16>`, or `vector<[8]x[8]xbf16>` |
| `i32` | 0 to 3 (inclusive) | `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>` |
| `i64` | 0 to 7 (inclusive) | `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>` |
| `i128` | 0 to 15 (inclusive) | `vector<[1]x[1]xi128>` |
}];
}

// A type constraint that verifies the bitwidth of the scalar integer returned
// from 'arm_sme.get_tile_id' matches the element bitwidth of a "virtual tile".
Expand Down Expand Up @@ -145,7 +185,7 @@ def CastTileToVector : ArmSME_Op<"cast_tile_to_vector", [Pure, TileElementWidthM
Canonicalization will look through `arm_sme.cast_tile_to_vector` and fold
the cast away if it comes from a `arm_sme.cast_vector_to_tile`.
}];
let arguments = (ins AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id);
let arguments = (ins TileID:$tile_id);
let results = (outs SMETile:$vector);
let assemblyFormat =
"$tile_id attr-dict `:` type($tile_id) `to` type($vector)";
Expand Down Expand Up @@ -181,7 +221,7 @@ def CastVectorToTile : ArmSME_Op<"cast_vector_to_tile", [Pure, TileElementWidthM
the cast away if it comes from a `arm_sme.cast_tile_to_vector`.
}];
let arguments = (ins SMETile:$vector);
let results = (outs AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id);
let results = (outs TileID:$tile_id);
let assemblyFormat =
"$vector attr-dict `:` type($vector) `to` type($tile_id)";
let hasCanonicalizeMethod = 1;
Expand Down Expand Up @@ -217,7 +257,7 @@ def GetTileID : ArmSME_Op<"get_tile_id"> {
```
}];

let results = (outs AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id);
let results = (outs TileID:$tile_id);
let assemblyFormat = "attr-dict `:` type($tile_id)";
}

Expand Down
16 changes: 8 additions & 8 deletions mlir/test/Dialect/ArmSME/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,39 +15,39 @@ func.func @arm_sme_cast_tile_to_vector__bad_tile_id_bitwidth(%tile_id : i8) -> v
// -----

func.func @arm_sme_cast_tile_to_vector__bad_vector_type_rank_1(%tile_id : i8) -> vector<[16]xi8> {
// expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]xi8>'}}
// expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]xi8>'}}
%0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]xi8>
return %0 : vector<[16]xi8>
}

// -----

func.func @arm_sme_cast_tile_to_vector__bad_vector_type_i4(%tile_id : i8) -> vector<[16]x[16]xi4> {
// expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]x[16]xi4>'}}
// expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]x[16]xi4>'}}
%0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi4>
return %0 : vector<[16]x[16]xi4>
}

// -----

func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_0(%tile_id : i8) -> vector<16x[16]xi8> {
// expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<16x[16]xi8>'}}
// expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<16x[16]xi8>'}}
%0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<16x[16]xi8>
return %0 : vector<16x[16]xi8>
}

// -----

func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_1(%tile_id : i8) -> vector<[16]x16xi8> {
// expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]x16xi8>'}}
// expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]x16xi8>'}}
%0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x16xi8>
return %0 : vector<[16]x16xi8>
}

// -----

func.func @arm_sme_cast_tile_to_vector_bad_shape(%tile_id : i8) -> vector<[4]x[16]xi8> {
// expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[4]x[16]xi8>'}}
// expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[4]x[16]xi8>'}}
%0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[4]x[16]xi8>
return %0 : vector<[4]x[16]xi8>
}
Expand All @@ -67,7 +67,7 @@ func.func @arm_sme_cast_vector_to_tile__bad_tile_id_bitwidth(%vector : vector<[1
// -----

func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -> i8 {
// expected-error@+1 {{op operand #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]xi8>'}}
// expected-error@+1 {{op operand #0 must be a vector type that fits into a SME tile, but got 'vector<[16]xi8>'}}
%0 = arm_sme.cast_vector_to_tile %vector : vector<[16]xi8> to i8
return %0 : i8
}
Expand All @@ -79,7 +79,7 @@ func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -
// -----

func.func @arm_sme_get_tile_id__bad_type() -> i1 {
// expected-error@+1 {{op result #0 must be 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 128-bit signless integer}}
// expected-error@+1 {{op result #0 must be an identifier of a virtual tile (of a size) within the ZA storage}}
%0 = arm_sme.get_tile_id : i1
return %0 : i1
}
Expand Down Expand Up @@ -172,7 +172,7 @@ func.func @arm_sme_load_tile_slice__bad_mask_type(%src : memref<?x?xi8>, %mask :

func.func @arm_sme_outerproduct__bad_result_type(%vecA: vector<[2]xi16>, %vecB: vector<[2]xi16>) -> vector<[2]x[2]xi16>
{
// expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[2]x[2]xi16>'}}
// expected-error@+1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[2]x[2]xi16>'}}
%0 = arm_sme.outerproduct %vecA, %vecB : vector<[2]xi16>, vector<[2]xi16>
return %0 : vector<[2]x[2]xi16>
}
Expand Down