Skip to content

[mlir][ArmSME] Add masking support to memory ops #69148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
178 changes: 143 additions & 35 deletions mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,24 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> {
let assemblyFormat = "attr-dict `:` type($res)";
}

def TileLoadOp : ArmSME_Op<"tile_load"> {
def TileLoadOp : ArmSME_Op<"tile_load", [
AttrSizedOperandSegments,
TypesMatchWith<
"padding type matches element type of result (if present)",
"result", "padding",
"::llvm::cast<VectorType>($_self).getElementType()",
"!getPadding() || std::equal_to<>()"
>,
TypesMatchWith<
"mask has i1 element type and same shape as result (if present)",
"result", "mask",
"VectorType("
"VectorType::Builder("
"::llvm::cast<mlir::VectorType>($_self)"
").setElementType(IntegerType::get($_self.getContext(), 1)))",
"!getMask() || std::equal_to<>()"
>
]> {
let summary = "Tile load operation";
let description = [{
Loads a 2D SME "virtual tile" from memory defined by a base and indices,
Expand All @@ -242,6 +259,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
dimensions, since the operation is scalable, and the element type must be a
scalar that matches the element type of the result.

An optional SSA value `padding` of the same elemental type as the MemRef is
provided to specify a fallback value in the case of masking.

An optional SSA value `mask` may be specified to mask out elements read
from the MemRef. The `mask` type is an `i1` vector with a shape that
matches how elements are read from the MemRef. Elements whose corresponding
mask element is `0` are masked out and replaced with `padding`.

If either `padding` or `mask` are specified, both must be specified.

Example 1: Load an 8-bit element ZA tile with horizontal layout (default) from memory (ZA0.B).
```mlir
%tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
Expand All @@ -256,10 +283,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
```mlir
%tile = arm_sme.tile_load %base[%c0, %c0] layout<horizontal> : memref<?x?xi128>, vector<[1]x[1]xi128>
```

Example 4: Masked load of int 32-bit element ZA tile with horizontal layout (default) from memory.
```mlir
%tile = arm_sme.tile_load %base[%c0, %c0], %pad, %mask : memref<?x?xf32>, vector<[4]x[4]xf32>
```
}];
let arguments = (ins
Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
Variadic<Index>:$indices,
Optional<AnyType>:$padding, Optional<AnyVector>:$mask,
ArmSME_TileSliceLayoutAttr:$layout
);
let results = (outs SMETile:$result);
Expand All @@ -273,12 +306,34 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
}
}];

let builders = [
OpBuilder<(ins "VectorType":$resultType, "Value":$base,
"ValueRange":$indices, "TileSliceLayout":$layout), [{
build($_builder, $_state, resultType, base, indices, {}, {}, layout);
}]>,
OpBuilder<(ins "VectorType":$resultType, "Value":$base,
"ValueRange":$indices), [{
build($_builder, $_state, resultType, base, indices, {}, {}, {});
}]>,
];

let assemblyFormat =
"$base `[` $indices `]` (`layout` `` $layout^)? attr-dict "
"`:` type($base) `,` type($result)";
"$base `[` $indices `]` (`,` $padding `,` $mask^)? (`layout` `` $layout^)?"
"attr-dict `:` type($base) `,` type($result)";
}

def TileStoreOp : ArmSME_Op<"tile_store"> {
def TileStoreOp : ArmSME_Op<"tile_store", [
AttrSizedOperandSegments,
TypesMatchWith<
"mask has i1 element type and same shape as value to store (if present)",
"valueToStore", "mask",
"VectorType("
"VectorType::Builder("
"::llvm::cast<mlir::VectorType>($_self)"
").setElementType(IntegerType::get($_self.getContext(), 1)))",
"!getMask() || std::equal_to<>()"
>
]> {
let summary = "Tile store operation";
let description = [{
Stores a 2D SME "virtual tile" to memory defined by a base and indices,
Expand All @@ -289,6 +344,11 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
rank 2 with dynamic dimensions, since the operation is scalable, and the
element type must be a scalar that matches the element type of the result.

An optional SSA value `mask` may be specified to mask out elements written
to the MemRef. The `mask` type is an `i1` vector of the same shape as the
vector type that matches how elements are written into the MemRef. Elements
whose corresponding mask element is `0` are masked out.

Example 1: Store an 8-bit element ZA tile with horizontal (default) layout to memory (ZA0.B).
```mlir
arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
Expand All @@ -303,10 +363,16 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
```mlir
arm_sme.tile_store %tile, %base[%c0, %c0] layout<horizontal> : vector<[1]x[1]xi128>, memref<?x?xi128>
```

Example 4: Masked store a int 32-bit element ZA tile with vertical layout to memory.
```mlir
arm_sme.tile_store %tile, %base[%c0, %c0], %mask layout<vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
```
}];
let arguments = (ins SMETile:$valueToStore,
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
Variadic<Index>:$indices, ArmSME_TileSliceLayoutAttr:$layout
Variadic<Index>:$indices, Optional<AnyVector>:$mask,
ArmSME_TileSliceLayoutAttr:$layout
);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
Expand All @@ -317,13 +383,28 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
}
}];

let builders = [
OpBuilder<(ins "Value":$valueToStore, "Value":$base,
"ValueRange":$indices), [{
build($_builder, $_state, valueToStore, base, indices, {});
}]>,
];

let assemblyFormat =
"$valueToStore `,` $base `[` $indices `]` (`layout` `` $layout^)? attr-dict "
"`:` type($base) `,` type($valueToStore)";
"$valueToStore `,` $base `[` $indices `]` (`,` $mask^)? (`layout` `` $layout^)?"
"attr-dict `:` type($base) `,` type($valueToStore)";
}

def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
AllTypesMatch<["tile", "result"]>
AllTypesMatch<["tile", "result"]>,
TypesMatchWith<
"mask has i1 element type and same shape as result",
"result", "mask",
"VectorType("
"VectorType::Builder("
"::llvm::cast<mlir::VectorType>($_self)"
").dropDim(0).setElementType(IntegerType::get($_self.getContext(), 1))"
")">,
]> {
let summary = "Tile slice load and update operation";
let description = [{
Expand All @@ -339,23 +420,27 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
dimensions since the operation is scalable, and the element type must be a
scalar that matches the element type of the result.

An SSA value `mask` specifies to mask out elements read from the MemRef.
The `mask` type is an `i1` vector with a shape that matches how elements
are read from the MemRef.

Example 1: Load a vector<[16]xi8> tile slice from memory into tile horizontally (default) at given index.
```mlir
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
%tile_update = arm_sme.load_tile_slice %base[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
```

Example 2: Load a vector<[4]xf32> tile slice from memory into tile vertically at given index.
```mlir
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
%tile_update = arm_sme.load_tile_slice %base[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
```

Example 3: Load a vector<[1]xi128> tile slice from memory into tile vertically at given index.
```mlir
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
%tile_update = arm_sme.load_tile_slice %base[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
```
}];
let arguments = (ins
Arg<AnyMemRef, "the reference to load from">:$base,
Arg<AnyMemRef, "the reference to load from">:$base, AnyVector:$mask,
SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
ArmSME_TileSliceLayoutAttr:$layout
);
Expand All @@ -371,12 +456,22 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
}];

let assemblyFormat = [{
$base `[` $indices `]` `,` $tile `,` $tile_slice_index (`layout` `` $layout^)?
attr-dict `:` type($base) `,` type($result)
$base `[` $indices `]` `,` $mask `,` $tile `,` $tile_slice_index
(`layout` `` $layout^)? attr-dict `:` type($base) `,` type($mask) `,`
type($result)
}];
}

def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
TypesMatchWith<
"mask has i1 element type and same shape as tile slice",
"tile", "mask",
"VectorType("
"VectorType::Builder("
"::llvm::cast<mlir::VectorType>($_self)"
").dropDim(0).setElementType(IntegerType::get($_self.getContext(), 1))"
")">
]> {
let summary = "Tile slice store operation";
let description = [{
Stores a 1D tile slice from a 2D SME "virtual tile" into memory. The tile
Expand All @@ -391,22 +486,27 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
dimensions since the operation is scalable, and the element type must be a
scalar that matches the element type of the input tile.

An SSA value `mask` specifies to mask out elements written to the MemRef.
The `mask` type is an `i1` vector with a shape that matches how elements
are written to the MemRef.

Example 1: Store vector<[16]xi8> horizontal (default) tile slice from tile at given index to memory.
```mlir
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %base[%c0] : vector<[16]x[16]xi8>, vector<[16]xi1>, memref<?x?xi8>
```

Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory.
```mlir
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout<vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %base[%c0] layout<vertical> : vector<[4]x[4]xf32>, vector<[4]xi1>, memref<?x?xf32>
```

Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory.
```mlir
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout<vertical> : vector<[1]x[1]xi128>, memref<?x?xi128>
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %base[%c0] layout<vertical> : vector<[1]x[1]xi128>, vector<[1]xi1>, memref<?x?xi128>
```
}];
let arguments = (ins SMETile:$tile, Index:$tile_slice_index,
let arguments = (ins
SMETile:$tile, Index:$tile_slice_index, AnyVector:$mask,
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
Variadic<Index>:$indices, ArmSME_TileSliceLayoutAttr:$layout
);
Expand All @@ -420,8 +520,8 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
}];

let assemblyFormat = [{
$tile `,` $tile_slice_index `,` $base `[` $indices `]` (`layout` `` $layout^)?
attr-dict `:` type($base) `,` type($tile)
$tile `,` $tile_slice_index `,` $mask `,` $base `[` $indices `]` (`layout` `` $layout^)?
attr-dict `:` type($base) `,` type($mask) `,` type($tile)
}];
}

Expand All @@ -441,21 +541,24 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
of a 2-D scalable vector tile at the given index. The type of the 1-D
scalable vector to be moved must match the type of the tile slice. A tile
slice is a 1-D vector of horizontally or vertically contiguous elements
within a ZA tile. Horizontal tile slices are currently assumed when
lowering to intrinsics. The updated tile is returned as the result.
within a ZA tile. The updated tile is returned as the result.

An optional tile slice layout attribute specifies whether the tile slice is
horizontal (default) or vertical.

Example 1: Move a vector<[16]xi8> into tile at given index.
Example 1: Move a vector<[16]xi8> into tile horizontally (default) at given index.
```mlir
%tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[16]xi8> into vector<[16]x[16]xi8>
```

Example 2: Move a vector<[2]xf64> into tile at given index.
Example 2: Move a vector<[2]xf64> into tile vertically at given index.
```mlir
%tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[2]xf64> into vector<[2]x[2]xf64>
%tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[2]xf64> into vector<[2]x[2]xf64>
```
}];
let arguments = (ins
SVEVector:$vector, SMETile:$tile, Index:$tile_slice_index);
SVEVector:$vector, SMETile:$tile, Index:$tile_slice_index,
ArmSME_TileSliceLayoutAttr:$layout);
let results = (outs SMETile:$result);

let extraClassDeclaration = [{
Expand All @@ -465,7 +568,7 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
}];

let assemblyFormat = [{
$vector `,` $tile `,` $tile_slice_index
$vector `,` $tile `,` $tile_slice_index (`layout` `` $layout^)?
attr-dict `:` type($vector) `into` type($result)
}];
}
Expand All @@ -480,29 +583,34 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
let description = [{
The tile slice to vector operation extracts a 1-D scalable slice from a 2-D
scalable tile at the given index. A tile slice is a 1-D vector of
horizontally or vertically contiguous elements within a ZA tile. Horizontal
tile slices are currently assumed when lowering to intrinsics.
horizontally or vertically contiguous elements within a ZA tile.

An optional tile slice layout attribute specifies whether the tile slice is
horizontal (default) or vertical.

Example 1: Extract `vector<[16]xi8>` from tile at the given index.
Example 1: Extract `vector<[16]xi8>` from tile horizontally at the given index.
```mlir
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[16]xi8> from vector<[16]x[16]xi8>
```

Example 2: Extract `vector<[2]xf64>` from tile at the given index.
Example 2: Extract `vector<[2]xf64>` from tile vertically at the given index.
```mlir
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
```
}];

let arguments = (ins SMETile:$tile, Index:$tile_slice_index);
let arguments = (ins
SMETile:$tile, Index:$tile_slice_index,
ArmSME_TileSliceLayoutAttr:$layout
);
let results = (outs SVEVector:$result);

let extraClassDeclaration = [{
VectorType getSliceType() { return getResult().getType(); }
}];

let assemblyFormat = [{
$tile `[` $tile_slice_index `]` attr-dict
$tile `[` $tile_slice_index `]` (`layout` `` $layout^)? attr-dict
`:` type($result) `from` type($tile)
}];
}
Expand Down
Loading