Skip to content

[mlir][ArmSME] Support vertical layout in load and store ops #66758

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h.inc"

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"

#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.h.inc"

#define GET_OP_CLASSES
Expand Down
125 changes: 83 additions & 42 deletions mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#ifndef ARMSME_OPS
#define ARMSME_OPS

include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
Expand All @@ -36,6 +37,7 @@ def ArmSME_Dialect : Dialect {
https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions
}];
let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect"];
let useDefaultAttributePrinterParser = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -83,6 +85,24 @@ def TileElementWidthMatchesTileID : TypesMatchWith<
"::llvm::cast<VectorType>($_self).getElementType())"
".getWidth())">;

//===----------------------------------------------------------------------===//
// ArmSME attr definitions
//===----------------------------------------------------------------------===//

def TileSliceLayout : I32EnumAttr<"TileSliceLayout", "Layout of a tile slice", [
I32EnumAttrCase<"Horizontal", 0, "horizontal">,
I32EnumAttrCase<"Vertical", 1, "vertical">,
]> {
let cppNamespace = "::mlir::arm_sme";
let genSpecializedAttr = 0;
}

/// An attribute that specifies the layout of a tile slice in a tile.
def ArmSME_TileSliceLayoutAttr : EnumAttr<ArmSME_Dialect, TileSliceLayout,
"layout"> {
let assemblyFormat = "`<` $value `>`";
}

//===----------------------------------------------------------------------===//
// ArmSME op definitions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -240,28 +260,33 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
let description = [{
Loads a 2D SME "virtual tile" from memory defined by a base and indices,
with the shape defined by the 2D scalable vector type of the result tile.
The slice of memory must be contiguous. The memref must be either rank 1 or
rank 2 with dynamic dimensions, since the operation is scalable, and the
element type must be a scalar that matches the element type of the result.
An optional tile slice layout attribute specifies whether the slices of the
tile being loaded are horizontal (default) or vertical. The slice of memory
must be contiguous. The memref must be either rank 1 or rank 2 with dynamic
dimensions, since the operation is scalable, and the element type must be a
scalar that matches the element type of the result.

Example 1: Load an 8-bit element ZA tile from memory (ZA0.B).
Example 1: Load an 8-bit element ZA tile with horizontal layout (default) from memory (ZA0.B).
```mlir
%tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
```

Example 2: Load a FP 32-bit element ZA tile from memory.
Example 2: Load a FP 32-bit element ZA tile with vertical layout from memory.
```mlir
%tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
%tile = arm_sme.tile_load %base[%c0, %c0], <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
```

Example 3: Load a 128-bit element ZA tile from memory.
Example 3: Load a 128-bit element ZA tile with horizontal layout (default) from memory.
```mlir
%tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
%tile = arm_sme.tile_load %base[%c0, %c0], <horizontal> : memref<?x?xi128>, vector<[1]x[1]xi128>
```
}];
let arguments = (ins
Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
Variadic<Index>:$indices);
Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
Variadic<Index>:$indices,
DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
"::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
);
let results = (outs SMETile:$result);

let extraClassDeclaration = [{
Expand All @@ -274,37 +299,42 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
}];

let assemblyFormat =
"$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
"$base `[` $indices `]` (`,` $layout^)? attr-dict "
"`:` type($base) `,` type($result)";
}

def TileStoreOp : ArmSME_Op<"tile_store"> {
let summary = "Tile store operation";
let description = [{
Stores a 2D SME "virtual tile" to memory defined by a base and indices,
with the shape defined by the 2D scalable vector type of the tile being
stored. The slice of memory must be contiguous. The memref must be either
rank 1 or rank 2 with dynamic dimensions, since the operation is scalable,
and the element type must be a scalar that matches the element type of the
result.
stored. An optional tile slice layout attribute specifies whether the
slices of the tile being stored are horizontal (default) or vertical. The
slice of memory must be contiguous. The memref must be either rank 1 or
rank 2 with dynamic dimensions, since the operation is scalable, and the
element type must be a scalar that matches the element type of the result.

Example 1: Store an 8-bit element ZA tile to memory (ZA0.B).
Example 1: Store an 8-bit element ZA tile with horizontal (default) layout to memory (ZA0.B).
```mlir
arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
```

Example 2: Store a FP 32-bit element ZA tile to memory.
Example 2: Store a FP 32-bit element ZA tile with vertical layout to memory.
```mlir
arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
arm_sme.tile_store %tile, %base[%c0, %c0], <vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
```

Example 3: Store a 128-bit element ZA tile to memory.
Example 3: Store a 128-bit element ZA tile with horizontal (default) layout to memory.
```mlir
arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
arm_sme.tile_store %tile, %base[%c0, %c0], <horizontal> : vector<[1]x[1]xi128>, memref<?x?xi128>
```
}];
let arguments = (ins SMETile:$valueToStore,
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
Variadic<Index>:$indices);
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
Variadic<Index>:$indices,
DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
"::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
Expand All @@ -314,8 +344,9 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
}
}];

let assemblyFormat = "$valueToStore `,` $base `[` $indices `]` attr-dict "
"`:` type($base) `,` type($valueToStore)";
let assemblyFormat =
"$valueToStore `,` $base `[` $indices `]` (`,` $layout^)? attr-dict "
"`:` type($base) `,` type($valueToStore)";
}

def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
Expand All @@ -326,31 +357,36 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
Loads a 1D tile slice from memory into a 2D SME "virtual tile". The tile
slice is defined by the dimension of the 2D scalable vector type pointed by
the index. A tile slice index describes where in the input tile the tile
slice is loaded to. The updated tile is returned as the result.
slice is loaded to. An optional tile slice layout attribute specifies
whether the tile slice being loaded at the given index is horizontal
(default) or vertical. The updated tile is returned as the result.

The slice of memory read is defined by a base and indices and must be
contiguous. The memref must be either rank 1 or rank 2, have dynamic
dimensions since the operation is scalable, and the element type must be a
scalar that matches the element type of the result.

Example 1: Load a vector<[16]xi8> tile slice from memory into tile at given index.
Example 1: Load a vector<[16]xi8> tile slice from memory into tile horizontally (default) at given index.
```mlir
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
```

Example 2: Load a vector<[4]xf32> tile slice from memory into tile at given index.
Example 2: Load a vector<[4]xf32> tile slice from memory into tile vertically at given index.
```mlir
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
```

Example 3: Load a vector<[1]xi128> tile slice from memory into tile at given index.
Example 3: Load a vector<[1]xi128> tile slice from memory into tile vertically at given index.
```mlir
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
```
}];
let arguments = (ins
Arg<AnyMemRef, "the reference to load from">:$base,
SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index);
Arg<AnyMemRef, "the reference to load from">:$base,
SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
"::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
);
let results = (outs SMETile:$result);

let extraClassDeclaration = [{
Expand All @@ -363,7 +399,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
}];

let assemblyFormat = [{
$base `[` $indices `]` `,` $tile `,` $tile_slice_index
$base `[` $indices `]` `,` $tile `,` $tile_slice_index (`,` $layout^)?
attr-dict `:` type($base) `,` type($result)
}];
}
Expand All @@ -374,31 +410,36 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
Stores a 1D tile slice from a 2D SME "virtual tile" into memory. The tile
slice is defined by the dimension of the 2D scalable vector type pointed by
the index. A tile slice index describes where in the input tile the tile
slice is stored from.
slice is stored from. An optional tile slice layout attribute specifies
whether the tile slice being stored from the given index is horizontal
(default) or vertical.

The slice of memory written is defined by a base and indices and must be
contiguous. The memref must be either rank 1 or rank 2, have dynamic
dimensions since the operation is scalable, and the element type must be a
scalar that matches the element type of the input tile.

Example 1: Store vector<[16]xi8> tile slice from tile at given index to memory.
Example 1: Store vector<[16]xi8> horizontal (default) tile slice from tile at given index to memory.
```mlir
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
```

Example 2: Store vector<[4]xf32> tile slice from tile at given index to memory.
Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory.
```mlir
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0], <vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
```

Example 3: Store a vector<[1]xi128> tile slice from tile at given index to memory.
Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory.
```mlir
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0], <vertical> : vector<[1]x[1]xi128>, memref<?x?xi128>
```
}];
let arguments = (ins SMETile:$tile, Index:$tile_slice_index,
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
Variadic<Index>:$indices);
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
Variadic<Index>:$indices,
DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
"::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
Expand All @@ -409,7 +450,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
}];

let assemblyFormat = [{
$tile `,` $tile_slice_index `,` $base `[` $indices `]`
$tile `,` $tile_slice_index `,` $base `[` $indices `]` (`,` $layout^)?
attr-dict `:` type($base) `,` type($tile)
}];
}
Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,9 @@ add_mlir_doc(ArmSME ArmSME Dialects/ -gen-dialect-doc -dialect=arm_sme)
set(LLVM_TARGET_DEFINITIONS ArmSME.td)
mlir_tablegen(ArmSMEConversions.inc -gen-llvmir-conversions)
add_public_tablegen_target(MLIRArmSMEConversionsIncGen)

mlir_tablegen(ArmSMEEnums.h.inc -gen-enum-decls)
mlir_tablegen(ArmSMEEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(ArmSMEAttrDefs.h.inc -gen-attrdef-decls -attrdefs-dialect=arm_sme)
mlir_tablegen(ArmSMEAttrDefs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=arm_sme)
add_public_tablegen_target(MLIRArmSMEAttrDefsIncGen)
14 changes: 7 additions & 7 deletions mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
getMemrefIndices(tileLoadOp.getIndices(),
tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
numTileSlices, memrefIndices, loc, rewriter);
rewriter.create<arm_sme::LoadTileSliceOp>(loc, tileType,
tileLoadOp.getBase(), tile,
memrefIndices, tileSliceIndex);
rewriter.create<arm_sme::LoadTileSliceOp>(
loc, tileType, tileLoadOp.getBase(), tile, memrefIndices,
tileSliceIndex, tileLoadOp.getLayout());

rewriter.setInsertionPointAfter(forOp);

Expand All @@ -134,7 +134,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
///
/// BEFORE:
/// ```mlir
/// arm_sme.tile_store %tile, %dest[%c0, %c0]
/// arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical>
/// : memref<?x?xi32>, vector<[4]x[4]xi32
/// ```
///
Expand All @@ -146,8 +146,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
/// %min_svl_s = arith.constant 4 : index
/// %svl_s = arith.muli %min_svl_s, %vscale : index
/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
/// arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx]
/// : memref<?x?xi32>, vector<[4]x[4]xi32>
/// arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx],
/// <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
/// }
/// ```
struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
Expand Down Expand Up @@ -184,7 +184,7 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
numTileSlices, memrefIndices, loc, rewriter);
rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
tileStoreOp.getBase(), memrefIndices);
tileStoreOp.getBase(), memrefIndices, tileStoreOp.getLayout());

return success();
}
Expand Down
12 changes: 12 additions & 0 deletions mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
using namespace mlir::arm_sme;
Expand All @@ -23,13 +25,23 @@ using namespace mlir::arm_sme;

#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.cpp.inc"

#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.cpp.inc"

#define GET_OP_CLASSES
#include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"

#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/ArmSME/IR/ArmSMETypes.cpp.inc"

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.cpp.inc"

void ArmSMEDialect::initialize() {
addAttributes<
#define GET_ATTRDEF_LIST
#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.cpp.inc"
>();

addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRArmSMEDialect

DEPENDS
MLIRArmSMEIncGen
MLIRArmSMEAttrDefsIncGen

LINK_LIBS PUBLIC
MLIRIR
Expand Down
Loading