-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir] Canonicalization pattern for 'shape.shape_of' #98531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir-shape Author: Rafael Ubal (rafaelubalmw) ChangesThe proposed canonicalization pattern converts
to
When lowering element-wise ops with unranked tensor operands, it may be necessary to reshape inputs into a 1D tensor. The following op pattern emerges:
When 2 consecutive element-wise operations Full diff: https://github.com/llvm/llvm-project/pull/98531.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 58c3f4c334577..639bd7851c35d 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1702,18 +1702,28 @@ struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
}
};
-struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
+// Canonicalize
+//
+// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+// %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
+//
+// to
+//
+// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+// %1 = %shape
+//
+struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
LogicalResult matchAndRewrite(shape::ShapeOfOp op,
PatternRewriter &rewriter) const override {
- if (!llvm::isa<ShapedType>(op.getArg().getType()))
+ auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>();
+ if (!tensorReshapeOp)
return failure();
- if (llvm::isa<ShapedType>(op.getType()))
+ if (op.getType() != tensorReshapeOp.getShape().getType())
return failure();
- rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(),
- op.getArg());
+ rewriter.replaceOp(op, tensorReshapeOp.getShape());
return success();
}
};
@@ -1753,7 +1763,7 @@ struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
+ patterns.add<ShapeOfCastExtentTensor, ShapeOfFromReshape,
ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
context);
}
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 40b137f1fa36e..a17a7d1499935 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1361,6 +1361,32 @@ func.func @broadcast_as_from_extent_tensor(%a : tensor<?xindex>) -> !shape.shape
// -----
+// CHECK-LABEL: func @shape_of_from_reshape
+// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>
+func.func @shape_of_from_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<?xindex> {
+ // CHECK: return %[[SHAPE]] : tensor<?xindex>
+ %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+ %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
+ return %1 : tensor<?xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func @shape_of_from_reshape_nofold
+// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>
+func.func @shape_of_from_reshape_nofold(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> !shape.shape {
+ // CHECK: %[[RESHAPED:.*]] = tensor.reshape %[[INPUT]](%[[SHAPE]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+ // CHECK: %[[SHAPE_OF:.*]] = shape.shape_of %[[RESHAPED]] : tensor<*xf32> -> !shape.shape
+ // CHECK: return %[[SHAPE_OF]] : !shape.shape
+ %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+ %1 = shape.shape_of %0 : tensor<*xf32> -> !shape.shape
+ return %1 : !shape.shape
+}
+
+// -----
+
// CHECK-LABEL: @cast_extent_tensor
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<?xindex>
func.func @cast_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<?xindex> {
|
You can test this locally with the following command:git-clang-format --diff d31603eefc2d8becfd1f41327b6a8db3e0e91a27 0e26420d3a21ad4b68db609d54d164457b293080 --extensions cpp -- mlir/lib/Dialect/Shape/IR/Shape.cpp mlir/lib/Dialect/Tensor/IR/TensorOps.cpp View the diff from clang-format here.diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 8eb8e57995..1a51ff8022 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1704,13 +1704,13 @@ struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
// Canonicalize
//
-// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
-// %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
+// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) ->
+// tensor<*xf32> %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
//
// to
//
-// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
-// %1 = %shape
+// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) ->
+// tensor<*xf32> %1 = %shape
//
struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
|
Would it make sense to also add a LIT test that validates the canonicalization behavior you describe, i.e. :
It ought to serve as a guard against ineffective canonicalizations and also offer a descriptive use case within the test suite. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was confused about the pattern you replaced yes, I think it was due to not having type inference defined. So it did a little local type inference by canonicalizing type (same ops remain).
…Added new comprehensive test 'unranked-tensor-lowering.mlir'
@sjarus - This was actually great advice, Suraj. In the process of creating the unit test you suggest, I discovered additional simplification opportunities through the introduction of 2 additional folding mechanisms for I created a new test file called |
@jpienaar Thanks for the comment, Jacques. Just to be clear, would you like me to remove pattern |
…shape (#98531) This PR includes 3 new canonicalization patterns: - Operation `shape.shape_of`: shape of reshape ``` // Before func.func @f(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<?xindex> { %reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32> %0 = shape.shape_of %reshape : tensor<*xf32> -> tensor<?xindex> return %0 : tensor<?xindex> } // After func.func @f(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<?xindex> { return %arg1 : tensor<?xindex> } ``` - Operation `tensor.reshape`: reshape of reshape ``` // Before func.func @fold_tensor_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: tensor<?xindex>) -> tensor<*xf32> { %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32> %1 = tensor.reshape %0(%arg2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32> return %1 : tensor<*xf32> } // After func.func @fold_tensor_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: tensor<?xindex>) -> tensor<*xf32> { %reshape = tensor.reshape %arg0(%arg2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32> return %reshape : tensor<*xf32> } ``` - Operation `tensor.reshape`: reshape 1D to 1D ``` // Before func.func @fold_reshape_1d(%input: tensor<?xf32>, %shape: tensor<1xindex>) -> tensor<?xf32> { %0 = tensor.reshape %input(%shape) : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32> return %0 : tensor<?xf32> } // After func.func @fold_reshape_1d(%arg0: tensor<?xf32>, %arg1: tensor<1xindex>) -> tensor<?xf32> { return %arg0 : tensor<?xf32> } ``` These three canonicalization patterns cooperate to simplify the IR structure emerging from the lowering of certain element-wise ops with unranked tensor inputs. See file `unranked-tensor-lowering.mlir` in the proposed change list for a detailed example and description. For context, this PR is meant to enable code optimizations for the code generated while lowering ops `quant.qcast` and `quant.dcast` with unranked tensors, as proposed in https://discourse.llvm.org/t/rfc-improvements-in-the-quant-dialect/79942 (implementation currently in progress).
This PR will fix a bug in a canonicalization pattern (operation shape.shape_of: shape of reshape) ``` // Before func.func @f(%arg0: tensor<?x1xf32>, %arg1: tensor<3xi32>) -> tensor<3xindex> { %reshape = tensor.reshape %arg0(%arg1) : (tensor<?x1xf32>, tensor<3xi32>) -> tensor<?x1x1xf32> %0 = shape.shape_of %reshape : tensor<?x1x1xf32> -> tensor<3xindex> return %0 : tensor<3xindex> } //This is will error out as follows: error: 'tensor.cast' op operand type 'tensor<3xi32>' and result type 'tensor<3xindex>' are cast incompatible %0 = shape.shape_of %reshape : tensor<?x1x1xf32> -> tensor<3xindex> ^ note: see current operation: %0 = "tensor.cast"(%arg1) : (tensor<3xi32>) -> tensor<3xindex> ``` ``` // After func.func @f(%arg0: tensor<?x1xf32>, %arg1: tensor<3xi32>) -> tensor<3xindex> { %0 = arith.index_cast %arg1 : tensor<3xi32> to tensor<3xindex> return %0 : tensor<3xindex> } ``` See file canonicalize.mlir in the change list for an example. For the context, this bug was found while running a test on Keras 3, the canonicalizer errors out due to an invalid tensor.cast operation when the batch size is dynamic. The operands of the op are tensor<3xi32> cast to tensor<3xindex>. This change is related to a previous PR: #98531 --------- Co-authored-by: Alaa Ali <[email protected]> Co-authored-by: Mehdi Amini <[email protected]>
…#134234) This PR will fix a bug in a canonicalization pattern (operation shape.shape_of: shape of reshape) ``` // Before func.func @f(%arg0: tensor<?x1xf32>, %arg1: tensor<3xi32>) -> tensor<3xindex> { %reshape = tensor.reshape %arg0(%arg1) : (tensor<?x1xf32>, tensor<3xi32>) -> tensor<?x1x1xf32> %0 = shape.shape_of %reshape : tensor<?x1x1xf32> -> tensor<3xindex> return %0 : tensor<3xindex> } //This is will error out as follows: error: 'tensor.cast' op operand type 'tensor<3xi32>' and result type 'tensor<3xindex>' are cast incompatible %0 = shape.shape_of %reshape : tensor<?x1x1xf32> -> tensor<3xindex> ^ note: see current operation: %0 = "tensor.cast"(%arg1) : (tensor<3xi32>) -> tensor<3xindex> ``` ``` // After func.func @f(%arg0: tensor<?x1xf32>, %arg1: tensor<3xi32>) -> tensor<3xindex> { %0 = arith.index_cast %arg1 : tensor<3xi32> to tensor<3xindex> return %0 : tensor<3xindex> } ``` See file canonicalize.mlir in the change list for an example. For the context, this bug was found while running a test on Keras 3, the canonicalizer errors out due to an invalid tensor.cast operation when the batch size is dynamic. The operands of the op are tensor<3xi32> cast to tensor<3xindex>. This change is related to a previous PR: llvm/llvm-project#98531 --------- Co-authored-by: Alaa Ali <[email protected]> Co-authored-by: Mehdi Amini <[email protected]>
This PR includes 3 new canonicalization patterns:
shape.shape_of
: shape of reshapetensor.reshape
: reshape of reshapetensor.reshape
: reshape 1D to 1DThese three canonicalization patterns cooperate to simplify the IR structure emerging from the lowering of certain element-wise ops with unranked tensor inputs. See file
unranked-tensor-lowering.mlir
in the proposed change list for a detailed example and description.For context, this PR is meant to enable code optimizations for the code generated while lowering ops
quant.qcast
andquant.dcast
with unranked tensors, as proposed in https://discourse.llvm.org/t/rfc-improvements-in-the-quant-dialect/79942 (implementation currently in progress).