Skip to content

Commit 24e5cf4

Browse files
maerhartZijunZhaoCCK
authored andcommitted
[mlir][bufferization] BufferDeallocationOpInterface: support custom ownership update logic (llvm#66350)
Add a method to the BufferDeallocationOpInterface that allows operations to implement the interface and provide custom logic to compute the ownership indicators of values it defines. As a demonstrating example, this new method is implemented by the `arith.select` operation.
1 parent e869359 commit 24e5cf4

File tree

9 files changed

+187
-16
lines changed

9 files changed

+187
-16
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===- BufferDeallocationOpInterfaceImpl.h ----------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
10+
#define MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
11+
12+
namespace mlir {
13+
14+
class DialectRegistry;
15+
16+
namespace arith {
17+
void registerBufferDeallocationOpInterfaceExternalModels(
18+
DialectRegistry &registry);
19+
} // namespace arith
20+
} // namespace mlir
21+
22+
#endif // MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H

mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,8 @@ class DeallocationState {
142142
/// a new SSA value, returned as the first element of the pair, which has
143143
/// 'Unique' ownership and can be used instead of the passed Value with the
144144
/// the ownership indicator returned as the second element of the pair.
145-
std::pair<Value, Value> getMemrefWithUniqueOwnership(OpBuilder &builder,
146-
Value memref);
145+
std::pair<Value, Value>
146+
getMemrefWithUniqueOwnership(OpBuilder &builder, Value memref, Block *block);
147147

148148
/// Given two basic blocks and the values passed via block arguments to the
149149
/// destination block, compute the list of MemRefs that have to be retained in

mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.td

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,34 @@ def BufferDeallocationOpInterface :
3939
/*retType=*/"FailureOr<Operation *>",
4040
/*methodName=*/"process",
4141
/*args=*/(ins "DeallocationState &":$state,
42-
"const DeallocationOptions &":$options)>
42+
"const DeallocationOptions &":$options)>,
43+
InterfaceMethod<
44+
/*desc=*/[{
45+
This method allows the implementing operation to specify custom logic
46+
to materialize an ownership indicator value for the given MemRef typed
47+
value it defines (including block arguments of nested regions). Since
48+
the operation itself has more information about its semantics the
49+
materialized IR can be more efficient compared to the default
50+
implementation and avoid cloning MemRefs and/or doing alias checking
51+
at runtime.
52+
Note that the same logic could also be implemented in the 'process'
53+
method above, however, the IR is always materialized then. If
54+
it's desirable to only materialize the IR to compute an updated
55+
ownership indicator when needed, it should be implemented using this
56+
method (which is especially important if operations are created that
57+
cannot be easily canonicalized away anymore).
58+
}],
59+
/*retType=*/"std::pair<Value, Value>",
60+
/*methodName=*/"materializeUniqueOwnershipForMemref",
61+
/*args=*/(ins "DeallocationState &":$state,
62+
"const DeallocationOptions &":$options,
63+
"OpBuilder &":$builder,
64+
"Value":$memref),
65+
/*methodBody=*/[{}],
66+
/*defaultImplementation=*/[{
67+
return state.getMemrefWithUniqueOwnership(
68+
builder, memref, memref.getParentBlock());
69+
}]>,
4370
];
4471
}
4572

mlir/include/mlir/InitAllDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
2121
#include "mlir/Dialect/Arith/IR/Arith.h"
2222
#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
23+
#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
2324
#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
2425
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
2526
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
@@ -133,6 +134,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
133134

134135
// Register all external models.
135136
affine::registerValueBoundsOpInterfaceExternalModels(registry);
137+
arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
136138
arith::registerBufferizableOpInterfaceExternalModels(registry);
137139
arith::registerValueBoundsOpInterfaceExternalModels(registry);
138140
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
//===- BufferDeallocationOpInterfaceImpl.cpp ------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
10+
#include "mlir/Dialect/Arith/IR/Arith.h"
11+
#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
12+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
13+
#include "mlir/IR/Dialect.h"
14+
#include "mlir/IR/Operation.h"
15+
16+
using namespace mlir;
17+
using namespace mlir::bufferization;
18+
19+
namespace {
20+
/// Provides custom logic to materialize ownership indicator values for the
21+
/// result value of 'arith.select'. Instead of cloning or runtime alias
22+
/// checking, this implementation inserts another `arith.select` to choose the
23+
/// ownership indicator of the operand in the same way the original
24+
/// `arith.select` chooses the MemRef operand. If at least one of the operand's
25+
/// ownerships is 'Unknown', fall back to the default implementation.
26+
///
27+
/// Example:
28+
/// ```mlir
29+
/// // let ownership(%m0) := %o0
30+
/// // let ownership(%m1) := %o1
31+
/// %res = arith.select %cond, %m0, %m1
32+
/// ```
33+
/// The default implementation would insert a clone and replace all uses of the
34+
/// result of `arith.select` with that clone:
35+
/// ```mlir
36+
/// %res = arith.select %cond, %m0, %m1
37+
/// %clone = bufferization.clone %res
38+
/// // let ownership(%res) := 'Unknown'
39+
/// // let ownership(%clone) := %true
40+
/// // replace all uses of %res with %clone
41+
/// ```
42+
/// This implementation, on the other hand, materializes the following:
43+
/// ```mlir
44+
/// %res = arith.select %cond, %m0, %m1
45+
/// %res_ownership = arith.select %cond, %o0, %o1
46+
/// // let ownership(%res) := %res_ownership
47+
/// ```
48+
struct SelectOpInterface
49+
: public BufferDeallocationOpInterface::ExternalModel<SelectOpInterface,
50+
arith::SelectOp> {
51+
FailureOr<Operation *> process(Operation *op, DeallocationState &state,
52+
const DeallocationOptions &options) const {
53+
return op; // nothing to do
54+
}
55+
56+
std::pair<Value, Value>
57+
materializeUniqueOwnershipForMemref(Operation *op, DeallocationState &state,
58+
const DeallocationOptions &options,
59+
OpBuilder &builder, Value value) const {
60+
auto selectOp = cast<arith::SelectOp>(op);
61+
assert(value == selectOp.getResult() &&
62+
"Value not defined by this operation");
63+
64+
Block *block = value.getParentBlock();
65+
if (!state.getOwnership(selectOp.getTrueValue(), block).isUnique() ||
66+
!state.getOwnership(selectOp.getFalseValue(), block).isUnique())
67+
return state.getMemrefWithUniqueOwnership(builder, value,
68+
value.getParentBlock());
69+
70+
Value ownership = builder.create<arith::SelectOp>(
71+
op->getLoc(), selectOp.getCondition(),
72+
state.getOwnership(selectOp.getTrueValue(), block).getIndicator(),
73+
state.getOwnership(selectOp.getFalseValue(), block).getIndicator());
74+
return {selectOp.getResult(), ownership};
75+
}
76+
};
77+
78+
} // namespace
79+
80+
void mlir::arith::registerBufferDeallocationOpInterfaceExternalModels(
81+
DialectRegistry &registry) {
82+
registry.addExtension(+[](MLIRContext *ctx, ArithDialect *dialect) {
83+
SelectOp::attachInterface<SelectOpInterface>(*ctx);
84+
});
85+
}

mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_dialect_library(MLIRArithTransforms
2+
BufferDeallocationOpInterfaceImpl.cpp
23
BufferizableOpInterfaceImpl.cpp
34
Bufferize.cpp
45
EmulateUnsupportedFloats.cpp

mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ void DeallocationState::getLiveMemrefsIn(Block *block,
134134

135135
std::pair<Value, Value>
136136
DeallocationState::getMemrefWithUniqueOwnership(OpBuilder &builder,
137-
Value memref) {
138-
auto iter = ownershipMap.find({memref, memref.getParentBlock()});
137+
Value memref, Block *block) {
138+
auto iter = ownershipMap.find({memref, block});
139139
assert(iter != ownershipMap.end() &&
140140
"Value must already have been registered in the ownership map");
141141

mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -376,13 +376,24 @@ class BufferDeallocation {
376376
/// Given an SSA value of MemRef type, returns the same of a new SSA value
377377
/// which has 'Unique' ownership where the ownership indicator is guaranteed
378378
/// to be always 'true'.
379-
Value getMemrefWithGuaranteedOwnership(OpBuilder &builder, Value memref);
379+
Value materializeMemrefWithGuaranteedOwnership(OpBuilder &builder,
380+
Value memref, Block *block);
380381

381382
/// Returns whether the given operation implements FunctionOpInterface, has
382383
/// private visibility, and the private-function-dynamic-ownership pass option
383384
/// is enabled.
384385
bool isFunctionWithoutDynamicOwnership(Operation *op);
385386

387+
/// Given an SSA value of MemRef type, this function queries the
388+
/// BufferDeallocationOpInterface of the defining operation of 'memref' for a
389+
/// materialized ownership indicator for 'memref'. If the op does not
390+
/// implement the interface or if the block for which the materialized value
391+
/// is requested does not match the block in which 'memref' is defined, the
392+
/// default implementation in
393+
/// `DeallocationState::getMemrefWithUniqueOwnership` is queried instead.
394+
std::pair<Value, Value>
395+
materializeUniqueOwnership(OpBuilder &builder, Value memref, Block *block);
396+
386397
/// Checks all the preconditions for operations implementing the
387398
/// FunctionOpInterface that have to hold for the deallocation to be
388399
/// applicable:
@@ -428,6 +439,28 @@ class BufferDeallocation {
428439
// BufferDeallocation Implementation
429440
//===----------------------------------------------------------------------===//
430441

442+
std::pair<Value, Value>
443+
BufferDeallocation::materializeUniqueOwnership(OpBuilder &builder, Value memref,
444+
Block *block) {
445+
// The interface can only materialize ownership indicators in the same block
446+
// as the defining op.
447+
if (memref.getParentBlock() != block)
448+
return state.getMemrefWithUniqueOwnership(builder, memref, block);
449+
450+
Operation *owner = memref.getDefiningOp();
451+
if (!owner)
452+
owner = memref.getParentBlock()->getParentOp();
453+
454+
// If the op implements the interface, query it for a materialized ownership
455+
// value.
456+
if (auto deallocOpInterface = dyn_cast<BufferDeallocationOpInterface>(owner))
457+
return deallocOpInterface.materializeUniqueOwnershipForMemref(
458+
state, options, builder, memref);
459+
460+
// Otherwise use the default implementation.
461+
return state.getMemrefWithUniqueOwnership(builder, memref, block);
462+
}
463+
431464
static bool regionOperatesOnMemrefValues(Region &region) {
432465
WalkResult result = region.walk([](Block *block) {
433466
if (llvm::any_of(block->getArguments(), isMemref))
@@ -677,11 +710,11 @@ BufferDeallocation::handleInterface(RegionBranchOpInterface op) {
677710
return newOp.getOperation();
678711
}
679712

680-
Value BufferDeallocation::getMemrefWithGuaranteedOwnership(OpBuilder &builder,
681-
Value memref) {
713+
Value BufferDeallocation::materializeMemrefWithGuaranteedOwnership(
714+
OpBuilder &builder, Value memref, Block *block) {
682715
// First, make sure we at least have 'Unique' ownership already.
683716
std::pair<Value, Value> newMemrefAndOnwership =
684-
state.getMemrefWithUniqueOwnership(builder, memref);
717+
materializeUniqueOwnership(builder, memref, block);
685718
Value newMemref = newMemrefAndOnwership.first;
686719
Value condition = newMemrefAndOnwership.second;
687720

@@ -785,7 +818,7 @@ FailureOr<Operation *> BufferDeallocation::handleInterface(CallOpInterface op) {
785818
continue;
786819
}
787820
auto [memref, condition] =
788-
state.getMemrefWithUniqueOwnership(builder, operand);
821+
materializeUniqueOwnership(builder, operand, op->getBlock());
789822
newOperands.push_back(memref);
790823
ownershipIndicatorsToAdd.push_back(condition);
791824
}
@@ -868,7 +901,8 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) {
868901
if (!isMemref(val.get()))
869902
continue;
870903

871-
val.set(getMemrefWithGuaranteedOwnership(builder, val.get()));
904+
val.set(materializeMemrefWithGuaranteedOwnership(builder, val.get(),
905+
op->getBlock()));
872906
}
873907
}
874908

mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-callop-interface.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,15 @@ func.func @function_call_requries_merged_ownership_mid_block(%arg0: i1) {
9595
// CHECK-NEXT: return
9696

9797
// CHECK-DYNAMIC-LABEL: func @function_call_requries_merged_ownership_mid_block
98+
// CHECK-DYNAMIC-SAME: ([[ARG0:%.+]]: i1)
9899
// CHECK-DYNAMIC: [[ALLOC0:%.+]] = memref.alloc(
99100
// CHECK-DYNAMIC-NEXT: [[ALLOC1:%.+]] = memref.alloca(
100-
// CHECK-DYNAMIC-NEXT: [[SELECT:%.+]] = arith.select{{.*}}[[ALLOC0]], [[ALLOC1]]
101-
// CHECK-DYNAMIC-NEXT: [[CLONE:%.+]] = bufferization.clone [[SELECT]]
102-
// CHECK-DYNAMIC-NEXT: [[RET:%.+]]:2 = call @f([[CLONE]], %true{{[0-9_]*}})
101+
// CHECK-DYNAMIC-NEXT: [[SELECT:%.+]] = arith.select [[ARG0]], [[ALLOC0]], [[ALLOC1]]
102+
// CHECK-DYNAMIC-NEXT: [[RET:%.+]]:2 = call @f([[SELECT]], [[ARG0]])
103103
// CHECK-DYNAMIC-NEXT: test.copy
104104
// CHECK-DYNAMIC-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[RET]]#0
105-
// CHECK-DYNAMIC-NEXT: bufferization.dealloc ([[ALLOC0]], [[CLONE]], [[BASE]] :
106-
// CHECK-DYNAMIC-SAME: if (%true{{[0-9_]*}}, %true{{[0-9_]*}}, [[RET]]#1)
105+
// CHECK-DYNAMIC-NEXT: bufferization.dealloc ([[ALLOC0]], [[BASE]] :
106+
// CHECK-DYNAMIC-SAME: if (%true{{[0-9_]*}}, [[RET]]#1)
107107
// CHECK-DYNAMIC-NOT: retain
108108
// CHECK-DYNAMIC-NEXT: return
109109

0 commit comments

Comments
 (0)