Skip to content

Commit 041baf2

Browse files
authored
[mlir][ArmSME] Use liveness information in the tile allocator (#90448)
This patch rewrites the ArmSME tile allocator to use liveness information to make better tile allocation decisions and improve the correctness of the ArmSME dialect. This algorithm used here is a linear scan over live ranges, where live ranges are assigned to tiles as they appear in the program (chronologically). Live ranges release their assigned tile ID when the current program point is passed their end. This is a greedy algorithm (which is mainly to keep the implementation relatively straightforward), and because it seems to be sufficient for most kernels (e.g. matmuls) that use ArmSME. The general steps of this are roughly from https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf, though there have been a few simplifications and assumptions made for our use case. Hopefully, the only changes needed for a user of the ArmSME dialect is that: - `-allocate-arm-sme-tiles` will no longer be a standalone pass - `-test-arm-sme-tile-allocation` is only for unit tests - `-convert-arm-sme-to-llvm` must happen after `-convert-scf-to-cf` - SME tile allocation is now part of the LLVM conversion By integrating this into the `ArmSME -> LLVM` conversion we can allow high-level (value-based) ArmSME operations to be side-effect-free, as we can guarantee nothing will rearrange ArmSME operations before we emit intrinsics (which could invalidate the tile allocation). The hope is for ArmSME operations to have no hidden state/side effects and allow easily lowering dialects such as `vector` and `arith` to SME, without making assumptions about how the input IR looks, as the semantics of the operations will be the same. That is no (new) side effects and the IR follows the rules of SSA (a value will never change). The aim is correctness, so we have a base for working on optimizations.
1 parent d422e90 commit 041baf2

30 files changed

+1470
-486
lines changed

mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <memory>
1313

1414
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
15+
#include "mlir/Interfaces/FunctionInterfaces.h"
1516

1617
namespace mlir {
1718
class Pass;
@@ -21,7 +22,8 @@ class RewritePatternSet;
2122
#include "mlir/Conversion/Passes.h.inc"
2223

2324
/// Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
24-
std::unique_ptr<Pass> createConvertArmSMEToLLVMPass();
25+
std::unique_ptr<Pass>
26+
createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges = false);
2527

2628
/// Configure target to convert from the ArmSME dialect to LLVM intrinsics.
2729
void configureArmSMEToLLVMConversionLegality(ConversionTarget &target);

mlir/include/mlir/Conversion/Passes.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1285,14 +1285,19 @@ def ConvertArmSMEToSCF : Pass<"convert-arm-sme-to-scf"> {
12851285
// ArmSMEToLLVM
12861286
//===----------------------------------------------------------------------===//
12871287

1288-
def ConvertArmSMEToLLVM : Pass<"convert-arm-sme-to-llvm"> {
1288+
def ConvertArmSMEToLLVM : InterfacePass<"convert-arm-sme-to-llvm", "FunctionOpInterface"> {
12891289
let summary = "Lower the operations from the ArmSME dialect into the LLVM "
12901290
"dialect";
12911291
let constructor = "mlir::createConvertArmSMEToLLVMPass()";
12921292
let dependentDialects = [
12931293
"arm_sme::ArmSMEDialect",
12941294
"LLVM::LLVMDialect"
12951295
];
1296+
let options = [
1297+
Option<"dumpTileLiveRanges", "dump-tile-live-ranges",
1298+
"bool", /*default=*/"false",
1299+
"Dump the live ranges of SME tiles (for debugging)">
1300+
];
12961301
}
12971302

12981303
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "mlir/Bytecode/BytecodeOpInterface.h"
1717
#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
18+
#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h"
1819
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
1920
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
2021
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -24,11 +25,6 @@
2425
#include "mlir/IR/OpDefinition.h"
2526
#include "mlir/Interfaces/SideEffectInterfaces.h"
2627

27-
namespace mlir::arm_sme {
28-
static constexpr unsigned kInMemoryTileIdBase = 16;
29-
#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
30-
} // namespace mlir::arm_sme
31-
3228
#define GET_ATTRDEF_CLASSES
3329
#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"
3430

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===- ArmSMEOpInterfaces.h - Arm SME Dialect OpInterfaces ------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_ARMSME_OPINTERFACES_H
10+
#define MLIR_DIALECT_ARMSME_OPINTERFACES_H
11+
12+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
13+
14+
namespace mlir::arm_sme {
15+
16+
namespace detail {
17+
LogicalResult verifyArmSMETileOpInterface(Operation *);
18+
}
19+
20+
// The first in-memory SME tile ID. This is set to 16 as that is the first tile
21+
// ID larger than any virtual tile ID supported by the SME ISA.
22+
static constexpr unsigned kInMemoryTileIdBase = 16;
23+
24+
#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
25+
} // namespace mlir::arm_sme
26+
27+
#endif // MLIR_DIALECT_ARMSME_OPINTERFACES_H

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td

Lines changed: 47 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ def ArmSMETileType : I32EnumAttr<"ArmSMETileType", "Arm SME tile type",
3939

4040
def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
4141
let description = [{
42-
An interface for operations that use or allocate Arm SME tiles. These
43-
operations need to be assigned a tile ID, an i32 attribute, which specifies
44-
which virtual tile within the ZA storage to use. The number of tiles
45-
available depends on the type of the tile. This is summarized below:
42+
An interface for operations that use Arm SME tiles. These operations need to
43+
be assigned a tile ID, an i32 attribute, which specifies which virtual tile
44+
within the ZA storage to use. The number of tiles available depends on the
45+
type of the tile. This is summarized below:
4646

4747
| Tile Vector Types | Possible Tile IDs |
4848
|-------------------------------------------------------------------------|---------------------|
@@ -51,10 +51,6 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
5151
| `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>` | 0 to 3 (inclusive) |
5252
| `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>` | 0 to 7 (inclusive) |
5353
| `vector<[1]x[1]xi128>` | 0 to 15 (inclusive) |
54-
55-
Operations that allocate a new tile (such as arm_sme.get_tile), are used as
56-
the roots for tile allocation, with all operations that (transitively)
57-
depend on a root being assigned the same tile ID.
5854
}];
5955
let methods = [
6056
InterfaceMethod<
@@ -84,20 +80,6 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
8480
return op->getAttrOfType<mlir::IntegerAttr>("tile_id");
8581
}]
8682
>,
87-
InterfaceMethod<
88-
[{
89-
The type of tile this operation allocates. Returns none (std::nullopt)
90-
if this operation does not allocate a tile.
91-
}],
92-
/*returnType=*/"std::optional<::mlir::arm_sme::ArmSMETileType>",
93-
/*methodName=*/"getAllocatedTileType",
94-
/*arguments=*/(ins),
95-
/*methodBody=*/[{}],
96-
/*defaultImpl=*/ [{
97-
// This operation does not allocate a tile.
98-
return std::nullopt;
99-
}]
100-
>,
10183
InterfaceMethod<
10284
"Returns the VectorType of the tile used by this operation.",
10385
/*returnType=*/"VectorType",
@@ -106,30 +88,13 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
10688
];
10789

10890
let extraSharedClassDeclaration = [{
109-
// A helper to create a new operation and propagate this operations tile ID.
110-
template<typename T, typename... Args>
111-
T createOpAndForwardTileId(::mlir::RewriterBase& rewriter, ::mlir::Location loc, Args &&...args) {
112-
auto op = rewriter.create<T>(loc, std::forward<Args>(args)...);
113-
if (auto tileOp = ::llvm::dyn_cast<ArmSMETileOpInterface>(op.getOperation()))
114-
tileOp.setTileId($_op.getTileId());
115-
return op;
116-
}
117-
118-
// A helper to replace this operation and forward its tile ID (if present).
119-
template<typename T, typename... Args>
120-
T replaceWithAndForwardTileId(::mlir::RewriterBase& rewriter, Args &&...args) {
121-
auto newOp = createOpAndForwardTileId<T>(rewriter, $_op.getLoc(), std::forward<Args>(args)...);
122-
rewriter.replaceOp($_op, newOp);
123-
return newOp;
124-
}
125-
12691
bool isInMemoryTile() {
12792
auto tileId = getTileId();
12893
return tileId && tileId.getInt() >= kInMemoryTileIdBase;
12994
}
13095
}];
13196

132-
let verify = [{ return ::mlir::arm_sme::verifyOperationHasValidTileId($_op); }];
97+
let verify = [{ return detail::verifyArmSMETileOpInterface($_op); }];
13398
}
13499

135100
//===----------------------------------------------------------------------===//
@@ -255,30 +220,30 @@ def ArmSME_TypeSizeAttr : EnumAttr<ArmSME_Dialect, TypeSize,
255220
class ArmSME_Op<string mnemonic, list<Trait> traits = []> :
256221
Op<ArmSME_Dialect, mnemonic, traits> {}
257222

258-
def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
259-
let summary = "Returns a SME virtual tile";
223+
def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface, Pure]> {
224+
let summary = "Creates an undefined value of SME virtual tile type";
260225
let description = [{
261-
Allocates a new SME "virtual tile" within a function. The contents of the
262-
tile returned from this operation are undefined.
226+
Creates a new SME "virtual tile" value within a function. The contents of
227+
the tile returned from this operation are undefined.
263228

264229
Example 1:
265230

266231
```mlir
267-
// Allocate an 8-bit element "virtual tile"
232+
// Create an 8-bit element "virtual tile" value:
268233
%za0_b = arm_sme.get_tile: vector<[16]x[16]xi8>
269234
```
270235

271236
Example 2:
272237

273238
```mlir
274-
// Allocate two 16-bit element "virtual tiles"
239+
// Create two 16-bit element "virtual tiles" values:
275240
%za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
276241
%za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
277242
```
278243

279244
Example 3:
280245
```mlir
281-
// Allocate an 128-bit element "virtual tile"
246+
// Create an 128-bit element "virtual tile" value:
282247
%za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
283248
```
284249
}];
@@ -290,37 +255,15 @@ def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
290255
VectorType getTileType() {
291256
return ::llvm::cast<VectorType>(getTile().getType());
292257
}
293-
294-
std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
295-
return arm_sme::getSMETileType(getTileType());
296-
}
297-
}];
298-
}
299-
300-
def MaterializeSSATileOp : ArmSME_Op<"materialize_ssa_tile", [Pure]> {
301-
let summary = "SME tile placeholder";
302-
let description = [{
303-
A placeholder to preserve dataflow while lowering to SME intrinsics (which
304-
do not take or return SME virtual tile values). This operation is intended
305-
to be DCE'd once all ArmSME operations have been lowered.
306-
307-
This operation is not intended to be used outside of the ArmSME -> LLVM
308-
conversion.
309258
}];
310-
let results = (outs SMETile:$tile);
311-
let assemblyFormat = "attr-dict `:` type($tile)";
312259
}
313260

314-
//
315-
// Tile reset.
316-
//
317-
318-
def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
319-
let summary = "Initialize the two-dimensional ZA array with 0s";
261+
def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface, Pure]> {
262+
let summary = "Creates a zero-initialized value of SME virtual tile type";
320263
let results = (outs SMETile:$res);
321264
let description = [{
322-
Initialise ZA with 0. This operation is convenient wrapper for the SME
323-
`zero` intrinsic and instruction.
265+
Creates a new SME "virtual tile" value within a function. The contents of
266+
the tile returned from this operation are zero-initialized.
324267

325268
Example 1: Zero an 8-bit element ZA tile.
326269

@@ -338,16 +281,39 @@ def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
338281
VectorType getVectorType() {
339282
return ::llvm::cast<VectorType>(getRes().getType());
340283
}
341-
std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
342-
return arm_sme::getSMETileType(getVectorType());
343-
}
344284
VectorType getTileType() {
345285
return getVectorType();
346286
}
347287
}];
348288
let assemblyFormat = "attr-dict `:` type($res)";
349289
}
350290

291+
def CopyTileOp : ArmSME_Op<"copy_tile", [
292+
Pure,
293+
ArmSMETileOpInterface,
294+
AllTypesMatch<["tile", "result"]>
295+
]> {
296+
let summary = "Copies an SME tile value";
297+
let arguments = (ins SMETile:$tile);
298+
let results = (outs SMETile:$result);
299+
let description = [{
300+
Copies an SME "virtual tile" value to a new SSA value. This operation is
301+
primarily intended to be used to normalize the IR prior to tile allocation.
302+
303+
Example:
304+
305+
```mlir
306+
%copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
307+
```
308+
}];
309+
let extraClassDeclaration = [{
310+
VectorType getTileType() {
311+
return ::llvm::cast<VectorType>(getResult().getType());
312+
}
313+
}];
314+
let assemblyFormat = "$tile attr-dict `:` type($result)";
315+
}
316+
351317
def TileLoadOp : ArmSME_Op<"tile_load", [
352318
ArmSMETileOpInterface,
353319
AttrSizedOperandSegments,
@@ -417,9 +383,6 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
417383
VectorType getVectorType() {
418384
return ::llvm::cast<VectorType>(getResult().getType());
419385
}
420-
std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
421-
return arm_sme::getSMETileType(getVectorType());
422-
}
423386
VectorType getTileType() {
424387
return getVectorType();
425388
}
@@ -545,7 +508,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
545508
```
546509
}];
547510
let arguments = (ins
548-
Arg<AnyMemRef, "the reference to load from">:$base, SVEPredicate:$mask,
511+
Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base, SVEPredicate:$mask,
549512
SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
550513
ArmSME_TileSliceLayoutAttr:$layout
551514
);
@@ -630,7 +593,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
630593
}
631594

632595
def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
633-
ArmSMETileOpInterface,
596+
ArmSMETileOpInterface, Pure,
634597
AllTypesMatch<["tile", "result"]>,
635598
TypesMatchWith<
636599
"type of 'vector' matches type of 'tile' slice",
@@ -679,7 +642,7 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
679642
}
680643

681644
def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [
682-
ArmSMETileOpInterface,
645+
ArmSMETileOpInterface, Pure,
683646
TypesMatchWith<
684647
"type of 'result' matches type of 'tile' slice",
685648
"tile", "result",
@@ -736,6 +699,7 @@ class OuterProductResultTileTypeConstraint<string operand> :
736699

737700
def OuterProductOp :
738701
ArmSME_Op<"outerproduct", [
702+
Pure,
739703
ArmSMETileOpInterface,
740704
AttrSizedOperandSegments,
741705
AllTypesMatch<["lhs", "rhs"]>,
@@ -802,12 +766,6 @@ let arguments = (ins
802766
VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
803767
VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
804768
VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
805-
std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
806-
// The outerproduct op allocates a new tile if no accumulator is passed.
807-
if (!getAcc())
808-
return arm_sme::getSMETileType(getResultType());
809-
return std::nullopt;
810-
}
811769
VectorType getTileType() {
812770
return getResultType();
813771
}
@@ -819,6 +777,7 @@ class OuterProductWideningBase<string mnemonic,
819777
list<Type> allowedResultVectorTypes,
820778
int numOuterProducts> :
821779
ArmSME_Op<mnemonic, [
780+
Pure,
822781
ArmSMETileOpInterface,
823782
AttrSizedOperandSegments,
824783
AllTypesMatch<["lhs", "rhs"]>,
@@ -857,12 +816,6 @@ class OuterProductWideningBase<string mnemonic,
857816
VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
858817
VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
859818
VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
860-
std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
861-
// The outerproduct op allocates a new tile if no accumulator is passed.
862-
if (!getAcc())
863-
return arm_sme::getSMETileType(getResultType());
864-
return std::nullopt;
865-
}
866819
VectorType getTileType() {
867820
return getResultType();
868821
}

mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,6 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
2929
const ArmStreamingMode = ArmStreamingMode::Streaming,
3030
const ArmZaMode = ArmZaMode::Disabled, bool onlyIfRequiredByOps = false);
3131

32-
/// Pass that allocates tile IDs to ArmSME operations.
33-
std::unique_ptr<Pass> createTileAllocationPass();
34-
3532
/// Pass that fuses 'arm_sme.outerproduct' ops into 2-way or 4-way widening
3633
/// variants.
3734
std::unique_ptr<Pass> createOuterProductFusionPass();

0 commit comments

Comments
 (0)