Skip to content

Commit a45e58a

Browse files
[mlir][bufferization] Add BufferViewFlowOpInterface (#78718)
This commit adds the `BufferViewFlowOpInterface` to the bufferization dialect. This interface can be implemented by ops that operate on buffers to indicate that a buffer op result and/or region entry block argument may be the same buffer as a buffer operand (or a view thereof). This interface is queried by the `BufferViewFlowAnalysis`. The new interface has two interface methods: * `populateDependencies`: Implementations use the provided callback to declare dependencies between operands and op results/region entry block arguments. E.g., for `%r = arith.select %c, %m1, %m2 : memref<5xf32>`, the interface implementation should declare two dependencies: %m1 -> %r and %m2 -> %r. * `mayBeTerminalBuffer`: An SSA value is a terminal buffer if the buffer view flow analysis stops at the specified value. E.g., because the value is a newly allocated buffer or because no further information is available about the origin of the buffer. Ops that implement the `RegionBranchOpInterface` or `BranchOpInterface` do not have to implement the `BufferViewFlowOpInterface`. The buffer dependencies can be inferred from those two interfaces. This commit makes the `BufferViewFlowAnalysis` more accurate. For unknown ops, it conservatively used to declare all combinations of operands and op results/region entry block arguments as dependencies (false positives). This is no longer the case. While the analysis is still a "maybe" analysis with false positives (e.g., when analyzing ops such as `arith.select` or `scf.if` where the taken branch is not known at compile time), results and region entry block arguments of unknown ops are now marked as terminal buffers. This commit addresses a TODO in `BufferViewFlowAnalysis.cpp`: ``` // TODO: We should have an op interface instead of a hard-coded list of // interfaces/ops. ``` It is no longer needed to hard-code ops.
1 parent 74799f4 commit a45e58a

File tree

15 files changed

+363
-14
lines changed

15 files changed

+363
-14
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//===- BufferViewFlowOpInterfaceImpl.h - Buffer View Analysis ---*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H
10+
#define MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H
11+
12+
namespace mlir {
13+
class DialectRegistry;
14+
15+
namespace arith {
16+
void registerBufferViewFlowOpInterfaceExternalModels(DialectRegistry &registry);
17+
} // namespace arith
18+
} // namespace mlir
19+
20+
#endif // MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===- BufferViewFlowOpInterface.h - Buffer View Flow Analysis --*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_
10+
#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_
11+
12+
#include "mlir/IR/OpDefinition.h"
13+
#include "mlir/Support/LLVM.h"
14+
15+
namespace mlir {
16+
class ValueRange;
17+
18+
namespace bufferization {
19+
20+
using RegisterDependenciesFn = std::function<void(ValueRange, ValueRange)>;
21+
22+
} // namespace bufferization
23+
} // namespace mlir
24+
25+
#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h.inc"
26+
27+
#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
//===-- BufferViewFlowOpInterface.td - Buffer View Flow ----*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef BUFFER_VIEW_FLOW_OP_INTERFACE
10+
#define BUFFER_VIEW_FLOW_OP_INTERFACE
11+
12+
include "mlir/IR/OpBase.td"
13+
14+
def BufferViewFlowOpInterface :
15+
OpInterface<"BufferViewFlowOpInterface"> {
16+
let description = [{
17+
An op interface for the buffer view flow analysis. This interface describes
18+
buffer dependencies between operands and op results/region entry block
19+
arguments.
20+
}];
21+
let cppNamespace = "::mlir::bufferization";
22+
let methods = [
23+
InterfaceMethod<
24+
/*desc=*/[{
25+
Populate buffer dependencies between operands and op results/region
26+
entry block arguments.
27+
28+
Implementations should register dependencies between an operand ("X")
29+
and an op result/region entry block argument ("Y") if Y may depend
30+
on X. Y depends on X if Y and X are the same buffer or if Y is a
31+
subview of X.
32+
33+
Example:
34+
```
35+
%r = arith.select %c, %m1, %m2 : memref<5xf32>
36+
```
37+
In the above example, %0 may depend on %m1 or %m2 and a correct
38+
interface implementation should call:
39+
- "registerDependenciesFn(%m1, %r)".
40+
- "registerDependenciesFn(%m2, %r)"
41+
}],
42+
/*retType=*/"void",
43+
/*methodName=*/"populateDependencies",
44+
/*args=*/(ins
45+
"::mlir::bufferization::RegisterDependenciesFn"
46+
:$registerDependenciesFn)
47+
>,
48+
InterfaceMethod<
49+
/*desc=*/[{
50+
Return "true" if the given value may be a terminal buffer. A buffer
51+
value is "terminal" if it cannot be traced back any further in the
52+
buffer view flow analysis.
53+
54+
Examples: A buffer could be terminal because:
55+
- it is a newly allocated buffer (e.g., "memref.alloc"),
56+
- or: because there is not enough compile-time information available
57+
to make a definite decision (e.g., "memref.realloc" may reallocate
58+
but we do not know for sure; another example are call ops where we
59+
would have to analyze the body of the callee).
60+
61+
Implementations can assume that the given SSA value is an OpResult of
62+
this operation or a region entry block argument of this operation.
63+
}],
64+
/*retType=*/"bool",
65+
/*methodName=*/"mayBeTerminalBuffer",
66+
/*args=*/(ins "Value":$value),
67+
/*methodBody=*/"",
68+
/*defaultImplementation=*/"return false;"
69+
>,
70+
];
71+
}
72+
73+
#endif // BUFFER_VIEW_FLOW_OP_INTERFACE

mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_mlir_doc(BufferizationOps BufferizationOps Dialects/ -gen-dialect-doc)
33
add_mlir_interface(AllocationOpInterface)
44
add_mlir_interface(BufferDeallocationOpInterface)
55
add_mlir_interface(BufferizableOpInterface)
6+
add_mlir_interface(BufferViewFlowOpInterface)
67

78
set(LLVM_TARGET_DEFINITIONS BufferizationEnums.td)
89
mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls)

mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,19 @@ class BufferViewFlowAnalysis {
6363
/// results have to be changed.
6464
void rename(Value from, Value to);
6565

66+
/// Returns "true" if the given value may be a terminal.
67+
bool mayBeTerminalBuffer(Value value) const;
68+
6669
private:
6770
/// This function constructs a mapping from values to its immediate
6871
/// dependencies.
6972
void build(Operation *op);
7073

7174
/// Maps values to all immediate dependencies this value can have.
7275
ValueMapT dependencies;
76+
77+
/// A set of all SSA values that may be terminal buffers.
78+
DenseSet<Value> terminals;
7379
};
7480

7581
} // namespace mlir
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//===- BufferViewFlowOpInterfaceImpl.h - Buffer View Analysis ---*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H
10+
#define MLIR_DIALECT_MEMREF_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H
11+
12+
namespace mlir {
13+
class DialectRegistry;
14+
15+
namespace memref {
16+
void registerBufferViewFlowOpInterfaceExternalModels(DialectRegistry &registry);
17+
} // namespace memref
18+
} // namespace mlir
19+
20+
#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H

mlir/include/mlir/InitAllDialects.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/Dialect/Arith/IR/Arith.h"
2222
#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
2323
#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
24+
#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
2425
#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
2526
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
2627
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
@@ -52,6 +53,7 @@
5253
#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
5354
#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
5455
#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
56+
#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
5557
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
5658
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
5759
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
@@ -148,6 +150,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
148150
affine::registerValueBoundsOpInterfaceExternalModels(registry);
149151
arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
150152
arith::registerBufferizableOpInterfaceExternalModels(registry);
153+
arith::registerBufferViewFlowOpInterfaceExternalModels(registry);
151154
arith::registerValueBoundsOpInterfaceExternalModels(registry);
152155
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
153156
registry);
@@ -157,6 +160,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
157160
gpu::registerBufferDeallocationOpInterfaceExternalModels(registry);
158161
linalg::registerAllDialectInterfaceImplementations(registry);
159162
memref::registerAllocationOpInterfaceExternalModels(registry);
163+
memref::registerBufferViewFlowOpInterfaceExternalModels(registry);
160164
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
161165
memref::registerValueBoundsOpInterfaceExternalModels(registry);
162166
memref::registerMemorySlotExternalModels(registry);
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
//===- BufferViewFlowOpInterfaceImpl.cpp - Buffer View Flow Analysis ------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
10+
11+
#include "mlir/Dialect/Arith/IR/Arith.h"
12+
#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
13+
14+
using namespace mlir;
15+
using namespace mlir::bufferization;
16+
17+
namespace mlir {
18+
namespace arith {
19+
namespace {
20+
21+
struct SelectOpInterface
22+
: public BufferViewFlowOpInterface::ExternalModel<SelectOpInterface,
23+
SelectOp> {
24+
void
25+
populateDependencies(Operation *op,
26+
RegisterDependenciesFn registerDependenciesFn) const {
27+
auto selectOp = cast<SelectOp>(op);
28+
29+
// Either one of the true/false value may be selected at runtime.
30+
registerDependenciesFn(selectOp.getTrueValue(), selectOp.getResult());
31+
registerDependenciesFn(selectOp.getFalseValue(), selectOp.getResult());
32+
}
33+
};
34+
35+
} // namespace
36+
} // namespace arith
37+
} // namespace mlir
38+
39+
void arith::registerBufferViewFlowOpInterfaceExternalModels(
40+
DialectRegistry &registry) {
41+
registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
42+
SelectOp::attachInterface<SelectOpInterface>(*ctx);
43+
});
44+
}

mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRArithTransforms
22
BufferDeallocationOpInterfaceImpl.cpp
33
BufferizableOpInterfaceImpl.cpp
44
Bufferize.cpp
5+
BufferViewFlowOpInterfaceImpl.cpp
56
EmulateUnsupportedFloats.cpp
67
EmulateWideInt.cpp
78
EmulateNarrowType.cpp
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
//===- BufferViewFlowOpInterface.cpp - Buffer View Flow Analysis ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
10+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11+
12+
namespace mlir {
13+
namespace bufferization {
14+
15+
#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp.inc"
16+
17+
} // namespace bufferization
18+
} // namespace mlir

mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
44
BufferDeallocationOpInterface.cpp
55
BufferizationOps.cpp
66
BufferizationDialect.cpp
7+
BufferViewFlowOpInterface.cpp
78
UnstructuredControlFlow.cpp
89

910
ADDITIONAL_HEADER_DIRS

mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,16 @@
88

99
#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
1010

11+
#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
12+
#include "mlir/Interfaces/CallInterfaces.h"
1113
#include "mlir/Interfaces/ControlFlowInterfaces.h"
14+
#include "mlir/Interfaces/FunctionInterfaces.h"
1215
#include "mlir/Interfaces/ViewLikeInterface.h"
1316
#include "llvm/ADT/SetOperations.h"
1417
#include "llvm/ADT/SetVector.h"
1518

1619
using namespace mlir;
20+
using namespace mlir::bufferization;
1721

1822
/// Constructs a new alias analysis using the op provided.
1923
BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); }
@@ -65,18 +69,44 @@ void BufferViewFlowAnalysis::rename(Value from, Value to) {
6569
void BufferViewFlowAnalysis::build(Operation *op) {
6670
// Registers all dependencies of the given values.
6771
auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
68-
for (auto [value, dep] : llvm::zip(values, dependencies))
72+
for (auto [value, dep] : llvm::zip_equal(values, dependencies))
6973
this->dependencies[value].insert(dep);
7074
};
7175

76+
// Mark all buffer results and buffer region entry block arguments of the
77+
// given op as terminals.
78+
auto populateTerminalValues = [&](Operation *op) {
79+
for (Value v : op->getResults())
80+
if (isa<BaseMemRefType>(v.getType()))
81+
this->terminals.insert(v);
82+
for (Region &r : op->getRegions())
83+
for (BlockArgument v : r.getArguments())
84+
if (isa<BaseMemRefType>(v.getType()))
85+
this->terminals.insert(v);
86+
};
87+
7288
op->walk([&](Operation *op) {
73-
// TODO: We should have an op interface instead of a hard-coded list of
74-
// interfaces/ops.
89+
// Query BufferViewFlowOpInterface. If the op does not implement that
90+
// interface, try to infer the dependencies from other interfaces that the
91+
// op may implement.
92+
if (auto bufferViewFlowOp = dyn_cast<BufferViewFlowOpInterface>(op)) {
93+
bufferViewFlowOp.populateDependencies(registerDependencies);
94+
for (Value v : op->getResults())
95+
if (isa<BaseMemRefType>(v.getType()) &&
96+
bufferViewFlowOp.mayBeTerminalBuffer(v))
97+
this->terminals.insert(v);
98+
for (Region &r : op->getRegions())
99+
for (BlockArgument v : r.getArguments())
100+
if (isa<BaseMemRefType>(v.getType()) &&
101+
bufferViewFlowOp.mayBeTerminalBuffer(v))
102+
this->terminals.insert(v);
103+
return WalkResult::advance();
104+
}
75105

76106
// Add additional dependencies created by view changes to the alias list.
77107
if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) {
78-
dependencies[viewInterface.getViewSource()].insert(
79-
viewInterface->getResult(0));
108+
registerDependencies(viewInterface.getViewSource(),
109+
viewInterface->getResult(0));
80110
return WalkResult::advance();
81111
}
82112

@@ -131,16 +161,30 @@ void BufferViewFlowAnalysis::build(Operation *op) {
131161
return WalkResult::advance();
132162
}
133163

134-
// Unknown op: Assume that all operands alias with all results.
135-
for (Value operand : op->getOperands()) {
136-
if (!isa<BaseMemRefType>(operand.getType()))
137-
continue;
138-
for (Value result : op->getResults()) {
139-
if (!isa<BaseMemRefType>(result.getType()))
140-
continue;
141-
registerDependencies({operand}, {result});
142-
}
164+
// Region terminators are handled together with RegionBranchOpInterface.
165+
if (isa<RegionBranchTerminatorOpInterface>(op))
166+
return WalkResult::advance();
167+
168+
if (isa<CallOpInterface>(op)) {
169+
// This is an intra-function analysis. We have no information about other
170+
// functions. Conservatively assume that each operand may alias with each
171+
// result. Also mark the results are terminals because the function could
172+
// return newly allocated buffers.
173+
populateTerminalValues(op);
174+
for (Value operand : op->getOperands())
175+
for (Value result : op->getResults())
176+
registerDependencies({operand}, {result});
177+
return WalkResult::advance();
143178
}
179+
180+
// We have no information about unknown ops.
181+
populateTerminalValues(op);
182+
144183
return WalkResult::advance();
145184
});
146185
}
186+
187+
bool BufferViewFlowAnalysis::mayBeTerminalBuffer(Value value) const {
188+
assert(isa<BaseMemRefType>(value.getType()) && "expected memref");
189+
return terminals.contains(value);
190+
}

0 commit comments

Comments
 (0)