-
Notifications
You must be signed in to change notification settings - Fork 13.7k
[mlir][linalg] Fix for invalid IR in eliminate_empty_tensors #73513
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-linalg Author: Spenser Bauman (sabauma) ChangesThe transform.structured.eliminate_empty_tensors can produce mis-typed IR when traversing use-def chains past tensor reshaping operations for sharing candidates. This results in Linalg operations whose output types do not match their 'outs' arguments. This patch filters out candidate tensor.empty operations when their types do not match the candidate input operand. Full diff: https://github.com/llvm/llvm-project/pull/73513.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
index 5a8320bdb287533..f28f8f0d34a4da5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
@@ -60,7 +60,10 @@ LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep(
config.alwaysIncludeLeaves = false;
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
in->get(), /*condition=*/
- [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
+ [&](Value val) {
+ return val.getDefiningOp<tensor::EmptyOp>() &&
+ val.getType() == in->get().getType();
+ },
config);
if (emptyTensors.empty())
continue;
diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir
index 0172760576efc51..7b575119c9cc44b 100644
--- a/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -42,3 +42,44 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+// This test is intended to check that the produced IR does not contain any
+// type errors from sharing empty tensor operations with different types.
+// The verifiers are sufficient to lock down the intended behavior.
+
+// CHECK-LABEL: func.func @collapse_shape_prevents_reuse(
+func.func @collapse_shape_prevents_reuse(%fill_value: f32) -> tensor<1x128x128x56xf32>
+{
+ %init0 = tensor.empty() : tensor<1x128x128x56xf32>
+ %init1 = tensor.empty() : tensor<1x128x128x56x1xf32>
+
+ %filled_tensor = linalg.fill
+ ins(%fill_value : f32)
+ outs(%init1 : tensor<1x128x128x56x1xf32>) -> tensor<1x128x128x56x1xf32>
+
+ %reshaped_tensor = tensor.collapse_shape %filled_tensor [[0], [1], [2], [3, 4]]
+ : tensor<1x128x128x56x1xf32> into tensor<1x128x128x56xf32>
+
+ %bias = linalg.generic {
+ indexing_maps = [#map, #map],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+ } ins(%reshaped_tensor : tensor<1x128x128x56xf32>)
+ outs(%init0 : tensor<1x128x128x56xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<1x128x128x56xf32>
+
+ return %bias : tensor<1x128x128x56xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.eliminate_empty_tensors %0 : !transform.any_op
+ transform.yield
+ }
+}
|
@llvm/pr-subscribers-mlir Author: Spenser Bauman (sabauma) ChangesThe transform.structured.eliminate_empty_tensors can produce mis-typed IR when traversing use-def chains past tensor reshaping operations for sharing candidates. This results in Linalg operations whose output types do not match their 'outs' arguments. This patch filters out candidate tensor.empty operations when their types do not match the candidate input operand. Full diff: https://github.com/llvm/llvm-project/pull/73513.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
index 5a8320bdb287533..f28f8f0d34a4da5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
@@ -60,7 +60,10 @@ LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep(
config.alwaysIncludeLeaves = false;
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
in->get(), /*condition=*/
- [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
+ [&](Value val) {
+ return val.getDefiningOp<tensor::EmptyOp>() &&
+ val.getType() == in->get().getType();
+ },
config);
if (emptyTensors.empty())
continue;
diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir
index 0172760576efc51..7b575119c9cc44b 100644
--- a/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -42,3 +42,44 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+// This test is intended to check that the produced IR does not contain any
+// type errors from sharing empty tensor operations with different types.
+// The verifiers are sufficient to lock down the intended behavior.
+
+// CHECK-LABEL: func.func @collapse_shape_prevents_reuse(
+func.func @collapse_shape_prevents_reuse(%fill_value: f32) -> tensor<1x128x128x56xf32>
+{
+ %init0 = tensor.empty() : tensor<1x128x128x56xf32>
+ %init1 = tensor.empty() : tensor<1x128x128x56x1xf32>
+
+ %filled_tensor = linalg.fill
+ ins(%fill_value : f32)
+ outs(%init1 : tensor<1x128x128x56x1xf32>) -> tensor<1x128x128x56x1xf32>
+
+ %reshaped_tensor = tensor.collapse_shape %filled_tensor [[0], [1], [2], [3, 4]]
+ : tensor<1x128x128x56x1xf32> into tensor<1x128x128x56xf32>
+
+ %bias = linalg.generic {
+ indexing_maps = [#map, #map],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+ } ins(%reshaped_tensor : tensor<1x128x128x56xf32>)
+ outs(%init0 : tensor<1x128x128x56xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<1x128x128x56xf32>
+
+ return %bias : tensor<1x128x128x56xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.eliminate_empty_tensors %0 : !transform.any_op
+ transform.yield
+ }
+}
|
Hi @matthias-springer, would you mind taking a look? |
The transform.structured.eliminate_empty_tensors can produce mis-typed IR when traversing use-def chains past tensor reshaping operations for sharing candidates. This results in Linalg operations whose output types do not match their 'outs' arguments. This patch filters out candidate tensor.empty operations when their types do not match the candidate input operand.
06d14a7
to
a9cff82
Compare
commit 6b65d79 Author: Spenser Bauman <[email protected]> Date: Mon Jan 1 12:12:40 2024 -0500 [mlir][linalg] Fix for invalid IR in eliminate_empty_tensors (llvm#73513)
The transform.structured.eliminate_empty_tensors can produce mis-typed IR when traversing use-def chains past tensor reshaping operations for sharing candidates. This results in Linalg operations whose output types do not match their 'outs' arguments.
This patch filters out candidate tensor.empty operations when their types do not match the candidate input operand.