Skip to content

Commit 0af448b

Browse files
author
Menooker
authored
[MLIR][Bufferization] BufferResultsToOutParams: Add an option to eliminate AllocOp and avoid Copy (#90011)
Add an option hoist-static-allocs to remove the unnecessary memref.alloc and memref.copy after this pass, when the memref in ReturnOp is allocated by memref.alloc and is statically shaped. Instead, it replaces the uses of the allocated memref with the memref in the out argument. By default, BufferResultsToOutParams will result in a memcpy operation to copy the originally returned memref to the output argument memref. This is inefficient when the source of memcpy (the returned memref in the original ReturnOp) is from a local AllocOp. The pass can use the output argument memref to replace the locally allocated memref for better performance.hoist-static-allocs avoids dynamic allocation and memory movement. This option will be critical for performance-sensivtive applications, which require BufferResultsToOutParams pass for a caller-owned output buffer calling convension.
1 parent bb01b89 commit 0af448b

File tree

4 files changed

+65
-6
lines changed

4 files changed

+65
-6
lines changed

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,10 @@ struct BufferResultsToOutParamsOpts {
166166
/// If true, the pass adds a "bufferize.result" attribute to each output
167167
/// parameter.
168168
bool addResultAttribute = false;
169+
170+
/// If true, the pass eliminates the memref.alloc and memcpy if the returned
171+
/// memref is allocated in the current function.
172+
bool hoistStaticAllocs = false;
169173
};
170174

171175
/// Creates a pass that converts memref function results to out-params.

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,11 +315,20 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
315315
The main issue with this pass (and the out-param calling convention) is that
316316
buffers for results need to be allocated in the caller. This currently only
317317
works for static shaped memrefs.
318+
319+
If the hoist-static-allocs option is on, the pass tries to eliminate the
320+
allocation for the returned memref and avoid the memory-copy if possible.
321+
This optimization applies on the returned memref which has static shape and
322+
is allocated by memref.alloc in the function. It will use the memref given
323+
in function argument to replace the allocated memref.
318324
}];
319325
let options = [
320326
Option<"addResultAttribute", "add-result-attr", "bool",
321327
/*default=*/"false",
322328
"Add the attribute 'bufferize.result' to all output parameters.">,
329+
Option<"hoistStaticAllocs", "hoist-static-allocs",
330+
"bool", /*default=*/"false",
331+
"Hoist static allocations to call sites.">,
323332
];
324333
let constructor = "mlir::bufferization::createBufferResultsToOutParamsPass()";
325334
let dependentDialects = ["memref::MemRefDialect"];

mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ updateFuncOp(func::FuncOp func,
107107
// the given out-params.
108108
static LogicalResult updateReturnOps(func::FuncOp func,
109109
ArrayRef<BlockArgument> appendedEntryArgs,
110-
MemCpyFn memCpyFn) {
110+
MemCpyFn memCpyFn,
111+
bool hoistStaticAllocs) {
111112
auto res = func.walk([&](func::ReturnOp op) {
112113
SmallVector<Value, 6> copyIntoOutParams;
113114
SmallVector<Value, 6> keepAsReturnOperands;
@@ -118,10 +119,15 @@ static LogicalResult updateReturnOps(func::FuncOp func,
118119
keepAsReturnOperands.push_back(operand);
119120
}
120121
OpBuilder builder(op);
121-
for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
122-
if (failed(
123-
memCpyFn(builder, op.getLoc(), std::get<0>(t), std::get<1>(t))))
124-
return WalkResult::interrupt();
122+
for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
123+
if (hoistStaticAllocs && isa<memref::AllocOp>(orig.getDefiningOp()) &&
124+
orig.getType().cast<MemRefType>().hasStaticShape()) {
125+
orig.replaceAllUsesWith(arg);
126+
orig.getDefiningOp()->erase();
127+
} else {
128+
if (failed(memCpyFn(builder, op.getLoc(), orig, arg)))
129+
return WalkResult::interrupt();
130+
}
125131
}
126132
builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
127133
op.erase();
@@ -212,7 +218,8 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
212218
return success();
213219
};
214220
if (failed(updateReturnOps(func, appendedEntryArgs,
215-
options.memCpyFn.value_or(defaultMemCpyFn)))) {
221+
options.memCpyFn.value_or(defaultMemCpyFn),
222+
options.hoistStaticAllocs))) {
216223
return failure();
217224
}
218225
}
@@ -233,6 +240,8 @@ struct BufferResultsToOutParamsPass
233240
// Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
234241
if (addResultAttribute)
235242
options.addResultAttribute = true;
243+
if (hoistStaticAllocs)
244+
options.hoistStaticAllocs = true;
236245

237246
if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
238247
options)))
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// RUN: mlir-opt -allow-unregistered-dialect -p 'builtin.module(buffer-results-to-out-params{hoist-static-allocs})' %s | FileCheck %s
2+
3+
// CHECK-LABEL: func @basic(
4+
// CHECK-SAME: %[[ARG:.*]]: memref<8x64xf32>) {
5+
// CHECK-NOT: memref.alloc()
6+
// CHECK: "test.source"(%[[ARG]]) : (memref<8x64xf32>) -> ()
7+
// CHECK: return
8+
// CHECK: }
9+
func.func @basic() -> (memref<8x64xf32>) {
10+
%b = memref.alloc() : memref<8x64xf32>
11+
"test.source"(%b) : (memref<8x64xf32>) -> ()
12+
return %b : memref<8x64xf32>
13+
}
14+
15+
// CHECK-LABEL: func @basic_no_change(
16+
// CHECK-SAME: %[[ARG:.*]]: memref<f32>) {
17+
// CHECK: %[[RESULT:.*]] = "test.source"() : () -> memref<f32>
18+
// CHECK: memref.copy %[[RESULT]], %[[ARG]] : memref<f32> to memref<f32>
19+
// CHECK: return
20+
// CHECK: }
21+
func.func @basic_no_change() -> (memref<f32>) {
22+
%0 = "test.source"() : () -> (memref<f32>)
23+
return %0 : memref<f32>
24+
}
25+
26+
// CHECK-LABEL: func @basic_dynamic(
27+
// CHECK-SAME: %[[D:.*]]: index, %[[ARG:.*]]: memref<?xf32>) {
28+
// CHECK: %[[RESULT:.*]] = memref.alloc(%[[D]]) : memref<?xf32>
29+
// CHECK: "test.source"(%[[RESULT]]) : (memref<?xf32>) -> ()
30+
// CHECK: memref.copy %[[RESULT]], %[[ARG]]
31+
// CHECK: return
32+
// CHECK: }
33+
func.func @basic_dynamic(%d: index) -> (memref<?xf32>) {
34+
%b = memref.alloc(%d) : memref<?xf32>
35+
"test.source"(%b) : (memref<?xf32>) -> ()
36+
return %b : memref<?xf32>
37+
}

0 commit comments

Comments
 (0)