@@ -45,6 +45,28 @@ static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
45
45
return b.create <memref::CastOp>(buffer.getLoc (), type, buffer).getResult ();
46
46
}
47
47
48
+ // / Helper function for loop bufferization. Return "true" if the given value
49
+ // / is guaranteed to not alias with an external tensor apart from values in
50
+ // / `exceptions`. A value is external if it is defined outside of the given
51
+ // / region or if it is an entry block argument of the region.
52
+ static bool doesNotAliasExternalValue (Value value, Region *region,
53
+ ValueRange exceptions,
54
+ const OneShotAnalysisState &state) {
55
+ assert (region->getBlocks ().size () == 1 &&
56
+ " expected region with single block" );
57
+ bool result = true ;
58
+ state.applyOnAliases (value, [&](Value alias) {
59
+ if (llvm::is_contained (exceptions, alias))
60
+ return ;
61
+ Region *aliasRegion = alias.getParentRegion ();
62
+ if (isa<BlockArgument>(alias) && !region->isProperAncestor (aliasRegion))
63
+ result = false ;
64
+ if (isa<OpResult>(alias) && !region->isAncestor (aliasRegion))
65
+ result = false ;
66
+ });
67
+ return result;
68
+ }
69
+
48
70
// / Bufferization of scf.condition.
49
71
struct ConditionOpInterface
50
72
: public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
@@ -633,12 +655,10 @@ struct ForOpInterface
633
655
return success ();
634
656
635
657
// According to the `getAliasing...` implementations, a bufferized OpResult
636
- // may alias only with the corresponding bufferized init_arg and with no
637
- // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
638
- // but not with any other OpOperand. If a corresponding OpResult/init_arg
639
- // pair bufferizes to equivalent buffers, this aliasing requirement is
640
- // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
641
- // (New buffer copies do not alias with any buffer.)
658
+ // may alias only with the corresponding bufferized init_arg (or with a
659
+ // newly allocated buffer) and not with other buffers defined outside of the
660
+ // loop. I.e., the i-th OpResult may alias with the i-th init_arg;
661
+ // but not with any other OpOperand.
642
662
auto forOp = cast<scf::ForOp>(op);
643
663
auto yieldOp = cast<scf::YieldOp>(forOp.getBody ()->getTerminator ());
644
664
OpBuilder::InsertionGuard g (rewriter);
@@ -647,20 +667,24 @@ struct ForOpInterface
647
667
// Indices of all iter_args that have tensor type. These are the ones that
648
668
// are bufferized.
649
669
DenseSet<int64_t > indices = getTensorIndices (forOp.getInitArgs ());
650
- // For every yielded value, is the value equivalent to its corresponding
651
- // bbArg?
652
- DenseSet<int64_t > equivalentYields = getEquivalentBuffers (
653
- forOp.getRegionIterArgs (), yieldOp.getResults (), state);
670
+ // For every yielded value, does it alias with something defined outside of
671
+ // the loop?
654
672
SmallVector<Value> yieldValues;
655
- for (int64_t idx = 0 ;
656
- idx < static_cast <int64_t >(yieldOp.getResults ().size ()); ++idx) {
657
- Value value = yieldOp.getResults ()[idx];
658
- if (!indices.contains (idx) || equivalentYields.contains (idx)) {
659
- yieldValues.push_back (value);
673
+ for (const auto it : llvm::enumerate (yieldOp.getResults ())) {
674
+ // Note: `state` is guaranteed to be a `OneShotAnalysisState`, but this
675
+ // type cannot be used in the signature of `resolveConflicts` because the
676
+ // op interface is in the "IR" build unit and the `OneShotAnalysisState`
677
+ // is defined in the "Transforms" build unit.
678
+ if (!indices.contains (it.index ()) ||
679
+ doesNotAliasExternalValue (
680
+ it.value (), &forOp.getRegion (),
681
+ /* exceptions=*/ forOp.getRegionIterArg (it.index ()),
682
+ static_cast <const OneShotAnalysisState &>(state))) {
683
+ yieldValues.push_back (it.value ());
660
684
continue ;
661
685
}
662
686
FailureOr<Value> alloc = allocateTensorForShapedValue (
663
- rewriter, yieldOp.getLoc (), value, state.getOptions ());
687
+ rewriter, yieldOp.getLoc (), it. value () , state.getOptions ());
664
688
if (failed (alloc))
665
689
return failure ();
666
690
yieldValues.push_back (*alloc);
0 commit comments