Skip to content

[mlir][scf][bufferize] Improve bufferization of allocs yielded from scf.for #68089

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 3, 2023

Conversation

matthias-springer
Copy link
Member

The BufferizableOpInterface implementation of scf.for currently assumes that an OpResult does not alias with any tensor apart from the corresponding init OpOperand. Newly allocated buffers (inside of the loop) are also allowed. The current implementation checks whether the respective init_arg and yielded value are equivalent. This is overly strict and causes extra buffer allocations/copies when yielding a new buffer allocation from a loop.

…scf.for`

The `BufferizableOpInterface` implementation of `scf.for` currently assumes that an OpResult does not alias with any tensor apart from the corresponding init OpOperand. Newly allocated buffers (inside of the loop) are also allowed. The current implementation checks whether the respective init_arg and yielded value are equivalent. This is overly strict and causes extra buffer allocations/copies when yielding a new buffer allocation from a loop.
@llvmbot
Copy link
Member

llvmbot commented Oct 3, 2023

@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir

Changes

The BufferizableOpInterface implementation of scf.for currently assumes that an OpResult does not alias with any tensor apart from the corresponding init OpOperand. Newly allocated buffers (inside of the loop) are also allowed. The current implementation checks whether the respective init_arg and yielded value are equivalent. This is overly strict and causes extra buffer allocations/copies when yielding a new buffer allocation from a loop.


Full diff: https://github.com/llvm/llvm-project/pull/68089.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+40-16)
  • (modified) mlir/test/Dialect/SCF/one-shot-bufferize.mlir (+3-11)
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 88c025c6b2b2e8b..0d02a590f296934 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -45,6 +45,28 @@ static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
   return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
 }
 
+/// Helper function for loop bufferization. Return "true" if the given value
+/// is guaranteed to not alias with an external tensor apart from values in
+/// `exceptions`. A value is external if it is defined outside of the given
+/// region or if it is an entry block argument of the region.
+static bool doesNotAliasExternalValue(Value value, Region *region,
+                                      ValueRange exceptions,
+                                      const OneShotAnalysisState &state) {
+  assert(region->getBlocks().size() == 1 &&
+         "expected region with single block");
+  bool result = true;
+  state.applyOnAliases(value, [&](Value alias) {
+    if (llvm::is_contained(exceptions, alias))
+      return;
+    Region *aliasRegion = alias.getParentRegion();
+    if (isa<BlockArgument>(alias) && !region->isProperAncestor(aliasRegion))
+      result = false;
+    if (isa<OpResult>(alias) && !region->isAncestor(aliasRegion))
+      result = false;
+  });
+  return result;
+}
+
 /// Bufferization of scf.condition.
 struct ConditionOpInterface
     : public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
@@ -633,12 +655,10 @@ struct ForOpInterface
       return success();
 
     // According to the `getAliasing...` implementations, a bufferized OpResult
-    // may alias only with the corresponding bufferized init_arg and with no
-    // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
-    // but not with any other OpOperand. If a corresponding OpResult/init_arg
-    // pair bufferizes to equivalent buffers, this aliasing requirement is
-    // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
-    // (New buffer copies do not alias with any buffer.)
+    // may alias only with the corresponding bufferized init_arg (or with a
+    // newly allocated buffer) and not with other buffers defined outside of the
+    // loop. I.e., the i-th OpResult may alias with the i-th init_arg;
+    // but not with any other OpOperand.
     auto forOp = cast<scf::ForOp>(op);
     auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
     OpBuilder::InsertionGuard g(rewriter);
@@ -647,20 +667,24 @@ struct ForOpInterface
     // Indices of all iter_args that have tensor type. These are the ones that
     // are bufferized.
     DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
-    // For every yielded value, is the value equivalent to its corresponding
-    // bbArg?
-    DenseSet<int64_t> equivalentYields = getEquivalentBuffers(
-        forOp.getRegionIterArgs(), yieldOp.getResults(), state);
+    // For every yielded value, does it alias with something defined outside of
+    // the loop?
     SmallVector<Value> yieldValues;
-    for (int64_t idx = 0;
-         idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) {
-      Value value = yieldOp.getResults()[idx];
-      if (!indices.contains(idx) || equivalentYields.contains(idx)) {
-        yieldValues.push_back(value);
+    for (const auto it : llvm::enumerate(yieldOp.getResults())) {
+      // Note: `state` is guaranteed to be a `OneShotAnalysisState`, but this
+      // type cannot be used in the signature of `resolveConflicts` because the
+      // op interface is in the "IR" build unit and the `OneShotAnalysisState`
+      // is defined in the "Transforms" build unit.
+      if (!indices.contains(it.index()) ||
+          doesNotAliasExternalValue(
+              it.value(), &forOp.getRegion(),
+              /*exceptions=*/forOp.getRegionIterArg(it.index()),
+              static_cast<const OneShotAnalysisState &>(state))) {
+        yieldValues.push_back(it.value());
         continue;
       }
       FailureOr<Value> alloc = allocateTensorForShapedValue(
-          rewriter, yieldOp.getLoc(), value, state.getOptions());
+          rewriter, yieldOp.getLoc(), it.value(), state.getOptions());
       if (failed(alloc))
         return failure();
       yieldValues.push_back(*alloc);
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index 347f253906933ee..24da8d84b18e260 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -269,20 +269,12 @@ func.func @scf_for_yield_non_equivalent(
 
 // -----
 
-// Note: This bufferizes to inefficient code, but bufferization should not see
-// such IR in the first place. The iter_arg would canonicalize away. This test
-// case is just to ensure that the bufferization generates correct code.
-
 // CHECK-LABEL: func @scf_for_yield_allocation(
 //  CHECK-SAME:     %[[t:.*]]: memref<?xf32
 //       CHECK:   %[[for:.*]] = scf.for {{.*}} iter_args(%[[iter:.*]] = %[[t]])
-// This alloc is for the bufferization.alloc_tensor.
-//   CHECK-DAG:     %[[alloc2:.*]] = memref.alloc(%{{.*}})
-// This alloc is for the scf.yield.
-//       CHECK:     %[[alloc3:.*]] = memref.alloc(%{{.*}})
-//       CHECK:     memref.copy %[[alloc2]], %[[alloc3]]
-//       CHECK:     %[[casted3:.*]] = memref.cast %[[alloc3]]
-//       CHECK:     scf.yield %[[casted3]]
+//   CHECK-DAG:     %[[alloc:.*]] = memref.alloc(%{{.*}})
+//       CHECK:     %[[casted:.*]] = memref.cast %[[alloc]]
+//       CHECK:     scf.yield %[[casted]]
 //       CHECK:   return %[[for]]
 func.func @scf_for_yield_allocation(%t: tensor<?xf32>, %lb : index, %ub : index,
                                %step : index) -> tensor<?xf32> {

@matthias-springer matthias-springer merged commit 173fd67 into llvm:main Oct 3, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants