Skip to content

[mlir][ArmSME] Add optional padding and mask operands to tile_load #69195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 49 additions & 3 deletions mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,26 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> {
let assemblyFormat = "attr-dict `:` type($res)";
}

def TileLoadOp : ArmSME_Op<"tile_load"> {
def TileLoadOp : ArmSME_Op<"tile_load", [
AttrSizedOperandSegments,
OptionalTypesMatchWith<
"padding type matches element type of result",
"result", "padding",
"::llvm::cast<VectorType>($_self).getElementType()"
>,
OptionalTypesMatchWith<
"mask has i1 element type and same shape as result",
"result", "mask",
"VectorType("
"VectorType::Builder("
"::llvm::cast<mlir::VectorType>($_self)"
").setElementType(IntegerType::get($_self.getContext(), 1)))"
>,
PredOpTrait<
"both `padding` and `mask` should be provided or neither",
CPred<"bool(getPadding()) == bool(getMask())">
>,
]> {
let summary = "Tile load operation";
let description = [{
Loads a 2D SME "virtual tile" from memory defined by a base and indices,
Expand All @@ -242,6 +261,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
dimensions, since the operation is scalable, and the element type must be a
scalar that matches the element type of the result.

An optional SSA value `padding` of the same elemental type as the MemRef is
provided to specify a fallback value in the case of masking.

An optional SSA value `mask` may be specified to mask out elements read
from the MemRef. The `mask` type is an `i1` vector with a shape that
matches how elements are read from the MemRef. Elements whose corresponding
mask element is `0` are masked out and replaced with `padding`.

If either `padding` or `mask` are specified, both must be specified.

Example 1: Load an 8-bit element ZA tile with horizontal layout (default) from memory (ZA0.B).
```mlir
%tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
Expand All @@ -256,10 +285,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
```mlir
%tile = arm_sme.tile_load %base[%c0, %c0] layout<horizontal> : memref<?x?xi128>, vector<[1]x[1]xi128>
```

Example 4: Masked load of int 32-bit element ZA tile with horizontal layout (default) from memory.
```mlir
%tile = arm_sme.tile_load %base[%c0, %c0], %pad, %mask : memref<?x?xf32>, vector<[4]x[4]xf32>
```
}];
let arguments = (ins
Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
Variadic<Index>:$indices,
Optional<AnyType>:$padding, Optional<AnyVector>:$mask,
ArmSME_TileSliceLayoutAttr:$layout
);
let results = (outs SMETile:$result);
Expand All @@ -273,9 +308,20 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
}
}];

let builders = [
OpBuilder<(ins "VectorType":$resultType, "Value":$base,
"ValueRange":$indices, "TileSliceLayout":$layout), [{
build($_builder, $_state, resultType, base, indices, {}, {}, layout);
}]>,
OpBuilder<(ins "VectorType":$resultType, "Value":$base,
"ValueRange":$indices), [{
build($_builder, $_state, resultType, base, indices, {}, {}, {});
}]>,
];

let assemblyFormat =
"$base `[` $indices `]` (`layout` `` $layout^)? attr-dict "
"`:` type($base) `,` type($result)";
"$base `[` $indices `]` (`,` $padding `,` $mask^)? (`layout` `` $layout^)?"
"attr-dict `:` type($base) `,` type($result)";
}

def TileStoreOp : ArmSME_Op<"tile_store"> {
Expand Down
53 changes: 53 additions & 0 deletions mlir/test/Dialect/ArmSME/invalid.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics

//===----------------------------------------------------------------------===//
// arm_sme.cast_tile_to_vector
//===----------------------------------------------------------------------===//

// -----

func.func @arm_sme_cast_tile_to_vector__bad_tile_id_bitwidth(%tile_id : i8) -> vector<[8]x[8]xi16> {
Expand Down Expand Up @@ -48,6 +52,10 @@ func.func @arm_sme_cast_tile_to_vector_bad_shape(%tile_id : i8) -> vector<[4]x[1
return %0 : vector<[4]x[16]xi8>
}

//===----------------------------------------------------------------------===//
// arm_sme.cast_vector_to_tile
//===----------------------------------------------------------------------===//

// -----

func.func @arm_sme_cast_vector_to_tile__bad_tile_id_bitwidth(%vector : vector<[1]x[1]xi128>) -> i32 {
Expand All @@ -64,6 +72,10 @@ func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -
return %0 : i8
}

//===----------------------------------------------------------------------===//
// arm_sme.get_tile_id
//===----------------------------------------------------------------------===//

// -----

func.func @arm_sme_get_tile_id__bad_type() -> i1 {
Expand All @@ -72,6 +84,10 @@ func.func @arm_sme_get_tile_id__bad_type() -> i1 {
return %0 : i1
}

//===----------------------------------------------------------------------===//
// arm_sme.move_vector_to_tile_slice
//===----------------------------------------------------------------------===//

// -----

func.func @arm_sme_move_vector_to_tile_slice_i8__bad_vector_type(%vector : vector<[8]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]x[16]xi8> {
Expand All @@ -90,10 +106,47 @@ func.func @arm_sme_move_vector_to_tile_slice_f32__bad_vector_type(%vector : vect
return %0 : vector<[4]x[4]xf32>
}

//===----------------------------------------------------------------------===//
// arm_sme.move_tile_slice_to_vector
//===----------------------------------------------------------------------===//

// -----

func.func @arm_sme_move_tile_slice_to_vector__bad_result_type(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[2]xf64> {
// expected-error@+1 {{op failed to verify that type of 'result' matches type of 'tile' slice}}
%0 = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[4]x[4]xf32>
return %0 : vector<[2]xf64>
}

//===----------------------------------------------------------------------===//
// arm_sme.tile_load
//===----------------------------------------------------------------------===//

// -----

func.func @arm_sme_tile_load__bad_padding_type(%src : memref<?x?xf64>, %pad : f32, %mask : vector<[2]x[2]xi1>) {
%c0 = arith.constant 0 : index
// expected-note@-2 {{prior use here}}
// expected-error@+1 {{use of value '%pad' expects different type than prior uses: 'f64' vs 'f32'}}
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xf64>, vector<[2]x[2]xf64>
return
}

// -----

func.func @arm_sme_tile_load__bad_mask_type(%src : memref<?x?xf64>, %pad : f64, %mask : vector<[4]x[4]xi1>) {
%c0 = arith.constant 0 : index
// expected-note@-2 {{prior use here}}
// expected-error@+1 {{use of value '%mask' expects different type than prior uses: 'vector<[2]x[2]xi1>' vs 'vector<[4]x[4]xi1>}}
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xf64>, vector<[2]x[2]xf64>
return
}

// -----

func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref<?x?xf64>, %pad : f64) {
%c0 = arith.constant 0 : index
// expected-error@+1 {{op failed to verify that both `padding` and `mask` should be provided or neither}}
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, : memref<?x?xf64>, vector<[2]x[2]xf64>
return
}
10 changes: 10 additions & 0 deletions mlir/test/Dialect/ArmSME/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,16 @@ func.func @arm_sme_tile_load_ver_f64(%src : memref<?x?xf64>) {

// -----

/// Padding and mask are optional
func.func @arm_sme_tile_load_hor_pad_f64(%src : memref<?x?xf64>, %pad : f64, %mask : vector<[2]x[2]xi1>) {
// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}], {{.*}}, {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
%c0 = arith.constant 0 : index
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xf64>, vector<[2]x[2]xf64>
return
}

// -----

/// Layout is optional and horizontal is the default, verify it's still parsed.
func.func @arm_sme_tile_load_explicit_hor(%src : memref<?x?xi8>) {
// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>
Expand Down