Skip to content

[mlir][Vector] Support poison in vector.shuffle mask #122188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 18, 2025

Conversation

dcaballe
Copy link
Contributor

@dcaballe dcaballe commented Jan 8, 2025

This PR extends the existing poison support in https://mlir.llvm.org/docs/Dialects/UBOps/ by representing poison mask values in vector.shuffle. Similar to LLVM (see https://github.com/llvm/llvm-project/blob/main/llvm/include/llvm/IR/Instructions.h#L1884) this requires defining an integer value (-1) to represent poison in the vector.shuffle mask.

The current implementation parses and prints -1 for the poison value. I implemented a custom parser/printer to use the poison keyword instead but, on a second thought, I removed it from the PR. I think it's an overkill to have to introduce a hand-written parser/printer for every operation supporting poison. I also explored adding new flavors of DenseIXArrayAttr that could take an argument to represent the poison value, but the code turned quite complex so I also desisted. Happy to get feedback about this and improve the assembly format as a follow-up!

@llvmbot
Copy link
Member

llvmbot commented Jan 8, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir-spirv

Author: Diego Caballero (dcaballe)

Changes

This PR extends the existing poison support in https://mlir.llvm.org/docs/Dialects/UBOps/ by representing poison mask values in vector.shuffle. Similar to LLVM (see https://github.com/llvm/llvm-project/blob/main/llvm/include/llvm/IR/Instructions.h#L1884) this requires defining an integer value (-1) to represent poison in the vector.shuffle mask.

The current implementation parses and prints -1 for the poison value. I implemented a custom parser/printer to use the poison keyword instead but, on a second thought, I removed it from the PR. I think it's an overkill to have to introduce a hand-written parser/printer for every operation supporting poison. I also explored adding new flavors of DenseIXArrayAttr that could take an argument to represent the poison value, but the code turned quite complex so I also desisted. Happy to get feedback about this and improve the assembly format as a follow-up!


Full diff: https://github.com/llvm/llvm-project/pull/122188.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+8-2)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+1-1)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+10)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+11)
  • (modified) mlir/test/Dialect/Vector/ops.mlir (+7)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 30a5b06374fad1..a786e4696415cb 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -434,7 +434,7 @@ def Vector_ShuffleOp
     The shuffle operation constructs a permutation (or duplication) of elements
     from two input vectors, returning a vector with the same element type as
     the input and a length that is the same as the shuffle mask. The two input
-    vectors must have the same element type, same rank , and trailing dimension
+    vectors must have the same element type, same rank, and trailing dimension
     sizes and shuffles their values in the
     leading dimension (which may differ in size) according to the given mask.
     The legality rules are:
@@ -448,7 +448,8 @@ def Vector_ShuffleOp
     * the mask length equals the leading dimension size of the result
     * numbering the input vector indices left to right across the operands, all
       mask values must be within range, viz. given two k-D operands v1 and v2
-      above, all mask values are in the range [0,s_1+t_1)
+      above, all mask values are in the range [0,s_1+t_1). -1 is used to
+      represent a poison mask value.
 
     Note, scalable vectors are not supported.
 
@@ -463,10 +464,15 @@ def Vector_ShuffleOp
                : vector<2xf32>, vector<2xf32>       ; yields vector<4xf32>
     %3 = vector.shuffle %a, %b[0, 1]
                : vector<f32>, vector<f32>           ; yields vector<2xf32>
+    %4 = vector.shuffle %a, %b[0, 4, -1, -1, -1, -1]
+               : vector<4xf32>, vector<4xf32>       ; yields vector<6xf32>
     ```
   }];
 
   let extraClassDeclaration = [{
+    // Integer to represent a poison value in a vector shuffle mask.
+    static constexpr int64_t kMaskPoisonValue = -1;
+
     VectorType getV1VectorType() {
       return ::llvm::cast<VectorType>(getV1().getType());
     }
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ae1cf95732336a..696d1e0f9b1e68 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2600,7 +2600,7 @@ LogicalResult ShuffleOp::verify() {
   int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
                       (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
   for (auto [idx, maskPos] : llvm::enumerate(mask)) {
-    if (maskPos < 0 || maskPos >= indexSize)
+    if (maskPos != kMaskPoisonValue && (maskPos < 0 || maskPos >= indexSize))
       return emitOpError("mask index #") << (idx + 1) << " out of range";
   }
   return success();
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index f95e943250bd44..931cc36c9d4a88 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1105,6 +1105,16 @@ func.func @shuffle_1D_index_direct(%arg0: vector<2xindex>, %arg1: vector<2xindex
 
 // -----
 
+func.func @shuffle_poison_mask(%arg0: vector<2xf32>, %arg1: vector<2xf32>) -> vector<4xf32> {
+  %1 = vector.shuffle %arg0, %arg1 [0, -1, 3, -1] : vector<2xf32>, vector<2xf32>
+  return %1 : vector<4xf32>
+}
+// CHECK-LABEL: @shuffle_poison_mask(
+//  CHECK-SAME:   %[[A:.*]]: vector<2xf32>, %[[B:.*]]: vector<2xf32>)
+//       CHECK:     %[[s:.*]] = llvm.shufflevector %[[A]], %[[B]] [0, -1, 3, -1] : vector<2xf32>
+
+// -----
+
 func.func @shuffle_1D(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<5xf32> {
   %1 = vector.shuffle %arg0, %arg1 [4, 3, 2, 1, 0] : vector<2xf32>, vector<3xf32>
   return %1 : vector<5xf32>
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 103148633bf97c..fd73cea5e4f306 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -613,6 +613,17 @@ func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> {
 
 // -----
 
+// CHECK-LABEL:  func @shuffle
+//  CHECK-SAME:  %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>
+//       CHECK:    %[[SHUFFLE:.*]] = spirv.VectorShuffle [1 : i32, -1 : i32, 5 : i32, -1 : i32] %[[ARG0]], %[[ARG1]] : vector<4xi32>, vector<4xi32> -> vector<4xi32>
+//       CHECK:    return %[[SHUFFLE]] : vector<4xi32>
+func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<4xi32> {
+  %shuffle = vector.shuffle %v0, %v1 [1, -1, 5, -1] : vector<4xi32>, vector<4xi32>
+  return %shuffle : vector<4xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @interleave
 //  CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf32>)
 //       CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[ARG0]], %[[ARG1]] : vector<2xf32>, vector<2xf32> -> vector<4xf32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 961f1b5ffeabec..cd6f3f518a1c07 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -190,6 +190,13 @@ func.func @shuffle2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf32
   return %1 : vector<3x4xf32>
 }
 
+// CHECK-LABEL: @shuffle_poison_mask
+func.func @shuffle_poison_mask(%a: vector<4xf32>, %b: vector<4xf32>) -> vector<4xf32> {
+  // CHECK: vector.shuffle %{{.*}}, %{{.*}}[1, -1, 6, -1] : vector<4xf32>, vector<4xf32>
+  %1 = vector.shuffle %a, %a[1, -1, 6, -1] : vector<4xf32>, vector<4xf32>
+  return %1 : vector<4xf32>
+}
+
 // CHECK-LABEL: @extract_element_0d
 func.func @extract_element_0d(%a: vector<f32>) -> f32 {
   // CHECK-NEXT: vector.extractelement %{{.*}}[] : vector<f32>

@dcaballe
Copy link
Contributor Author

Ping :)

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall, just one minor issue.

Would be nice to leave a TODO somewhere to handle this in folds/cannon patterns, e.g., extract, extract_strided_slice, etc.

@@ -613,6 +613,17 @@ func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> {

// -----

// CHECK-LABEL: func @shuffle
// CHECK-SAME: %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>
// CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [1 : i32, -1 : i32, 5 : i32, -1 : i32] %[[ARG0]], %[[ARG1]] : vector<4xi32>, vector<4xi32> -> vector<4xi32>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR extends the existing poison support in https://mlir.llvm.org/docs/Dialects/UBOps/
by representing poison mask values in `vector.shuffle`. Similar to LLVM (see
https://github.com/llvm/llvm-project/blob/main/llvm/include/llvm/IR/Instructions.h#L1884)
this requires defining an integer value (`-1`) representing poison in the `vector.shuffle` mask.

The current implementation parses and prints `-1` for the poison value. I implemented a custom
parser/printer to use the `poison` keyword instead but I think it's an overkill to have to introduce
a hand-written parsers/printers for every operation supporting poison. I also explored adding new
flavors of `DenseIXArrayAttr` that could take an argument to represent the poison value, but I also
desisted as the resulting code was too complex. Happy to get feedback about this and improve the
assembly format as a follow-up.
@dcaballe
Copy link
Contributor Author

Thanks!

Would be nice to leave a TODO somewhere to handle this in folds/cannon patterns, e.g., extract, extract_strided_slice, etc.

I added this to the ops' doc. I already have some of this implemented.

@dcaballe dcaballe merged commit eae5ca9 into llvm:main Jan 18, 2025
8 checks passed
dcaballe added a commit to dcaballe/llvm-project that referenced this pull request Jan 18, 2025
Following up on llvm#122188, this PR adds support for poison indices to
`ExtractOp` and `InsertOp`. It also includes canonicalization patterns
to turn extract/insert ops with poison indices into `ub.poison`.
dcaballe added a commit to dcaballe/llvm-project that referenced this pull request Jan 28, 2025
Following up on llvm#122188, this PR adds support for poison indices to
`ExtractOp` and `InsertOp`. It also includes canonicalization patterns
to turn extract/insert ops with poison indices into `ub.poison`.
dcaballe added a commit that referenced this pull request Jan 28, 2025
…123488)

Following up on #122188, this PR adds support for poison indices to
`ExtractOp` and `InsertOp`. It also includes canonicalization patterns
to turn extract/insert ops with poison indices into `ub.poison`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants