Skip to content

Commit 9d4b20a

Browse files
[mlir][bufferization] Allow mixed static/dynamic shapes in materialize_in_destination op (#92681)
This commit relaxes the verifier of `bufferization.materialize_in_destination` such that mixed static/dynamic dimensions are allowed for the source and destination operands. E.g., `tensor<5xf32>` and `tensor<?xf32>` are now compatible, but it is assumed that the dynamic dimension is `5` at runtime. This commit fixes #91265.
1 parent cd676e5 commit 9d4b20a

File tree

4 files changed

+37
-10
lines changed

4 files changed

+37
-10
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,7 @@ def Bufferization_CloneOp : Bufferization_Op<"clone", [
217217

218218
def Bufferization_MaterializeInDestinationOp
219219
: Bufferization_Op<"materialize_in_destination",
220-
[AllShapesMatch<["source", "dest"]>,
221-
AllElementTypesMatch<["source", "dest"]>,
220+
[AllElementTypesMatch<["source", "dest"]>,
222221
BufferizableOpInterface, DestinationStyleOpInterface,
223222
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
224223
DeclareOpInterfaceMethods<SubsetOpInterface,
@@ -239,9 +238,9 @@ def Bufferization_MaterializeInDestinationOp
239238
memref, `source` materializes in `dest`, which is already a buffer. The op
240239
has no results in that case.
241240

242-
`source`, `dest` and `result` (if present) must have the same shape and
243-
element type. If the op has a result, the types of `result` and `dest` must
244-
match exactly (e.g., including any tensor encodings).
241+
`source`, `dest` and `result` (if present) must have the same runtime shape
242+
and element type. If the op has a result, the types of `result` and `dest`
243+
must match exactly (e.g., including any tensor encodings).
245244

246245
By default, this op bufferizes to a memcpy from the future buffer of the
247246
`source` tensor to the future buffer of the `dest` tensor or to the `dest`

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,24 @@ LogicalResult MaterializeInDestinationOp::verify() {
686686
if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
687687
return emitOpError("'writable' must be specified if and only if the "
688688
"destination is of memref type");
689+
TensorType srcType = getSource().getType();
690+
ShapedType destType = cast<ShapedType>(getDest().getType());
691+
if (srcType.hasRank() != destType.hasRank())
692+
return emitOpError("source/destination shapes are incompatible");
693+
if (srcType.hasRank()) {
694+
if (srcType.getRank() != destType.getRank())
695+
return emitOpError("rank mismatch between source and destination shape");
696+
for (auto [src, dest] :
697+
llvm::zip(srcType.getShape(), destType.getShape())) {
698+
if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) {
699+
// Cannot verify dynamic dimension size. Assume that that they match at
700+
// runtime.
701+
continue;
702+
}
703+
if (src != dest)
704+
return emitOpError("source/destination shapes are incompatible");
705+
}
706+
}
689707
return success();
690708
}
691709

mlir/test/Dialect/Bufferization/invalid.mlir

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,16 @@ func.func @invalid_writable_on_op() {
4343

4444
// -----
4545

46-
func.func @invalid_materialize_in_destination(%arg0: tensor<?xf32>, %arg1: tensor<5xf32>) {
47-
// expected-error @below{{failed to verify that all of {source, dest} have same shape}}
48-
bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<5xf32>) -> tensor<5xf32>
46+
func.func @invalid_materialize_in_destination(%arg0: tensor<4xf32>, %arg1: tensor<5xf32>) {
47+
// expected-error @below{{source/destination shapes are incompatible}}
48+
bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<4xf32>, tensor<5xf32>) -> tensor<5xf32>
49+
}
50+
51+
// -----
52+
53+
func.func @invalid_materialize_in_destination(%arg0: tensor<5x5xf32>, %arg1: tensor<5xf32>) {
54+
// expected-error @below{{rank mismatch between source and destination shape}}
55+
bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<5x5xf32>, tensor<5xf32>) -> tensor<5xf32>
4956
}
5057

5158
// -----

mlir/test/Dialect/Bufferization/ops.mlir

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,15 @@ func.func @test_dealloc_tensor_op(%arg0: tensor<4xi32>) {
5959
}
6060

6161
// CHECK-LABEL: func @test_materialize_in_destination_op
62-
func.func @test_materialize_in_destination_op(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: memref<?xf32, 3>)
63-
-> tensor<?xf32> {
62+
func.func @test_materialize_in_destination_op(
63+
%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: memref<?xf32, 3>,
64+
%arg4: tensor<5xf32>) -> tensor<?xf32> {
6465
// CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
6566
%1 = bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
6667
// CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, memref<?xf32, 3>) -> ()
6768
bufferization.materialize_in_destination %arg0 in restrict writable %arg2 : (tensor<?xf32>, memref<?xf32, 3>) -> ()
69+
// CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, tensor<5xf32>) -> tensor<5xf32>
70+
%2 = bufferization.materialize_in_destination %arg0 in %arg4 : (tensor<?xf32>, tensor<5xf32>) -> tensor<5xf32>
6871
return %1 : tensor<?xf32>
6972
}
7073

0 commit comments

Comments
 (0)