Skip to content

Commit 5417a5f

Browse files
authored
[mlir][ArmSME] Add rudimentary support for tile spills to the stack (#76086)
This adds very basic (and inelegant) support for something like spilling and reloading tiles, if you use more SME tiles than physically exist. This is purely implemented to prevent the compiler from aborting if a function uses too many tiles (i.e. due to bad unrolling), but is expected to perform very poorly. Currently, this works in two stages: During tile allocation, if we run out of tiles instead of giving up, we switch to allocating 'in-memory' tile IDs. These are tile IDs that start at 16 (which is higher than any real tile ID). A warning will also be emitted for each (root) tile op assigned an in-memory tile ID: ``` warning: failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance ``` Everything after this works like normal until `-convert-arm-sme-to-llvm` Here the in-memory tile op: ```mlir arm_sme.tile_op { tile_id = <IN MEMORY TILE> } ``` Is lowered to: ```mlir // At function entry: %alloca = memref.alloca ... : memref<?x?xty> // Around the op: // Swap the contents of %alloca and tile 0. scf.for %slice_idx { %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}> "arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx) <{tile_id = 0 : i32}> vector.store %current_slice, %alloca[%slice_idx, %c0] } // Execute op using tile 0. arm_sme.tile_op { tile_id = 0 } // Swap the contents of %alloca and tile 0. // This restores tile 0 to its original state. scf.for %slice_idx { %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}> "arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx) <{tile_id = 0 : i32}> vector.store %current_slice, %alloca[%slice_idx, %c0] } ``` This is inserted during the lowering to LLVM as spilling/reloading registers is a very low-level concept, that can't really be modeled correctly at a high level in MLIR. Note: This is always doing the worst case full-tile swap. This could be optimized to only spill/load data the tile op will use, which could be just a slice. It's also not making any use of liveness, which could allow reusing tiles. But these is not seen as important as correct code should only use the available number of tiles.
1 parent 8751bbe commit 5417a5f

File tree

8 files changed

+595
-61
lines changed

8 files changed

+595
-61
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@
2525
#include "mlir/Interfaces/SideEffectInterfaces.h"
2626

2727
namespace mlir::arm_sme {
28+
static constexpr unsigned kInMemoryTileIdBase = 16;
2829
#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
29-
}
30+
} // namespace mlir::arm_sme
3031

3132
#define GET_ATTRDEF_CLASSES
3233
#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,11 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
9797
// This operation does not allocate a tile.
9898
return std::nullopt;
9999
}]
100+
>,
101+
InterfaceMethod<
102+
"Returns the VectorType of the tile used by this operation.",
103+
/*returnType=*/"VectorType",
104+
/*methodName=*/"getTileType"
100105
>
101106
];
102107

@@ -117,6 +122,11 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
117122
rewriter.replaceOp($_op, newOp);
118123
return newOp;
119124
}
125+
126+
bool isInMemoryTile() {
127+
auto tileId = getTileId();
128+
return tileId && tileId.getInt() >= kInMemoryTileIdBase;
129+
}
120130
}];
121131

122132
let verify = [{ return ::mlir::arm_sme::verifyOperationHasValidTileId($_op); }];
@@ -331,6 +341,9 @@ def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
331341
std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
332342
return arm_sme::getSMETileType(getVectorType());
333343
}
344+
VectorType getTileType() {
345+
return getVectorType();
346+
}
334347
}];
335348
let assemblyFormat = "attr-dict `:` type($res)";
336349
}
@@ -407,6 +420,9 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
407420
std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
408421
return arm_sme::getSMETileType(getVectorType());
409422
}
423+
VectorType getTileType() {
424+
return getVectorType();
425+
}
410426
}];
411427

412428
let builders = [
@@ -475,6 +491,9 @@ def TileStoreOp : ArmSME_Op<"tile_store", [
475491
VectorType getVectorType() {
476492
return ::llvm::cast<VectorType>(getValueToStore().getType());
477493
}
494+
VectorType getTileType() {
495+
return getVectorType();
496+
}
478497
}];
479498

480499
let builders = [
@@ -539,6 +558,9 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
539558
VectorType getVectorType() {
540559
return ::llvm::cast<VectorType>(getResult().getType());
541560
}
561+
VectorType getTileType() {
562+
return getVectorType();
563+
}
542564
}];
543565

544566
let assemblyFormat = [{
@@ -596,6 +618,9 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
596618
VectorType getVectorType() {
597619
return ::llvm::cast<VectorType>(getTile().getType());
598620
}
621+
VectorType getTileType() {
622+
return getVectorType();
623+
}
599624
}];
600625

601626
let assemblyFormat = [{
@@ -688,6 +713,9 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [
688713

689714
let extraClassDeclaration = [{
690715
VectorType getSliceType() { return getResult().getType(); }
716+
VectorType getTileType() {
717+
return ::llvm::cast<VectorType>(getTile().getType());
718+
}
691719
}];
692720

693721
let assemblyFormat = [{
@@ -780,6 +808,9 @@ let arguments = (ins
780808
return arm_sme::getSMETileType(getResultType());
781809
return std::nullopt;
782810
}
811+
VectorType getTileType() {
812+
return getResultType();
813+
}
783814
}];
784815
}
785816

0 commit comments

Comments
 (0)