Skip to content

[mlir] Canonicalization pattern for 'shape.shape_of' #98531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 19, 2024

Conversation

rafaelubalmw
Copy link
Contributor

@rafaelubalmw rafaelubalmw commented Jul 11, 2024

This PR includes 3 new canonicalization patterns:

  • Operation shape.shape_of: shape of reshape
// Before
func.func @f(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<?xindex> {
  %reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
  %0 = shape.shape_of %reshape : tensor<*xf32> -> tensor<?xindex>
  return %0 : tensor<?xindex>
}

// After
func.func @f(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<?xindex> {
  return %arg1 : tensor<?xindex>
}
  • Operation tensor.reshape: reshape of reshape
// Before
func.func @fold_tensor_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: tensor<?xindex>) -> tensor<*xf32> {
  %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
  %1 = tensor.reshape %0(%arg2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
  return %1 : tensor<*xf32>
}

// After
func.func @fold_tensor_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: tensor<?xindex>) -> tensor<*xf32> {
  %reshape = tensor.reshape %arg0(%arg2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
  return %reshape : tensor<*xf32>
}
  • Operation tensor.reshape: reshape 1D to 1D
// Before
func.func @fold_reshape_1d(%input: tensor<?xf32>, %shape: tensor<1xindex>) -> tensor<?xf32> {
  %0 = tensor.reshape %input(%shape) : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
  return %0 : tensor<?xf32>
}

// After
func.func @fold_reshape_1d(%arg0: tensor<?xf32>, %arg1: tensor<1xindex>) -> tensor<?xf32> {
  return %arg0 : tensor<?xf32>
}

These three canonicalization patterns cooperate to simplify the IR structure emerging from the lowering of certain element-wise ops with unranked tensor inputs. See file unranked-tensor-lowering.mlir in the proposed change list for a detailed example and description.

For context, this PR is meant to enable code optimizations for the code generated while lowering ops quant.qcast and quant.dcast with unranked tensors, as proposed in https://discourse.llvm.org/t/rfc-improvements-in-the-quant-dialect/79942 (implementation currently in progress).

@llvmbot
Copy link
Member

llvmbot commented Jul 11, 2024

@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-shape

Author: Rafael Ubal (rafaelubalmw)

Changes

The proposed canonicalization pattern converts

func.func @<!-- -->f(%arg0: tensor&lt;*xf32&gt;, %arg1: tensor&lt;?xindex&gt;) -&gt; tensor&lt;?xindex&gt; {
  %reshape = tensor.reshape %arg0(%arg1) : (tensor&lt;*xf32&gt;, tensor&lt;?xindex&gt;) -&gt; tensor&lt;*xf32&gt;
  %0 = shape.shape_of %reshape : tensor&lt;*xf32&gt; -&gt; tensor&lt;?xindex&gt;
  return %0 : tensor&lt;?xindex&gt;
}

to

func.func @<!-- -->f(%arg0: tensor&lt;*xf32&gt;, %arg1: tensor&lt;?xindex&gt;) -&gt; tensor&lt;?xindex&gt; {
  return %arg1 : tensor&lt;?xindex&gt;
}

When lowering element-wise ops with unranked tensor operands, it may be necessary to reshape inputs into a 1D tensor. The following op pattern emerges:

%unranked_shape = shape.shape_of %unranked_input
%ranked_shape = shape.num_elements %unranked_shape
%ranked_input = tensor.reshape %input, %ranked_shape

%ranked_result = ... %ranked_input ...

%unranked_result = tensor.reshape %ranked_result, %unranked_shape

When 2 consecutive element-wise operations op1 and op2 with unranked inputs are lowered into such a pattern, the proposed canonicalization pattern fuses the last tensor.reshape from op1 with the first shape.shape_of from op2. CSE may then fuse both occurrences of shape.num_elements from op1 and op2.


Full diff: https://github.com/llvm/llvm-project/pull/98531.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Shape/IR/Shape.cpp (+16-6)
  • (modified) mlir/test/Dialect/Shape/canonicalize.mlir (+26)
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 58c3f4c334577..639bd7851c35d 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1702,18 +1702,28 @@ struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
   }
 };
 
-struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
+// Canonicalize
+//
+// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+// %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
+//
+// to
+//
+// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+// %1 = %shape
+//
+struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
                                 PatternRewriter &rewriter) const override {
-    if (!llvm::isa<ShapedType>(op.getArg().getType()))
+    auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>();
+    if (!tensorReshapeOp)
       return failure();
-    if (llvm::isa<ShapedType>(op.getType()))
+    if (op.getType() != tensorReshapeOp.getShape().getType())
       return failure();
 
-    rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(),
-                                                  op.getArg());
+    rewriter.replaceOp(op, tensorReshapeOp.getShape());
     return success();
   }
 };
@@ -1753,7 +1763,7 @@ struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
 
 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                             MLIRContext *context) {
-  patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
+  patterns.add<ShapeOfCastExtentTensor, ShapeOfFromReshape,
                ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
       context);
 }
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 40b137f1fa36e..a17a7d1499935 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1361,6 +1361,32 @@ func.func @broadcast_as_from_extent_tensor(%a : tensor<?xindex>) -> !shape.shape
 
 // -----
 
+// CHECK-LABEL: func @shape_of_from_reshape
+// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>
+func.func @shape_of_from_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<?xindex> {
+  // CHECK: return %[[SHAPE]] : tensor<?xindex>
+  %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+  %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
+  return %1 : tensor<?xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func @shape_of_from_reshape_nofold
+// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>
+func.func @shape_of_from_reshape_nofold(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> !shape.shape {
+  // CHECK: %[[RESHAPED:.*]] = tensor.reshape %[[INPUT]](%[[SHAPE]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+  // CHECK: %[[SHAPE_OF:.*]] = shape.shape_of %[[RESHAPED]] : tensor<*xf32> -> !shape.shape
+  // CHECK: return %[[SHAPE_OF]] : !shape.shape
+  %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+  %1 = shape.shape_of %0 : tensor<*xf32> -> !shape.shape
+  return %1 : !shape.shape
+}
+
+// -----
+
 // CHECK-LABEL: @cast_extent_tensor
 // CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<?xindex>
 func.func @cast_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<?xindex> {

Copy link

github-actions bot commented Jul 11, 2024

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff d31603eefc2d8becfd1f41327b6a8db3e0e91a27 0e26420d3a21ad4b68db609d54d164457b293080 --extensions cpp -- mlir/lib/Dialect/Shape/IR/Shape.cpp mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
View the diff from clang-format here.
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 8eb8e57995..1a51ff8022 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1704,13 +1704,13 @@ struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
 
 // Canonicalize
 //
-// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
-// %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
+// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) ->
+// tensor<*xf32> %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
 //
 // to
 //
-// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
-// %1 = %shape
+// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) ->
+// tensor<*xf32> %1 = %shape
 //
 struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;

@sjarus
Copy link
Contributor

sjarus commented Jul 12, 2024

Would it make sense to also add a LIT test that validates the canonicalization behavior you describe, i.e. :

When 2 consecutive element-wise operations op1 and op2 with unranked inputs are lowered into such a pattern, the proposed canonicalization pattern fuses the last tensor.reshape from op1 with the first shape.shape_of from op2. CSE may then fuse both occurrences of shape.num_elements from op1 and op2.

It ought to serve as a guard against ineffective canonicalizations and also offer a descriptive use case within the test suite.

Copy link
Contributor

@sjarus sjarus left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment.

Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was confused about the pattern you replaced yes, I think it was due to not having type inference defined. So it did a little local type inference by canonicalizing type (same ops remain).

…Added new comprehensive test 'unranked-tensor-lowering.mlir'
@rafaelubalmw
Copy link
Contributor Author

@sjarus - This was actually great advice, Suraj. In the process of creating the unit test you suggest, I discovered additional simplification opportunities through the introduction of 2 additional folding mechanisms for tensor.reshape, which I've added to this PR - see the updated PR description for details. The first detects sequences of tensor.reshape ops and forces all consumers to use the original tensor. The second folds away a tensor.reshape op with both 1D source and result, even with dynamic shape.

I created a new test file called unranked-tensor-lowering.mlir that runs the specific pass sequence -canonicalize -cse. A detailed comment in the file provides the necessary context to understand the purpose of that test. My only concern is that this test stresses a combination of rewrite patterns that are not exclusively specific to the shape dialect. Maybe it should be moved elsewhere?

@rafaelubalmw
Copy link
Contributor Author

I was confused about the pattern you replaced yes, I think it was due to not having type inference defined. So it did a little local type inference by canonicalizing type (same ops remain).

@jpienaar Thanks for the comment, Jacques. Just to be clear, would you like me to remove pattern ShapeOfWithTensor as I propose, or should I revert my change and keep the pattern? If I keep it, there should be a unit test stressing it, whose structure I'm unable to envision at the moment. Any idea of what that test would look like?

@rafaelubalmw rafaelubalmw merged commit 38d0b2d into llvm:main Jul 19, 2024
3 of 7 checks passed
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
…shape (#98531)

This PR includes 3 new canonicalization patterns:

- Operation `shape.shape_of`: shape of reshape

```
// Before
func.func @f(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<?xindex> {
  %reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
  %0 = shape.shape_of %reshape : tensor<*xf32> -> tensor<?xindex>
  return %0 : tensor<?xindex>
}

// After
func.func @f(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<?xindex> {
  return %arg1 : tensor<?xindex>
}
```

- Operation `tensor.reshape`: reshape of reshape

```
// Before
func.func @fold_tensor_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: tensor<?xindex>) -> tensor<*xf32> {
  %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
  %1 = tensor.reshape %0(%arg2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
  return %1 : tensor<*xf32>
}

// After
func.func @fold_tensor_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: tensor<?xindex>) -> tensor<*xf32> {
  %reshape = tensor.reshape %arg0(%arg2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
  return %reshape : tensor<*xf32>
}
```

- Operation `tensor.reshape`: reshape 1D to 1D

```
// Before
func.func @fold_reshape_1d(%input: tensor<?xf32>, %shape: tensor<1xindex>) -> tensor<?xf32> {
  %0 = tensor.reshape %input(%shape) : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
  return %0 : tensor<?xf32>
}

// After
func.func @fold_reshape_1d(%arg0: tensor<?xf32>, %arg1: tensor<1xindex>) -> tensor<?xf32> {
  return %arg0 : tensor<?xf32>
}
```

These three canonicalization patterns cooperate to simplify the IR
structure emerging from the lowering of certain element-wise ops with
unranked tensor inputs. See file `unranked-tensor-lowering.mlir` in the
proposed change list for a detailed example and description.

For context, this PR is meant to enable code optimizations for the code
generated while lowering ops `quant.qcast` and `quant.dcast` with
unranked tensors, as proposed in
https://discourse.llvm.org/t/rfc-improvements-in-the-quant-dialect/79942
(implementation currently in progress).
joker-eph added a commit that referenced this pull request Apr 4, 2025
This PR will fix a bug in a canonicalization pattern (operation
shape.shape_of: shape of reshape)

```
// Before
func.func @f(%arg0: tensor<?x1xf32>, %arg1: tensor<3xi32>) -> tensor<3xindex> {
  %reshape = tensor.reshape %arg0(%arg1) : (tensor<?x1xf32>, tensor<3xi32>) -> tensor<?x1x1xf32>
  %0 = shape.shape_of %reshape : tensor<?x1x1xf32> -> tensor<3xindex>
  return %0 : tensor<3xindex>
}
//This is will error out as follows:
error: 'tensor.cast' op operand type 'tensor<3xi32>' and result type 'tensor<3xindex>' are cast incompatible
  %0 = shape.shape_of %reshape : tensor<?x1x1xf32> -> tensor<3xindex>
       ^
note: see current operation: %0 = "tensor.cast"(%arg1) : (tensor<3xi32>) -> tensor<3xindex>
```

```
// After
func.func @f(%arg0: tensor<?x1xf32>, %arg1: tensor<3xi32>) -> tensor<3xindex> {
  %0 = arith.index_cast %arg1 : tensor<3xi32> to tensor<3xindex>
  return %0 : tensor<3xindex>
}
```
See file canonicalize.mlir in the change list for an example.

For the context, this bug was found while running a test on Keras 3, the
canonicalizer errors out due to an invalid tensor.cast operation when
the batch size is dynamic.
The operands of the op are tensor<3xi32> cast to tensor<3xindex>.
This change is related to a previous PR:
#98531

---------

Co-authored-by: Alaa Ali <[email protected]>
Co-authored-by: Mehdi Amini <[email protected]>
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Apr 4, 2025
…#134234)

This PR will fix a bug in a canonicalization pattern (operation
shape.shape_of: shape of reshape)

```
// Before
func.func @f(%arg0: tensor<?x1xf32>, %arg1: tensor<3xi32>) -> tensor<3xindex> {
  %reshape = tensor.reshape %arg0(%arg1) : (tensor<?x1xf32>, tensor<3xi32>) -> tensor<?x1x1xf32>
  %0 = shape.shape_of %reshape : tensor<?x1x1xf32> -> tensor<3xindex>
  return %0 : tensor<3xindex>
}
//This is will error out as follows:
error: 'tensor.cast' op operand type 'tensor<3xi32>' and result type 'tensor<3xindex>' are cast incompatible
  %0 = shape.shape_of %reshape : tensor<?x1x1xf32> -> tensor<3xindex>
       ^
note: see current operation: %0 = "tensor.cast"(%arg1) : (tensor<3xi32>) -> tensor<3xindex>
```

```
// After
func.func @f(%arg0: tensor<?x1xf32>, %arg1: tensor<3xi32>) -> tensor<3xindex> {
  %0 = arith.index_cast %arg1 : tensor<3xi32> to tensor<3xindex>
  return %0 : tensor<3xindex>
}
```
See file canonicalize.mlir in the change list for an example.

For the context, this bug was found while running a test on Keras 3, the
canonicalizer errors out due to an invalid tensor.cast operation when
the batch size is dynamic.
The operands of the op are tensor<3xi32> cast to tensor<3xindex>.
This change is related to a previous PR:
llvm/llvm-project#98531

---------

Co-authored-by: Alaa Ali <[email protected]>
Co-authored-by: Mehdi Amini <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants