Skip to content

Commit a1b2ace

Browse files
authored
[mlir][ArmSME] Add optional padding and mask operands to tile_load (#69195)
Padding and mask are optional, but if one is specified both must be specified. This is consistent with vector.transfer_read.
1 parent f664326 commit a1b2ace

File tree

3 files changed

+112
-3
lines changed

3 files changed

+112
-3
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,26 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> {
231231
let assemblyFormat = "attr-dict `:` type($res)";
232232
}
233233

234-
def TileLoadOp : ArmSME_Op<"tile_load"> {
234+
def TileLoadOp : ArmSME_Op<"tile_load", [
235+
AttrSizedOperandSegments,
236+
OptionalTypesMatchWith<
237+
"padding type matches element type of result",
238+
"result", "padding",
239+
"::llvm::cast<VectorType>($_self).getElementType()"
240+
>,
241+
OptionalTypesMatchWith<
242+
"mask has i1 element type and same shape as result",
243+
"result", "mask",
244+
"VectorType("
245+
"VectorType::Builder("
246+
"::llvm::cast<mlir::VectorType>($_self)"
247+
").setElementType(IntegerType::get($_self.getContext(), 1)))"
248+
>,
249+
PredOpTrait<
250+
"both `padding` and `mask` should be provided or neither",
251+
CPred<"bool(getPadding()) == bool(getMask())">
252+
>,
253+
]> {
235254
let summary = "Tile load operation";
236255
let description = [{
237256
Loads a 2D SME "virtual tile" from memory defined by a base and indices,
@@ -242,6 +261,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
242261
dimensions, since the operation is scalable, and the element type must be a
243262
scalar that matches the element type of the result.
244263

264+
An optional SSA value `padding` of the same elemental type as the MemRef is
265+
provided to specify a fallback value in the case of masking.
266+
267+
An optional SSA value `mask` may be specified to mask out elements read
268+
from the MemRef. The `mask` type is an `i1` vector with a shape that
269+
matches how elements are read from the MemRef. Elements whose corresponding
270+
mask element is `0` are masked out and replaced with `padding`.
271+
272+
If either `padding` or `mask` are specified, both must be specified.
273+
245274
Example 1: Load an 8-bit element ZA tile with horizontal layout (default) from memory (ZA0.B).
246275
```mlir
247276
%tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
@@ -256,10 +285,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
256285
```mlir
257286
%tile = arm_sme.tile_load %base[%c0, %c0] layout<horizontal> : memref<?x?xi128>, vector<[1]x[1]xi128>
258287
```
288+
289+
Example 4: Masked load of int 32-bit element ZA tile with horizontal layout (default) from memory.
290+
```mlir
291+
%tile = arm_sme.tile_load %base[%c0, %c0], %pad, %mask : memref<?x?xf32>, vector<[4]x[4]xf32>
292+
```
259293
}];
260294
let arguments = (ins
261295
Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
262296
Variadic<Index>:$indices,
297+
Optional<AnyType>:$padding, Optional<AnyVector>:$mask,
263298
ArmSME_TileSliceLayoutAttr:$layout
264299
);
265300
let results = (outs SMETile:$result);
@@ -273,9 +308,20 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
273308
}
274309
}];
275310

311+
let builders = [
312+
OpBuilder<(ins "VectorType":$resultType, "Value":$base,
313+
"ValueRange":$indices, "TileSliceLayout":$layout), [{
314+
build($_builder, $_state, resultType, base, indices, {}, {}, layout);
315+
}]>,
316+
OpBuilder<(ins "VectorType":$resultType, "Value":$base,
317+
"ValueRange":$indices), [{
318+
build($_builder, $_state, resultType, base, indices, {}, {}, {});
319+
}]>,
320+
];
321+
276322
let assemblyFormat =
277-
"$base `[` $indices `]` (`layout` `` $layout^)? attr-dict "
278-
"`:` type($base) `,` type($result)";
323+
"$base `[` $indices `]` (`,` $padding `,` $mask^)? (`layout` `` $layout^)?"
324+
"attr-dict `:` type($base) `,` type($result)";
279325
}
280326

281327
def TileStoreOp : ArmSME_Op<"tile_store"> {

mlir/test/Dialect/ArmSME/invalid.mlir

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
22

3+
//===----------------------------------------------------------------------===//
4+
// arm_sme.cast_tile_to_vector
5+
//===----------------------------------------------------------------------===//
6+
37
// -----
48

59
func.func @arm_sme_cast_tile_to_vector__bad_tile_id_bitwidth(%tile_id : i8) -> vector<[8]x[8]xi16> {
@@ -48,6 +52,10 @@ func.func @arm_sme_cast_tile_to_vector_bad_shape(%tile_id : i8) -> vector<[4]x[1
4852
return %0 : vector<[4]x[16]xi8>
4953
}
5054

55+
//===----------------------------------------------------------------------===//
56+
// arm_sme.cast_vector_to_tile
57+
//===----------------------------------------------------------------------===//
58+
5159
// -----
5260

5361
func.func @arm_sme_cast_vector_to_tile__bad_tile_id_bitwidth(%vector : vector<[1]x[1]xi128>) -> i32 {
@@ -64,6 +72,10 @@ func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -
6472
return %0 : i8
6573
}
6674

75+
//===----------------------------------------------------------------------===//
76+
// arm_sme.get_tile_id
77+
//===----------------------------------------------------------------------===//
78+
6779
// -----
6880

6981
func.func @arm_sme_get_tile_id__bad_type() -> i1 {
@@ -72,6 +84,10 @@ func.func @arm_sme_get_tile_id__bad_type() -> i1 {
7284
return %0 : i1
7385
}
7486

87+
//===----------------------------------------------------------------------===//
88+
// arm_sme.move_vector_to_tile_slice
89+
//===----------------------------------------------------------------------===//
90+
7591
// -----
7692

7793
func.func @arm_sme_move_vector_to_tile_slice_i8__bad_vector_type(%vector : vector<[8]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]x[16]xi8> {
@@ -90,10 +106,47 @@ func.func @arm_sme_move_vector_to_tile_slice_f32__bad_vector_type(%vector : vect
90106
return %0 : vector<[4]x[4]xf32>
91107
}
92108

109+
//===----------------------------------------------------------------------===//
110+
// arm_sme.move_tile_slice_to_vector
111+
//===----------------------------------------------------------------------===//
112+
93113
// -----
94114

95115
func.func @arm_sme_move_tile_slice_to_vector__bad_result_type(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[2]xf64> {
96116
// expected-error@+1 {{op failed to verify that type of 'result' matches type of 'tile' slice}}
97117
%0 = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[4]x[4]xf32>
98118
return %0 : vector<[2]xf64>
99119
}
120+
121+
//===----------------------------------------------------------------------===//
122+
// arm_sme.tile_load
123+
//===----------------------------------------------------------------------===//
124+
125+
// -----
126+
127+
func.func @arm_sme_tile_load__bad_padding_type(%src : memref<?x?xf64>, %pad : f32, %mask : vector<[2]x[2]xi1>) {
128+
%c0 = arith.constant 0 : index
129+
// expected-note@-2 {{prior use here}}
130+
// expected-error@+1 {{use of value '%pad' expects different type than prior uses: 'f64' vs 'f32'}}
131+
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xf64>, vector<[2]x[2]xf64>
132+
return
133+
}
134+
135+
// -----
136+
137+
func.func @arm_sme_tile_load__bad_mask_type(%src : memref<?x?xf64>, %pad : f64, %mask : vector<[4]x[4]xi1>) {
138+
%c0 = arith.constant 0 : index
139+
// expected-note@-2 {{prior use here}}
140+
// expected-error@+1 {{use of value '%mask' expects different type than prior uses: 'vector<[2]x[2]xi1>' vs 'vector<[4]x[4]xi1>}}
141+
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xf64>, vector<[2]x[2]xf64>
142+
return
143+
}
144+
145+
// -----
146+
147+
func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref<?x?xf64>, %pad : f64) {
148+
%c0 = arith.constant 0 : index
149+
// expected-error@+1 {{op failed to verify that both `padding` and `mask` should be provided or neither}}
150+
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, : memref<?x?xf64>, vector<[2]x[2]xf64>
151+
return
152+
}

mlir/test/Dialect/ArmSME/roundtrip.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,16 @@ func.func @arm_sme_tile_load_ver_f64(%src : memref<?x?xf64>) {
438438

439439
// -----
440440

441+
/// Padding and mask are optional
442+
func.func @arm_sme_tile_load_hor_pad_f64(%src : memref<?x?xf64>, %pad : f64, %mask : vector<[2]x[2]xi1>) {
443+
// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}], {{.*}}, {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
444+
%c0 = arith.constant 0 : index
445+
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xf64>, vector<[2]x[2]xf64>
446+
return
447+
}
448+
449+
// -----
450+
441451
/// Layout is optional and horizontal is the default, verify it's still parsed.
442452
func.func @arm_sme_tile_load_explicit_hor(%src : memref<?x?xi8>) {
443453
// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>

0 commit comments

Comments
 (0)