Skip to content

Commit 1d4ce57

Browse files
authored
[mlir][bufferization] skip empty tensor elimination if they have different element type (#96998)
In the origin implementation, the empty tensor elimination will add a `tensor.cast` and eliminate the tensor even if they have different element type(f32, bf16). Here add a check for element type and skip the elimination if they are different.
1 parent 9b94056 commit 1d4ce57

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
152152
if (emptyTensorOp == replacement.getDefiningOp())
153153
continue;
154154
if (replacement.getType() != v.getType()) {
155+
if (cast<ShapedType>(replacement.getType()).getElementType() !=
156+
cast<ShapedType>(v.getType()).getElementType())
157+
continue;
155158
rewriter.setInsertionPointAfterValue(replacement);
156159
replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
157160
replacement);

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,32 @@ func.func @buffer_forwarding_no_conflict(%arg0: tensor<?xf32> {bufferization.wri
4747
// CHECK-SAME: __equivalent_func_args__ = [0, 0]
4848
return %2, %2 : tensor<?xf32>, tensor<?xf32>
4949
}
50+
51+
// -----
52+
53+
// CHECK-LABEL: func @buffer_forwarding_conflict_with_different_element_type
54+
func.func @buffer_forwarding_conflict_with_different_element_type(%arg0: tensor<?xf32> {bufferization.writable = true}, %arg1: index) -> (tensor<?xf32>, tensor<?xf32>) {
55+
// CHECK: tensor.extract_slice
56+
// CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"]
57+
%cst = arith.constant 0.000000e+00 : f32
58+
%0 = tensor.empty(%arg1) : tensor<?xf32>
59+
60+
// CHECK: bufferization.alloc_tensor(%arg1)
61+
%1 = tensor.empty(%arg1) : tensor<?xbf16>
62+
63+
// CHECK: linalg.copy
64+
// CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]
65+
%2 = linalg.copy ins(%0 : tensor<?xf32>) outs(%1 : tensor<?xbf16>) -> tensor<?xbf16>
66+
67+
// CHECK: linalg.copy
68+
// CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]
69+
%3 = linalg.copy ins(%2 : tensor<?xbf16>) outs(%0 : tensor<?xf32>) -> tensor<?xf32>
70+
71+
// CHECK: tensor.insert_slice
72+
// CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "none"]
73+
%4 = tensor.insert_slice %3 into %arg0[42] [%arg1] [1] : tensor<?xf32> into tensor<?xf32>
74+
75+
// CHECK: return
76+
// CHECK-SAME: __equivalent_func_args__ = [0, 0]
77+
return %4, %4 : tensor<?xf32>, tensor<?xf32>
78+
}

0 commit comments

Comments
 (0)