Skip to content

Commit 1908f47

Browse files
authored
[mlir][ArmSME] Add optional mask operand to tile_store (#70657)
1 parent 2862d17 commit 1908f47

File tree

5 files changed

+71
-22
lines changed

5 files changed

+71
-22
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ def TileElementWidthMatchesTileID : TypesMatchWith<
6060
"::llvm::cast<VectorType>($_self).getElementType())"
6161
".getWidth())">;
6262

63+
class HasMatchingMaskTypeConstraint<string vector, string mask> :
64+
OptionalTypesMatchWith<
65+
mask # " has i1 element type and same shape as " # vector,
66+
vector, mask,
67+
"::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))">;
68+
6369
//===----------------------------------------------------------------------===//
6470
// ArmSME attr definitions
6571
//===----------------------------------------------------------------------===//
@@ -259,14 +265,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
259265
"result", "padding",
260266
"::llvm::cast<VectorType>($_self).getElementType()"
261267
>,
262-
OptionalTypesMatchWith<
263-
"mask has i1 element type and same shape as result",
264-
"result", "mask",
265-
"VectorType("
266-
"VectorType::Builder("
267-
"::llvm::cast<mlir::VectorType>($_self)"
268-
").setElementType(IntegerType::get($_self.getContext(), 1)))"
269-
>,
268+
HasMatchingMaskTypeConstraint<"result", "mask">,
270269
PredOpTrait<
271270
"both `padding` and `mask` should be provided or neither",
272271
CPred<"bool(getPadding()) == bool(getMask())">
@@ -345,7 +344,10 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
345344
"attr-dict `:` type($base) `,` type($result)";
346345
}
347346

348-
def TileStoreOp : ArmSME_Op<"tile_store"> {
347+
def TileStoreOp : ArmSME_Op<"tile_store", [
348+
AttrSizedOperandSegments,
349+
HasMatchingMaskTypeConstraint<"valueToStore", "mask">,
350+
]> {
349351
let summary = "Tile store operation";
350352
let description = [{
351353
Stores a 2D SME "virtual tile" to memory defined by a base and indices,
@@ -356,6 +358,9 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
356358
rank 2 with dynamic dimensions, since the operation is scalable, and the
357359
element type must be a scalar that matches the element type of the result.
358360

361+
An optional `mask` may be provided, the shape of which corresponds to the
362+
`tile`, and selects which elements of the tile will be stored.
363+
359364
Example 1: Store an 8-bit element ZA tile with horizontal (default) layout to memory (ZA0.B).
360365
```mlir
361366
arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
@@ -370,10 +375,16 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
370375
```mlir
371376
arm_sme.tile_store %tile, %base[%c0, %c0] layout<horizontal> : vector<[1]x[1]xi128>, memref<?x?xi128>
372377
```
378+
379+
Example 4: Masked store a int 32-bit element ZA tile with vertical layout to memory.
380+
```mlir
381+
arm_sme.tile_store %tile, %base[%c0, %c0], %mask layout<vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
382+
```
373383
}];
374384
let arguments = (ins SMETile:$valueToStore,
375385
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
376-
Variadic<Index>:$indices, ArmSME_TileSliceLayoutAttr:$layout
386+
Variadic<Index>:$indices, Optional<AnyVector>:$mask,
387+
ArmSME_TileSliceLayoutAttr:$layout
377388
);
378389
let extraClassDeclaration = [{
379390
MemRefType getMemRefType() {
@@ -384,9 +395,16 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
384395
}
385396
}];
386397

398+
let builders = [
399+
OpBuilder<(ins "Value":$valueToStore, "Value":$base,
400+
"ValueRange":$indices), [{
401+
build($_builder, $_state, valueToStore, base, indices, {});
402+
}]>,
403+
];
404+
387405
let assemblyFormat =
388-
"$valueToStore `,` $base `[` $indices `]` (`layout` `` $layout^)? attr-dict "
389-
"`:` type($base) `,` type($valueToStore)";
406+
"$valueToStore `,` $base `[` $indices `]` (`,` $mask^)? (`layout` `` $layout^)?"
407+
"attr-dict `:` type($base) `,` type($valueToStore)";
390408
}
391409

392410
def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
@@ -595,12 +613,6 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
595613
}];
596614
}
597615

598-
class HasMatchingMaskTypeConstraint<string operand> :
599-
OptionalTypesMatchWith<
600-
"shape of `" # operand # "Mask` matches `" # operand # "`",
601-
operand, operand # "Mask",
602-
"::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))">;
603-
604616
class OuterProductResultTileTypeConstraint<string operand> :
605617
OptionalTypesMatchWith<operand # "type is derived from `lhs` and `rhs`",
606618
"lhs", operand,
@@ -615,8 +627,8 @@ def OuterProductOp :
615627
ArmSME_Op<"outerproduct", [Pure,
616628
AttrSizedOperandSegments,
617629
AllTypesMatch<["lhs", "rhs"]>,
618-
HasMatchingMaskTypeConstraint<"lhs">,
619-
HasMatchingMaskTypeConstraint<"rhs">,
630+
HasMatchingMaskTypeConstraint<"lhs", "lhsMask">,
631+
HasMatchingMaskTypeConstraint<"rhs", "rhsMask">,
620632
PredOpTrait<
621633
"both `lhsMask` and `rhsMask` should be provided or neither",
622634
CPred<"bool(getLhsMask()) == bool(getRhsMask())">>,

mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ struct TransferWriteToArmSMELowering
144144
return failure();
145145

146146
rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
147-
writeOp, writeOp.getVector(), writeOp.getSource(),
148-
writeOp.getIndices());
147+
writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
148+
writeOp.getMask());
149149
return success();
150150
}
151151
};

mlir/test/Dialect/ArmSME/invalid.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,20 @@ func.func @arm_sme_load_tile_slice__bad_mask_type(%src : memref<?x?xi8>, %mask :
164164
return
165165
}
166166

167+
//===----------------------------------------------------------------------===//
168+
// arm_sme.tile_store
169+
//===----------------------------------------------------------------------===//
170+
171+
// -----
172+
173+
func.func @arm_sme_tile_store__bad_mask_type(%tile : vector<[16]x[16]xi8>, %mask : vector<[1]x[1]xi1>, %dest : memref<?x?xi8>) {
174+
%c0 = arith.constant 0 : index
175+
// expected-note@-2 {{prior use here}}
176+
// expected-error@+1 {{use of value '%mask' expects different type than prior uses: 'vector<[16]x[16]xi1>' vs 'vector<[1]x[1]xi1>}}
177+
arm_sme.tile_store %tile, %dest[%c0, %c0], %mask : memref<?x?xi8>, vector<[16]x[16]xi8>
178+
return
179+
}
180+
167181
//===----------------------------------------------------------------------===//
168182
// arm_sme.outerproduct
169183
//===----------------------------------------------------------------------===//

mlir/test/Dialect/ArmSME/roundtrip.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,15 @@ func.func @arm_sme_tile_store_ver_f64(%tile : vector<[2]x[2]xf64>, %dest : memre
624624

625625
// -----
626626

627+
func.func @arm_sme_tile_store_with_mask_ver_f32(%tile : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %mask : vector<[4]x[4]xi1>) {
628+
// CHECK: arm_sme.tile_store {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
629+
%c0 = arith.constant 0 : index
630+
arm_sme.tile_store %tile, %dest[%c0, %c0], %mask layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
631+
return
632+
}
633+
634+
// -----
635+
627636
/// Layout is optional and horizontal is the default, verify it's still parsed.
628637
func.func @arm_sme_tile_store_ver_i8(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
629638
// CHECK: arm_sme.tile_store %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>

mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,20 @@ func.func @transfer_write_2d_f64(%vector : vector<[2]x[2]xf64>, %dest : memref<?
315315

316316
// -----
317317

318+
// CHECK-LABEL: func.func @transfer_write_2d_with_mask_f64(
319+
// CHECK-SAME: %[[VECTOR:.*]]: vector<[2]x[2]xf64>,
320+
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf64>,
321+
// CHECK-SAME: %[[MASK:.*]]: vector<[2]x[2]xi1>) {
322+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
323+
// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]], %[[MASK]] : memref<?x?xf64>, vector<[2]x[2]xf64>
324+
func.func @transfer_write_2d_with_mask_f64(%vector : vector<[2]x[2]xf64>, %dest : memref<?x?xf64>, %mask : vector<[2]x[2]xi1>) {
325+
%c0 = arith.constant 0 : index
326+
vector.transfer_write %vector, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[2]x[2]xf64>, memref<?x?xf64>
327+
return
328+
}
329+
330+
// -----
331+
318332
// The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero'
319333
// lowering only occurs for vector types of correct rank, shape, element size
320334
// and number of scalable dims.

0 commit comments

Comments
 (0)