-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][tensor] Loosen restrictions on folding dynamic reshapes #137963
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The main idea behind the change is to allow expand-of-collapse folds for reshapes like `?x?xk` -> `?` (k>1). The rationale here is that the expand op must have a coherent index/affine expression specified in its `output_shape` argument (see example below), and if it doesn't, the IR has already been invalidated at an earlier stage: ``` %c32 = arith.constant 32 : index %div = arith.divsi %<some_index>, %c32 : index %collapsed = tensor.collapse_shape %41#1 [[0], [1, 2], [3, 4]] : tensor<9x?x32x?x32xf32> into tensor<9x?x?xf32> %affine = affine.apply affine_map<()[s0] -> (s0 * 32)> ()[%div] %expanded = tensor.expand_shape %collapsed [[0], [1, 2], [3]] output_shape [9, %div, 32, %affine] : tensor<9x?x?xf32> into tensor<9x?x32x?xf32> ``` On the above assumption, adjust the routine in `getReassociationIndicesForCollapse()` to allow dynamic reshapes beyond just `?x..?x1x1x..x1` -> `?`. Moreover, the reassociation util was refactored to clearly distinguish between dynamic and static subshapes. A few known caveats were noted as a comment; it doesn't seem possible to fold all qualifying dynamic shape patterns in a deterministic way without looking into affine expressions simultaneously. That would be difficult to maintain in a single general utility. Other implementation ideas/larger refactoring could include: - abandoning the util usage in the `ComposeExpandOfCollapseOp` pattern, employing similar logic to `ComposeCollapseOfExpandOp`; - providing dialect-specific implementations for Linalg/Tensor. Signed-off-by: Artem Gindinson <[email protected]>
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Artem Gindinson (AGindinson) ChangesThe main idea behind the change is to allow expand-of-collapse folds for reshapes like
On the above assumption, adjust the routine in Moreover, the reassociation util was refactored to clearly distinguish between dynamic and static subshapes. A few known caveats were noted as a comment; it doesn't seem possible to fold all qualifying dynamic shape patterns in a deterministic way without looking into affine expressions simultaneously. That would be difficult to maintain in a single general utility. Other implementation ideas/larger refactoring could include:
Signed-off-by: Artem Gindinson <[email protected]> Full diff: https://github.com/llvm/llvm-project/pull/137963.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index ed40a080441bc..694783849198a 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -31,59 +31,70 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
std::optional<SmallVector<ReassociationIndices>>
mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
ArrayRef<int64_t> targetShape) {
- if (sourceShape.size() <= targetShape.size())
+ unsigned numSourceDims = sourceShape.size(),
+ numTargetDims = targetShape.size();
+ if (numSourceDims <= numTargetDims)
return std::nullopt;
- unsigned sourceDim = 0;
- SmallVector<ReassociationIndices> reassociationMap;
- reassociationMap.reserve(targetShape.size());
-
- ReassociationIndices currIndices;
- int64_t prodOfCollapsedDims = 1;
- while (sourceDim < sourceShape.size()) {
- unsigned targetDim = reassociationMap.size();
- // If we have mapped all the target dimensions stop and handle the remaining
- // tail of size-1 dimensions explicitly.
- if (targetDim == targetShape.size())
- break;
+ SmallVector<ReassociationIndices, 4> reassociationMap;
+ reassociationMap.reserve(numTargetDims);
+ unsigned sourceDim = 0, targetDim = 0;
+ for (; targetDim < numTargetDims; ++targetDim) {
int64_t currTargetShape = targetShape[targetDim];
- while (sourceDim < (sourceShape.size() - 1) &&
- sourceShape[sourceDim] != ShapedType::kDynamic &&
- prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
+ ReassociationIndices currIndices;
+ // 1. Target dimension is dynamic. Source shape should contain at least
+ // one dynamic dimension.
+ if (currTargetShape == ShapedType::kDynamic) {
+ // FIXME: We stop the search with the first dynamic dimension, while in
+ // fact, we can have a valid pattern like 2x?x?x4x8 -> ?x4x8. It becomes
+ // indeterministic altogether when we have neighboring dynamic dimensions
+ // in the target shape. Most of these patterns will be safely rejected,
+ // however we might achieve more correct folds by taking affine
+ // expressions into account, if these can be passed on by the call sites.
+ bool foundDynamic = false;
+ while (sourceDim < numSourceDims) {
+ currIndices.push_back(sourceDim);
+ if (sourceShape[sourceDim++] == ShapedType::kDynamic) {
+ foundDynamic = true;
+ break;
+ }
+ }
+ if (!foundDynamic)
+ return std::nullopt;
+
+ reassociationMap.push_back(currIndices);
+ continue;
+ }
+ // 2. Target dimension is static. The product of dimensions of the expanded
+ // shape should match the collapsed dimension shape.
+ int64_t prodOfCollapsedDims = 1;
+ bool reachedTargetDimSize = false;
+ while (sourceDim < numSourceDims) {
+ // Source shape cannot be dynamic if the target dim is static.
+ if (sourceShape[sourceDim] == ShapedType::kDynamic)
+ return std::nullopt;
prodOfCollapsedDims *= sourceShape[sourceDim];
- currIndices.push_back(sourceDim++);
+ if (prodOfCollapsedDims > currTargetShape)
+ break;
+ else if (prodOfCollapsedDims == currTargetShape) {
+ currIndices.push_back(sourceDim++);
+ reachedTargetDimSize = true;
+ break;
+ } else // prodOfCollapsedDims < currTargetShape
+ currIndices.push_back(sourceDim++);
}
-
- // If the current expanded dimension is dynamic, then the collapsed
- // dimensions should also be dynamic and product of all previous unprocessed
- // dimensions of the expanded shape should be 1.
- if (sourceShape[sourceDim] == ShapedType::kDynamic &&
- (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
+ if (!reachedTargetDimSize)
return std::nullopt;
-
- // If the collapsed dim is dynamic, the current expanded dim should also
- // be dynamic.
- if (currTargetShape == ShapedType::kDynamic &&
- sourceShape[sourceDim] != ShapedType::kDynamic)
- return std::nullopt;
-
- // For static shapes, if the product of dimensions of the expanded shape
- // should match the collapsed dimension shape.
- if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
- return std::nullopt;
-
- currIndices.push_back(sourceDim++);
- reassociationMap.emplace_back(ReassociationIndices{});
- std::swap(reassociationMap.back(), currIndices);
- prodOfCollapsedDims = 1;
+ reassociationMap.push_back(currIndices);
}
- // All the dimensions in the target must have been processed.
- if (reassociationMap.size() != targetShape.size())
- return std::nullopt;
- // Process any remaining entries in the source shape. They all need to be
- // 1 or dynamic.
- for (; sourceDim < sourceShape.size(); sourceDim++) {
- if (sourceShape[sourceDim] != ShapedType::kDynamic &&
+ // Now that we've mapped all the target dimensions, process any remaining
+ // entries in the source shape explicitly. Either the last target dimension
+ // is dynamic, or all remaining source entries need to be 1 or dynamic. Same
+ // applies when target shape is empty (can be the case for subshape
+ // reassociations).
+ for (; sourceDim < numSourceDims; sourceDim++) {
+ if ((targetShape.empty() || targetShape.back() != ShapedType::kDynamic) &&
+ sourceShape[sourceDim] != ShapedType::kDynamic &&
sourceShape[sourceDim] != 1)
return std::nullopt;
// The map is empty when the target type is a scalar.
diff --git a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
index 51350e5bc8498..6979770154bab 100644
--- a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
@@ -158,8 +158,8 @@ func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> {
// -----
// CHECK-LABEL: func.func @unpack_dynamic
-// CHECK-NOT: tensor.collapse
-// CHECK: linalg.unpack
+// CHECK: tensor.collapse
+// CHECK-NOT: linalg.unpack
func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> {
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 85bf6fba52aa4..443f931745557 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1068,7 +1068,7 @@ func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf3
// -----
-func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
+func.func @fold_expand_of_collapse_mixed_subshape(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
-> tensor<?x4x?xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
: tensor<?x4x?xf32> into tensor<?x?xf32>
@@ -1076,12 +1076,28 @@ func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: ind
: tensor<?x?xf32> into tensor<?x4x?xf32>
return %1 : tensor<?x4x?xf32>
}
-// CHECK-LABEL: @fold_expand_of_collapse_dynamic
+// CHECK-LABEL: @fold_expand_of_collapse_mixed_subshape
// CHECK-NOT: tensor.{{.*}}_shape
// -----
-func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
+func.func @fold_expand_of_collapse_mixed_target_subshape(%arg0 : tensor<?x4x?x2xf32>, %arg1: index, %arg2: index)
+ -> tensor<?x4x?xf32> {
+ %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]]
+ : tensor<?x4x?x2xf32> into tensor<?x?xf32>
+ %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
+ : tensor<?x?xf32> into tensor<?x4x?xf32>
+ return %1 : tensor<?x4x?xf32>
+}
+// CHECK-LABEL: @fold_expand_of_collapse_mixed_target_subshape
+// CHECK-NOT: tensor.expand_shape
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0 {{\[}}[0], [1], [2, 3]]
+// CHECK-SAME: : tensor<?x4x?x2xf32> into tensor<?x4x?xf32>
+// CHECK-NEXT: return %[[COLLAPSE]]
+
+// -----
+
+func.func @no_fold_expand_of_collapse_fully_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
-> tensor<?x?x?xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
: tensor<?x?x?xf32> into tensor<?x?xf32>
@@ -1089,7 +1105,7 @@ func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1:
: tensor<?x?xf32> into tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
}
-// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic
+// CHECK-LABEL: @no_fold_expand_of_collapse_fully_dynamic
// CHECK: tensor.collapse_shape
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape
// CHECK: return %[[EXPAND]]
|
Signed-off-by: Artem Gindinson <[email protected]>
Signed-off-by: Artem Gindinson <[email protected]>
Signed-off-by: Artem Gindinson <[email protected]> Co-authored-by: Ian Wood <[email protected]>
Signed-off-by: Artem Gindinson <[email protected]>
Signed-off-by: Artem Gindinson <[email protected]>
Just following along loosely. I think this is fairly involved and tricky. Marking as request changes since I intend to come back and review it in depth. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(empty comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there is a bug here. For example:
EXPECT_EQ(getReassociationIndicesForCollapse(
{ShapedType::kDynamic, 10, 10, 10, ShapedType::kDynamic},
{ShapedType::kDynamic, 10, ShapedType::kDynamic}),
std::nullopt);
This fails with the output ({ { 0 }, { 1 }, { 2, 3, 4 } })
(I think std::nullopt
is correct result here)
Other examples:
EXPECT_EQ(getReassociationIndicesForCollapse(
{ShapedType::kDynamic, 3, 4, 3, ShapedType::kDynamic},
{ShapedType::kDynamic, 12, ShapedType::kDynamic}),
std::nullopt);
EXPECT_EQ(getReassociationIndicesForCollapse(
{ShapedType::kDynamic, 8, 4, 2, 16, ShapedType::kDynamic},
{ShapedType::kDynamic, 32, ShapedType::kDynamic}),
std::nullopt);
EXPECT_EQ(getReassociationIndicesForCollapse(
{ShapedType::kDynamic, 2, 2, 2, ShapedType::kDynamic},
{ShapedType::kDynamic, 2, 2, ShapedType::kDynamic}),
std::nullopt);
I'm trying to think of a generalized rule for when it becomes ambiguous. One possible way might be to check against a reversed target/source shape:
getReassociationIndicesForCollapse(source, target) ==
reverseReassociation(
getReassociationIndicesForCollapse(reverse(source), reverse(target)))
The idea is that the current algorithm is "anti-greedy" and will try go minimize the size of the reassociations towards the beginning. So reversing the ordering will compare anti-greedy with greedy. The results will be the same if there is non-ambiguous reassociation.
Thanks, great catch first and foremost! This may be a good approach, but I'm not sure how to conveniently let the ambiguity for collapsed dimensions of 1 pass through. Either tweaking the reverse iteration logic to push the 1's over to the next iteration (assuming this wouldn't lead to more peculiar edge cases, it might clutter the readability further) or writing the final comparison by hand... I'll mull this over and submit a new version next week. |
@IanWood1, while making adjustments, I've realized:
For a full solution, I'm taking another quick look at an approach with slicing the shapes into mixed & static subshapes. IMO, full-on backtracking is an overkill given the context. If not, let's just stick with 1 and accept the missed cases - we'll still get a good increase in coverage. I'll post what I have EO tomorrow. |
Too much drama, because:
This problem can be mathematically reduced to my current version. We just need to handle static head and/or tail separately from the dynamic/mixed subshape in between. A few examples for success & failure: https://gist.github.com/AGindinson/2942c0d9667d62f367d8893918b45076. The logic from the current version would be used heavily, so a review iteration would be of great help at this point @IanWood1 @MaheshRavishankar. As you see fit, we can also push the current approach over the merge line, then I'll iterate further to a complete solution. |
Signed-off-by: Artem Gindinson <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic looks good from what I understand but is a bit complicated so I may need to reread to make sure I fully grasp what is going on.
if (!iterationRange.isInRange(sourceShapeAsRange)) | ||
return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is not needed since the for loop will be skipped and return failure
bool foundDynamic = false; | ||
for (; iterationRange.isInRange(sourceShapeAsRange); | ||
iterationRange.expandRight()) { | ||
int64_t sourceSize = sourceShape[iterationRange.rightIdx]; | ||
if (foundDynamic && !matchGreedily) | ||
break; | ||
if (sourceSize == ShapedType::kDynamic) | ||
foundDynamic = true; | ||
resultRange = iterationRange; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can be simplified to stop after the first dynamic dim is found. Then, depending on matchGreedy
keep the current resultRange
or set the rhs to right most dim in the shape.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, this is essentially a leftover from when I've had separate greedy parameters for static & dynamic sources
if (!iterationRange.isInRange(sourceShapeAsRange)) | ||
return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
bool containsSingleIndex() const { return size() == 1; } | ||
|
||
void expandRight() { ++rightIdx; } | ||
void shrinkLeft() { ++leftIdx; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are simple enough that it might make sense to remove them but either way works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, these no longer make sense after I've added some proceed-with-caution comments for the struct
// tail of size-1 dimensions explicitly. | ||
if (targetDim == targetShape.size()) | ||
bool reachedTargetDimSize = false; | ||
while (iterationRange.isInRange(sourceShapeAsRange)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Splitting this into two loops may improve complexity. This loop would handle finding a minimal ReassociationIndexRange
and the second would extend it with 1s. That way you only need 1 break (after reaching the correct product).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, leftover from before the split into static & dynamic helpers with different "greed logic"
result.reserve(size() + rhs.size() / 2); // Attempt to amortize | ||
for (int64_t idx = this->leftIdx; idx <= this->rightIdx; ++idx) { | ||
if (idx < rhs.leftIdx || idx > rhs.rightIdx) | ||
result.push_back(idx); | ||
} | ||
for (int64_t rhsIndex = rhs.leftIdx; rhsIndex <= rhs.rightIdx; ++rhsIndex) { | ||
if (rhsIndex < leftIdx || rhsIndex > rightIdx) | ||
result.push_back(rhsIndex); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need to iterate over the entirety of both intervals. Something like this should work:
result.reserve(size() + rhs.size() / 2); // Attempt to amortize | |
for (int64_t idx = this->leftIdx; idx <= this->rightIdx; ++idx) { | |
if (idx < rhs.leftIdx || idx > rhs.rightIdx) | |
result.push_back(idx); | |
} | |
for (int64_t rhsIndex = rhs.leftIdx; rhsIndex <= rhs.rightIdx; ++rhsIndex) { | |
if (rhsIndex < leftIdx || rhsIndex > rightIdx) | |
result.push_back(rhsIndex); | |
} | |
int64_t leftStart = std::min(this->leftIndex, rhs->leftIndex); | |
int64_t leftEnd = std::max(this->leftIndex, rhs->leftIndex); | |
llvm::append_range(result, llvm::seq(leftStart, leftEnd + 1)); | |
/// same for the right |
I think I messed up the math a bit but it also solves the result.reserve
if (failed(range.verify())) | ||
return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't fail right? maybe assert instead?
// All source dimensions must be unit or dynamic. | ||
if (sourceSize != 1 && sourceSize != ShapedType::kDynamic) | ||
return std::nullopt; | ||
allSourceIndices.emplace_back(sourceDimIdx); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
allSourceIndices.emplace_back(sourceDimIdx); | |
allSourceIndices.push_back(sourceDimIdx); |
auto &range = ranges[targetDimIdx]; | ||
auto &reverseRange = reverseRanges[targetDimIdx]; | ||
// Get non-overlapping indices between the ranges | ||
ReassociationIndices nonMatchingIndices = range ^ reverseRange; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although ^
makes sense, it would probably be more clear and easier to search if this was a normal method instead of an overloaded operator.
// Store the gathered information as required for the next iteration. | ||
prevTargetSize = targetSize; | ||
sourceDimIdx = sourceRange->rightIdx + 1; | ||
reassocRanges.emplace_back(std::move(*sourceRange)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think emplace_back
or move
is needed since ReassociationIndexRange
only contains POD types (also this is giving a clangd warning).
move
is a bit less clear than explicitly calling reset
The main idea behind the change is to allow expand-of-collapse folds for reshapes like
?x?xk
->?
(k>1). The rationale here is that the expand op must have a coherent index/affine expression specified in itsoutput_shape
argument (see example below), and if it doesn't, the IR has already been invalidated at an earlier stage:On the above assumption, adjust the routine in
getReassociationIndicesForCollapse()
to allow dynamic reshapes beyond just?x..?x1x1x..x1
->?
. Dynamic subshapes introduce two kinds of issues:?x?x10x? into ?x?
)?x2x3x4 into ?x12
)To address 1, we should detect such sequences in the target shape before assigning multiple dynamic dimensions into the same index set. For 2, we take note that a static target dimension was preceded by a dynamic one and allow an "offset" subshape of source static dimensions, as long as there's an exact sequence for the target size later in the source shape.
This PR aims to address all reshapes that can be determined based purely on shapes (and original reassociation
maps, as done in
ComposeExpandOfCollapseOp::findCollapsingReassociation)
. It doesn't seem possible to fold all qualifying dynamic shape patterns in a deterministic way without looking into affine expressions simultaneously. That would be difficult to maintain in a single general utility, so a path forward would be to provide dialect-specific implementations for Linalg/Tensor.Signed-off-by: Artem Gindinson [email protected]