-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][ArmSME] Add optional padding and mask operands to tile_load #69195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ArmSME] Add optional padding and mask operands to tile_load #69195
Conversation
@llvm/pr-subscribers-mlir-sme @llvm/pr-subscribers-mlir Author: Cullen Rhodes (c-rhodes) ChangesPadding and mask are optional, but if one is specified both must be specified. This is consistent with vector.transfer_read. Full diff: https://github.com/llvm/llvm-project/pull/69195.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index dab54b63d8d22be..6f6b54aad0058e5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -231,7 +231,24 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> {
let assemblyFormat = "attr-dict `:` type($res)";
}
-def TileLoadOp : ArmSME_Op<"tile_load"> {
+def TileLoadOp : ArmSME_Op<"tile_load", [
+ AttrSizedOperandSegments,
+ TypesMatchWith<
+ "padding type matches element type of result (if present)",
+ "result", "padding",
+ "::llvm::cast<VectorType>($_self).getElementType()",
+ "!getPadding() || std::equal_to<>()"
+ >,
+ TypesMatchWith<
+ "mask has i1 element type and same shape as result (if present)",
+ "result", "mask",
+ "VectorType("
+ "VectorType::Builder("
+ "::llvm::cast<mlir::VectorType>($_self)"
+ ").setElementType(IntegerType::get($_self.getContext(), 1)))",
+ "!getMask() || std::equal_to<>()"
+ >
+]> {
let summary = "Tile load operation";
let description = [{
Loads a 2D SME "virtual tile" from memory defined by a base and indices,
@@ -242,6 +259,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
dimensions, since the operation is scalable, and the element type must be a
scalar that matches the element type of the result.
+ An optional SSA value `padding` of the same elemental type as the MemRef is
+ provided to specify a fallback value in the case of masking.
+
+ An optional SSA value `mask` may be specified to mask out elements read
+ from the MemRef. The `mask` type is an `i1` vector with a shape that
+ matches how elements are read from the MemRef. Elements whose corresponding
+ mask element is `0` are masked out and replaced with `padding`.
+
+ If either `padding` or `mask` are specified, both must be specified.
+
Example 1: Load an 8-bit element ZA tile with horizontal layout (default) from memory (ZA0.B).
```mlir
%tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
@@ -256,10 +283,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
```mlir
%tile = arm_sme.tile_load %base[%c0, %c0] layout<horizontal> : memref<?x?xi128>, vector<[1]x[1]xi128>
```
+
+ Example 4: Masked load of int 32-bit element ZA tile with horizontal layout (default) from memory.
+ ```mlir
+ %tile = arm_sme.tile_load %base[%c0, %c0], %pad, %mask : memref<?x?xf32>, vector<[4]x[4]xf32>
+ ```
}];
let arguments = (ins
Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
Variadic<Index>:$indices,
+ Optional<AnyType>:$padding, Optional<AnyVector>:$mask,
ArmSME_TileSliceLayoutAttr:$layout
);
let results = (outs SMETile:$result);
@@ -273,9 +306,20 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
}
}];
+ let builders = [
+ OpBuilder<(ins "VectorType":$resultType, "Value":$base,
+ "ValueRange":$indices, "TileSliceLayout":$layout), [{
+ build($_builder, $_state, resultType, base, indices, {}, {}, layout);
+ }]>,
+ OpBuilder<(ins "VectorType":$resultType, "Value":$base,
+ "ValueRange":$indices), [{
+ build($_builder, $_state, resultType, base, indices, {}, {}, {});
+ }]>,
+ ];
+
let assemblyFormat =
- "$base `[` $indices `]` (`layout` `` $layout^)? attr-dict "
- "`:` type($base) `,` type($result)";
+ "$base `[` $indices `]` (`,` $padding `,` $mask^)? (`layout` `` $layout^)?"
+ "attr-dict `:` type($base) `,` type($result)";
}
def TileStoreOp : ArmSME_Op<"tile_store"> {
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 431009b1b9ede2f..9229f0415c076c3 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -1,5 +1,9 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+//===----------------------------------------------------------------------===//
+// arm_sme.cast_tile_to_vector
+//===----------------------------------------------------------------------===//
+
// -----
func.func @arm_sme_cast_tile_to_vector__bad_tile_id_bitwidth(%tile_id : i8) -> vector<[8]x[8]xi16> {
@@ -48,6 +52,10 @@ func.func @arm_sme_cast_tile_to_vector_bad_shape(%tile_id : i8) -> vector<[4]x[1
return %0 : vector<[4]x[16]xi8>
}
+//===----------------------------------------------------------------------===//
+// arm_sme.cast_vector_to_tile
+//===----------------------------------------------------------------------===//
+
// -----
func.func @arm_sme_cast_vector_to_tile__bad_tile_id_bitwidth(%vector : vector<[1]x[1]xi128>) -> i32 {
@@ -64,6 +72,10 @@ func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -
return %0 : i8
}
+//===----------------------------------------------------------------------===//
+// arm_sme.get_tile_id
+//===----------------------------------------------------------------------===//
+
// -----
func.func @arm_sme_get_tile_id__bad_type() -> i1 {
@@ -72,6 +84,10 @@ func.func @arm_sme_get_tile_id__bad_type() -> i1 {
return %0 : i1
}
+//===----------------------------------------------------------------------===//
+// arm_sme.move_vector_to_tile_slice
+//===----------------------------------------------------------------------===//
+
// -----
func.func @arm_sme_move_vector_to_tile_slice_i8__bad_vector_type(%vector : vector<[8]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]x[16]xi8> {
@@ -90,6 +106,10 @@ func.func @arm_sme_move_vector_to_tile_slice_f32__bad_vector_type(%vector : vect
return %0 : vector<[4]x[4]xf32>
}
+//===----------------------------------------------------------------------===//
+// arm_sme.move_tile_slice_to_vector
+//===----------------------------------------------------------------------===//
+
// -----
func.func @arm_sme_move_tile_slice_to_vector__bad_result_type(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[2]xf64> {
@@ -97,3 +117,27 @@ func.func @arm_sme_move_tile_slice_to_vector__bad_result_type(%tile : vector<[4]
%0 = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[4]x[4]xf32>
return %0 : vector<[2]xf64>
}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.tile_load
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_tile_load__bad_padding_type(%src : memref<?x?xf64>, %pad : f32, %mask : vector<[2]x[2]xi1>) {
+ %c0 = arith.constant 0 : index
+ // expected-note@-2 {{prior use here}}
+ // expected-error@+1 {{use of value '%pad' expects different type than prior uses: 'f64' vs 'f32'}}
+ %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xf64>, vector<[2]x[2]xf64>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_load__bad_mask_type(%src : memref<?x?xf64>, %pad : f64, %mask : vector<[4]x[4]xi1>) {
+ %c0 = arith.constant 0 : index
+ // expected-note@-2 {{prior use here}}
+ // expected-error@+1 {{use of value '%mask' expects different type than prior uses: 'vector<[2]x[2]xi1>' vs 'vector<[4]x[4]xi1>}}
+ %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xf64>, vector<[2]x[2]xf64>
+ return
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 427154158e797fd..f6459f085843655 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -438,6 +438,16 @@ func.func @arm_sme_tile_load_ver_f64(%src : memref<?x?xf64>) {
// -----
+/// Padding and mask are optional
+func.func @arm_sme_tile_load_hor_pad_f64(%src : memref<?x?xf64>, %pad : f64, %mask : vector<[2]x[2]xi1>) {
+ // CHECK: arm_sme.tile_load %{{.*}}[{{.*}}], {{.*}}, {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xf64>, vector<[2]x[2]xf64>
+ return
+}
+
+// -----
+
/// Layout is optional and horizontal is the default, verify it's still parsed.
func.func @arm_sme_tile_load_explicit_hor(%src : memref<?x?xi8>) {
// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>
|
…Format (#68876) This is just a slight specialization of `TypesMatchWith` that returns success if an optional parameter is missing. There may be other places this could help e.g.: https://github.com/llvm/llvm-project/blob/eb21049b4b904b072679ece60e73c6b0dc0d1ebf/mlir/include/mlir/Dialect/X86Vector/X86Vector.td#L58-L59 ...but I'm leaving those to avoid some churn. This constraint will be handy for us in some later patches, it's a formalization of a short circuiting trick with the `comparator` of the `TypesMatchWith` constraint (devised for #69195). ``` TypesMatchWith< "padding type matches element type of result (if present)", "result", "padding", "::llvm::cast<VectorType>($_self).getElementType()", // This returns true if no padding is present, or it's present with a type that matches the element type of `result`. "!getPadding() || std::equal_to<>()"> ``` This is a little non-obvious, so after this patch you can instead do: ``` OptionalTypesMatchWith< "padding type matches element type of result (if present)", "result", "padding", "::llvm::cast<VectorType>($_self).getElementType()"> ```
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, ta!
The extra test for the verifier (to make sure that neither mask
nor pad
can be used without the other) could be added separately.
Padding and mask are optional, but if one is specified both must be specified. This is consistent with vector.transfer_read.
f4b89ab
to
dbc1c02
Compare
Thanks for reviewing. I've added the test. |
Padding and mask are optional, but if one is specified both must be specified. This is consistent with vector.transfer_read.