Skip to content

Commit df9ed9c

Browse files
authored
[mlir][transform] Fix failure in flattening already flattened linalg ops (#86037)
The previous implementation was doing an early successful return on `rank <= 1` without adding the original op to transform results. This resulted in errors about number of returns. This patch fixes this by adding the original op to results. Additionally, we first check if op is elementwise and return a slienceable failure early if not.
1 parent 733640d commit df9ed9c

File tree

2 files changed

+31
-5
lines changed

2 files changed

+31
-5
lines changed

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3269,15 +3269,20 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
32693269
transform::ApplyToEachResultList &results,
32703270
transform::TransformState &state) {
32713271
rewriter.setInsertionPoint(target);
3272-
if (target.getNumLoops() <= 1)
3272+
if (!isElementwise(target)) {
3273+
failed(rewriter.notifyMatchFailure(
3274+
target, "only elementwise flattening is supported"));
3275+
return emitDefaultSilenceableFailure(target);
3276+
}
3277+
// If rank <= 1, do nothing
3278+
if (target.getNumLoops() <= 1) {
3279+
results.push_back(target);
32733280
return DiagnosedSilenceableFailure::success();
3281+
}
32743282
ReassociationIndices reassociation(target.getNumLoops());
32753283
std::iota(reassociation.begin(), reassociation.end(), 0);
32763284
auto maybeFlattened =
3277-
(isElementwise(target))
3278-
? collapseOpIterationDims(target, reassociation, rewriter)
3279-
: FailureOr<CollapseResult>(rewriter.notifyMatchFailure(
3280-
target, "only elementwise flattening is supported"));
3285+
collapseOpIterationDims(target, reassociation, rewriter);
32813286
if (failed(maybeFlattened))
32823287
return emitDefaultSilenceableFailure(target);
32833288
results.push_back(maybeFlattened->collapsedOp);

mlir/test/Dialect/Linalg/flatten-elementwise.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,27 @@ module attributes {transform.with_named_sequence} {
6767

6868
// -----
6969

70+
// CHECK-LABEL: func.func @map_already_flat(
71+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32xf32>
72+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<32xf32>
73+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<32xf32>
74+
// CHECK-NEXT: linalg.map { arith.addf } ins(%[[ARG0]], %[[ARG1]] : memref<32xf32>, memref<32xf32>) outs(%[[ARG2]] : memref<32xf32>)
75+
func.func @map_already_flat(%arg0: memref<32xf32>, %arg1: memref<32xf32>, %arg2: memref<32xf32>) {
76+
linalg.map {arith.addf} ins(%arg0, %arg1: memref<32xf32>, memref<32xf32>) outs(%arg2: memref<32xf32>)
77+
return
78+
}
79+
80+
module attributes {transform.with_named_sequence} {
81+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
82+
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
83+
%flattened = transform.structured.flatten_elementwise %0
84+
: (!transform.any_op) -> !transform.any_op
85+
transform.yield
86+
}
87+
}
88+
89+
// -----
90+
7091
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
7192
// CHECK-LABEL: func.func @generic
7293
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32>

0 commit comments

Comments
 (0)