-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir] Add bubbling patterns for non intersecting reshapes #94637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This is based on #94631. Please only review the last commit. |
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: None (Max191) ChangesThis PR adds fusion by expansion patterns to push a tensor.expand_shape up through a tensor.collapse_shape with non-intersecting reassociations. Sometimes parallel collapse_shape ops like this can block propagation of expand_shape ops, so this allows them to pass through each other. Full diff: https://github.com/llvm/llvm-project/pull/94637.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index e8f6edc3f133e..96f0f7bf1aa49 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -85,21 +85,51 @@ bool isReassociationValid(ArrayRef<AffineMap> reassociation,
template <typename ReshapeOpTy, typename InverseReshapeOpTy>
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
ArrayRef<Attribute> operands) {
-
+ // Fold identity reshape.
if (reshapeOp.getSrcType() == reshapeOp.getType())
return reshapeOp.getSrc();
- // Fold producer-consumer reshape ops where the operand type of the
- // producer is same as the return type of the consumer.
- auto reshapeSrcOp =
- reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
- if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
- return reshapeSrcOp.getSrc();
-
// Reshape of a constant can be replaced with a new constant.
if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
+ // Fold if the producer reshape source has the same shape with at most 1
+ // dynamic dimension.
+ auto reshapeSrcOp =
+ reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
+ if (!reshapeSrcOp)
+ return nullptr;
+ auto srcType = reshapeSrcOp.getSrcType();
+ auto resultType = reshapeOp.getResultType();
+ if (srcType != resultType)
+ return nullptr;
+
+ if (llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) {
+ return reshapeSrcOp.getSrc();
+ }
+
+ // Fold producer-consumer reshape ops when they are perfect inverses of each
+ // other:
+ // 1) Reassociation indices are equivalent.
+ // 2) Boundary types are equivalent.
+ // 3) No reassociations have more than 1 dynamic dimension, and reassociated
+ // shapes are equal for each reassociation.
+ auto reassociations = reshapeOp.getReassociationIndices();
+ if (reassociations != reshapeSrcOp.getReassociationIndices())
+ return nullptr;
+ // If the reshapes are expanding and then collapsing, the ops can be folded
+ // despite multiple dynamic dimensions.
+ if (srcType.getRank() < reshapeSrcOp.getResultType().getRank())
+ return reshapeSrcOp.getSrc();
+ ArrayRef<int64_t> expandedSrcShape = srcType.getShape();
+ ArrayRef<int64_t> expandedResultShape = resultType.getShape();
+ if (llvm::all_of(reassociations, [&](auto reInd) {
+ ArrayRef<int64_t> srcSlice =
+ expandedSrcShape.slice(reInd.front(), reInd.size());
+ return llvm::count_if(srcSlice, ShapedType::isDynamic) < 2;
+ })) {
+ return reshapeSrcOp.getSrc();
+ }
return nullptr;
}
@@ -360,10 +390,12 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
resultShape.slice(resultIndices.front(), resultIndices.size());
if (srcSubShape.size() == resultSubShape.size()) {
- if (srcSubShape == resultSubShape)
+ if (srcSubShape == resultSubShape &&
+ llvm::count_if(srcSubShape, ShapedType::isDynamic) < 2) {
composedReassociation.push_back(srcIndices);
- else
+ } else {
return std::nullopt;
+ }
}
// Find reassociation to collapse `srcSubShape` into `resultSubShape`.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index ad313c2d5ce60..579116904aad2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1023,6 +1023,76 @@ struct FoldReshapeWithGenericOpByExpansion
private:
ControlFusionFn controlFoldingReshapes;
};
+
+/// Pattern to bubble up a tensor.expand_shape op through a producer
+/// tensor.collapse_shape op that has non intersecting reassociations.
+struct BubbleUpExpandThroughParallelCollapse
+ : public OpRewritePattern<tensor::ExpandShapeOp> {
+ using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+ PatternRewriter &rewriter) const override {
+ auto collapseOp =
+ expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
+ if (!collapseOp || !collapseOp->hasOneUse())
+ return failure();
+ auto expandReInds = expandOp.getReassociationIndices();
+ auto collapseReInds = collapseOp.getReassociationIndices();
+
+ // Reshapes are parallel to each other if none of the reassociation indices
+ // have greater than 1 index for both reshapes.
+ for (auto [expandReassociation, collapseReassociation] :
+ llvm::zip_equal(expandReInds, collapseReInds)) {
+ if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
+ return failure();
+ }
+
+ // Compute new reassociation indices and expanded/collaped shapes.
+ SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
+ Location loc = expandOp->getLoc();
+ SmallVector<OpFoldResult> collapseSizes =
+ tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
+ SmallVector<OpFoldResult> expandSizes(getMixedValues(
+ expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
+ SmallVector<OpFoldResult> newExpandSizes;
+ int64_t index = 0, expandIndex = 0, collapseIndex = 0;
+ for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) {
+ if (collapseReassociation.size() != 1) {
+ ReassociationIndices newCollapseReassociation;
+ for (size_t i = 0; i < collapseReassociation.size(); ++i) {
+ newCollapseReassociation.push_back(index);
+ newExpandReInds.push_back({index++});
+ newExpandSizes.push_back(collapseSizes[collapseIndex++]);
+ }
+ newCollapseReInds.push_back(newCollapseReassociation);
+ expandIndex++;
+ continue;
+ }
+ ReassociationIndices newExpandReassociation;
+ auto expandReassociation = expandReInds[idx];
+ for (size_t i = 0; i < expandReassociation.size(); ++i) {
+ newExpandReassociation.push_back(index);
+ newCollapseReInds.push_back({index++});
+ newExpandSizes.push_back(expandSizes[expandIndex++]);
+ }
+ newExpandReInds.push_back(newExpandReassociation);
+ collapseIndex++;
+ }
+
+ // Swap reshape order.
+ SmallVector<Value> dynamicSizes;
+ SmallVector<int64_t> staticSizes;
+ dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
+ auto expandResultType = expandOp.getResultType().clone(staticSizes);
+ auto newExpand = rewriter.create<tensor::ExpandShapeOp>(
+ loc, expandResultType, collapseOp.getSrc(), newExpandReInds,
+ newExpandSizes);
+ rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+ expandOp, newExpand.getResult(), newCollapseReInds);
+ return success();
+ }
+};
+
} // namespace
//===---------------------------------------------------------------------===//
@@ -1939,6 +2009,7 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
controlFoldingReshapes);
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
+ patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
}
void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index f42666f81bbad..1354b138983a0 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -826,3 +826,37 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
// CHECK-SAME: [0, 1], [2, 3]
// CHECK-SAME: tensor<?x7x?x8xf32> into tensor<?x?xf32>
// CHECK: return %[[T4]]
+
+// -----
+
+func.func @bubble_parallel_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
+ %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
+ %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
+ output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+ return %expand : tensor<?x?x?x?xf32>
+}
+// CHECK: func @bubble_parallel_reshapes
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
+// CHECK-DAG: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]]
+// CHECK-SAME: output_shape [%[[S0]], %[[DIM1]], %[[DIM2]], %[[S2]], %[[S3]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?xf32>
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]] {{\[}}[0], [1, 2], [3], [4]] : tensor<?x?x?x?x?xf32> into tensor<?x?x?x?xf32>
+// CHECK: return %[[COLLAPSE]]
+
+// -----
+
+func.func @no_bubble_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
+ %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
+ %expand = tensor.expand_shape %collapse [[0], [1, 2], [3]]
+ output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+ return %expand : tensor<?x?x?x?xf32>
+}
+// CHECK: func @no_bubble_intersecting_reshapes
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3]]
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0], [1, 2], [3]]
+// CHECK: return %[[EXPAND]]
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index f7fbd3834288b..9a6b03986ccb6 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1139,7 +1139,7 @@ func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32>
return %1 : tensor<12x4xf32>
}
// CHECK-LABEL: @fold_collapse_of_expand
-// CHECK-NOT: linalg.{{.*}}shape
+// CHECK-NOT: tensor.{{.*}}_shape
// -----
@@ -1152,7 +1152,75 @@ func.func @fold_collapse_of_expand_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: @fold_collapse_of_expand_dynamic
-// CHECK-NOT: linalg.{{.*}}_shape
+// CHECK-NOT: tensor.{{.*}}_shape
+
+// -----
+
+func.func @fold_collapse_of_expand_fully_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
+ -> tensor<?x?xf32> {
+ %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
+ : tensor<?x?xf32> into tensor<?x?x?xf32>
+ %1 = tensor.collapse_shape %0 [[0, 1], [2]]
+ : tensor<?x?x?xf32> into tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: @fold_collapse_of_expand_fully_dynamic
+// CHECK-NOT: tensor.{{.*}}_shape
+
+// -----
+
+func.func @no_fold_parallel_collapse_of_expand_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index, %arg4: index)
+ -> tensor<?x?x?xf32> {
+ %0 = tensor.expand_shape %arg0 [[0, 1], [2], [3]] output_shape [%arg1, %arg2, %arg3, %arg4]
+ : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+ %1 = tensor.collapse_shape %0 [[0], [1], [2, 3]]
+ : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
+ return %1 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: @no_fold_parallel_collapse_of_expand_dynamic
+// CHECK: tensor.expand_shape
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape
+// CHECK: return %[[COLLAPSE]]
+
+// -----
+
+func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf32> {
+ %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
+ : tensor<3x4x4xf32> into tensor<12x4xf32>
+ %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [3, 4, 4]
+ : tensor<12x4xf32> into tensor<3x4x4xf32>
+ return %1 : tensor<3x4x4xf32>
+}
+// CHECK-LABEL: @fold_expand_of_collapse
+// CHECK-NOT: tensor.{{.*}}_shape
+
+// -----
+
+func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
+ -> tensor<?x4x?xf32> {
+ %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
+ : tensor<?x4x?xf32> into tensor<?x?xf32>
+ %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
+ : tensor<?x?xf32> into tensor<?x4x?xf32>
+ return %1 : tensor<?x4x?xf32>
+}
+// CHECK-LABEL: @fold_expand_of_collapse_dynamic
+// CHECK-NOT: tensor.{{.*}}_shape
+
+// -----
+
+func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
+ -> tensor<?x?x?xf32> {
+ %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
+ : tensor<?x?x?xf32> into tensor<?x?xf32>
+ %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
+ : tensor<?x?xf32> into tensor<?x?x?xf32>
+ return %1 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic
+// CHECK: tensor.collapse_shape
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape
+// CHECK: return %[[EXPAND]]
// -----
|
The reverse of this pattern (collapse_shape up through expand_shape) should also be implemented, but I'd rather leave that as another PR later. |
8b5a6be
to
c96c4ad
Compare
rebased now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I think I understand the code and this makes sense. Can you add a test for partially intersecting as well?
|
||
// Reshapes are parallel to each other if none of the reassociation indices | ||
// have greater than 1 index for both reshapes. | ||
for (auto [expandReassociation, collapseReassociation] : |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Such reshapes should just be folded away.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment means that there are no reassociations where the size is greater than 1 for both the expand and collapse at the same time. There could be cases where only one of the collapse or expand shape have size > 1, which would be parallel reshapes, but not identity reshapes. I can update the comment to be more clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These patterns dont have anything to do with Linalg
ops. Can we mvoe this to TensorDialect
. Probably need a populate*
method there that you can include in this file.
Refactored @Max191's PR #94637 to move it to `Tensor` From the original PR >This PR adds fusion by expansion patterns to push a tensor.expand_shape up through a tensor.collapse_shape with non-intersecting reassociations. Sometimes parallel collapse_shape ops like this can block propagation of expand_shape ops, so this allows them to pass through each other. I'm not sure if I put the code/tests in the right places, so let me know where those go if they aren't. cc @MaheshRavishankar @hanhanW --------- Co-authored-by: Max Dawkins <[email protected]>
Refactored @Max191's PR llvm#94637 to move it to `Tensor` From the original PR >This PR adds fusion by expansion patterns to push a tensor.expand_shape up through a tensor.collapse_shape with non-intersecting reassociations. Sometimes parallel collapse_shape ops like this can block propagation of expand_shape ops, so this allows them to pass through each other. I'm not sure if I put the code/tests in the right places, so let me know where those go if they aren't. cc @MaheshRavishankar @hanhanW --------- Co-authored-by: Max Dawkins <[email protected]>
This PR adds fusion by expansion patterns to push a tensor.expand_shape up through a tensor.collapse_shape with non-intersecting reassociations. Sometimes parallel collapse_shape ops like this can block propagation of expand_shape ops, so this allows them to pass through each other.