Skip to content

[mlir][Vector] Support poison in vector.shuffle mask #122188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -434,10 +434,9 @@ def Vector_ShuffleOp
The shuffle operation constructs a permutation (or duplication) of elements
from two input vectors, returning a vector with the same element type as
the input and a length that is the same as the shuffle mask. The two input
vectors must have the same element type, same rank , and trailing dimension
sizes and shuffles their values in the
leading dimension (which may differ in size) according to the given mask.
The legality rules are:
vectors must have the same element type, same rank, and trailing dimension
sizes and shuffles their values in the leading dimension (which may differ
in size) according to the given mask. The legality rules are:
* the two operands must have the same element type as the result
- Either, the two operands and the result must have the same
rank and trailing dimension sizes, viz. given two k-D operands
Expand All @@ -448,7 +447,9 @@ def Vector_ShuffleOp
* the mask length equals the leading dimension size of the result
* numbering the input vector indices left to right across the operands, all
mask values must be within range, viz. given two k-D operands v1 and v2
above, all mask values are in the range [0,s_1+t_1)
above, all mask values are in the range [0,s_1+t_1). The value `-1`
represents a poison mask value, which specifies that the selected element
is poison.

Note, scalable vectors are not supported.

Expand All @@ -463,10 +464,15 @@ def Vector_ShuffleOp
: vector<2xf32>, vector<2xf32> ; yields vector<4xf32>
%3 = vector.shuffle %a, %b[0, 1]
: vector<f32>, vector<f32> ; yields vector<2xf32>
%4 = vector.shuffle %a, %b[0, 4, -1, -1, -1, -1]
: vector<4xf32>, vector<4xf32> ; yields vector<6xf32>
```
}];

let extraClassDeclaration = [{
// Integer to represent a poison value in a vector shuffle mask.
static constexpr int64_t kMaskPoisonValue = -1;

VectorType getV1VectorType() {
return ::llvm::cast<VectorType>(getV1().getType());
}
Expand Down Expand Up @@ -700,6 +706,8 @@ def Vector_ExtractOp :
%4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32>
%5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32>
```

TODO: Implement support for poison indices.
}];

let arguments = (ins
Expand Down Expand Up @@ -890,6 +898,8 @@ def Vector_InsertOp :
%11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
%12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
```

TODO: Implement support for poison indices.
}];

let arguments = (ins
Expand Down Expand Up @@ -980,6 +990,8 @@ def Vector_ScalableInsertOp :
```mlir
%2 = vector.scalable.insert %0, %1[5] : vector<4xf32> into vector<[16]xf32>
```

TODO: Implement support for poison indices.
}];

let assemblyFormat = [{
Expand Down Expand Up @@ -1031,6 +1043,8 @@ def Vector_ScalableExtractOp :
```mlir
%1 = vector.scalable.extract %0[5] : vector<4xf32> from vector<[16]xf32>
```

TODO: Implement support for poison indices.
}];

let assemblyFormat = [{
Expand Down Expand Up @@ -1075,6 +1089,8 @@ def Vector_InsertStridedSliceOp :
{offsets = [0, 0, 2], strides = [1, 1]}:
vector<2x4xf32> into vector<16x4x8xf32>
```

TODO: Implement support for poison indices.
}];

let assemblyFormat = [{
Expand Down Expand Up @@ -1220,6 +1236,8 @@ def Vector_ExtractStridedSliceOp :
%1 = vector.extract_strided_slice %0[0:2:1][2:4:1]
vector<4x8x16xf32> to vector<2x4x16xf32>
```

TODO: Implement support for poison indices.
}];
let builders = [
OpBuilder<(ins "Value":$source, "ArrayRef<int64_t>":$offsets,
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2600,7 +2600,7 @@ LogicalResult ShuffleOp::verify() {
int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
(v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
for (auto [idx, maskPos] : llvm::enumerate(mask)) {
if (maskPos < 0 || maskPos >= indexSize)
if (maskPos != kMaskPoisonValue && (maskPos < 0 || maskPos >= indexSize))
return emitOpError("mask index #") << (idx + 1) << " out of range";
}
return success();
Expand Down
10 changes: 10 additions & 0 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,16 @@ func.func @shuffle_1D_index_direct(%arg0: vector<2xindex>, %arg1: vector<2xindex

// -----

func.func @shuffle_poison_mask(%arg0: vector<2xf32>, %arg1: vector<2xf32>) -> vector<4xf32> {
%1 = vector.shuffle %arg0, %arg1 [0, -1, 3, -1] : vector<2xf32>, vector<2xf32>
return %1 : vector<4xf32>
}
// CHECK-LABEL: @shuffle_poison_mask(
// CHECK-SAME: %[[A:.*]]: vector<2xf32>, %[[B:.*]]: vector<2xf32>)
// CHECK: %[[s:.*]] = llvm.shufflevector %[[A]], %[[B]] [0, -1, 3, -1] : vector<2xf32>

// -----

func.func @shuffle_1D(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<5xf32> {
%1 = vector.shuffle %arg0, %arg1 [4, 3, 2, 1, 0] : vector<2xf32>, vector<3xf32>
return %1 : vector<5xf32>
Expand Down
11 changes: 11 additions & 0 deletions mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,17 @@ func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> {

// -----

// CHECK-LABEL: func @shuffle
// CHECK-SAME: %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>
// CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [1 : i32, -1 : i32, 5 : i32, -1 : i32] %[[ARG0]], %[[ARG1]] : vector<4xi32>, vector<4xi32> -> vector<4xi32>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// CHECK: return %[[SHUFFLE]] : vector<4xi32>
func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<4xi32> {
%shuffle = vector.shuffle %v0, %v1 [1, -1, 5, -1] : vector<4xi32>, vector<4xi32>
return %shuffle : vector<4xi32>
}

// -----

// CHECK-LABEL: func @interleave
// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf32>)
// CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[ARG0]], %[[ARG1]] : vector<2xf32>, vector<2xf32> -> vector<4xf32>
Expand Down
7 changes: 7 additions & 0 deletions mlir/test/Dialect/Vector/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,13 @@ func.func @shuffle2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf32
return %1 : vector<3x4xf32>
}

// CHECK-LABEL: @shuffle_poison_mask
func.func @shuffle_poison_mask(%a: vector<4xf32>, %b: vector<4xf32>) -> vector<4xf32> {
// CHECK: vector.shuffle %{{.*}}, %{{.*}}[1, -1, 6, -1] : vector<4xf32>, vector<4xf32>
%1 = vector.shuffle %a, %a[1, -1, 6, -1] : vector<4xf32>, vector<4xf32>
return %1 : vector<4xf32>
}

// CHECK-LABEL: @extract_element_0d
func.func @extract_element_0d(%a: vector<f32>) -> f32 {
// CHECK-NEXT: vector.extractelement %{{.*}}[] : vector<f32>
Expand Down
Loading