14
14
#ifndef ARMSME_OPS
15
15
#define ARMSME_OPS
16
16
17
+ include "mlir/IR/EnumAttr.td"
17
18
include "mlir/IR/OpBase.td"
18
19
include "mlir/Interfaces/SideEffectInterfaces.td"
19
20
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
@@ -36,6 +37,7 @@ def ArmSME_Dialect : Dialect {
36
37
https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions
37
38
}];
38
39
let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect"];
40
+ let useDefaultAttributePrinterParser = 1;
39
41
}
40
42
41
43
//===----------------------------------------------------------------------===//
@@ -83,6 +85,24 @@ def TileElementWidthMatchesTileID : TypesMatchWith<
83
85
"::llvm::cast<VectorType>($_self).getElementType())"
84
86
".getWidth())">;
85
87
88
+ //===----------------------------------------------------------------------===//
89
+ // ArmSME attr definitions
90
+ //===----------------------------------------------------------------------===//
91
+
92
+ def TileSliceLayout : I32EnumAttr<"TileSliceLayout", "Layout of a tile slice", [
93
+ I32EnumAttrCase<"Horizontal", 0, "horizontal">,
94
+ I32EnumAttrCase<"Vertical", 1, "vertical">,
95
+ ]> {
96
+ let cppNamespace = "::mlir::arm_sme";
97
+ let genSpecializedAttr = 0;
98
+ }
99
+
100
+ /// An attribute that specifies the layout of a tile slice in a tile.
101
+ def ArmSME_TileSliceLayoutAttr : EnumAttr<ArmSME_Dialect, TileSliceLayout,
102
+ "layout"> {
103
+ let assemblyFormat = "`<` $value `>`";
104
+ }
105
+
86
106
//===----------------------------------------------------------------------===//
87
107
// ArmSME op definitions
88
108
//===----------------------------------------------------------------------===//
@@ -240,28 +260,33 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
240
260
let description = [{
241
261
Loads a 2D SME "virtual tile" from memory defined by a base and indices,
242
262
with the shape defined by the 2D scalable vector type of the result tile.
243
- The slice of memory must be contiguous. The memref must be either rank 1 or
244
- rank 2 with dynamic dimensions, since the operation is scalable, and the
245
- element type must be a scalar that matches the element type of the result.
263
+ An optional tile slice layout attribute specifies whether the slices of the
264
+ tile being loaded are horizontal (default) or vertical. The slice of memory
265
+ must be contiguous. The memref must be either rank 1 or rank 2 with dynamic
266
+ dimensions, since the operation is scalable, and the element type must be a
267
+ scalar that matches the element type of the result.
246
268
247
- Example 1: Load an 8-bit element ZA tile from memory (ZA0.B).
269
+ Example 1: Load an 8-bit element ZA tile with horizontal layout (default) from memory (ZA0.B).
248
270
```mlir
249
271
%tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
250
272
```
251
273
252
- Example 2: Load a FP 32-bit element ZA tile from memory.
274
+ Example 2: Load a FP 32-bit element ZA tile with vertical layout from memory.
253
275
```mlir
254
- %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
276
+ %tile = arm_sme.tile_load %base[%c0, %c0], <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
255
277
```
256
278
257
- Example 3: Load a 128-bit element ZA tile from memory.
279
+ Example 3: Load a 128-bit element ZA tile with horizontal layout (default) from memory.
258
280
```mlir
259
- %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
281
+ %tile = arm_sme.tile_load %base[%c0, %c0], <horizontal> : memref<?x?xi128>, vector<[1]x[1]xi128>
260
282
```
261
283
}];
262
284
let arguments = (ins
263
- Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
264
- Variadic<Index>:$indices);
285
+ Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
286
+ Variadic<Index>:$indices,
287
+ DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
288
+ "::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
289
+ );
265
290
let results = (outs SMETile:$result);
266
291
267
292
let extraClassDeclaration = [{
@@ -274,37 +299,42 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
274
299
}];
275
300
276
301
let assemblyFormat =
277
- "$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
302
+ "$base `[` $indices `]` (`,` $layout^)? attr-dict "
303
+ "`:` type($base) `,` type($result)";
278
304
}
279
305
280
306
def TileStoreOp : ArmSME_Op<"tile_store"> {
281
307
let summary = "Tile store operation";
282
308
let description = [{
283
309
Stores a 2D SME "virtual tile" to memory defined by a base and indices,
284
310
with the shape defined by the 2D scalable vector type of the tile being
285
- stored. The slice of memory must be contiguous. The memref must be either
286
- rank 1 or rank 2 with dynamic dimensions, since the operation is scalable,
287
- and the element type must be a scalar that matches the element type of the
288
- result.
311
+ stored. An optional tile slice layout attribute specifies whether the
312
+ slices of the tile being stored are horizontal (default) or vertical. The
313
+ slice of memory must be contiguous. The memref must be either rank 1 or
314
+ rank 2 with dynamic dimensions, since the operation is scalable, and the
315
+ element type must be a scalar that matches the element type of the result.
289
316
290
- Example 1: Store an 8-bit element ZA tile to memory (ZA0.B).
317
+ Example 1: Store an 8-bit element ZA tile with horizontal (default) layout to memory (ZA0.B).
291
318
```mlir
292
319
arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
293
320
```
294
321
295
- Example 2: Store a FP 32-bit element ZA tile to memory.
322
+ Example 2: Store a FP 32-bit element ZA tile with vertical layout to memory.
296
323
```mlir
297
- arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
324
+ arm_sme.tile_store %tile, %base[%c0, %c0], <vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
298
325
```
299
326
300
- Example 3: Store a 128-bit element ZA tile to memory.
327
+ Example 3: Store a 128-bit element ZA tile with horizontal (default) layout to memory.
301
328
```mlir
302
- arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
329
+ arm_sme.tile_store %tile, %base[%c0, %c0], <horizontal> : vector<[1]x[1]xi128>, memref<?x?xi128>
303
330
```
304
331
}];
305
332
let arguments = (ins SMETile:$valueToStore,
306
- Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
307
- Variadic<Index>:$indices);
333
+ Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
334
+ Variadic<Index>:$indices,
335
+ DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
336
+ "::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
337
+ );
308
338
let extraClassDeclaration = [{
309
339
MemRefType getMemRefType() {
310
340
return ::llvm::cast<MemRefType>(getBase().getType());
@@ -314,8 +344,9 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
314
344
}
315
345
}];
316
346
317
- let assemblyFormat = "$valueToStore `,` $base `[` $indices `]` attr-dict "
318
- "`:` type($base) `,` type($valueToStore)";
347
+ let assemblyFormat =
348
+ "$valueToStore `,` $base `[` $indices `]` (`,` $layout^)? attr-dict "
349
+ "`:` type($base) `,` type($valueToStore)";
319
350
}
320
351
321
352
def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
@@ -326,31 +357,36 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
326
357
Loads a 1D tile slice from memory into a 2D SME "virtual tile". The tile
327
358
slice is defined by the dimension of the 2D scalable vector type pointed by
328
359
the index. A tile slice index describes where in the input tile the tile
329
- slice is loaded to. The updated tile is returned as the result.
360
+ slice is loaded to. An optional tile slice layout attribute specifies
361
+ whether the tile slice being loaded at the given index is horizontal
362
+ (default) or vertical. The updated tile is returned as the result.
330
363
331
364
The slice of memory read is defined by a base and indices and must be
332
365
contiguous. The memref must be either rank 1 or rank 2, have dynamic
333
366
dimensions since the operation is scalable, and the element type must be a
334
367
scalar that matches the element type of the result.
335
368
336
- Example 1: Load a vector<[16]xi8> tile slice from memory into tile at given index.
369
+ Example 1: Load a vector<[16]xi8> tile slice from memory into tile horizontally (default) at given index.
337
370
```mlir
338
371
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
339
372
```
340
373
341
- Example 2: Load a vector<[4]xf32> tile slice from memory into tile at given index.
374
+ Example 2: Load a vector<[4]xf32> tile slice from memory into tile vertically at given index.
342
375
```mlir
343
- %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
376
+ %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
344
377
```
345
378
346
- Example 3: Load a vector<[1]xi128> tile slice from memory into tile at given index.
379
+ Example 3: Load a vector<[1]xi128> tile slice from memory into tile vertically at given index.
347
380
```mlir
348
- %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
381
+ %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
349
382
```
350
383
}];
351
384
let arguments = (ins
352
- Arg<AnyMemRef, "the reference to load from">:$base,
353
- SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index);
385
+ Arg<AnyMemRef, "the reference to load from">:$base,
386
+ SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
387
+ DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
388
+ "::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
389
+ );
354
390
let results = (outs SMETile:$result);
355
391
356
392
let extraClassDeclaration = [{
@@ -363,7 +399,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
363
399
}];
364
400
365
401
let assemblyFormat = [{
366
- $base `[` $indices `]` `,` $tile `,` $tile_slice_index
402
+ $base `[` $indices `]` `,` $tile `,` $tile_slice_index (`,` $layout^)?
367
403
attr-dict `:` type($base) `,` type($result)
368
404
}];
369
405
}
@@ -374,31 +410,36 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
374
410
Stores a 1D tile slice from a 2D SME "virtual tile" into memory. The tile
375
411
slice is defined by the dimension of the 2D scalable vector type pointed by
376
412
the index. A tile slice index describes where in the input tile the tile
377
- slice is stored from.
413
+ slice is stored from. An optional tile slice layout attribute specifies
414
+ whether the tile slice being stored from the given index is horizontal
415
+ (default) or vertical.
378
416
379
417
The slice of memory written is defined by a base and indices and must be
380
418
contiguous. The memref must be either rank 1 or rank 2, have dynamic
381
419
dimensions since the operation is scalable, and the element type must be a
382
420
scalar that matches the element type of the input tile.
383
421
384
- Example 1: Store vector<[16]xi8> tile slice from tile at given index to memory.
422
+ Example 1: Store vector<[16]xi8> horizontal (default) tile slice from tile at given index to memory.
385
423
```mlir
386
424
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
387
425
```
388
426
389
- Example 2: Store vector<[4]xf32> tile slice from tile at given index to memory.
427
+ Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory.
390
428
```mlir
391
- arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
429
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0], <vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
392
430
```
393
431
394
- Example 3: Store a vector<[1]xi128> tile slice from tile at given index to memory.
432
+ Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory.
395
433
```mlir
396
- arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
434
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0], <vertical> : vector<[1]x[1]xi128>, memref<?x?xi128>
397
435
```
398
436
}];
399
437
let arguments = (ins SMETile:$tile, Index:$tile_slice_index,
400
- Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
401
- Variadic<Index>:$indices);
438
+ Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
439
+ Variadic<Index>:$indices,
440
+ DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
441
+ "::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
442
+ );
402
443
let extraClassDeclaration = [{
403
444
MemRefType getMemRefType() {
404
445
return ::llvm::cast<MemRefType>(getBase().getType());
@@ -409,7 +450,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
409
450
}];
410
451
411
452
let assemblyFormat = [{
412
- $tile `,` $tile_slice_index `,` $base `[` $indices `]`
453
+ $tile `,` $tile_slice_index `,` $base `[` $indices `]` (`,` $layout^)?
413
454
attr-dict `:` type($base) `,` type($tile)
414
455
}];
415
456
}
0 commit comments