-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][bufferization] Allow mixed static/dynamic shapes in materialize_in_destination
op
#92681
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][bufferization] Allow mixed static/dynamic shapes in materialize_in_destination
op
#92681
Conversation
…ze_in_destination` op This commit relaxes the verifier of `bufferization.materialize_in_destination` such that mixed static/dynamic dimensions are allowed for the source and destination operands. E.g., `tensor<5xf32>` and `tensor<?xf32>` are now compatible, but it is assumed that the dynamic dimension is `5` at runtime. This commit fixes #91265.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-bufferization Author: Matthias Springer (matthias-springer) ChangesThis commit relaxes the verifier of This commit fixes #91265. Full diff: https://github.com/llvm/llvm-project/pull/92681.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 4f609ddff9a41..1c70a4b8df925 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -217,8 +217,7 @@ def Bufferization_CloneOp : Bufferization_Op<"clone", [
def Bufferization_MaterializeInDestinationOp
: Bufferization_Op<"materialize_in_destination",
- [AllShapesMatch<["source", "dest"]>,
- AllElementTypesMatch<["source", "dest"]>,
+ [AllElementTypesMatch<["source", "dest"]>,
BufferizableOpInterface, DestinationStyleOpInterface,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<SubsetOpInterface,
@@ -239,9 +238,9 @@ def Bufferization_MaterializeInDestinationOp
memref, `source` materializes in `dest`, which is already a buffer. The op
has no results in that case.
- `source`, `dest` and `result` (if present) must have the same shape and
- element type. If the op has a result, the types of `result` and `dest` must
- match exactly (e.g., including any tensor encodings).
+ `source`, `dest` and `result` (if present) must have the same runtime shape
+ and element type. If the op has a result, the types of `result` and `dest`
+ must match exactly (e.g., including any tensor encodings).
By default, this op bufferizes to a memcpy from the future buffer of the
`source` tensor to the future buffer of the `dest` tensor or to the `dest`
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 0acb0c24ab313..3b7b412842bfb 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -686,6 +686,24 @@ LogicalResult MaterializeInDestinationOp::verify() {
if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
return emitOpError("'writable' must be specified if and only if the "
"destination is of memref type");
+ TensorType srcType = getSource().getType();
+ ShapedType destType = cast<ShapedType>(getDest().getType());
+ if (srcType.hasRank() != destType.hasRank())
+ return emitOpError("source/destination shapes are incompatible");
+ if (srcType.hasRank()) {
+ if (srcType.getRank() != destType.getRank())
+ return emitOpError("rank mismatch between source and destination shape");
+ for (auto [src, dest] :
+ llvm::zip(srcType.getShape(), destType.getShape())) {
+ if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) {
+ // Cannot verify dynamic dimension size. Assume that that they match at
+ // runtime.
+ continue;
+ }
+ if (src != dest)
+ return emitOpError("source/destination shapes are incompatible");
+ }
+ }
return success();
}
diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir
index 4ebdb0a8f0490..2c8807b66de74 100644
--- a/mlir/test/Dialect/Bufferization/invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/invalid.mlir
@@ -43,9 +43,16 @@ func.func @invalid_writable_on_op() {
// -----
-func.func @invalid_materialize_in_destination(%arg0: tensor<?xf32>, %arg1: tensor<5xf32>) {
- // expected-error @below{{failed to verify that all of {source, dest} have same shape}}
- bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<5xf32>) -> tensor<5xf32>
+func.func @invalid_materialize_in_destination(%arg0: tensor<4xf32>, %arg1: tensor<5xf32>) {
+ // expected-error @below{{source/destination shapes are incompatible}}
+ bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<4xf32>, tensor<5xf32>) -> tensor<5xf32>
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination(%arg0: tensor<5x5xf32>, %arg1: tensor<5xf32>) {
+ // expected-error @below{{rank mismatch between source and destination shape}}
+ bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<5x5xf32>, tensor<5xf32>) -> tensor<5xf32>
}
// -----
diff --git a/mlir/test/Dialect/Bufferization/ops.mlir b/mlir/test/Dialect/Bufferization/ops.mlir
index d4bda0632189d..ad4a66c1b7978 100644
--- a/mlir/test/Dialect/Bufferization/ops.mlir
+++ b/mlir/test/Dialect/Bufferization/ops.mlir
@@ -59,12 +59,15 @@ func.func @test_dealloc_tensor_op(%arg0: tensor<4xi32>) {
}
// CHECK-LABEL: func @test_materialize_in_destination_op
-func.func @test_materialize_in_destination_op(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: memref<?xf32, 3>)
- -> tensor<?xf32> {
+func.func @test_materialize_in_destination_op(
+ %arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: memref<?xf32, 3>,
+ %arg4: tensor<5xf32>) -> tensor<?xf32> {
// CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%1 = bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, memref<?xf32, 3>) -> ()
bufferization.materialize_in_destination %arg0 in restrict writable %arg2 : (tensor<?xf32>, memref<?xf32, 3>) -> ()
+ // CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, tensor<5xf32>) -> tensor<5xf32>
+ %2 = bufferization.materialize_in_destination %arg0 in %arg4 : (tensor<?xf32>, tensor<5xf32>) -> tensor<5xf32>
return %1 : tensor<?xf32>
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This commit relaxes the verifier of
bufferization.materialize_in_destination
such that mixed static/dynamic dimensions are allowed for the source and destination operands. E.g.,tensor<5xf32>
andtensor<?xf32>
are now compatible, but it is assumed that the dynamic dimension is5
at runtime.This commit fixes #91265.