Skip to content

[MLIR][Bufferization] BufferResultsToOutParams: Add an option to eliminate AllocOp and avoid Copy #90011

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ struct BufferResultsToOutParamsOpts {
/// If true, the pass adds a "bufferize.result" attribute to each output
/// parameter.
bool addResultAttribute = false;

/// If true, the pass eliminates the memref.alloc and memcpy if the returned
/// memref is allocated in the current function.
bool hoistStaticAllocs = false;
};

/// Creates a pass that converts memref function results to out-params.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,11 +315,20 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
The main issue with this pass (and the out-param calling convention) is that
buffers for results need to be allocated in the caller. This currently only
works for static shaped memrefs.

If the hoist-static-allocs option is on, the pass tries to eliminate the
allocation for the returned memref and avoid the memory-copy if possible.
This optimization applies on the returned memref which has static shape and
is allocated by memref.alloc in the function. It will use the memref given
in function argument to replace the allocated memref.
}];
let options = [
Option<"addResultAttribute", "add-result-attr", "bool",
/*default=*/"false",
"Add the attribute 'bufferize.result' to all output parameters.">,
Option<"hoistStaticAllocs", "hoist-static-allocs",
"bool", /*default=*/"false",
"Hoist static allocations to call sites.">,
];
let constructor = "mlir::bufferization::createBufferResultsToOutParamsPass()";
let dependentDialects = ["memref::MemRefDialect"];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ updateFuncOp(func::FuncOp func,
// the given out-params.
static LogicalResult updateReturnOps(func::FuncOp func,
ArrayRef<BlockArgument> appendedEntryArgs,
MemCpyFn memCpyFn) {
MemCpyFn memCpyFn,
bool hoistStaticAllocs) {
auto res = func.walk([&](func::ReturnOp op) {
SmallVector<Value, 6> copyIntoOutParams;
SmallVector<Value, 6> keepAsReturnOperands;
Expand All @@ -118,10 +119,15 @@ static LogicalResult updateReturnOps(func::FuncOp func,
keepAsReturnOperands.push_back(operand);
}
OpBuilder builder(op);
for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
if (failed(
memCpyFn(builder, op.getLoc(), std::get<0>(t), std::get<1>(t))))
return WalkResult::interrupt();
for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
if (hoistStaticAllocs && isa<memref::AllocOp>(orig.getDefiningOp()) &&
orig.getType().cast<MemRefType>().hasStaticShape()) {
orig.replaceAllUsesWith(arg);
orig.getDefiningOp()->erase();
} else {
if (failed(memCpyFn(builder, op.getLoc(), orig, arg)))
return WalkResult::interrupt();
}
}
builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
op.erase();
Expand Down Expand Up @@ -212,7 +218,8 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
return success();
};
if (failed(updateReturnOps(func, appendedEntryArgs,
options.memCpyFn.value_or(defaultMemCpyFn)))) {
options.memCpyFn.value_or(defaultMemCpyFn),
options.hoistStaticAllocs))) {
return failure();
}
}
Expand All @@ -233,6 +240,8 @@ struct BufferResultsToOutParamsPass
// Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
if (addResultAttribute)
options.addResultAttribute = true;
if (hoistStaticAllocs)
options.hoistStaticAllocs = true;

if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
options)))
Expand Down
37 changes: 37 additions & 0 deletions mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// RUN: mlir-opt -allow-unregistered-dialect -p 'builtin.module(buffer-results-to-out-params{hoist-static-allocs})' %s | FileCheck %s

// CHECK-LABEL: func @basic(
// CHECK-SAME: %[[ARG:.*]]: memref<8x64xf32>) {
// CHECK-NOT: memref.alloc()
// CHECK: "test.source"(%[[ARG]]) : (memref<8x64xf32>) -> ()
// CHECK: return
// CHECK: }
func.func @basic() -> (memref<8x64xf32>) {
%b = memref.alloc() : memref<8x64xf32>
"test.source"(%b) : (memref<8x64xf32>) -> ()
return %b : memref<8x64xf32>
}

// CHECK-LABEL: func @basic_no_change(
// CHECK-SAME: %[[ARG:.*]]: memref<f32>) {
// CHECK: %[[RESULT:.*]] = "test.source"() : () -> memref<f32>
// CHECK: memref.copy %[[RESULT]], %[[ARG]] : memref<f32> to memref<f32>
// CHECK: return
// CHECK: }
func.func @basic_no_change() -> (memref<f32>) {
%0 = "test.source"() : () -> (memref<f32>)
return %0 : memref<f32>
}

// CHECK-LABEL: func @basic_dynamic(
// CHECK-SAME: %[[D:.*]]: index, %[[ARG:.*]]: memref<?xf32>) {
// CHECK: %[[RESULT:.*]] = memref.alloc(%[[D]]) : memref<?xf32>
// CHECK: "test.source"(%[[RESULT]]) : (memref<?xf32>) -> ()
// CHECK: memref.copy %[[RESULT]], %[[ARG]]
// CHECK: return
// CHECK: }
func.func @basic_dynamic(%d: index) -> (memref<?xf32>) {
%b = memref.alloc(%d) : memref<?xf32>
"test.source"(%b) : (memref<?xf32>) -> ()
return %b : memref<?xf32>
}
Loading