Skip to content

Commit 75a71c2

Browse files
authored
[mlir][ArmSME] Support vertical layout in load and store ops (#66758)
In SME a ZA tile slice is a one-dimensional set of horizontally or vertically contiguous elements within a ZA tile. Currently the load and store ops only support horizontal tile slices. This patch adds a tile slice layout attribute to the load and store ops to support both horizontal and vertical tile slices. When lowering from Vector dialect horizontal layout is the default.
1 parent 3fc7af5 commit 75a71c2

File tree

11 files changed

+1237
-176
lines changed

11 files changed

+1237
-176
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@
2121
#include "mlir/IR/OpDefinition.h"
2222
#include "mlir/Interfaces/SideEffectInterfaces.h"
2323

24+
#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h.inc"
25+
26+
#define GET_ATTRDEF_CLASSES
27+
#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"
28+
2429
#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.h.inc"
2530

2631
#define GET_OP_CLASSES

mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td

Lines changed: 83 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#ifndef ARMSME_OPS
1515
#define ARMSME_OPS
1616

17+
include "mlir/IR/EnumAttr.td"
1718
include "mlir/IR/OpBase.td"
1819
include "mlir/Interfaces/SideEffectInterfaces.td"
1920
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
@@ -36,6 +37,7 @@ def ArmSME_Dialect : Dialect {
3637
https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions
3738
}];
3839
let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect"];
40+
let useDefaultAttributePrinterParser = 1;
3941
}
4042

4143
//===----------------------------------------------------------------------===//
@@ -83,6 +85,24 @@ def TileElementWidthMatchesTileID : TypesMatchWith<
8385
"::llvm::cast<VectorType>($_self).getElementType())"
8486
".getWidth())">;
8587

88+
//===----------------------------------------------------------------------===//
89+
// ArmSME attr definitions
90+
//===----------------------------------------------------------------------===//
91+
92+
def TileSliceLayout : I32EnumAttr<"TileSliceLayout", "Layout of a tile slice", [
93+
I32EnumAttrCase<"Horizontal", 0, "horizontal">,
94+
I32EnumAttrCase<"Vertical", 1, "vertical">,
95+
]> {
96+
let cppNamespace = "::mlir::arm_sme";
97+
let genSpecializedAttr = 0;
98+
}
99+
100+
/// An attribute that specifies the layout of a tile slice in a tile.
101+
def ArmSME_TileSliceLayoutAttr : EnumAttr<ArmSME_Dialect, TileSliceLayout,
102+
"layout"> {
103+
let assemblyFormat = "`<` $value `>`";
104+
}
105+
86106
//===----------------------------------------------------------------------===//
87107
// ArmSME op definitions
88108
//===----------------------------------------------------------------------===//
@@ -240,28 +260,33 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
240260
let description = [{
241261
Loads a 2D SME "virtual tile" from memory defined by a base and indices,
242262
with the shape defined by the 2D scalable vector type of the result tile.
243-
The slice of memory must be contiguous. The memref must be either rank 1 or
244-
rank 2 with dynamic dimensions, since the operation is scalable, and the
245-
element type must be a scalar that matches the element type of the result.
263+
An optional tile slice layout attribute specifies whether the slices of the
264+
tile being loaded are horizontal (default) or vertical. The slice of memory
265+
must be contiguous. The memref must be either rank 1 or rank 2 with dynamic
266+
dimensions, since the operation is scalable, and the element type must be a
267+
scalar that matches the element type of the result.
246268

247-
Example 1: Load an 8-bit element ZA tile from memory (ZA0.B).
269+
Example 1: Load an 8-bit element ZA tile with horizontal layout (default) from memory (ZA0.B).
248270
```mlir
249271
%tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
250272
```
251273

252-
Example 2: Load a FP 32-bit element ZA tile from memory.
274+
Example 2: Load a FP 32-bit element ZA tile with vertical layout from memory.
253275
```mlir
254-
%tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
276+
%tile = arm_sme.tile_load %base[%c0, %c0], <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
255277
```
256278

257-
Example 3: Load a 128-bit element ZA tile from memory.
279+
Example 3: Load a 128-bit element ZA tile with horizontal layout (default) from memory.
258280
```mlir
259-
%tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
281+
%tile = arm_sme.tile_load %base[%c0, %c0], <horizontal> : memref<?x?xi128>, vector<[1]x[1]xi128>
260282
```
261283
}];
262284
let arguments = (ins
263-
Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
264-
Variadic<Index>:$indices);
285+
Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
286+
Variadic<Index>:$indices,
287+
DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
288+
"::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
289+
);
265290
let results = (outs SMETile:$result);
266291

267292
let extraClassDeclaration = [{
@@ -274,37 +299,42 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
274299
}];
275300

276301
let assemblyFormat =
277-
"$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
302+
"$base `[` $indices `]` (`,` $layout^)? attr-dict "
303+
"`:` type($base) `,` type($result)";
278304
}
279305

280306
def TileStoreOp : ArmSME_Op<"tile_store"> {
281307
let summary = "Tile store operation";
282308
let description = [{
283309
Stores a 2D SME "virtual tile" to memory defined by a base and indices,
284310
with the shape defined by the 2D scalable vector type of the tile being
285-
stored. The slice of memory must be contiguous. The memref must be either
286-
rank 1 or rank 2 with dynamic dimensions, since the operation is scalable,
287-
and the element type must be a scalar that matches the element type of the
288-
result.
311+
stored. An optional tile slice layout attribute specifies whether the
312+
slices of the tile being stored are horizontal (default) or vertical. The
313+
slice of memory must be contiguous. The memref must be either rank 1 or
314+
rank 2 with dynamic dimensions, since the operation is scalable, and the
315+
element type must be a scalar that matches the element type of the result.
289316

290-
Example 1: Store an 8-bit element ZA tile to memory (ZA0.B).
317+
Example 1: Store an 8-bit element ZA tile with horizontal (default) layout to memory (ZA0.B).
291318
```mlir
292319
arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
293320
```
294321

295-
Example 2: Store a FP 32-bit element ZA tile to memory.
322+
Example 2: Store a FP 32-bit element ZA tile with vertical layout to memory.
296323
```mlir
297-
arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
324+
arm_sme.tile_store %tile, %base[%c0, %c0], <vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
298325
```
299326

300-
Example 3: Store a 128-bit element ZA tile to memory.
327+
Example 3: Store a 128-bit element ZA tile with horizontal (default) layout to memory.
301328
```mlir
302-
arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
329+
arm_sme.tile_store %tile, %base[%c0, %c0], <horizontal> : vector<[1]x[1]xi128>, memref<?x?xi128>
303330
```
304331
}];
305332
let arguments = (ins SMETile:$valueToStore,
306-
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
307-
Variadic<Index>:$indices);
333+
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
334+
Variadic<Index>:$indices,
335+
DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
336+
"::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
337+
);
308338
let extraClassDeclaration = [{
309339
MemRefType getMemRefType() {
310340
return ::llvm::cast<MemRefType>(getBase().getType());
@@ -314,8 +344,9 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
314344
}
315345
}];
316346

317-
let assemblyFormat = "$valueToStore `,` $base `[` $indices `]` attr-dict "
318-
"`:` type($base) `,` type($valueToStore)";
347+
let assemblyFormat =
348+
"$valueToStore `,` $base `[` $indices `]` (`,` $layout^)? attr-dict "
349+
"`:` type($base) `,` type($valueToStore)";
319350
}
320351

321352
def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
@@ -326,31 +357,36 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
326357
Loads a 1D tile slice from memory into a 2D SME "virtual tile". The tile
327358
slice is defined by the dimension of the 2D scalable vector type pointed by
328359
the index. A tile slice index describes where in the input tile the tile
329-
slice is loaded to. The updated tile is returned as the result.
360+
slice is loaded to. An optional tile slice layout attribute specifies
361+
whether the tile slice being loaded at the given index is horizontal
362+
(default) or vertical. The updated tile is returned as the result.
330363

331364
The slice of memory read is defined by a base and indices and must be
332365
contiguous. The memref must be either rank 1 or rank 2, have dynamic
333366
dimensions since the operation is scalable, and the element type must be a
334367
scalar that matches the element type of the result.
335368

336-
Example 1: Load a vector<[16]xi8> tile slice from memory into tile at given index.
369+
Example 1: Load a vector<[16]xi8> tile slice from memory into tile horizontally (default) at given index.
337370
```mlir
338371
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
339372
```
340373

341-
Example 2: Load a vector<[4]xf32> tile slice from memory into tile at given index.
374+
Example 2: Load a vector<[4]xf32> tile slice from memory into tile vertically at given index.
342375
```mlir
343-
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
376+
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
344377
```
345378

346-
Example 3: Load a vector<[1]xi128> tile slice from memory into tile at given index.
379+
Example 3: Load a vector<[1]xi128> tile slice from memory into tile vertically at given index.
347380
```mlir
348-
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
381+
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
349382
```
350383
}];
351384
let arguments = (ins
352-
Arg<AnyMemRef, "the reference to load from">:$base,
353-
SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index);
385+
Arg<AnyMemRef, "the reference to load from">:$base,
386+
SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
387+
DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
388+
"::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
389+
);
354390
let results = (outs SMETile:$result);
355391

356392
let extraClassDeclaration = [{
@@ -363,7 +399,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
363399
}];
364400

365401
let assemblyFormat = [{
366-
$base `[` $indices `]` `,` $tile `,` $tile_slice_index
402+
$base `[` $indices `]` `,` $tile `,` $tile_slice_index (`,` $layout^)?
367403
attr-dict `:` type($base) `,` type($result)
368404
}];
369405
}
@@ -374,31 +410,36 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
374410
Stores a 1D tile slice from a 2D SME "virtual tile" into memory. The tile
375411
slice is defined by the dimension of the 2D scalable vector type pointed by
376412
the index. A tile slice index describes where in the input tile the tile
377-
slice is stored from.
413+
slice is stored from. An optional tile slice layout attribute specifies
414+
whether the tile slice being stored from the given index is horizontal
415+
(default) or vertical.
378416

379417
The slice of memory written is defined by a base and indices and must be
380418
contiguous. The memref must be either rank 1 or rank 2, have dynamic
381419
dimensions since the operation is scalable, and the element type must be a
382420
scalar that matches the element type of the input tile.
383421

384-
Example 1: Store vector<[16]xi8> tile slice from tile at given index to memory.
422+
Example 1: Store vector<[16]xi8> horizontal (default) tile slice from tile at given index to memory.
385423
```mlir
386424
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
387425
```
388426

389-
Example 2: Store vector<[4]xf32> tile slice from tile at given index to memory.
427+
Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory.
390428
```mlir
391-
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
429+
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0], <vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
392430
```
393431

394-
Example 3: Store a vector<[1]xi128> tile slice from tile at given index to memory.
432+
Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory.
395433
```mlir
396-
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
434+
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0], <vertical> : vector<[1]x[1]xi128>, memref<?x?xi128>
397435
```
398436
}];
399437
let arguments = (ins SMETile:$tile, Index:$tile_slice_index,
400-
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
401-
Variadic<Index>:$indices);
438+
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
439+
Variadic<Index>:$indices,
440+
DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
441+
"::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
442+
);
402443
let extraClassDeclaration = [{
403444
MemRefType getMemRefType() {
404445
return ::llvm::cast<MemRefType>(getBase().getType());
@@ -409,7 +450,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
409450
}];
410451

411452
let assemblyFormat = [{
412-
$tile `,` $tile_slice_index `,` $base `[` $indices `]`
453+
$tile `,` $tile_slice_index `,` $base `[` $indices `]` (`,` $layout^)?
413454
attr-dict `:` type($base) `,` type($tile)
414455
}];
415456
}

mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,9 @@ add_mlir_doc(ArmSME ArmSME Dialects/ -gen-dialect-doc -dialect=arm_sme)
44
set(LLVM_TARGET_DEFINITIONS ArmSME.td)
55
mlir_tablegen(ArmSMEConversions.inc -gen-llvmir-conversions)
66
add_public_tablegen_target(MLIRArmSMEConversionsIncGen)
7+
8+
mlir_tablegen(ArmSMEEnums.h.inc -gen-enum-decls)
9+
mlir_tablegen(ArmSMEEnums.cpp.inc -gen-enum-defs)
10+
mlir_tablegen(ArmSMEAttrDefs.h.inc -gen-attrdef-decls -attrdefs-dialect=arm_sme)
11+
mlir_tablegen(ArmSMEAttrDefs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=arm_sme)
12+
add_public_tablegen_target(MLIRArmSMEAttrDefsIncGen)

mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,9 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
116116
getMemrefIndices(tileLoadOp.getIndices(),
117117
tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
118118
numTileSlices, memrefIndices, loc, rewriter);
119-
rewriter.create<arm_sme::LoadTileSliceOp>(loc, tileType,
120-
tileLoadOp.getBase(), tile,
121-
memrefIndices, tileSliceIndex);
119+
rewriter.create<arm_sme::LoadTileSliceOp>(
120+
loc, tileType, tileLoadOp.getBase(), tile, memrefIndices,
121+
tileSliceIndex, tileLoadOp.getLayout());
122122

123123
rewriter.setInsertionPointAfter(forOp);
124124

@@ -134,7 +134,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
134134
///
135135
/// BEFORE:
136136
/// ```mlir
137-
/// arm_sme.tile_store %tile, %dest[%c0, %c0]
137+
/// arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical>
138138
/// : memref<?x?xi32>, vector<[4]x[4]xi32
139139
/// ```
140140
///
@@ -146,8 +146,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
146146
/// %min_svl_s = arith.constant 4 : index
147147
/// %svl_s = arith.muli %min_svl_s, %vscale : index
148148
/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
149-
/// arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx]
150-
/// : memref<?x?xi32>, vector<[4]x[4]xi32>
149+
/// arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx],
150+
/// <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
151151
/// }
152152
/// ```
153153
struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
@@ -184,7 +184,7 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
184184
numTileSlices, memrefIndices, loc, rewriter);
185185
rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
186186
tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
187-
tileStoreOp.getBase(), memrefIndices);
187+
tileStoreOp.getBase(), memrefIndices, tileStoreOp.getLayout());
188188

189189
return success();
190190
}

mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212

1313
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
1414
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15+
#include "mlir/IR/DialectImplementation.h"
1516
#include "mlir/IR/TypeUtilities.h"
17+
#include "llvm/ADT/TypeSwitch.h"
1618

1719
using namespace mlir;
1820
using namespace mlir::arm_sme;
@@ -23,13 +25,23 @@ using namespace mlir::arm_sme;
2325

2426
#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.cpp.inc"
2527

28+
#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.cpp.inc"
29+
2630
#define GET_OP_CLASSES
2731
#include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
2832

2933
#define GET_TYPEDEF_CLASSES
3034
#include "mlir/Dialect/ArmSME/IR/ArmSMETypes.cpp.inc"
3135

36+
#define GET_ATTRDEF_CLASSES
37+
#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.cpp.inc"
38+
3239
void ArmSMEDialect::initialize() {
40+
addAttributes<
41+
#define GET_ATTRDEF_LIST
42+
#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.cpp.inc"
43+
>();
44+
3345
addOperations<
3446
#define GET_OP_LIST
3547
#include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"

mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRArmSMEDialect
66

77
DEPENDS
88
MLIRArmSMEIncGen
9+
MLIRArmSMEAttrDefsIncGen
910

1011
LINK_LIBS PUBLIC
1112
MLIRIR

0 commit comments

Comments
 (0)