Skip to content

Commit 5979e1d

Browse files
authored
[mlir] Fix empty-tensor-elimination around self-copies (#68129)
* Fixes #67977, a crash in `empty-tensor-elimination`. * Also improves `linalg.copy` canonicalization. * Also improves indentation indentation in `mlir-linalg-ods-yaml-gen.cpp`.
1 parent 3a35ca0 commit 5979e1d

File tree

4 files changed

+28
-14
lines changed

4 files changed

+28
-14
lines changed

mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
149149
op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
150150
if (!replacement)
151151
continue;
152+
if (emptyTensorOp == replacement.getDefiningOp())
153+
continue;
152154
if (replacement.getType() != v.getType()) {
153155
rewriter.setInsertionPointAfterValue(replacement);
154156
replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -545,16 +545,17 @@ class RegionBuilderHelper {
545545

546546
namespace {
547547

548-
struct EraseSelfCopyOnBuffers : OpRewritePattern<CopyOp> {
548+
struct EraseSelfCopy : OpRewritePattern<CopyOp> {
549549
using OpRewritePattern<CopyOp>::OpRewritePattern;
550550
LogicalResult matchAndRewrite(CopyOp copyOp,
551551
PatternRewriter &rewriter) const override {
552-
if (!copyOp.hasBufferSemantics())
553-
return rewriter.notifyMatchFailure(copyOp,
554-
"does not have buffer semantics");
555-
if (copyOp.getInputs().front() != copyOp.getOutputs().front())
552+
if (copyOp.getInputs() != copyOp.getOutputs())
556553
return rewriter.notifyMatchFailure(copyOp, "not a self copy");
557-
rewriter.eraseOp(copyOp);
554+
if (copyOp.hasBufferSemantics())
555+
rewriter.eraseOp(copyOp);
556+
else
557+
rewriter.replaceOp(copyOp, copyOp.getInputs());
558+
558559
return success();
559560
}
560561
};
@@ -563,7 +564,7 @@ struct EraseSelfCopyOnBuffers : OpRewritePattern<CopyOp> {
563564

564565
void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
565566
MLIRContext *context) {
566-
results.add<EraseSelfCopyOnBuffers>(context);
567+
results.add<EraseSelfCopy>(context);
567568
}
568569

569570
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,3 +317,14 @@ func.func @linalg_copy(%t: tensor<5xf32>, %f: f32) -> tensor<5xf32> {
317317
%1 = linalg.copy ins(%filled : tensor<5xf32>) outs(%t : tensor<5xf32>) -> tensor<5xf32>
318318
return %1 : tensor<5xf32>
319319
}
320+
321+
// -----
322+
323+
// CHECK-LABEL: func @linalg_copy_empty(
324+
// CHECK: %[[ret:.*]] = memref.alloc()
325+
// CHECK-NEXT: return %[[ret]]
326+
func.func @linalg_copy_empty() -> tensor<26xi32> {
327+
%0 = tensor.empty() : tensor<26xi32>
328+
%1 = linalg.copy ins(%0 : tensor<26xi32>) outs(%0 : tensor<26xi32>) -> tensor<26xi32>
329+
return %1 : tensor<26xi32>
330+
}

mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,13 +1029,13 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b,
10291029
// {1}: attribute name
10301030
// {2}: default type function name
10311031
static const char attrDef[] = R"FMT(
1032-
{0} {1}Val = {0}::{2};
1033-
auto {1}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{
1034-
return attr.getName() == "{1}"; });
1035-
if ({1}Iter != attrs.end()) {{
1036-
if (auto attr = llvm::dyn_cast<{0}Attr>({1}Iter->getValue()))
1037-
{1}Val = attr.getValue();
1038-
}
1032+
{0} {1}Val = {0}::{2};
1033+
auto {1}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{
1034+
return attr.getName() == "{1}"; });
1035+
if ({1}Iter != attrs.end()) {{
1036+
if (auto attr = llvm::dyn_cast<{0}Attr>({1}Iter->getValue()))
1037+
{1}Val = attr.getValue();
1038+
}
10391039
)FMT";
10401040
std::string enumName = convertOperandKindToEnumName(arg.kind);
10411041
attrs.push_back(

0 commit comments

Comments
 (0)