@@ -39,10 +39,10 @@ def ArmSMETileType : I32EnumAttr<"ArmSMETileType", "Arm SME tile type",
39
39
40
40
def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
41
41
let description = [{
42
- An interface for operations that use or allocate Arm SME tiles. These
43
- operations need to be assigned a tile ID, an i32 attribute, which specifies
44
- which virtual tile within the ZA storage to use. The number of tiles
45
- available depends on the type of the tile. This is summarized below:
42
+ An interface for operations that use Arm SME tiles. These operations need to
43
+ be assigned a tile ID, an i32 attribute, which specifies which virtual tile
44
+ within the ZA storage to use. The number of tiles available depends on the
45
+ type of the tile. This is summarized below:
46
46
47
47
| Tile Vector Types | Possible Tile IDs |
48
48
|-------------------------------------------------------------------------|---------------------|
@@ -51,10 +51,6 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
51
51
| `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>` | 0 to 3 (inclusive) |
52
52
| `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>` | 0 to 7 (inclusive) |
53
53
| `vector<[1]x[1]xi128>` | 0 to 15 (inclusive) |
54
-
55
- Operations that allocate a new tile (such as arm_sme.get_tile), are used as
56
- the roots for tile allocation, with all operations that (transitively)
57
- depend on a root being assigned the same tile ID.
58
54
}];
59
55
let methods = [
60
56
InterfaceMethod<
@@ -84,20 +80,6 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
84
80
return op->getAttrOfType<mlir::IntegerAttr>("tile_id");
85
81
}]
86
82
>,
87
- InterfaceMethod<
88
- [{
89
- The type of tile this operation allocates. Returns none (std::nullopt)
90
- if this operation does not allocate a tile.
91
- }],
92
- /*returnType=*/"std::optional<::mlir::arm_sme::ArmSMETileType>",
93
- /*methodName=*/"getAllocatedTileType",
94
- /*arguments=*/(ins),
95
- /*methodBody=*/[{}],
96
- /*defaultImpl=*/ [{
97
- // This operation does not allocate a tile.
98
- return std::nullopt;
99
- }]
100
- >,
101
83
InterfaceMethod<
102
84
"Returns the VectorType of the tile used by this operation.",
103
85
/*returnType=*/"VectorType",
@@ -106,30 +88,13 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
106
88
];
107
89
108
90
let extraSharedClassDeclaration = [{
109
- // A helper to create a new operation and propagate this operations tile ID.
110
- template<typename T, typename... Args>
111
- T createOpAndForwardTileId(::mlir::RewriterBase& rewriter, ::mlir::Location loc, Args &&...args) {
112
- auto op = rewriter.create<T>(loc, std::forward<Args>(args)...);
113
- if (auto tileOp = ::llvm::dyn_cast<ArmSMETileOpInterface>(op.getOperation()))
114
- tileOp.setTileId($_op.getTileId());
115
- return op;
116
- }
117
-
118
- // A helper to replace this operation and forward its tile ID (if present).
119
- template<typename T, typename... Args>
120
- T replaceWithAndForwardTileId(::mlir::RewriterBase& rewriter, Args &&...args) {
121
- auto newOp = createOpAndForwardTileId<T>(rewriter, $_op.getLoc(), std::forward<Args>(args)...);
122
- rewriter.replaceOp($_op, newOp);
123
- return newOp;
124
- }
125
-
126
91
bool isInMemoryTile() {
127
92
auto tileId = getTileId();
128
93
return tileId && tileId.getInt() >= kInMemoryTileIdBase;
129
94
}
130
95
}];
131
96
132
- let verify = [{ return ::mlir::arm_sme::verifyOperationHasValidTileId ($_op); }];
97
+ let verify = [{ return detail::verifyArmSMETileOpInterface ($_op); }];
133
98
}
134
99
135
100
//===----------------------------------------------------------------------===//
@@ -255,30 +220,30 @@ def ArmSME_TypeSizeAttr : EnumAttr<ArmSME_Dialect, TypeSize,
255
220
class ArmSME_Op<string mnemonic, list<Trait> traits = []> :
256
221
Op<ArmSME_Dialect, mnemonic, traits> {}
257
222
258
- def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
259
- let summary = "Returns a SME virtual tile";
223
+ def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface, Pure ]> {
224
+ let summary = "Creates an undefined value of SME virtual tile type ";
260
225
let description = [{
261
- Allocates a new SME "virtual tile" within a function. The contents of the
262
- tile returned from this operation are undefined.
226
+ Creates a new SME "virtual tile" value within a function. The contents of
227
+ the tile returned from this operation are undefined.
263
228
264
229
Example 1:
265
230
266
231
```mlir
267
- // Allocate an 8-bit element "virtual tile"
232
+ // Create an 8-bit element "virtual tile" value:
268
233
%za0_b = arm_sme.get_tile: vector<[16]x[16]xi8>
269
234
```
270
235
271
236
Example 2:
272
237
273
238
```mlir
274
- // Allocate two 16-bit element "virtual tiles"
239
+ // Create two 16-bit element "virtual tiles" values:
275
240
%za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
276
241
%za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
277
242
```
278
243
279
244
Example 3:
280
245
```mlir
281
- // Allocate an 128-bit element "virtual tile"
246
+ // Create an 128-bit element "virtual tile" value:
282
247
%za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
283
248
```
284
249
}];
@@ -290,37 +255,15 @@ def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
290
255
VectorType getTileType() {
291
256
return ::llvm::cast<VectorType>(getTile().getType());
292
257
}
293
-
294
- std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
295
- return arm_sme::getSMETileType(getTileType());
296
- }
297
- }];
298
- }
299
-
300
- def MaterializeSSATileOp : ArmSME_Op<"materialize_ssa_tile", [Pure]> {
301
- let summary = "SME tile placeholder";
302
- let description = [{
303
- A placeholder to preserve dataflow while lowering to SME intrinsics (which
304
- do not take or return SME virtual tile values). This operation is intended
305
- to be DCE'd once all ArmSME operations have been lowered.
306
-
307
- This operation is not intended to be used outside of the ArmSME -> LLVM
308
- conversion.
309
258
}];
310
- let results = (outs SMETile:$tile);
311
- let assemblyFormat = "attr-dict `:` type($tile)";
312
259
}
313
260
314
- //
315
- // Tile reset.
316
- //
317
-
318
- def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
319
- let summary = "Initialize the two-dimensional ZA array with 0s";
261
+ def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface, Pure]> {
262
+ let summary = "Creates a zero-initialized value of SME virtual tile type";
320
263
let results = (outs SMETile:$res);
321
264
let description = [{
322
- Initialise ZA with 0. This operation is convenient wrapper for the SME
323
- `zero` intrinsic and instruction .
265
+ Creates a new SME "virtual tile" value within a function. The contents of
266
+ the tile returned from this operation are zero-initialized .
324
267
325
268
Example 1: Zero an 8-bit element ZA tile.
326
269
@@ -338,16 +281,39 @@ def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
338
281
VectorType getVectorType() {
339
282
return ::llvm::cast<VectorType>(getRes().getType());
340
283
}
341
- std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
342
- return arm_sme::getSMETileType(getVectorType());
343
- }
344
284
VectorType getTileType() {
345
285
return getVectorType();
346
286
}
347
287
}];
348
288
let assemblyFormat = "attr-dict `:` type($res)";
349
289
}
350
290
291
+ def CopyTileOp : ArmSME_Op<"copy_tile", [
292
+ Pure,
293
+ ArmSMETileOpInterface,
294
+ AllTypesMatch<["tile", "result"]>
295
+ ]> {
296
+ let summary = "Copies an SME tile value";
297
+ let arguments = (ins SMETile:$tile);
298
+ let results = (outs SMETile:$result);
299
+ let description = [{
300
+ Copies an SME "virtual tile" value to a new SSA value. This operation is
301
+ primarily intended to be used to normalize the IR prior to tile allocation.
302
+
303
+ Example:
304
+
305
+ ```mlir
306
+ %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
307
+ ```
308
+ }];
309
+ let extraClassDeclaration = [{
310
+ VectorType getTileType() {
311
+ return ::llvm::cast<VectorType>(getResult().getType());
312
+ }
313
+ }];
314
+ let assemblyFormat = "$tile attr-dict `:` type($result)";
315
+ }
316
+
351
317
def TileLoadOp : ArmSME_Op<"tile_load", [
352
318
ArmSMETileOpInterface,
353
319
AttrSizedOperandSegments,
@@ -417,9 +383,6 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
417
383
VectorType getVectorType() {
418
384
return ::llvm::cast<VectorType>(getResult().getType());
419
385
}
420
- std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
421
- return arm_sme::getSMETileType(getVectorType());
422
- }
423
386
VectorType getTileType() {
424
387
return getVectorType();
425
388
}
@@ -545,7 +508,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
545
508
```
546
509
}];
547
510
let arguments = (ins
548
- Arg<AnyMemRef, "the reference to load from">:$base, SVEPredicate:$mask,
511
+ Arg<AnyMemRef, "the reference to load from", [MemRead] >:$base, SVEPredicate:$mask,
549
512
SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
550
513
ArmSME_TileSliceLayoutAttr:$layout
551
514
);
@@ -630,7 +593,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
630
593
}
631
594
632
595
def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
633
- ArmSMETileOpInterface,
596
+ ArmSMETileOpInterface, Pure,
634
597
AllTypesMatch<["tile", "result"]>,
635
598
TypesMatchWith<
636
599
"type of 'vector' matches type of 'tile' slice",
@@ -679,7 +642,7 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
679
642
}
680
643
681
644
def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [
682
- ArmSMETileOpInterface,
645
+ ArmSMETileOpInterface, Pure,
683
646
TypesMatchWith<
684
647
"type of 'result' matches type of 'tile' slice",
685
648
"tile", "result",
@@ -736,6 +699,7 @@ class OuterProductResultTileTypeConstraint<string operand> :
736
699
737
700
def OuterProductOp :
738
701
ArmSME_Op<"outerproduct", [
702
+ Pure,
739
703
ArmSMETileOpInterface,
740
704
AttrSizedOperandSegments,
741
705
AllTypesMatch<["lhs", "rhs"]>,
@@ -802,12 +766,6 @@ let arguments = (ins
802
766
VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
803
767
VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
804
768
VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
805
- std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
806
- // The outerproduct op allocates a new tile if no accumulator is passed.
807
- if (!getAcc())
808
- return arm_sme::getSMETileType(getResultType());
809
- return std::nullopt;
810
- }
811
769
VectorType getTileType() {
812
770
return getResultType();
813
771
}
@@ -819,6 +777,7 @@ class OuterProductWideningBase<string mnemonic,
819
777
list<Type> allowedResultVectorTypes,
820
778
int numOuterProducts> :
821
779
ArmSME_Op<mnemonic, [
780
+ Pure,
822
781
ArmSMETileOpInterface,
823
782
AttrSizedOperandSegments,
824
783
AllTypesMatch<["lhs", "rhs"]>,
@@ -857,12 +816,6 @@ class OuterProductWideningBase<string mnemonic,
857
816
VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
858
817
VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
859
818
VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
860
- std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
861
- // The outerproduct op allocates a new tile if no accumulator is passed.
862
- if (!getAcc())
863
- return arm_sme::getSMETileType(getResultType());
864
- return std::nullopt;
865
- }
866
819
VectorType getTileType() {
867
820
return getResultType();
868
821
}
0 commit comments