Skip to content

[mlir][scf][bufferize] Improve bufferization of allocs yielded from scf.for #68089

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 40 additions & 16 deletions mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,28 @@ static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
}

/// Helper function for loop bufferization. Return "true" if the given value
/// is guaranteed to not alias with an external tensor apart from values in
/// `exceptions`. A value is external if it is defined outside of the given
/// region or if it is an entry block argument of the region.
static bool doesNotAliasExternalValue(Value value, Region *region,
ValueRange exceptions,
const OneShotAnalysisState &state) {
assert(region->getBlocks().size() == 1 &&
"expected region with single block");
bool result = true;
state.applyOnAliases(value, [&](Value alias) {
if (llvm::is_contained(exceptions, alias))
return;
Region *aliasRegion = alias.getParentRegion();
if (isa<BlockArgument>(alias) && !region->isProperAncestor(aliasRegion))
result = false;
if (isa<OpResult>(alias) && !region->isAncestor(aliasRegion))
result = false;
});
return result;
}

/// Bufferization of scf.condition.
struct ConditionOpInterface
: public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
Expand Down Expand Up @@ -633,12 +655,10 @@ struct ForOpInterface
return success();

// According to the `getAliasing...` implementations, a bufferized OpResult
// may alias only with the corresponding bufferized init_arg and with no
// other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
// but not with any other OpOperand. If a corresponding OpResult/init_arg
// pair bufferizes to equivalent buffers, this aliasing requirement is
// satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
// (New buffer copies do not alias with any buffer.)
// may alias only with the corresponding bufferized init_arg (or with a
// newly allocated buffer) and not with other buffers defined outside of the
// loop. I.e., the i-th OpResult may alias with the i-th init_arg;
// but not with any other OpOperand.
auto forOp = cast<scf::ForOp>(op);
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
OpBuilder::InsertionGuard g(rewriter);
Expand All @@ -647,20 +667,24 @@ struct ForOpInterface
// Indices of all iter_args that have tensor type. These are the ones that
// are bufferized.
DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
// For every yielded value, is the value equivalent to its corresponding
// bbArg?
DenseSet<int64_t> equivalentYields = getEquivalentBuffers(
forOp.getRegionIterArgs(), yieldOp.getResults(), state);
// For every yielded value, does it alias with something defined outside of
// the loop?
SmallVector<Value> yieldValues;
for (int64_t idx = 0;
idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) {
Value value = yieldOp.getResults()[idx];
if (!indices.contains(idx) || equivalentYields.contains(idx)) {
yieldValues.push_back(value);
for (const auto it : llvm::enumerate(yieldOp.getResults())) {
// Note: `state` is guaranteed to be a `OneShotAnalysisState`, but this
// type cannot be used in the signature of `resolveConflicts` because the
// op interface is in the "IR" build unit and the `OneShotAnalysisState`
// is defined in the "Transforms" build unit.
if (!indices.contains(it.index()) ||
doesNotAliasExternalValue(
it.value(), &forOp.getRegion(),
/*exceptions=*/forOp.getRegionIterArg(it.index()),
static_cast<const OneShotAnalysisState &>(state))) {
yieldValues.push_back(it.value());
continue;
}
FailureOr<Value> alloc = allocateTensorForShapedValue(
rewriter, yieldOp.getLoc(), value, state.getOptions());
rewriter, yieldOp.getLoc(), it.value(), state.getOptions());
if (failed(alloc))
return failure();
yieldValues.push_back(*alloc);
Expand Down
14 changes: 3 additions & 11 deletions mlir/test/Dialect/SCF/one-shot-bufferize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -269,20 +269,12 @@ func.func @scf_for_yield_non_equivalent(

// -----

// Note: This bufferizes to inefficient code, but bufferization should not see
// such IR in the first place. The iter_arg would canonicalize away. This test
// case is just to ensure that the bufferization generates correct code.

// CHECK-LABEL: func @scf_for_yield_allocation(
// CHECK-SAME: %[[t:.*]]: memref<?xf32
// CHECK: %[[for:.*]] = scf.for {{.*}} iter_args(%[[iter:.*]] = %[[t]])
// This alloc is for the bufferization.alloc_tensor.
// CHECK-DAG: %[[alloc2:.*]] = memref.alloc(%{{.*}})
// This alloc is for the scf.yield.
// CHECK: %[[alloc3:.*]] = memref.alloc(%{{.*}})
// CHECK: memref.copy %[[alloc2]], %[[alloc3]]
// CHECK: %[[casted3:.*]] = memref.cast %[[alloc3]]
// CHECK: scf.yield %[[casted3]]
// CHECK-DAG: %[[alloc:.*]] = memref.alloc(%{{.*}})
// CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
// CHECK: scf.yield %[[casted]]
// CHECK: return %[[for]]
func.func @scf_for_yield_allocation(%t: tensor<?xf32>, %lb : index, %ub : index,
%step : index) -> tensor<?xf32> {
Expand Down