@@ -231,7 +231,26 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> {
231
231
let assemblyFormat = "attr-dict `:` type($res)";
232
232
}
233
233
234
- def TileLoadOp : ArmSME_Op<"tile_load"> {
234
+ def TileLoadOp : ArmSME_Op<"tile_load", [
235
+ AttrSizedOperandSegments,
236
+ OptionalTypesMatchWith<
237
+ "padding type matches element type of result",
238
+ "result", "padding",
239
+ "::llvm::cast<VectorType>($_self).getElementType()"
240
+ >,
241
+ OptionalTypesMatchWith<
242
+ "mask has i1 element type and same shape as result",
243
+ "result", "mask",
244
+ "VectorType("
245
+ "VectorType::Builder("
246
+ "::llvm::cast<mlir::VectorType>($_self)"
247
+ ").setElementType(IntegerType::get($_self.getContext(), 1)))"
248
+ >,
249
+ PredOpTrait<
250
+ "both `padding` and `mask` should be provided or neither",
251
+ CPred<"bool(getPadding()) == bool(getMask())">
252
+ >,
253
+ ]> {
235
254
let summary = "Tile load operation";
236
255
let description = [{
237
256
Loads a 2D SME "virtual tile" from memory defined by a base and indices,
@@ -242,6 +261,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
242
261
dimensions, since the operation is scalable, and the element type must be a
243
262
scalar that matches the element type of the result.
244
263
264
+ An optional SSA value `padding` of the same elemental type as the MemRef is
265
+ provided to specify a fallback value in the case of masking.
266
+
267
+ An optional SSA value `mask` may be specified to mask out elements read
268
+ from the MemRef. The `mask` type is an `i1` vector with a shape that
269
+ matches how elements are read from the MemRef. Elements whose corresponding
270
+ mask element is `0` are masked out and replaced with `padding`.
271
+
272
+ If either `padding` or `mask` are specified, both must be specified.
273
+
245
274
Example 1: Load an 8-bit element ZA tile with horizontal layout (default) from memory (ZA0.B).
246
275
```mlir
247
276
%tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
@@ -256,10 +285,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
256
285
```mlir
257
286
%tile = arm_sme.tile_load %base[%c0, %c0] layout<horizontal> : memref<?x?xi128>, vector<[1]x[1]xi128>
258
287
```
288
+
289
+ Example 4: Masked load of int 32-bit element ZA tile with horizontal layout (default) from memory.
290
+ ```mlir
291
+ %tile = arm_sme.tile_load %base[%c0, %c0], %pad, %mask : memref<?x?xf32>, vector<[4]x[4]xf32>
292
+ ```
259
293
}];
260
294
let arguments = (ins
261
295
Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
262
296
Variadic<Index>:$indices,
297
+ Optional<AnyType>:$padding, Optional<AnyVector>:$mask,
263
298
ArmSME_TileSliceLayoutAttr:$layout
264
299
);
265
300
let results = (outs SMETile:$result);
@@ -273,9 +308,20 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
273
308
}
274
309
}];
275
310
311
+ let builders = [
312
+ OpBuilder<(ins "VectorType":$resultType, "Value":$base,
313
+ "ValueRange":$indices, "TileSliceLayout":$layout), [{
314
+ build($_builder, $_state, resultType, base, indices, {}, {}, layout);
315
+ }]>,
316
+ OpBuilder<(ins "VectorType":$resultType, "Value":$base,
317
+ "ValueRange":$indices), [{
318
+ build($_builder, $_state, resultType, base, indices, {}, {}, {});
319
+ }]>,
320
+ ];
321
+
276
322
let assemblyFormat =
277
- "$base `[` $indices `]` (`layout` ` ` $layout ^)? attr-dict "
278
- "`:` type($base) `,` type($result)";
323
+ "$base `[` $indices `]` (`,` $padding `, ` $mask ^)? (`layout` `` $layout^)? "
324
+ "attr-dict `:` type($base) `,` type($result)";
279
325
}
280
326
281
327
def TileStoreOp : ArmSME_Op<"tile_store"> {
0 commit comments