Skip to content

Commit 173fd67

Browse files
[mlir][scf][bufferize] Improve bufferization of allocs yielded from scf.for (#68089)
The `BufferizableOpInterface` implementation of `scf.for` currently assumes that an OpResult does not alias with any tensor apart from the corresponding init OpOperand. Newly allocated buffers (inside of the loop) are also allowed. The current implementation checks whether the respective init_arg and yielded value are equivalent. This is overly strict and causes extra buffer allocations/copies when yielding a new buffer allocation from a loop.
1 parent 464dfeb commit 173fd67

File tree

2 files changed

+43
-27
lines changed

2 files changed

+43
-27
lines changed

mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,28 @@ static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
4545
return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
4646
}
4747

48+
/// Helper function for loop bufferization. Return "true" if the given value
49+
/// is guaranteed to not alias with an external tensor apart from values in
50+
/// `exceptions`. A value is external if it is defined outside of the given
51+
/// region or if it is an entry block argument of the region.
52+
static bool doesNotAliasExternalValue(Value value, Region *region,
53+
ValueRange exceptions,
54+
const OneShotAnalysisState &state) {
55+
assert(region->getBlocks().size() == 1 &&
56+
"expected region with single block");
57+
bool result = true;
58+
state.applyOnAliases(value, [&](Value alias) {
59+
if (llvm::is_contained(exceptions, alias))
60+
return;
61+
Region *aliasRegion = alias.getParentRegion();
62+
if (isa<BlockArgument>(alias) && !region->isProperAncestor(aliasRegion))
63+
result = false;
64+
if (isa<OpResult>(alias) && !region->isAncestor(aliasRegion))
65+
result = false;
66+
});
67+
return result;
68+
}
69+
4870
/// Bufferization of scf.condition.
4971
struct ConditionOpInterface
5072
: public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
@@ -633,12 +655,10 @@ struct ForOpInterface
633655
return success();
634656

635657
// According to the `getAliasing...` implementations, a bufferized OpResult
636-
// may alias only with the corresponding bufferized init_arg and with no
637-
// other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
638-
// but not with any other OpOperand. If a corresponding OpResult/init_arg
639-
// pair bufferizes to equivalent buffers, this aliasing requirement is
640-
// satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
641-
// (New buffer copies do not alias with any buffer.)
658+
// may alias only with the corresponding bufferized init_arg (or with a
659+
// newly allocated buffer) and not with other buffers defined outside of the
660+
// loop. I.e., the i-th OpResult may alias with the i-th init_arg;
661+
// but not with any other OpOperand.
642662
auto forOp = cast<scf::ForOp>(op);
643663
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
644664
OpBuilder::InsertionGuard g(rewriter);
@@ -647,20 +667,24 @@ struct ForOpInterface
647667
// Indices of all iter_args that have tensor type. These are the ones that
648668
// are bufferized.
649669
DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
650-
// For every yielded value, is the value equivalent to its corresponding
651-
// bbArg?
652-
DenseSet<int64_t> equivalentYields = getEquivalentBuffers(
653-
forOp.getRegionIterArgs(), yieldOp.getResults(), state);
670+
// For every yielded value, does it alias with something defined outside of
671+
// the loop?
654672
SmallVector<Value> yieldValues;
655-
for (int64_t idx = 0;
656-
idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) {
657-
Value value = yieldOp.getResults()[idx];
658-
if (!indices.contains(idx) || equivalentYields.contains(idx)) {
659-
yieldValues.push_back(value);
673+
for (const auto it : llvm::enumerate(yieldOp.getResults())) {
674+
// Note: `state` is guaranteed to be a `OneShotAnalysisState`, but this
675+
// type cannot be used in the signature of `resolveConflicts` because the
676+
// op interface is in the "IR" build unit and the `OneShotAnalysisState`
677+
// is defined in the "Transforms" build unit.
678+
if (!indices.contains(it.index()) ||
679+
doesNotAliasExternalValue(
680+
it.value(), &forOp.getRegion(),
681+
/*exceptions=*/forOp.getRegionIterArg(it.index()),
682+
static_cast<const OneShotAnalysisState &>(state))) {
683+
yieldValues.push_back(it.value());
660684
continue;
661685
}
662686
FailureOr<Value> alloc = allocateTensorForShapedValue(
663-
rewriter, yieldOp.getLoc(), value, state.getOptions());
687+
rewriter, yieldOp.getLoc(), it.value(), state.getOptions());
664688
if (failed(alloc))
665689
return failure();
666690
yieldValues.push_back(*alloc);

mlir/test/Dialect/SCF/one-shot-bufferize.mlir

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -269,20 +269,12 @@ func.func @scf_for_yield_non_equivalent(
269269

270270
// -----
271271

272-
// Note: This bufferizes to inefficient code, but bufferization should not see
273-
// such IR in the first place. The iter_arg would canonicalize away. This test
274-
// case is just to ensure that the bufferization generates correct code.
275-
276272
// CHECK-LABEL: func @scf_for_yield_allocation(
277273
// CHECK-SAME: %[[t:.*]]: memref<?xf32
278274
// CHECK: %[[for:.*]] = scf.for {{.*}} iter_args(%[[iter:.*]] = %[[t]])
279-
// This alloc is for the bufferization.alloc_tensor.
280-
// CHECK-DAG: %[[alloc2:.*]] = memref.alloc(%{{.*}})
281-
// This alloc is for the scf.yield.
282-
// CHECK: %[[alloc3:.*]] = memref.alloc(%{{.*}})
283-
// CHECK: memref.copy %[[alloc2]], %[[alloc3]]
284-
// CHECK: %[[casted3:.*]] = memref.cast %[[alloc3]]
285-
// CHECK: scf.yield %[[casted3]]
275+
// CHECK-DAG: %[[alloc:.*]] = memref.alloc(%{{.*}})
276+
// CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
277+
// CHECK: scf.yield %[[casted]]
286278
// CHECK: return %[[for]]
287279
func.func @scf_for_yield_allocation(%t: tensor<?xf32>, %lb : index, %ub : index,
288280
%step : index) -> tensor<?xf32> {

0 commit comments

Comments
 (0)