@@ -60,6 +60,12 @@ def TileElementWidthMatchesTileID : TypesMatchWith<
60
60
"::llvm::cast<VectorType>($_self).getElementType())"
61
61
".getWidth())">;
62
62
63
+ class HasMatchingMaskTypeConstraint<string vector, string mask> :
64
+ OptionalTypesMatchWith<
65
+ mask # " has i1 element type and same shape as " # vector,
66
+ vector, mask,
67
+ "::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))">;
68
+
63
69
//===----------------------------------------------------------------------===//
64
70
// ArmSME attr definitions
65
71
//===----------------------------------------------------------------------===//
@@ -259,14 +265,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
259
265
"result", "padding",
260
266
"::llvm::cast<VectorType>($_self).getElementType()"
261
267
>,
262
- OptionalTypesMatchWith<
263
- "mask has i1 element type and same shape as result",
264
- "result", "mask",
265
- "VectorType("
266
- "VectorType::Builder("
267
- "::llvm::cast<mlir::VectorType>($_self)"
268
- ").setElementType(IntegerType::get($_self.getContext(), 1)))"
269
- >,
268
+ HasMatchingMaskTypeConstraint<"result", "mask">,
270
269
PredOpTrait<
271
270
"both `padding` and `mask` should be provided or neither",
272
271
CPred<"bool(getPadding()) == bool(getMask())">
@@ -345,7 +344,10 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
345
344
"attr-dict `:` type($base) `,` type($result)";
346
345
}
347
346
348
- def TileStoreOp : ArmSME_Op<"tile_store"> {
347
+ def TileStoreOp : ArmSME_Op<"tile_store", [
348
+ AttrSizedOperandSegments,
349
+ HasMatchingMaskTypeConstraint<"valueToStore", "mask">,
350
+ ]> {
349
351
let summary = "Tile store operation";
350
352
let description = [{
351
353
Stores a 2D SME "virtual tile" to memory defined by a base and indices,
@@ -356,6 +358,9 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
356
358
rank 2 with dynamic dimensions, since the operation is scalable, and the
357
359
element type must be a scalar that matches the element type of the result.
358
360
361
+ An optional `mask` may be provided, the shape of which corresponds to the
362
+ `tile`, and selects which elements of the tile will be stored.
363
+
359
364
Example 1: Store an 8-bit element ZA tile with horizontal (default) layout to memory (ZA0.B).
360
365
```mlir
361
366
arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
@@ -370,10 +375,16 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
370
375
```mlir
371
376
arm_sme.tile_store %tile, %base[%c0, %c0] layout<horizontal> : vector<[1]x[1]xi128>, memref<?x?xi128>
372
377
```
378
+
379
+ Example 4: Masked store a int 32-bit element ZA tile with vertical layout to memory.
380
+ ```mlir
381
+ arm_sme.tile_store %tile, %base[%c0, %c0], %mask layout<vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
382
+ ```
373
383
}];
374
384
let arguments = (ins SMETile:$valueToStore,
375
385
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
376
- Variadic<Index>:$indices, ArmSME_TileSliceLayoutAttr:$layout
386
+ Variadic<Index>:$indices, Optional<AnyVector>:$mask,
387
+ ArmSME_TileSliceLayoutAttr:$layout
377
388
);
378
389
let extraClassDeclaration = [{
379
390
MemRefType getMemRefType() {
@@ -384,9 +395,16 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
384
395
}
385
396
}];
386
397
398
+ let builders = [
399
+ OpBuilder<(ins "Value":$valueToStore, "Value":$base,
400
+ "ValueRange":$indices), [{
401
+ build($_builder, $_state, valueToStore, base, indices, {});
402
+ }]>,
403
+ ];
404
+
387
405
let assemblyFormat =
388
- "$valueToStore `,` $base `[` $indices `]` (`layout` `` $layout^)? attr-dict "
389
- "`:` type($base) `,` type($valueToStore)";
406
+ "$valueToStore `,` $base `[` $indices `]` (`,` $mask^)? (` layout` `` $layout^)?"
407
+ "attr-dict `:` type($base) `,` type($valueToStore)";
390
408
}
391
409
392
410
def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
@@ -595,12 +613,6 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
595
613
}];
596
614
}
597
615
598
- class HasMatchingMaskTypeConstraint<string operand> :
599
- OptionalTypesMatchWith<
600
- "shape of `" # operand # "Mask` matches `" # operand # "`",
601
- operand, operand # "Mask",
602
- "::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))">;
603
-
604
616
class OuterProductResultTileTypeConstraint<string operand> :
605
617
OptionalTypesMatchWith<operand # "type is derived from `lhs` and `rhs`",
606
618
"lhs", operand,
@@ -615,8 +627,8 @@ def OuterProductOp :
615
627
ArmSME_Op<"outerproduct", [Pure,
616
628
AttrSizedOperandSegments,
617
629
AllTypesMatch<["lhs", "rhs"]>,
618
- HasMatchingMaskTypeConstraint<"lhs">,
619
- HasMatchingMaskTypeConstraint<"rhs">,
630
+ HasMatchingMaskTypeConstraint<"lhs", "lhsMask" >,
631
+ HasMatchingMaskTypeConstraint<"rhs", "rhsMask" >,
620
632
PredOpTrait<
621
633
"both `lhsMask` and `rhsMask` should be provided or neither",
622
634
CPred<"bool(getLhsMask()) == bool(getRhsMask())">>,
0 commit comments