-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][Vector] Add vector.to_elements
op
#141457
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This PR introduces the `vector.to_elements` op, which decomposes a vector into its scalar elements. This operation is symmetrical to the existing `vector.from_elements`. Examples: ``` // Decompose a 0-D vector. %0 = vector.to_elements %v0 : vector<f32> // %0 = %v0[0] // Decompose a 1-D vector. %0:2 = vector.to_elements %v1 : vector<2xf32> // %0#0 = %v1[0] // %0#1 = %v1[1] // Decompose a 2-D. %0:6 = vector.to_elements %v2 : vector<2x3xf32> // %0#0 = %v2[0, 0] // %0#1 = %v2[0, 1] // %0#2 = %v2[0, 2] // %0#3 = %v2[1, 0] // %0#4 = %v2[1, 1] // %0#5 = %v2[1, 2] ``` This op is aimed at reducing code size when modeling "structured" vector extractions and simplifying canonicalizations of large sequences of `vector.extract` and `vector.insert` ops into `vector.shuffle` and other sophisticated ops that can re-arrange vector elements. More related PRs to come!
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Diego Caballero (dcaballe) ChangesThis PR introduces the Examples:
This op is aimed at reducing code size when modeling "structured" vector extractions and simplifying canonicalizations of large sequences of More related PRs to come! Full diff: https://github.com/llvm/llvm-project/pull/141457.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 5e8421ed67d66..3da47d8e612e2 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -789,6 +789,57 @@ def Vector_FMAOp :
}];
}
+def Vector_ToElementsOp : Vector_Op<"to_elements", [
+ Pure,
+ TypesMatchWith<"operand element type matches result types",
+ "input", "elements", "SmallVector<Type>("
+ "::llvm::cast<VectorType>($_self).getNumElements(), "
+ "::llvm::cast<VectorType>($_self).getElementType())">]> {
+ let summary = "operation that decomposes a vector into all its scalar elements";
+ let description = [{
+ This operation decomposes all the scalar elements from a vector. The
+ decomposed scalar elements are returned in row-major order. The number of
+ scalar results must match the number of elements in the input vector type.
+ All the result elements have the same result type, which must match the
+ element type of the input vector. Scalable vectors are not supported.
+
+ Examples:
+
+ ```mlir
+ // Decompose a 0-D vector.
+ %0 = vector.to_elements %v0 : vector<f32>
+ // %0 = %v0[0]
+
+ // Decompose a 1-D vector.
+ %0:2 = vector.to_elements %v1 : vector<2xf32>
+ // %0#0 = %v1[0]
+ // %0#1 = %v1[1]
+
+ // Decompose a 2-D.
+ %0:6 = vector.to_elements %v2 : vector<2x3xf32>
+ // %0#0 = %v2[0, 0]
+ // %0#1 = %v2[0, 1]
+ // %0#2 = %v2[0, 2]
+ // %0#3 = %v2[1, 0]
+ // %0#4 = %v2[1, 1]
+ // %0#5 = %v2[1, 2]
+
+ // Decompose a 3-D vector.
+ %0:6 = vector.to_elements %v3 : vector<3x1x2xf32>
+ // %0#0 = %v3[0, 0, 0]
+ // %0#1 = %v3[0, 0, 1]
+ // %0#2 = %v3[1, 0, 0]
+ // %0#3 = %v3[1, 0, 1]
+ // %0#4 = %v3[2, 0, 0]
+ // %0#5 = %v3[2, 0, 1]
+ ```
+ }];
+
+ let arguments = (ins AnyVectorOfAnyRank:$input);
+ let results = (outs Variadic<AnyType>:$elements);
+ let assemblyFormat = "$input attr-dict `:` type($input)";
+}
+
def Vector_FromElementsOp : Vector_Op<"from_elements", [
Pure,
TypesMatchWith<"operand types match result element type",
@@ -798,26 +849,30 @@ def Vector_FromElementsOp : Vector_Op<"from_elements", [
let summary = "operation that defines a vector from scalar elements";
let description = [{
This operation defines a vector from one or multiple scalar elements. The
- number of elements must match the number of elements in the result type.
- All elements must have the same type, which must match the element type of
- the result vector type.
-
- `elements` are a flattened version of the result vector in row-major order.
+ scalar elements are arranged in row-major within the vector. The number of
+ elements must match the number of elements in the result type. All elements
+ must have the same type, which must match the element type of the result
+ vector type. Scalable vectors are not supported.
- Example:
+ Examples:
```mlir
- // %f1
+ // Define a 0-D vector.
%0 = vector.from_elements %f1 : vector<f32>
- // [%f1, %f2]
+ // [%f1]
+
+ // Define a 1-D vector.
%1 = vector.from_elements %f1, %f2 : vector<2xf32>
- // [[%f1, %f2, %f3], [%f4, %f5, %f6]]
+ // [%f1, %f2]
+
+ // Define a 2-D vector.
%2 = vector.from_elements %f1, %f2, %f3, %f4, %f5, %f6 : vector<2x3xf32>
- // [[[%f1, %f2]], [[%f3, %f4]], [[%f5, %f6]]]
+ // [[%f1, %f2, %f3], [%f4, %f5, %f6]]
+
+ // Define a 3-D vector.
%3 = vector.from_elements %f1, %f2, %f3, %f4, %f5, %f6 : vector<3x1x2xf32>
+ // [[[%f1, %f2]], [[%f3, %f4]], [[%f5, %f6]]]
```
-
- Note, scalable vectors are not supported.
}];
let arguments = (ins Variadic<AnyType>:$elements);
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 04810ed52584f..70a7274182442 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1896,7 +1896,24 @@ func.func @deinterleave_scalable_rank_fail(%vec : vector<2x[4]xf32>) {
// -----
-func.func @invalid_from_elements(%a: f32) {
+func.func @to_elements_wrong_num_results(%a: vector<1x1x2xf32>) {
+ // expected-error @+1 {{operation defines 2 results but was provided 4 to bind}}
+ %0:4 = vector.to_elements %a : vector<1x1x2xf32>
+ return
+}
+
+// -----
+
+func.func @to_elements_wrong_result_type(%a: vector<2xf32>) -> i32 {
+ // expected-error @+3 {{use of value '%0' expects different type than prior uses: 'i32'}}
+ // expected-note @+1 {{prior use here}}
+ %0:2 = vector.to_elements %a : vector<2xf32>
+ return %0#0 : i32
+}
+
+// -----
+
+func.func @from_elements_wrong_num_operands(%a: f32) {
// expected-error @+1 {{'vector.from_elements' number of operands and types do not match: got 1 operands and 2 types}}
vector.from_elements %a : vector<2xf32>
return
@@ -1905,12 +1922,11 @@ func.func @invalid_from_elements(%a: f32) {
// -----
// expected-note @+1 {{prior use here}}
-func.func @invalid_from_elements(%a: f32, %b: i32) {
+func.func @from_elements_wrong_operand_type(%a: f32, %b: i32) {
// expected-error @+1 {{use of value '%b' expects different type than prior uses: 'f32' vs 'i32'}}
vector.from_elements %a, %b : vector<2xf32>
return
}
-
// -----
func.func @invalid_from_elements_scalable(%a: f32, %b: i32) {
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index f3220aed4360c..7cfe4e89d6e2f 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1175,6 +1175,25 @@ func.func @deinterleave_nd_scalable(%arg:vector<2x3x4x[6]xf32>) -> (vector<2x3x4
return %0, %1 : vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>
}
+// CHECK-LABEL: func @to_elements(
+// CHECK-SAME: %[[A_VEC:.*]]: vector<f32>, %[[B_VEC:.*]]: vector<4xf32>,
+// CHECK-SAME: %[[C_VEC:.*]]: vector<1xf32>, %[[D_VEC:.*]]: vector<2x2xf32>)
+func.func @to_elements(%a_vec : vector<f32>, %b_vec : vector<4xf32>, %c_vec : vector<1xf32>, %d_vec : vector<2x2xf32>)
+ -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
+ // CHECK: %[[A_ELEMS:.*]] = vector.to_elements %[[A_VEC]] : vector<f32>
+ %0 = vector.to_elements %a_vec : vector<f32>
+ // CHECK: %[[B_ELEMS:.*]]:4 = vector.to_elements %[[B_VEC]] : vector<4xf32>
+ %1:4 = vector.to_elements %b_vec : vector<4xf32>
+ // CHECK: %[[C_ELEMS:.*]] = vector.to_elements %[[C_VEC]] : vector<1xf32>
+ %2 = vector.to_elements %c_vec : vector<1xf32>
+ // CHECK: %[[D_ELEMS:.*]]:4 = vector.to_elements %[[D_VEC]] : vector<2x2xf32>
+ %3:4 = vector.to_elements %d_vec : vector<2x2xf32>
+ // CHECK: return %[[A_ELEMS]], %[[B_ELEMS]]#0, %[[B_ELEMS]]#1, %[[B_ELEMS]]#2,
+ // CHECK-SAME: %[[B_ELEMS]]#3, %[[C_ELEMS]], %[[D_ELEMS]]#0, %[[D_ELEMS]]#1,
+ // CHECK-SAME: %[[D_ELEMS]]#2, %[[D_ELEMS]]#3
+ return %0, %1#0, %1#1, %1#2, %1#3, %2, %3#0, %3#1, %3#2, %3#3 : f32, f32, f32, f32, f32, f32, f32, f32, f32, f32
+}
+
// CHECK-LABEL: func @from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @from_elements(%a: f32, %b: f32) -> (vector<f32>, vector<1xf32>, vector<1x2xf32>, vector<2x2xf32>) {
|
Thanks, Diego! I’ve left a few minor comments inline. I also have one broader question.
That mostly makes sense, but I’m curious about the practical next steps and the overall direction we’re aiming for. Specifically:
Basically, I want to make sure we avoid "dangling" ops 😅 - currently, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
TypesMatchWith<"operand element type matches result types", | ||
"input", "elements", "SmallVector<Type>(" | ||
"::llvm::cast<VectorType>($_self).getNumElements(), " | ||
"::llvm::cast<VectorType>($_self).getElementType())">]> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this be shared with vector.from_elements
?
``` | ||
}]; | ||
|
||
let arguments = (ins AnyVectorOfAnyRank:$input); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
$input
is not very descriptive and quite uncommon in Vector
. Instead, IMHO, we should re-use one of the existing names to maintain consistency. My suggestion would be:
%source
forto_elements
,$dest
forfrom_elements
.
Why $source
? Basically, I looked at other Vector Ops that take one argument:
- https://mlir.llvm.org/docs/Dialects/Vector/#vectorbroadcast-vectorbroadcastop
- https://mlir.llvm.org/docs/Dialects/Vector/#vectorbitcast-vectorbitcastop
Why $dest
? It naturally complements $source
.
This operation decomposes all the scalar elements from a vector. The | ||
decomposed scalar elements are returned in row-major order. The number of | ||
scalar results must match the number of elements in the input vector type. | ||
All the result elements have the same result type, which must match the | ||
element type of the input vector. Scalable vectors are not supported. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it important that it decomposes into all elements? This op could be really useful for unrolling a dimension if we could do it dimwise. Something like:
%0:16 = vector.to_elements %v : vector<16x4xf32> -> vector<4xf32>
This should have the exact same semantics as vector.extract, just doing multiple extracts at once.
I would much rather have this form of the operation, it is much closer to vector.extract and works for N-D vectors much better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that keeping the symmetry with from_elements
is valuable. I'm not sure I follow the suggestion, but is it doing something that chaining extract
/ extract_strided_slice
/ shape_cast
/ to_elements
cannot achieve?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While I think that overall we should be deprecating/simplifying as per suggestions in this doc I think this particular new op is a sensible one to have alongside from_elements
, thanks!
``` | ||
}]; | ||
|
||
let arguments = (ins AnyVectorOfAnyRank:$input); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let arguments = (ins AnyVectorOfAnyRank:$input); | |
let arguments = (ins AnyFixedVectorOfAnyRank:$input); |
perhaps?
This operation decomposes all the scalar elements from a vector. The | ||
decomposed scalar elements are returned in row-major order. The number of | ||
scalar results must match the number of elements in the input vector type. | ||
All the result elements have the same result type, which must match the | ||
element type of the input vector. Scalable vectors are not supported. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that keeping the symmetry with from_elements
is valuable. I'm not sure I follow the suggestion, but is it doing something that chaining extract
/ extract_strided_slice
/ shape_cast
/ to_elements
cannot achieve?
Thanks for the feedback! Let me elaborate a more on the value that this op (and also the existing
I don’t anticipate major changes other than replacing loops creating
Yeah, mostly because we needed the symmetrical op that this MR is introducing and implement the corresponding canonicalization patterns and lowering, which should come after this.
As I mentioned above, it’s important that all the elements are decomposed to offer an implicit and trivial extraction order that doesn’t have to be analyzed. However, I think decomposing into sub-vectors is a natural follow-up that would be very helpful for unrolling, yes! I suggest, though, that we approach this incrementally by first having all the pieces in place for the simple scalar cases before enabling more cases complex cases. Does it sound reasonable? |
Thanks for elaborating, Diego! This makes sense to me, the benefits of including these Ops in Overall this looks good to me % the ongoing discussions. |
This PR introduces the
vector.to_elements
op, which decomposes a vector into its scalar elements. This operation is symmetrical to the existingvector.from_elements
.Examples:
This op is aimed at reducing code size when modeling "structured" vector extractions and simplifying canonicalizations of large sequences of
vector.extract
andvector.insert
ops intovector.shuffle
and other sophisticated ops that can re-arrange vector elements.More related PRs to come!