Skip to content

[mlir] Canonicalization pattern for 'shape.shape_of' #98531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions mlir/lib/Dialect/Shape/IR/Shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1702,18 +1702,36 @@ struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
}
};

struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
// Canonicalize
//
// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
//
// to
//
// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// %1 = %shape
//
struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;

LogicalResult matchAndRewrite(shape::ShapeOfOp op,
PatternRewriter &rewriter) const override {
if (!llvm::isa<ShapedType>(op.getArg().getType()))
return failure();
if (llvm::isa<ShapedType>(op.getType()))
return failure();

rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(),
op.getArg());
auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>();
if (!tensorReshapeOp)
return rewriter.notifyMatchFailure(op, "producer is not tensor.reshape");
if (!isa<TensorType>(op.getType()))
return rewriter.notifyMatchFailure(op, "result is not a tensor");

// Operand 'shape' of 'tensor.reshape' may now be used as the result of
// 'shape.shape_of'. While its type is guaranteed to be compatible in well-
// formed IR, it may not be identical (dynamically vs statically shaped),
// in which case it needs to be cast first.
Value shape = tensorReshapeOp.getShape();
if (op.getType() != shape.getType())
shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(), shape);

rewriter.replaceOp(op, shape);
return success();
}
};
Expand Down Expand Up @@ -1753,7 +1771,7 @@ struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {

void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
patterns.add<ShapeOfCastExtentTensor, ShapeOfFromReshape,
ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
context);
}
Expand Down
14 changes: 13 additions & 1 deletion mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1585,13 +1585,25 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
getResult().getType()))
return reshapedSource;

// If the producer of operand 'source' is another 'tensor.reshape' op, use the
// producer's input instead as the original tensor to reshape. This could
// render such producer dead code.
if (auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
getSourceMutable().assign(reshapeOpProducer.getSource());
return getResult();
}

auto source = getSource();
auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
auto resultTy = dyn_cast<RankedTensorType>(getType());

if (!sourceTy || !resultTy || sourceTy != resultTy)
return {};

// If the source and result are both 1D tensors and have the same type, the
// reshape has no effect, even if the tensor is dynamically shaped.
if (sourceTy.getRank() == 1)
return source;

if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
auto elements = fromElements.getElements();
bool dynamicNoop =
Expand Down
39 changes: 39 additions & 0 deletions mlir/test/Dialect/Shape/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1361,6 +1361,45 @@ func.func @broadcast_as_from_extent_tensor(%a : tensor<?xindex>) -> !shape.shape

// -----

// CHECK-LABEL: func @shape_of_from_reshape
// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>
func.func @shape_of_from_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<?xindex> {
// CHECK: return %[[SHAPE]] : tensor<?xindex>
%0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
%1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
return %1 : tensor<?xindex>
}

// -----

// CHECK-LABEL: func @shape_of_from_reshape_compatible_types
// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
// CHECK-SAME: %[[SHAPE:.*]]: tensor<5xindex>
func.func @shape_of_from_reshape_compatible_types(%arg0: tensor<*xf32>, %arg1: tensor<5xindex>) -> tensor<?xindex> {
// CHECK: %[[CAST_SHAPE:.*]] = tensor.cast %[[SHAPE]] : tensor<5xindex> to tensor<?xindex>
// CHECK: return %[[CAST_SHAPE]] : tensor<?xindex>
%0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<5xindex>) -> tensor<*xf32>
%1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
return %1 : tensor<?xindex>
}

// -----

// CHECK-LABEL: func @shape_of_from_reshape_nofold
// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>
func.func @shape_of_from_reshape_nofold(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> !shape.shape {
// CHECK: %[[RESHAPED:.*]] = tensor.reshape %[[INPUT]](%[[SHAPE]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: %[[SHAPE_OF:.*]] = shape.shape_of %[[RESHAPED]] : tensor<*xf32> -> !shape.shape
// CHECK: return %[[SHAPE_OF]] : !shape.shape
%0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
%1 = shape.shape_of %0 : tensor<*xf32> -> !shape.shape
return %1 : !shape.shape
}

// -----

// CHECK-LABEL: @cast_extent_tensor
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<?xindex>
func.func @cast_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<?xindex> {
Expand Down
90 changes: 90 additions & 0 deletions mlir/test/Dialect/Shape/unranked-tensor-lowering.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// RUN: mlir-opt -split-input-file -canonicalize -cse %s | FileCheck %s

// This test verifies the simplification of IR patterns that emerge when
// lowering high-level element-wise ops with unranked tensor inputs. Consider
// the following function incrementing and doubling the value of an input
// unranked tensor using ops in a hypothetical high-level dialect called 'hl':
//
// func.func @f(%input: tensor<*xf32>) -> tensor<*xf32> {
// %0 = hl.inc %input : tensor<*xf32>
// %1 = hl.double %0 : tensor<*xf32>
// return %1 : tensor<*xf32>
// }
//
// A possible strategy to lower 'hl.inc' consists in reshaping its operand into
// a 1D tensor, creating a 1D tensor splat with the same total size as the input
// operand and with value 1.0, adding both 1D tensors using 'arith.addf', and
// reshaping the result back into the original input shape. A similar process
// applies for 'hl.double', except with a tensor splat with value 2.0 and an
// 'arith.mulf' op. The body of the function in the test below contains the full
// sequence.
//
// Since such lowering process would operate on individual 'hl' ops in a
// context-oblivious manner, the emitted code produces a redundant IR pattern
// where the result of 'arith.addf' is reshaped into an unranked tensor, just
// for it to be immediately reshaped back into the 1D tensor consumed by
// 'arith.mulf'. This entails the overhead of re-computing the unranked tensor
// shape ('shape.shape_of') and size ('shape.num_elements').
//
// This test verifies that the consecutive application of a canonicalization and
// a CSE pass successfully simplifies this emerging pattern, leading to a
// version of the code in which the result of the emitted 'arith.addf' op
// associated with 'hl.inc' is directly consumed by the 'arith.mulf' op
// associated with 'hl.double', as observed in the FileCheck directives. The
// main rewrite patterns at play are 'shape.shape_of' canonicalization,
// 'tensor.reshape' canonicalization, and 'shape.num_elements' subexpression
// elimination.
//

// CHECK-LABEL: @unranked_tensor_lowering
// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>

// CHECK-DAG: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-DAG: %[[TWO:.*]] = arith.constant 2.000000e+00 : f32

// CHECK: %[[INPUT_SHAPE:.*]] = shape.shape_of %[[INPUT]] : tensor<*xf32> -> tensor<?xindex>
// CHECK: %[[INPUT_SIZE:.*]] = shape.num_elements %[[INPUT_SHAPE]] : tensor<?xindex> -> index
// CHECK: %[[INPUT_COLLAPSED_SHAPE:.*]] = tensor.from_elements %[[INPUT_SIZE]] : tensor<1xindex>
// CHECK: %[[INPUT_COLLAPSED:.*]] = tensor.reshape %[[INPUT]](%[[INPUT_COLLAPSED_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>

// CHECK: %[[ONE_SPLAT:.*]] = tensor.splat %[[ONE]]{{\[}}%[[INPUT_SIZE]]] : tensor<?xf32>
// CHECK: %[[SUM_COLLAPSED:.*]] = arith.addf %[[INPUT_COLLAPSED]], %[[ONE_SPLAT]] : tensor<?xf32>

// CHECK: %[[TWO_SPLAT:.*]] = tensor.splat %[[TWO]]{{\[}}%[[INPUT_SIZE]]] : tensor<?xf32>
// CHECK: %[[PRODUCT_COLLAPSED:.*]] = arith.mulf %[[SUM_COLLAPSED]], %[[TWO_SPLAT]] : tensor<?xf32>

// CHECK: %[[PRODUCT:.*]] = tensor.reshape %[[PRODUCT_COLLAPSED]](%[[INPUT_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: return %[[PRODUCT]] : tensor<*xf32>

func.func @unranked_tensor_lowering(%input: tensor<*xf32>) -> tensor<*xf32> {

// Collapse input
%input_shape = shape.shape_of %input : tensor<*xf32> -> tensor<?xindex>
%input_size = shape.num_elements %input_shape : tensor<?xindex> -> index
%input_collapsed_shape = tensor.from_elements %input_size : tensor<1xindex>
%input_collapsed = tensor.reshape %input(%input_collapsed_shape) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>

// Second operand for sum
%one = arith.constant 1.0 : f32
%one_splat = tensor.splat %one[%input_size] : tensor<?xf32>

// Compute sum and expand it
%sum_collapsed = arith.addf %input_collapsed, %one_splat : tensor<?xf32>
%sum = tensor.reshape %sum_collapsed(%input_shape) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>

// Collapse sum
%sum_shape = shape.shape_of %sum : tensor<*xf32> -> tensor<?xindex>
%sum_size = shape.num_elements %sum_shape : tensor<?xindex> -> index
%sum_collapsed_shape = tensor.from_elements %sum_size : tensor<1xindex>
%sum_collapsed_0 = tensor.reshape %sum(%sum_collapsed_shape) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>

// Second operand for product
%two = arith.constant 2.0 : f32
%two_splat = tensor.splat %two[%sum_size] : tensor<?xf32>

// Compute product and expand it
%product_collapsed = arith.mulf %sum_collapsed_0, %two_splat : tensor<?xf32>
%product = tensor.reshape %product_collapsed(%sum_shape) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>

return %product : tensor<*xf32>
}
27 changes: 27 additions & 0 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,33 @@ func.func @fold_reshape_constant_splat(%shape : tensor<1xi32>) -> tensor<4xf32>

// -----

// CHECK-LABEL: func @fold_reshape_chain
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<*xf32>
// CHECK-SAME: %[[SHAPE_0:[a-zA-Z0-9_]+]]: tensor<?xindex>
// CHECK-SAME: %[[SHAPE_1:[a-zA-Z0-9_]+]]: tensor<?xindex>
// CHECK-SAME: %[[SHAPE_2:[a-zA-Z0-9_]+]]: tensor<?xindex>
// CHECK: %[[RESULT:.*]] = tensor.reshape %[[INPUT]](%[[SHAPE_2]])
// CHECK: return %[[RESULT]]
func.func @fold_reshape_chain(%input: tensor<*xf32>, %shape_0: tensor<?xindex>, %shape_1: tensor<?xindex>, %shape_2: tensor<?xindex>) -> tensor<*xf32> {
%0 = tensor.reshape %input(%shape_0) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
%1 = tensor.reshape %0(%shape_1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
%2 = tensor.reshape %1(%shape_2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
return %2 : tensor<*xf32>
}

// -----

// CHECK-LABEL: func @fold_reshape_1d
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<?xf32>
// CHECK-SAME: %[[SHAPE:[a-zA-Z0-9_]+]]: tensor<1xindex>
// CHECK: return %[[INPUT]]
func.func @fold_reshape_1d(%input: tensor<?xf32>, %shape: tensor<1xindex>) -> tensor<?xf32> {
%0 = tensor.reshape %input(%shape) : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}

// -----

// CHECK-LABEL: func @fold_extract_constant_splat
// CHECK-NOT: tensor.extract_slice
// CHECK: arith.constant dense<42> : tensor<4x4xi32>
Expand Down
Loading