Skip to content

[mlir][ArmSME] Add optional padding and mask operands to tile_load #69195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 30, 2023

Conversation

c-rhodes
Copy link
Collaborator

Padding and mask are optional, but if one is specified both must be specified. This is consistent with vector.transfer_read.

@llvmbot
Copy link
Member

llvmbot commented Oct 16, 2023

@llvm/pr-subscribers-mlir-sme

@llvm/pr-subscribers-mlir

Author: Cullen Rhodes (c-rhodes)

Changes

Padding and mask are optional, but if one is specified both must be specified. This is consistent with vector.transfer_read.


Full diff: https://github.com/llvm/llvm-project/pull/69195.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td (+47-3)
  • (modified) mlir/test/Dialect/ArmSME/invalid.mlir (+44)
  • (modified) mlir/test/Dialect/ArmSME/roundtrip.mlir (+10)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index dab54b63d8d22be..6f6b54aad0058e5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -231,7 +231,24 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> {
   let assemblyFormat = "attr-dict `:` type($res)";
 }
 
-def TileLoadOp : ArmSME_Op<"tile_load"> {
+def TileLoadOp : ArmSME_Op<"tile_load", [
+  AttrSizedOperandSegments,
+  TypesMatchWith<
+    "padding type matches element type of result (if present)",
+    "result", "padding",
+    "::llvm::cast<VectorType>($_self).getElementType()",
+    "!getPadding() || std::equal_to<>()"
+  >,
+  TypesMatchWith<
+    "mask has i1 element type and same shape as result (if present)",
+    "result", "mask",
+    "VectorType("
+      "VectorType::Builder("
+        "::llvm::cast<mlir::VectorType>($_self)"
+      ").setElementType(IntegerType::get($_self.getContext(), 1)))",
+    "!getMask() || std::equal_to<>()"
+  >
+]> {
   let summary = "Tile load operation";
   let description = [{
     Loads a 2D SME "virtual tile" from memory defined by a base and indices,
@@ -242,6 +259,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
     dimensions, since the operation is scalable, and the element type must be a
     scalar that matches the element type of the result.
 
+    An optional SSA value `padding` of the same elemental type as the MemRef is
+    provided to specify a fallback value in the case of masking.
+
+    An optional SSA value `mask` may be specified to mask out elements read
+    from the MemRef. The `mask` type is an `i1` vector with a shape that
+    matches how elements are read from the MemRef. Elements whose corresponding
+    mask element is `0` are masked out and replaced with `padding`.
+
+    If either `padding` or `mask` are specified, both must be specified.
+
     Example 1: Load an 8-bit element ZA tile with horizontal layout (default) from memory (ZA0.B).
     ```mlir
     %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
@@ -256,10 +283,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
     ```mlir
     %tile = arm_sme.tile_load %base[%c0, %c0] layout<horizontal> : memref<?x?xi128>, vector<[1]x[1]xi128>
     ```
+
+    Example 4: Masked load of int 32-bit element ZA tile with horizontal layout (default) from memory.
+    ```mlir
+    %tile = arm_sme.tile_load %base[%c0, %c0], %pad, %mask : memref<?x?xf32>, vector<[4]x[4]xf32>
+    ```
   }];
   let arguments = (ins
     Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
     Variadic<Index>:$indices,
+    Optional<AnyType>:$padding, Optional<AnyVector>:$mask,
     ArmSME_TileSliceLayoutAttr:$layout
   );
   let results = (outs SMETile:$result);
@@ -273,9 +306,20 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
     }
   }];
 
+  let builders = [
+    OpBuilder<(ins "VectorType":$resultType, "Value":$base,
+                   "ValueRange":$indices, "TileSliceLayout":$layout), [{
+      build($_builder, $_state, resultType, base, indices, {}, {}, layout);
+    }]>,
+    OpBuilder<(ins "VectorType":$resultType, "Value":$base,
+                   "ValueRange":$indices), [{
+      build($_builder, $_state, resultType, base, indices, {}, {}, {});
+    }]>,
+  ];
+
   let assemblyFormat =
-    "$base `[` $indices `]` (`layout` `` $layout^)? attr-dict "
-      "`:` type($base) `,` type($result)";
+    "$base `[` $indices `]` (`,` $padding `,` $mask^)? (`layout` `` $layout^)?"
+      "attr-dict `:` type($base) `,` type($result)";
 }
 
 def TileStoreOp : ArmSME_Op<"tile_store"> {
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 431009b1b9ede2f..9229f0415c076c3 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -1,5 +1,9 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics
 
+//===----------------------------------------------------------------------===//
+// arm_sme.cast_tile_to_vector
+//===----------------------------------------------------------------------===//
+
 // -----
 
 func.func @arm_sme_cast_tile_to_vector__bad_tile_id_bitwidth(%tile_id : i8) -> vector<[8]x[8]xi16> {
@@ -48,6 +52,10 @@ func.func @arm_sme_cast_tile_to_vector_bad_shape(%tile_id : i8) -> vector<[4]x[1
   return %0 : vector<[4]x[16]xi8>
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.cast_vector_to_tile
+//===----------------------------------------------------------------------===//
+
 // -----
 
 func.func @arm_sme_cast_vector_to_tile__bad_tile_id_bitwidth(%vector : vector<[1]x[1]xi128>) -> i32 {
@@ -64,6 +72,10 @@ func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -
   return %0 : i8
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.get_tile_id
+//===----------------------------------------------------------------------===//
+
 // -----
 
 func.func @arm_sme_get_tile_id__bad_type() -> i1 {
@@ -72,6 +84,10 @@ func.func @arm_sme_get_tile_id__bad_type() -> i1 {
   return %0 : i1
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.move_vector_to_tile_slice
+//===----------------------------------------------------------------------===//
+
 // -----
 
 func.func @arm_sme_move_vector_to_tile_slice_i8__bad_vector_type(%vector : vector<[8]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]x[16]xi8> {
@@ -90,6 +106,10 @@ func.func @arm_sme_move_vector_to_tile_slice_f32__bad_vector_type(%vector : vect
   return %0 : vector<[4]x[4]xf32>
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.move_tile_slice_to_vector
+//===----------------------------------------------------------------------===//
+
 // -----
 
 func.func @arm_sme_move_tile_slice_to_vector__bad_result_type(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[2]xf64> {
@@ -97,3 +117,27 @@ func.func @arm_sme_move_tile_slice_to_vector__bad_result_type(%tile : vector<[4]
   %0 = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[4]x[4]xf32>
   return %0 : vector<[2]xf64>
 }
+
+//===----------------------------------------------------------------------===//
+// arm_sme.tile_load
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_tile_load__bad_padding_type(%src : memref<?x?xf64>, %pad : f32, %mask : vector<[2]x[2]xi1>) {
+  %c0 = arith.constant 0 : index
+  // expected-note@-2 {{prior use here}}
+  // expected-error@+1 {{use of value '%pad' expects different type than prior uses: 'f64' vs 'f32'}}
+  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xf64>, vector<[2]x[2]xf64>
+  return
+}
+
+// -----
+
+func.func @arm_sme_tile_load__bad_mask_type(%src : memref<?x?xf64>, %pad : f64, %mask : vector<[4]x[4]xi1>) {
+  %c0 = arith.constant 0 : index
+  // expected-note@-2 {{prior use here}}
+  // expected-error@+1 {{use of value '%mask' expects different type than prior uses: 'vector<[2]x[2]xi1>' vs 'vector<[4]x[4]xi1>}}
+  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xf64>, vector<[2]x[2]xf64>
+  return
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 427154158e797fd..f6459f085843655 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -438,6 +438,16 @@ func.func @arm_sme_tile_load_ver_f64(%src : memref<?x?xf64>) {
 
 // -----
 
+/// Padding and mask are optional
+func.func @arm_sme_tile_load_hor_pad_f64(%src : memref<?x?xf64>, %pad : f64, %mask : vector<[2]x[2]xi1>) {
+  // CHECK: arm_sme.tile_load %{{.*}}[{{.*}}], {{.*}}, {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+  %c0 = arith.constant 0 : index
+  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xf64>, vector<[2]x[2]xf64>
+  return
+}
+
+// -----
+
 /// Layout is optional and horizontal is the default, verify it's still parsed.
 func.func @arm_sme_tile_load_explicit_hor(%src : memref<?x?xi8>) {
   // CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>

@c-rhodes c-rhodes requested a review from dcaballe October 16, 2023 12:53
MacDue added a commit that referenced this pull request Oct 19, 2023
…Format (#68876)

This is just a slight specialization of `TypesMatchWith` that returns
success if an optional parameter is missing.

There may be other places this could help e.g.:

https://github.com/llvm/llvm-project/blob/eb21049b4b904b072679ece60e73c6b0dc0d1ebf/mlir/include/mlir/Dialect/X86Vector/X86Vector.td#L58-L59
...but I'm leaving those to avoid some churn.

This constraint will be handy for us in some later patches, it's a
formalization of a short circuiting trick with the `comparator` of the
`TypesMatchWith` constraint (devised for #69195).

```
TypesMatchWith<
  "padding type matches element type of result (if present)",
  "result", "padding",
  "::llvm::cast<VectorType>($_self).getElementType()",
  // This returns true if no padding is present, or it's present with a type that matches the element type of `result`.
  "!getPadding() || std::equal_to<>()">
```

This is a little non-obvious, so after this patch you can instead do:
```
OptionalTypesMatchWith<
  "padding type matches element type of result (if present)",
  "result", "padding",
  "::llvm::cast<VectorType>($_self).getElementType()">
```
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, ta!

The extra test for the verifier (to make sure that neither mask nor pad can be used without the other) could be added separately.

Padding and mask are optional, but if one is specified both must be
specified. This is consistent with vector.transfer_read.
* Use OptionalTypesMatchWith
* Add constraint to verify both padding and mask are specified, as well
  as test.
@c-rhodes c-rhodes force-pushed the mlir-arm-sme-tile-load-pad-and-mask branch from f4b89ab to dbc1c02 Compare October 30, 2023 11:26
@c-rhodes
Copy link
Collaborator Author

LGTM, ta!

The extra test for the verifier (to make sure that neither mask nor pad can be used without the other) could be added separately.

Thanks for reviewing. I've added the test.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants