-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][ArmSME] Support vertical layout in load and store ops #66758
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ArmSME] Support vertical layout in load and store ops #66758
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir ChangesIn SME a ZA tile slice is a one-dimensional set of horizontally or vertically contiguous elements within a ZA tile. Currently the load and store ops only support horizontal tile slices. This patch adds a tile slice layout attribute to the load and store ops to support both horizontal and vertical tile slices. When lowering from Vector dialect horizontal layout is the default. Patch is 98.44 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66758.diff 14 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index d1ed02abfd5c552..f947fc8fe1631b8 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -21,6 +21,11 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"
+
#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.h.inc"
#define GET_OP_CLASSES
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 7f02e723f3d91c2..1a4984f3bd6ba27 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -14,6 +14,7 @@
#ifndef ARMSME_OPS
#define ARMSME_OPS
+include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
@@ -36,6 +37,7 @@ def ArmSME_Dialect : Dialect {
https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions
}];
let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect"];
+ let useDefaultAttributePrinterParser = 1;
}
//===----------------------------------------------------------------------===//
@@ -83,6 +85,24 @@ def TileElementWidthMatchesTileID : TypesMatchWith<
"::llvm::cast<VectorType>($_self).getElementType())"
".getWidth())">;
+//===----------------------------------------------------------------------===//
+// ArmSME attr definitions
+//===----------------------------------------------------------------------===//
+
+def TileSliceLayout : I32EnumAttr<"TileSliceLayout", "Layout of a tile slice", [
+ I32EnumAttrCase<"Horizontal", 0, "hor">,
+ I32EnumAttrCase<"Vertical", 1, "ver">,
+]> {
+ let cppNamespace = "::mlir::arm_sme";
+ let genSpecializedAttr = 0;
+}
+
+/// An attribute that specifies the layout of a tile slice in a tile.
+def ArmSME_TileSliceLayoutAttr : EnumAttr<ArmSME_Dialect, TileSliceLayout,
+ "layout"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
//===----------------------------------------------------------------------===//
// ArmSME op definitions
//===----------------------------------------------------------------------===//
@@ -239,27 +259,33 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
let summary = "Tile load operation";
let description = [{
Loads a 2D SME "virtual tile" from memory defined by a base and indices,
- with the shape defined by the 2D scalable vector type of the result tile.
- The slice of memory must be contiguous. The memref must be either rank 1 or
- rank 2 with dynamic dimensions, since the operation is scalable, and the
- element type must be a scalar that matches the element type of the result.
+ with the shape defined by the 2D scalable vector type of the result tile. A
+ tile slice layout attribute specifies whether the slices of the tile being
+ loaded are horizontal or vertical. The slice of memory must be contiguous.
+ The memref must be either rank 1 or rank 2 with dynamic dimensions, since
+ the operation is scalable, and the element type must be a scalar that
+ matches the element type of the result.
+
+ The default tile slice layout when lowering from higher-level dialects is
+ horizontal.
- Example 1: Load an 8-bit element ZA tile from memory (ZA0.B).
+ Example 1: Load an 8-bit element ZA tile with horizontal layout from memory (ZA0.B).
```mlir
- %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+ %tile = arm_sme.tile_load <hor>, %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
```
- Example 2: Load a FP 32-bit element ZA tile from memory.
+ Example 2: Load a FP 32-bit element ZA tile with vertical layout from memory.
```mlir
- %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %tile = arm_sme.tile_load <ver>, %base[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
```
- Example 3: Load a 128-bit element ZA tile from memory.
+ Example 3: Load a 128-bit element ZA tile with horizontal layout from memory.
```mlir
- %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+ %tile = arm_sme.tile_load <hor>, %base[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
```
}];
let arguments = (ins
+ ArmSME_TileSliceLayoutAttr:$layout,
Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
Variadic<Index>:$indices);
let results = (outs SMETile:$result);
@@ -274,7 +300,8 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
}];
let assemblyFormat =
- "$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
+ "$layout `,` $base `[` $indices `]` attr-dict "
+ "`:` type($base) `,` type($result)";
}
def TileStoreOp : ArmSME_Op<"tile_store"> {
@@ -282,27 +309,32 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
let description = [{
Stores a 2D SME "virtual tile" to memory defined by a base and indices,
with the shape defined by the 2D scalable vector type of the tile being
- stored. The slice of memory must be contiguous. The memref must be either
- rank 1 or rank 2 with dynamic dimensions, since the operation is scalable,
- and the element type must be a scalar that matches the element type of the
- result.
+ stored. A tile slice layout attribute specifies whether the slices of the
+ tile being stored are horizontal or vertical. The slice of memory must be
+ contiguous. The memref must be either rank 1 or rank 2 with dynamic
+ dimensions, since the operation is scalable, and the element type must be a
+ scalar that matches the element type of the result.
+
+ The default tile slice layout when lowering from higher-level dialects is
+ horizontal.
- Example 1: Store an 8-bit element ZA tile to memory (ZA0.B).
+ Example 1: Store an 8-bit element ZA tile with horizontal layout to memory (ZA0.B).
```mlir
- arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
+ arm_sme.tile_store %tile, <hor>, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
```
- Example 2: Store a FP 32-bit element ZA tile to memory.
+ Example 2: Store a FP 32-bit element ZA tile with vertical layout to memory.
```mlir
- arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
+ arm_sme.tile_store %tile, <ver>, %base[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
```
- Example 3: Store a 128-bit element ZA tile to memory.
+ Example 3: Store a 128-bit element ZA tile with horizontal layout to memory.
```mlir
- arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
+ arm_sme.tile_store %tile, <hor>, %base[%c0, %c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
```
}];
let arguments = (ins SMETile:$valueToStore,
+ ArmSME_TileSliceLayoutAttr:$layout,
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
Variadic<Index>:$indices);
let extraClassDeclaration = [{
@@ -314,8 +346,9 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
}
}];
- let assemblyFormat = "$valueToStore `,` $base `[` $indices `]` attr-dict "
- "`:` type($base) `,` type($valueToStore)";
+ let assemblyFormat =
+ "$valueToStore `,` $layout `,` $base `[` $indices `]` attr-dict "
+ "`:` type($base) `,` type($valueToStore)";
}
def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
@@ -326,29 +359,32 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
Loads a 1D tile slice from memory into a 2D SME "virtual tile". The tile
slice is defined by the dimension of the 2D scalable vector type pointed by
the index. A tile slice index describes where in the input tile the tile
- slice is loaded to. The updated tile is returned as the result.
+ slice is loaded to. A tile slice layout attribute specifies whether the
+ tile slice being loaded at the given index is horizontal or vertical. The
+ updated tile is returned as the result.
The slice of memory read is defined by a base and indices and must be
contiguous. The memref must be either rank 1 or rank 2, have dynamic
dimensions since the operation is scalable, and the element type must be a
scalar that matches the element type of the result.
- Example 1: Load a vector<[16]xi8> tile slice from memory into tile at given index.
+ Example 1: Load a vector<[16]xi8> tile slice from memory into tile horizontally at given index.
```mlir
- %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
+ %tile_update = arm_sme.load_tile_slice <hor>, %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
```
- Example 2: Load a vector<[4]xf32> tile slice from memory into tile at given index.
+ Example 2: Load a vector<[4]xf32> tile slice from memory into tile vertically at given index.
```mlir
- %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %tile_update = arm_sme.load_tile_slice <ver>, %base[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
```
- Example 3: Load a vector<[1]xi128> tile slice from memory into tile at given index.
+ Example 3: Load a vector<[1]xi128> tile slice from memory into tile vertically at given index.
```mlir
- %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
+ %tile_update = arm_sme.load_tile_slice <ver>, %base[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
```
}];
let arguments = (ins
+ ArmSME_TileSliceLayoutAttr:$layout,
Arg<AnyMemRef, "the reference to load from">:$base,
SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index);
let results = (outs SMETile:$result);
@@ -363,7 +399,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
}];
let assemblyFormat = [{
- $base `[` $indices `]` `,` $tile `,` $tile_slice_index
+ $layout `,` $base `[` $indices `]` `,` $tile `,` $tile_slice_index
attr-dict `:` type($base) `,` type($result)
}];
}
@@ -374,29 +410,31 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
Stores a 1D tile slice from a 2D SME "virtual tile" into memory. The tile
slice is defined by the dimension of the 2D scalable vector type pointed by
the index. A tile slice index describes where in the input tile the tile
- slice is stored from.
+ slice is stored from. A tile slice layout attribute specifies whether the
+ tile slice being stored from the given index is horizontal or vertical.
The slice of memory written is defined by a base and indices and must be
contiguous. The memref must be either rank 1 or rank 2, have dynamic
dimensions since the operation is scalable, and the element type must be a
scalar that matches the element type of the input tile.
- Example 1: Store vector<[16]xi8> tile slice from tile at given index to memory.
+ Example 1: Store vector<[16]xi8> horizontal tile slice from tile at given index to memory.
```mlir
- arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
```
- Example 2: Store vector<[4]xf32> tile slice from tile at given index to memory.
+ Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory.
```mlir
- arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %base[%c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
```
- Example 3: Store a vector<[1]xi128> tile slice from tile at given index to memory.
+ Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory.
```mlir
- arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %base[%c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
```
}];
let arguments = (ins SMETile:$tile, Index:$tile_slice_index,
+ ArmSME_TileSliceLayoutAttr:$layout,
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
Variadic<Index>:$indices);
let extraClassDeclaration = [{
@@ -409,7 +447,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
}];
let assemblyFormat = [{
- $tile `,` $tile_slice_index `,` $base `[` $indices `]`
+ $tile `,` $tile_slice_index `,` $layout `,` $base `[` $indices `]`
attr-dict `:` type($base) `,` type($tile)
}];
}
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
index d20ee65e62e7dc0..7afd0d014541687 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
@@ -4,3 +4,10 @@ add_mlir_doc(ArmSME ArmSME Dialects/ -gen-dialect-doc -dialect=arm_sme)
set(LLVM_TARGET_DEFINITIONS ArmSME.td)
mlir_tablegen(ArmSMEConversions.inc -gen-llvmir-conversions)
add_public_tablegen_target(MLIRArmSMEConversionsIncGen)
+
+mlir_tablegen(ArmSMEEnums.h.inc -gen-enum-decls)
+mlir_tablegen(ArmSMEEnums.cpp.inc -gen-enum-defs)
+mlir_tablegen(ArmSMEAttrDefs.h.inc -gen-attrdef-decls -attrdefs-dialect=arm_sme)
+mlir_tablegen(ArmSMEAttrDefs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=arm_sme)
+add_public_tablegen_target(MLIRArmSMEEnumsIncGen)
+add_dependencies(mlir-headers MLIRArmSMEEnumsIncGen)
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 4028a7ad0870b51..86cabe67f2695f1 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -54,7 +54,7 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
///
/// BEFORE:
/// ```mlir
-/// %tile = arm_sme.tile_load %src[%c0, %c0] :
+/// %tile = arm_sme.tile_load <hor>, %src[%c0, %c0] :
/// memref<?x?xi32>, vector<[4]x[4]xi32>
/// ```
///
@@ -68,7 +68,7 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
/// %min_svl_s = arith.constant 4 : index
/// %svl_s = arith.muli %min_svl_s, %vscale : index
/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
-/// %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx],
+/// %tile_update = arm_sme.load_tile_slice <hor>, %src[%tile_slice_idx],
/// %tile, %tile_slice_idx : memref<?x?xi32>, vector<[4]x[4]xi32>
/// }
/// ```
@@ -116,9 +116,9 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
getMemrefIndices(tileLoadOp.getIndices(),
tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
numTileSlices, memrefIndices, loc, rewriter);
- rewriter.create<arm_sme::LoadTileSliceOp>(loc, tileType,
- tileLoadOp.getBase(), tile,
- memrefIndices, tileSliceIndex);
+ rewriter.create<arm_sme::LoadTileSliceOp>(
+ loc, tileType, tileLoadOp.getLayout(), tileLoadOp.getBase(), tile,
+ memrefIndices, tileSliceIndex);
rewriter.setInsertionPointAfter(forOp);
@@ -134,7 +134,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
///
/// BEFORE:
/// ```mlir
-/// arm_sme.tile_store %tile, %dest[%c0, %c0]
+/// arm_sme.tile_store %tile, <ver>, %dest[%c0, %c0]
/// : memref<?x?xi32>, vector<[4]x[4]xi32
/// ```
///
@@ -146,8 +146,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
/// %min_svl_s = arith.constant 4 : index
/// %svl_s = arith.muli %min_svl_s, %vscale : index
/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
-/// arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx]
-/// : memref<?x?xi32>, vector<[4]x[4]xi32>
+/// arm_sme.store_tile_slice %tile, %tile_slice_idx, <ver>,
+/// %dest[%tile_slice_idx] : memref<?x?xi32>, vector<[4]x[4]xi32>
/// }
/// ```
struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
@@ -184,7 +184,7 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
numTileSlices, memrefIndices, loc, rewriter);
rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
- tileStoreOp.getBase(), memrefIndices);
+ tileStoreOp.getLayout(), tileStoreOp.getBase(), memrefIndices);
return success();
}
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 0a1a087d9c8d6c7..feaec0e035ed9fd 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -65,8 +65,8 @@ namespace {
///
/// is converted to:
///
-/// arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
-/// vector<[16]x[16]xi8>
+/// arm_sme.tile_store %vector, <hor>, %source[%c0, %c0]
+/// : memref<?x?xi8>, vector<[16]x[16]xi8>
struct TransferWriteToArmSMELowering
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
@@ -81,8 +81,8 @@ struct TransferWriteToArmSMELowering
return failure();
rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
- writeOp, writeOp.getVector(), writeOp.getSource(),
- writeOp.getIndices());
+ writeOp, writeOp.getVector(), arm_sme::TileSliceLayout::Horizontal,
+ writeOp.getSource(), writeOp.getIndices());
return success();
}
};
@@ -97,7 +97,8 @@ struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> {
return failure();
rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
- load, load.getVectorType(), load.getBase(), load.getIndices());
+ load, load.getVectorType(), arm_sme::TileSliceLayout::Horizontal,
+ load.getBase(), load.getIndices());
return success();
}
@@ -113,7 +114,8 @@ struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
return failure();
rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
- store, store.getValueToStore(), store.getBase(), store.getIndices());
+ store, store.getValueToStore(), arm_sme::TileSliceLayout::Horizontal,
+ store.getBase(), store.getIndices());
return success();
}
diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
index 750627421215dfb..92fb146691a0beb 100644
--- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -12,6 +12,8 @@
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::arm_sme;
@@ -22,13 +24,23 @@ using namespace mlir::arm_sme;
#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.cpp.inc"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.cpp.inc"
+
#define GET_OP_CLASSES
#include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/ArmSME/IR/ArmSMETypes.cpp.inc"
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.cpp.inc"
+
void ArmSMEDialect::initialize() {
+ addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.cpp.inc"
+ >();
+
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
diff --git a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
index 9b6332a478ad...
[truncated]
|
Quite a large change, but very well executed and makes a lot of sense. Btw, what logic would be deciding to generate vertical instead of horizontal loads/stores? |
Thanks for reviewing. For
run:
We can extend to the lowering to SME to support this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! Just a small refactoring comment
switch (tileElementWidth) { | ||
default: | ||
llvm_unreachable("unexpected element type!"); | ||
case 8: | ||
rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>( | ||
loc, allActiveMask, ptr, tileI32, tileSliceI32); | ||
break; | ||
case 16: | ||
rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>( | ||
loc, allActiveMask, ptr, tileI32, tileSliceI32); | ||
break; | ||
case 32: | ||
rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>( | ||
loc, allActiveMask, ptr, tileI32, tileSliceI32); | ||
break; | ||
case 64: | ||
rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>( | ||
loc, allActiveMask, ptr, tileI32, tileSliceI32); | ||
break; | ||
case 128: | ||
rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>( | ||
loc, allActiveMask, ptr, tileI32, tileSliceI32); | ||
break; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
worth refactor this into a template function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's not obvious to me how I could improve this with a template function, could you give us some pointers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the only thing that came to my mind:
template <int N>
void callLoadIntrinsic(ConversionPatternRewriter &rewriter,
arm_sme::StoreTileSliceOp storeTileSliceOp,
mlir::vector::SplatOp allActiveMask, Value ptr, Value tileI32,
mlir::arith::IndexCastUIOp tileSliceI32) {
}
template <>
void callLoadIntrinsic<8>(ConversionPatternRewriter &rewriter,
arm_sme::StoreTileSliceOp storeTileSliceOp,
mlir::vector::SplatOp allActiveMask, Value ptr,
Value tileI32,
mlir::arith::IndexCastUIOp tileSliceI32) {
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_horiz>(
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
}
But that wouldn't be an improvement IMHO, so I'd go ahead with what you have here already.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could have a template function with a single switch and then:
callLoadIntrinsic<arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz...>(...)
callLoadIntrinsic<the vertical ones>(...)
That should reduce the code by 2x but maybe I'm missing something :)
case 128: | ||
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_horiz>( | ||
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32); | ||
break; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same?
In SME a ZA tile slice is a one-dimensional set of horizontally or vertically contiguous elements within a ZA tile. Currently the load and store ops only support horizontal tile slices. This patch adds a tile slice layout attribute to the load and store ops to support both horizontal and vertical tile slices. When lowering from Vector dialect horizontal layout is the default.
a3322e8
to
802cbec
Compare
thanks for reviewing all. I've updated this to make tile slice layout optional with a default of horizontal, as suggested by @dcaballe on #66760. Optional custom attributes aren't supported as the first argument in the asm string so I've moved layout to the end of the asm string for both loads and stores for consistency. Quite a lot of churn to support this unfortunately, would appreciate a second pair of eyes on the changes before I land. |
Forgot to mention, also changes the format in the IR from: |
✅ With the latest revision this PR passed the C/C++ code formatter. |
This patch adds support for lowering vector.transpose to ArmSME. It's implemented by storing the input tile of the tranpose to memory and reloading vertically, building on top of the tile slice layout support. Tranposing via memory is obviously expensive, the current intention is to avoid the transpose if possible, this is therefore intended as a fallback and to provide base support for Vector ops. If it turns out transposes can't be avoided then this should be replaced with a more optimal implementation, perhaps with tile <-> vector (MOVA) ops. Depends on llvm#66758.
This patch adds support for lowering vector.transpose to ArmSME. It's implemented by storing the input tile of the tranpose to memory and reloading vertically, building on top of the tile slice layout support. Tranposing via memory is obviously expensive, the current intention is to avoid the transpose if possible, this is therefore intended as a fallback and to provide base support for Vector ops. If it turns out transposes can't be avoided then this should be replaced with a more optimal implementation, perhaps with tile <-> vector (MOVA) ops. Depends on #66758.
In SME a ZA tile slice is a one-dimensional set of horizontally or vertically contiguous elements within a ZA tile. Currently the load and store ops only support horizontal tile slices. This patch adds a tile slice layout attribute to the load and store ops to support both horizontal and vertical tile slices.
When lowering from Vector dialect horizontal layout is the default.