Skip to content

[MLIR][Bufferization] BufferResultsToOutParams: Add an option to eliminate AllocOp and avoid Copy #90011

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 8, 2024

Conversation

Menooker
Copy link

@Menooker Menooker commented Apr 25, 2024

Add an option hoist-static-allocs to remove the unnecessary memref.alloc and memref.copy after this pass, when the memref in ReturnOp is allocated by memref.alloc and is statically shaped. Instead, it replaces the uses of the allocated memref with the memref in the out argument.
By default, BufferResultsToOutParams will result in a memcpy operation to copy the originally returned memref to the output argument memref. This is inefficient when the source of memcpy (the returned memref in the original ReturnOp) is from a local AllocOp. The pass can use the output argument memref to replace the locally allocated memref for better performance.hoist-static-allocs avoids dynamic allocation and memory movement.
This option will be critical for performance-sensivtive applications, which require BufferResultsToOutParams pass for a caller-owned output buffer calling convension.

Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be
notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write
permissions for the repository. In which case you can instead tag reviewers by
name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review
by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate
is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot llvmbot added mlir mlir:bufferization Bufferization infrastructure labels Apr 25, 2024
@llvmbot
Copy link
Member

llvmbot commented Apr 25, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-bufferization

Author: Menooker (Menooker)

Changes

Add an option elim-alloc-copy to remove the unnecessary memref.alloc and memref.copy after this pass, when the memref in ReturnOp is allocated by memref.alloc. Instead, it replaces the uses of the allocated memref with the memref in the out argument.
By default, BufferResultsToOutParams will result in a memcpy operation to copy the originally returned memref to the output argument memref. This is inefficient when the source of memcpy (the returned memref in the original ReturnOp) is from a local AllocOp. The pass can use the output argument memref to replace the locally allocated memref for better performance. elim-alloc-copy avoids dynamic allocation and memory movement.
This option will be critical for performance-sensivtive applications, which require BufferResultsToOutParams pass for a caller-owned output buffer calling convension.


Full diff: https://github.com/llvm/llvm-project/pull/90011.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h (+4)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td (+5)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp (+14-6)
  • (added) mlir/test/Transforms/buffer-results-to-out-params-elim.mlir (+24)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index a729bc99b987cd..e5d026d7469f98 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -166,6 +166,10 @@ struct BufferResultsToOutParamsOpts {
   /// If true, the pass adds a "bufferize.result" attribute to each output
   /// parameter.
   bool addResultAttribute = false;
+  
+  /// If true, the pass eliminates the memref.alloc and memcpy if the returned
+  /// memref is allocated in the current function.
+  bool avoidBufferResultAllocAndCopy = false;
 };
 
 /// Creates a pass that converts memref function results to out-params.
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index 1303dc2c9ae10f..e3197cc16377ee 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -320,6 +320,11 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
     Option<"addResultAttribute", "add-result-attr", "bool",
        /*default=*/"false",
        "Add the attribute 'bufferize.result' to all output parameters.">,
+    Option<"avoidBufferResultAllocAndCopy", "avoid-buffer-result-alloc-copy",
+       "bool", /*default=*/"false",
+       "When the returned memref is allocated by `memref.alloc` in the function"
+       ", eliminate the allocation and the memref.copy. And use the memref"
+       " given in function argument instead">,
   ];
   let constructor = "mlir::bufferization::createBufferResultsToOutParamsPass()";
   let dependentDialects = ["memref::MemRefDialect"];
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index a2222e169c4d64..ce6a4821ccc202 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -107,7 +107,8 @@ updateFuncOp(func::FuncOp func,
 // the given out-params.
 static LogicalResult updateReturnOps(func::FuncOp func,
                                      ArrayRef<BlockArgument> appendedEntryArgs,
-                                     MemCpyFn memCpyFn) {
+                                     MemCpyFn memCpyFn,
+                                     bool avoidBufferResultAllocAndCopy) {
   auto res = func.walk([&](func::ReturnOp op) {
     SmallVector<Value, 6> copyIntoOutParams;
     SmallVector<Value, 6> keepAsReturnOperands;
@@ -118,10 +119,14 @@ static LogicalResult updateReturnOps(func::FuncOp func,
         keepAsReturnOperands.push_back(operand);
     }
     OpBuilder builder(op);
-    for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
-      if (failed(
-              memCpyFn(builder, op.getLoc(), std::get<0>(t), std::get<1>(t))))
-        return WalkResult::interrupt();
+    for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
+      if (avoidBufferResultAllocAndCopy && isa<memref::AllocOp>(orig.getDefiningOp())) {
+        orig.replaceAllUsesWith(arg);
+        orig.getDefiningOp()->erase();
+      } else {
+        if (failed(memCpyFn(builder, op.getLoc(), orig, arg)))
+          return WalkResult::interrupt();
+      }
     }
     builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
     op.erase();
@@ -212,7 +217,8 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
       return success();
     };
     if (failed(updateReturnOps(func, appendedEntryArgs,
-                               options.memCpyFn.value_or(defaultMemCpyFn)))) {
+                               options.memCpyFn.value_or(defaultMemCpyFn),
+                               options.avoidBufferResultAllocAndCopy))) {
       return failure();
     }
   }
@@ -233,6 +239,8 @@ struct BufferResultsToOutParamsPass
     // Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
     if (addResultAttribute)
       options.addResultAttribute = true;
+    if (avoidBufferResultAllocAndCopy)
+      options.avoidBufferResultAllocAndCopy = true;
 
     if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
                                                               options)))
diff --git a/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir b/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
new file mode 100644
index 00000000000000..0b2a0b6e14d180
--- /dev/null
+++ b/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt -allow-unregistered-dialect -p 'builtin.module(buffer-results-to-out-params{avoid-buffer-result-alloc-copy})'  %s | FileCheck %s
+
+// CHECK-LABEL:   func @basic(
+// CHECK-SAME:                %[[ARG:.*]]: memref<8x64xf32>) {
+// CHECK-NOT:        memref.alloc()
+// CHECK:           "test.source"(%[[ARG]])  : (memref<8x64xf32>) -> ()
+// CHECK:           return
+// CHECK:         }
+func.func @basic() -> (memref<8x64xf32>) {
+  %b = memref.alloc() : memref<8x64xf32>
+  "test.source"(%b)  : (memref<8x64xf32>) -> ()
+  return %b : memref<8x64xf32>
+}
+
+// CHECK-LABEL:   func @basic_no_change(
+// CHECK-SAME:                %[[ARG:.*]]: memref<f32>) {
+// CHECK:           %[[RESULT:.*]] = "test.source"() : () -> memref<f32>
+// CHECK:           memref.copy %[[RESULT]], %[[ARG]]  : memref<f32> to memref<f32>
+// CHECK:           return
+// CHECK:         }
+func.func @basic_no_change() -> (memref<f32>) {
+  %0 = "test.source"() : () -> (memref<f32>)
+  return %0 : memref<f32>
+}
\ No newline at end of file

memCpyFn(builder, op.getLoc(), std::get<0>(t), std::get<1>(t))))
return WalkResult::interrupt();
for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
if (avoidBufferResultAllocAndCopy && isa<memref::AllocOp>(orig.getDefiningOp())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious about this: does it support the case where the shape of the MemRef is unknown at compile time?

I mean for example like following:

// the shape of %arg0 and %arg1 may be different.
func.func @main(%arg0: memref<?x?xf32>, %arg1 : memref<?x?xf32>) -> memref<?x?xf32> {
  %0 = scf.if %pred {
    scf.yield %arg0 :  memref<?x?xf32>
  } else {
    scf.yield %arg1 :  memref<?x?xf32>
  }
  %2 = memref.alloc(…)
  concat(%0, %0, %2)
  return %2 : memref<?x?xf32>
}

->

// the shape of %2 is depends on the runtime value, how does the user to alloc the buffer in advance and pass it to the funciton?
func.func @main(%arg0: memref<?x?xf32>, %arg1 : memref<?x?xf32>, %2 : memref<?x?xf32>) {
  %0 = scf.if %pred {
    scf.yield %arg0 :  memref<?x?xf32>
  } else {
    scf.yield %arg1 :  memref<?x?xf32>
  }
  concat(%0, %0, %2)
  return
}

Copy link
Author

@Menooker Menooker Apr 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comments. This added option in the pass was supposed to work on statically shaped allocated memref only. I will update the code to filter out dynamic-shaped cases.

@@ -320,6 +320,11 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
Option<"addResultAttribute", "add-result-attr", "bool",
/*default=*/"false",
"Add the attribute 'bufferize.result' to all output parameters.">,
Option<"avoidBufferResultAllocAndCopy", "avoid-buffer-result-alloc-copy",
"bool", /*default=*/"false",
"When the returned memref has static shape and is allocated by "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This description is rendered as a table (I think when running mlir-opt -help) and usually just one line.

How about renaming this option to hoist-static-allocs? And as a description for the option: "Hoist static allocations to call sites." (and moving the current description to the pass description above).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestions! I have updated the code.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This enhancement is indeed helpful to remove the overhead of result copy.
Just curious about the new name: as the original pass always requires the the caller to allocate the buffer for the output params(newly added function argument), it might be a bit confusing that we're running with "buffer-results-to-out-params" pass but with "hoist-static-allocs=0".

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A new allocation is indeed always placed at the call site. But I would call this "hoisting" only if an equivalent allocation is removed in the body of the callee. (Hoisting = Moving) This happens only if hoist-static-allocs=1.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks for the clarification:)

Copy link

@ZhennanQin ZhennanQin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

…inate AllocOp and avoid Copy

Add an option hoist-static-allocs to remove the unnecessary memref.alloc and memref.copy after this pass, when the memref in ReturnOp is allocated by memref.alloc and is statically shaped. Instead, it replaces the uses of the allocated memref with the memref in the out argument.
By default, BufferResultsToOutParams will result in a memcpy operation to copy the originally returned memref to the output argument memref. This is inefficient when the source of memcpy (the returned memref in the original ReturnOp) is from a local AllocOp. The pass can use the output argument memref to replace the locally allocated memref for better performance. elim-alloc-copy avoids dynamic allocation and memory movement.
This option will be critical for performance-sensivtive applications, which require BufferResultsToOutParams pass for a caller-owned output buffer calling convension.
@Menooker Menooker force-pushed the fix_buffer_results_to_arg branch from 5a15f77 to acdbdbf Compare May 7, 2024 07:54
@Menooker Menooker merged commit 0af448b into llvm:main May 8, 2024
3 checks passed
Copy link

github-actions bot commented May 8, 2024

@Menooker Congratulations on having your first Pull Request (PR) merged into the LLVM Project!

Your changes will be combined with recent changes from other authors, then tested
by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR.

Please check whether problems have been caused by your change specifically, as
the builds can include changes from many authors. It is not uncommon for your
change to be included in a build that fails due to someone else's changes, or
infrastructure issues.

How to do this, and the rest of the post-merge process, is covered in detail here.

If your change does cause a problem, it may be reverted, or you can revert it yourself.
This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again.

If you don't get any reports, no action is required from you. Your changes are working as expected, well done!

@Menooker Menooker deleted the fix_buffer_results_to_arg branch May 8, 2024 06:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants