Skip to content

Commit 6b65d79

Browse files
authored
[mlir][linalg] Fix for invalid IR in eliminate_empty_tensors (#73513)
The transform.structured.eliminate_empty_tensors can produce mis-typed IR when traversing use-def chains past tensor reshaping operations for sharing candidates. This results in Linalg operations whose output types do not match their 'outs' arguments. This patch filters out candidate tensor.empty operations when their types do not match the candidate input operand.
1 parent f33245a commit 6b65d79

File tree

2 files changed

+90
-1
lines changed

2 files changed

+90
-1
lines changed

mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,10 @@ LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep(
6060
config.alwaysIncludeLeaves = false;
6161
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
6262
in->get(), /*condition=*/
63-
[&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
63+
[&](Value val) {
64+
return val.getDefiningOp<tensor::EmptyOp>() &&
65+
val.getType() == in->get().getType();
66+
},
6467
config);
6568
if (emptyTensors.empty())
6669
continue;

mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,89 @@ module attributes {transform.with_named_sequence} {
4242
transform.yield
4343
}
4444
}
45+
46+
// -----
47+
48+
#map = affine_map<(d0) -> (d0)>
49+
50+
// This test is intended to check that the produced IR does not contain any
51+
// type errors from sharing empty tensor operations with different types.
52+
// The verifiers are sufficient to lock down the intended behavior.
53+
54+
// CHECK-LABEL: func.func @collapse_shape_prevents_reuse(
55+
func.func @collapse_shape_prevents_reuse(%fill_value: f32) -> tensor<56xf32>
56+
{
57+
%init0 = tensor.empty() : tensor<56xf32>
58+
%init1 = tensor.empty() : tensor<56x1xf32>
59+
60+
%filled_tensor = linalg.fill
61+
ins(%fill_value : f32)
62+
outs(%init1 : tensor<56x1xf32>) -> tensor<56x1xf32>
63+
64+
// The collapse shape alters the tensor rank, so the %init1 tensor.empty cannot be
65+
// pushed into the output of the linalg.generic.
66+
%reshaped_tensor = tensor.collapse_shape %filled_tensor [[0, 1]]
67+
: tensor<56x1xf32> into tensor<56xf32>
68+
69+
%bias = linalg.generic {
70+
indexing_maps = [#map, #map],
71+
iterator_types = ["parallel"]
72+
} ins(%reshaped_tensor : tensor<56xf32>)
73+
outs(%init0 : tensor<56xf32>) {
74+
^bb0(%in: f32, %out: f32):
75+
linalg.yield %in : f32
76+
} -> tensor<56xf32>
77+
78+
return %bias : tensor<56xf32>
79+
}
80+
81+
module attributes {transform.with_named_sequence} {
82+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
83+
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
84+
transform.structured.eliminate_empty_tensors %0 : !transform.any_op
85+
transform.yield
86+
}
87+
}
88+
89+
// -----
90+
91+
#map = affine_map<(d0, d1) -> (d0, d1)>
92+
93+
// This test is intended to check that the produced IR does not contain any
94+
// type errors from sharing empty tensor operations with different types.
95+
// The verifiers are sufficient to lock down the intended behavior.
96+
97+
// CHECK-LABEL: func.func @collapse_cast_prevents_reuse(
98+
func.func @collapse_cast_prevents_reuse(%fill_value: f32) -> tensor<56x?xf32>
99+
{
100+
%c1 = arith.constant 1 : index
101+
%init0 = tensor.empty(%c1) : tensor<56x?xf32>
102+
%init1 = tensor.empty() : tensor<56x1xf32>
103+
104+
%filled_tensor = linalg.fill
105+
ins(%fill_value : f32)
106+
outs(%init1 : tensor<56x1xf32>) -> tensor<56x1xf32>
107+
108+
// The cast alters the number of dynamic dims, so the %init1 tensor.empty cannot be
109+
// pushed into the output of the linalg.generic.
110+
%cast = tensor.cast %filled_tensor : tensor<56x1xf32> to tensor<56x?xf32>
111+
112+
%bias = linalg.generic {
113+
indexing_maps = [#map, #map],
114+
iterator_types = ["parallel", "parallel"]
115+
} ins(%cast : tensor<56x?xf32>)
116+
outs(%init0 : tensor<56x?xf32>) {
117+
^bb0(%in: f32, %out: f32):
118+
linalg.yield %in : f32
119+
} -> tensor<56x?xf32>
120+
121+
return %bias : tensor<56x?xf32>
122+
}
123+
124+
module attributes {transform.with_named_sequence} {
125+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
126+
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
127+
transform.structured.eliminate_empty_tensors %0 : !transform.any_op
128+
transform.yield
129+
}
130+
}

0 commit comments

Comments
 (0)