Skip to content

Commit c005df3

Browse files
authored
[mlir][linalg] Fix EraseIdentityLinalgOp on fill-like ops (#130000)
Adds a check to make sure that the linalg op is safe to erase by ensuring that the `linalg.yield` is yielding one of the linalg op's block args. This check already exists for linalg ops with pure tensor semantics. Closes #129414 --------- Signed-off-by: Ian Wood <[email protected]>
1 parent 4e9794f commit c005df3

File tree

2 files changed

+51
-11
lines changed

2 files changed

+51
-11
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,8 +1278,9 @@ LogicalResult GenericOp::verify() { return success(); }
12781278

12791279
namespace {
12801280

1281-
/// Remove any linalg operation (on tensors) that are just copying
1282-
/// the values from inputs to the results. Requirements are
1281+
/// Remove linalg operations that are just copying the values from inputs to
1282+
/// results. In the memref case, the operation must be copying to and from the
1283+
/// same value. Requirements are:
12831284
/// 1) All iterator types are parallel
12841285
/// 2) The body contains just a yield operation with the yielded values being
12851286
/// the arguments corresponding to the operands.
@@ -1304,18 +1305,27 @@ struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
13041305

13051306
// In the buffer case, we need to check exact buffer equality.
13061307
if (linalgOp.hasPureBufferSemantics()) {
1307-
if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
1308-
linalgOp.getDpsInputOperand(0)->get() ==
1308+
if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1309+
linalgOp.getDpsInputOperand(0)->get() !=
13091310
linalgOp.getDpsInitOperand(0)->get()) {
1310-
rewriter.eraseOp(linalgOp);
1311-
return success();
1311+
return rewriter.notifyMatchFailure(
1312+
linalgOp, "expected single input and output to be the same value");
13121313
}
1313-
return failure();
1314+
1315+
auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
1316+
if (!yieldArg || yieldArg.getOwner() != &body) {
1317+
return rewriter.notifyMatchFailure(linalgOp,
1318+
"cannot fold fill-like op");
1319+
}
1320+
1321+
rewriter.eraseOp(linalgOp);
1322+
return success();
13141323
}
13151324

1316-
// Mixed semantics is not supported yet.
1317-
if (!linalgOp.hasPureTensorSemantics())
1318-
return failure();
1325+
if (!linalgOp.hasPureTensorSemantics()) {
1326+
return rewriter.notifyMatchFailure(
1327+
linalgOp, "mixed semantics is not supported yet");
1328+
}
13191329

13201330
// Get the argument number of the returned values. That is the operand
13211331
// number to use for replacing uses of this operation.

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ func.func @dynamic_fill_pack(%arg0: tensor<?x?xf32>) -> tensor<?x?x16x16xf32> {
495495

496496
// -----
497497

498-
// CHECK: func @fold_self_copy
498+
// CHECK-LABEL: func @fold_self_copy
499499
func.func @fold_self_copy(%0 : memref<4x16xf32>) {
500500
// CHECK-NEXT: return
501501
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
@@ -511,6 +511,36 @@ func.func @fold_self_copy(%0 : memref<4x16xf32>) {
511511

512512
// -----
513513

514+
// CHECK-LABEL: func @no_fold_fill_like_memref
515+
// CHECK-NEXT: linalg.generic
516+
func.func @no_fold_fill_like_memref(%in_out : memref<4x16xf32>, %fill_val : f32) {
517+
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
518+
affine_map<(d0, d1) -> (d0, d1)>],
519+
iterator_types = ["parallel", "parallel"]}
520+
ins(%in_out : memref<4x16xf32>)
521+
outs(%in_out : memref<4x16xf32>) {
522+
^bb0(%arg0: f32, %arg1: f32):
523+
linalg.yield %fill_val : f32
524+
}
525+
return
526+
}
527+
528+
// -----
529+
530+
// CHECK-LABEL: func @no_fold_fill_like_tensor
531+
// CHECK-NEXT: linalg.generic
532+
func.func @no_fold_fill_like_tensor(%in_out : tensor<4x16xf32>, %fill_val : f32) -> tensor<4x16xf32> {
533+
%result = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
534+
affine_map<(d0, d1) -> (d0, d1)>],
535+
iterator_types = ["parallel", "parallel"]}
536+
ins(%in_out : tensor<4x16xf32>)
537+
outs(%in_out : tensor<4x16xf32>) {
538+
^bb0(%arg0: f32, %arg1: f32):
539+
linalg.yield %fill_val : f32
540+
} -> tensor<4x16xf32>
541+
return %result : tensor<4x16xf32>
542+
}
543+
514544
// CHECK-LABEL: func @fold_static_pad_fill
515545
// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
516546
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<412x276xf32>

0 commit comments

Comments
 (0)