Skip to content

[mlir][tensor] Loosen restrictions on folding dynamic reshapes #137963

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 57 additions & 46 deletions mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,59 +31,70 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
std::optional<SmallVector<ReassociationIndices>>
mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
ArrayRef<int64_t> targetShape) {
if (sourceShape.size() <= targetShape.size())
unsigned numSourceDims = sourceShape.size(),
numTargetDims = targetShape.size();
if (numSourceDims <= numTargetDims)
return std::nullopt;
unsigned sourceDim = 0;
SmallVector<ReassociationIndices> reassociationMap;
reassociationMap.reserve(targetShape.size());

ReassociationIndices currIndices;
int64_t prodOfCollapsedDims = 1;
while (sourceDim < sourceShape.size()) {
unsigned targetDim = reassociationMap.size();
// If we have mapped all the target dimensions stop and handle the remaining
// tail of size-1 dimensions explicitly.
if (targetDim == targetShape.size())
break;
SmallVector<ReassociationIndices, 4> reassociationMap;
reassociationMap.reserve(numTargetDims);

unsigned sourceDim = 0, targetDim = 0;
for (; targetDim < numTargetDims; ++targetDim) {
int64_t currTargetShape = targetShape[targetDim];
while (sourceDim < (sourceShape.size() - 1) &&
sourceShape[sourceDim] != ShapedType::kDynamic &&
prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
ReassociationIndices currIndices;
// 1. Target dimension is dynamic. Source shape should contain at least
// one dynamic dimension.
if (currTargetShape == ShapedType::kDynamic) {
// FIXME: We stop the search with the first dynamic dimension, while in
// fact, we can have a valid pattern like 2x?x?x4x8 -> ?x4x8. It becomes
// indeterministic altogether when we have neighboring dynamic dimensions
// in the target shape. Most of these patterns will be safely rejected,
// however we might achieve more correct folds by taking affine
// expressions into account, if these can be passed on by the call sites.
bool foundDynamic = false;
while (sourceDim < numSourceDims) {
currIndices.push_back(sourceDim);
if (sourceShape[sourceDim++] == ShapedType::kDynamic) {
foundDynamic = true;
break;
}
}
if (!foundDynamic)
return std::nullopt;

reassociationMap.push_back(currIndices);
continue;
}
// 2. Target dimension is static. The product of dimensions of the expanded
// shape should match the collapsed dimension shape.
int64_t prodOfCollapsedDims = 1;
bool reachedTargetDimSize = false;
while (sourceDim < numSourceDims) {
// Source shape cannot be dynamic if the target dim is static.
if (sourceShape[sourceDim] == ShapedType::kDynamic)
return std::nullopt;
prodOfCollapsedDims *= sourceShape[sourceDim];
currIndices.push_back(sourceDim++);
if (prodOfCollapsedDims > currTargetShape)
break;
else if (prodOfCollapsedDims == currTargetShape) {
currIndices.push_back(sourceDim++);
reachedTargetDimSize = true;
break;
} else // prodOfCollapsedDims < currTargetShape
currIndices.push_back(sourceDim++);
}

// If the current expanded dimension is dynamic, then the collapsed
// dimensions should also be dynamic and product of all previous unprocessed
// dimensions of the expanded shape should be 1.
if (sourceShape[sourceDim] == ShapedType::kDynamic &&
(currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
if (!reachedTargetDimSize)
return std::nullopt;

// If the collapsed dim is dynamic, the current expanded dim should also
// be dynamic.
if (currTargetShape == ShapedType::kDynamic &&
sourceShape[sourceDim] != ShapedType::kDynamic)
return std::nullopt;

// For static shapes, if the product of dimensions of the expanded shape
// should match the collapsed dimension shape.
if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
return std::nullopt;

currIndices.push_back(sourceDim++);
reassociationMap.emplace_back(ReassociationIndices{});
std::swap(reassociationMap.back(), currIndices);
prodOfCollapsedDims = 1;
reassociationMap.push_back(currIndices);
}
// All the dimensions in the target must have been processed.
if (reassociationMap.size() != targetShape.size())
return std::nullopt;
// Process any remaining entries in the source shape. They all need to be
// 1 or dynamic.
for (; sourceDim < sourceShape.size(); sourceDim++) {
if (sourceShape[sourceDim] != ShapedType::kDynamic &&
// Now that we've mapped all the target dimensions, process any remaining
// entries in the source shape explicitly. Either the last target dimension
// is dynamic, or all remaining source entries need to be 1 or dynamic. Same
// applies when target shape is empty (can be the case for subshape
// reassociations).
for (; sourceDim < numSourceDims; sourceDim++) {
if ((targetShape.empty() || targetShape.back() != ShapedType::kDynamic) &&
sourceShape[sourceDim] != ShapedType::kDynamic &&
sourceShape[sourceDim] != 1)
return std::nullopt;
// The map is empty when the target type is a scalar.
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> {
// -----

// CHECK-LABEL: func.func @unpack_dynamic
// CHECK-NOT: tensor.collapse
// CHECK: linalg.unpack
// CHECK: tensor.collapse
// CHECK-NOT: linalg.unpack
func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> {
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
Expand Down
24 changes: 20 additions & 4 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1068,28 +1068,44 @@ func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf3

// -----

func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
func.func @fold_expand_of_collapse_mixed_subshape(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
-> tensor<?x4x?xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
: tensor<?x4x?xf32> into tensor<?x?xf32>
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
: tensor<?x?xf32> into tensor<?x4x?xf32>
return %1 : tensor<?x4x?xf32>
}
// CHECK-LABEL: @fold_expand_of_collapse_dynamic
// CHECK-LABEL: @fold_expand_of_collapse_mixed_subshape
// CHECK-NOT: tensor.{{.*}}_shape

// -----

func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
func.func @fold_expand_of_collapse_mixed_target_subshape(%arg0 : tensor<?x4x?x2xf32>, %arg1: index, %arg2: index)
-> tensor<?x4x?xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]]
: tensor<?x4x?x2xf32> into tensor<?x?xf32>
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
: tensor<?x?xf32> into tensor<?x4x?xf32>
return %1 : tensor<?x4x?xf32>
}
// CHECK-LABEL: @fold_expand_of_collapse_mixed_target_subshape
// CHECK-NOT: tensor.expand_shape
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0 {{\[}}[0], [1], [2, 3]]
// CHECK-SAME: : tensor<?x4x?x2xf32> into tensor<?x4x?xf32>
// CHECK-NEXT: return %[[COLLAPSE]]

// -----

func.func @no_fold_expand_of_collapse_fully_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
-> tensor<?x?x?xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
: tensor<?x?x?xf32> into tensor<?x?xf32>
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
: tensor<?x?xf32> into tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
}
// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic
// CHECK-LABEL: @no_fold_expand_of_collapse_fully_dynamic
// CHECK: tensor.collapse_shape
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape
// CHECK: return %[[EXPAND]]
Expand Down