Skip to content

Commit 67fc166

Browse files
authored
[MLIR] Add bufferization state class to OneShotBufferization pass (#138143)
This PR is a follow-up on #138125, and adds a bufferization state class providing information about the IR. The information currently consists of a cached list of symbol tables, which aims to solve the quadratic scaling of the bufferization task with respect to the number of symbols. The PR breaks API compatibility: the `bufferize` method of the `BufferizableOpInterface` has been enriched with a reference to a `BufferizationState` object. The bufferization state must be kept in a valid state by the interface implementations. For example, if an operation with the `Symbol` trait is inserted or replaced, its parent `SymbolTable` must be updated accordingly (see, for example, the bufferization of `arith::ConstantOp`, where the symbol table of the module gets the new global symbol inserted). Similarly, the invalidation of a symbol table must be performed if an operation with the `SymbolTable` trait is removed (this can be performed using the `invalidateSymbolTable` method, introduced in #138014).
1 parent 4fdcde5 commit 67fc166

27 files changed

+214
-86
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,20 @@ class AnalysisState {
578578
insideMutuallyExclusiveRegionsCache;
579579
};
580580

581+
/// BufferizationState provides information about the state of the IR during the
582+
/// bufferization process.
583+
class BufferizationState {
584+
public:
585+
/// Get a reference to the collection of cached symbol tables.
586+
SymbolTableCollection &getSymbolTables();
587+
588+
private:
589+
/// The cached symbol tables.
590+
/// The user is expected to update / invalidate the cached symbol tables if
591+
/// the bufferized operation has the Symbol or SymbolTable traits.
592+
SymbolTableCollection symbolTables;
593+
};
594+
581595
/// Create an AllocTensorOp for the given shaped value (memref or tensor).
582596
/// If `copy` is set, the shaped value is copied. Otherwise, a tensor with
583597
/// undefined contents is allocated.

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
426426
/*retType=*/"::llvm::LogicalResult",
427427
/*methodName=*/"bufferize",
428428
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
429-
"const ::mlir::bufferization::BufferizationOptions &":$options),
429+
"const ::mlir::bufferization::BufferizationOptions &":$options,
430+
"::mlir::bufferization::BufferizationState &":$state),
430431
/*methodBody=*/"",
431432
/*defaultImplementation=*/[{
432433
llvm_unreachable("bufferize not implemented");

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
9393

9494
let extraClassDeclaration = [{
9595
LogicalResult bufferize(RewriterBase &rewriter,
96-
const BufferizationOptions &options);
96+
const BufferizationOptions &options,
97+
BufferizationState &state);
9798

9899
bool resultBufferizesToMemoryWrite(OpResult opResult,
99100
const AnalysisState &state);
@@ -282,7 +283,8 @@ def Bufferization_MaterializeInDestinationOp
282283

283284
let extraClassDeclaration = [{
284285
LogicalResult bufferize(RewriterBase &rewriter,
285-
const BufferizationOptions &options);
286+
const BufferizationOptions &options,
287+
BufferizationState &state);
286288

287289
bool bufferizesToMemoryRead(OpOperand &opOperand,
288290
const AnalysisState &state);
@@ -375,7 +377,8 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
375377
}
376378

377379
LogicalResult bufferize(RewriterBase &rewriter,
378-
const BufferizationOptions &options);
380+
const BufferizationOptions &options,
381+
BufferizationState &state);
379382
}];
380383
}
381384

@@ -458,7 +461,8 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
458461
//===------------------------------------------------------------------===//
459462

460463
LogicalResult bufferize(RewriterBase &rewriter,
461-
const BufferizationOptions &options) const {
464+
const BufferizationOptions &options,
465+
BufferizationState &state) const {
462466
// to_tensor/to_buffer pairs fold away after bufferization.
463467
return success();
464468
}
@@ -550,7 +554,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
550554
}
551555

552556
LogicalResult bufferize(RewriterBase &rewriter,
553-
const BufferizationOptions &options);
557+
const BufferizationOptions &options,
558+
BufferizationState &state);
554559
}];
555560

556561
let assemblyFormat = [{

mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class GlobalOp;
2929
} // namespace memref
3030

3131
namespace bufferization {
32+
class BufferizationState;
3233

3334
/// A simple analysis that detects allocation operations.
3435
class BufferPlacementAllocs {
@@ -122,9 +123,14 @@ class BufferPlacementTransformationBase {
122123
// Globals are created lazily at the top of the enclosing ModuleOp with pretty
123124
// names. Duplicates are avoided.
124125
FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp constantOp,
126+
SymbolTableCollection &symbolTables,
125127
uint64_t alignment,
126128
Attribute memorySpace = {});
127129

130+
void removeSymbol(Operation *op, BufferizationState &state);
131+
132+
void insertSymbol(Operation *op, BufferizationState &state);
133+
128134
} // namespace bufferization
129135
} // namespace mlir
130136

mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ struct BufferizationStatistics {
4545
/// additional buffer copies or set "options.copyBeforeWrite = true". The
4646
/// general bufferization entry point is `runOneShotBufferize`.
4747
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
48+
BufferizationState &bufferizationState,
4849
BufferizationStatistics *statistics = nullptr);
4950

5051
/// Bufferize the signature of `block` and its callers (i.e., ops that have the

mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state,
270270
/// Run One-Shot Bufferize on the given op: Analysis + Bufferization
271271
LogicalResult
272272
runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options,
273+
BufferizationState &state,
273274
BufferizationStatistics *statistics = nullptr);
274275

275276
} // namespace bufferization

mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ namespace bufferization {
2020
struct BufferizationStatistics;
2121
class OneShotAnalysisState;
2222
struct OneShotBufferizationOptions;
23+
class BufferizationState;
2324

2425
/// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in
2526
/// `state`.
@@ -38,6 +39,7 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
3839
/// will be inserted only to these FuncOps.
3940
llvm::LogicalResult
4041
bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
42+
BufferizationState &state,
4143
BufferizationStatistics *statistics = nullptr);
4244

4345
/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
@@ -50,7 +52,7 @@ void removeBufferizationAttributesInModule(ModuleOp moduleOp);
5052
llvm::LogicalResult runOneShotModuleBufferize(
5153
ModuleOp moduleOp,
5254
const bufferization::OneShotBufferizationOptions &options,
53-
BufferizationStatistics *statistics = nullptr);
55+
BufferizationState &state, BufferizationStatistics *statistics = nullptr);
5456

5557
} // namespace bufferization
5658
} // namespace mlir

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ namespace mlir {
3030
namespace bufferization {
3131
class AllocTensorOp;
3232
class OneShotAnalysisState;
33+
class BufferizationState;
3334
} // namespace bufferization
3435

3536
namespace linalg {

mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ struct ConstantOpInterface
2424
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
2525
arith::ConstantOp> {
2626
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
27-
const BufferizationOptions &options) const {
27+
const BufferizationOptions &options,
28+
BufferizationState &state) const {
2829
auto constantOp = cast<arith::ConstantOp>(op);
2930
auto type = dyn_cast<RankedTensorType>(constantOp.getType());
3031

@@ -46,7 +47,8 @@ struct ConstantOpInterface
4647
// Create global memory segment and replace tensor with memref pointing to
4748
// that memory segment.
4849
FailureOr<memref::GlobalOp> globalOp =
49-
getGlobalFor(constantOp, options.bufferAlignment, memorySpace);
50+
getGlobalFor(constantOp, state.getSymbolTables(),
51+
options.bufferAlignment, memorySpace);
5052
if (failed(globalOp))
5153
return failure();
5254
memref::GlobalOp globalMemref = *globalOp;
@@ -83,7 +85,8 @@ struct IndexCastOpInterface
8385
}
8486

8587
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
86-
const BufferizationOptions &options) const {
88+
const BufferizationOptions &options,
89+
BufferizationState &state) const {
8790
auto castOp = cast<arith::IndexCastOp>(op);
8891
auto resultTensorType = cast<TensorType>(castOp.getType());
8992

@@ -131,7 +134,8 @@ struct SelectOpInterface
131134
}
132135

133136
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
134-
const BufferizationOptions &options) const {
137+
const BufferizationOptions &options,
138+
BufferizationState &state) const {
135139
auto selectOp = cast<arith::SelectOp>(op);
136140
Location loc = selectOp.getLoc();
137141

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ void AnalysisState::resetCache() {
125125
insideMutuallyExclusiveRegionsCache.clear();
126126
}
127127

128+
SymbolTableCollection &BufferizationState::getSymbolTables() {
129+
return symbolTables;
130+
}
131+
128132
Region *bufferization::getNextEnclosingRepetitiveRegion(
129133
Region *region, const BufferizationOptions &options) {
130134
assert(isRepetitiveRegion(region, options) && "expected repetitive region");

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ void mlir::bufferization::populateDynamicDimSizes(
149149
//===----------------------------------------------------------------------===//
150150

151151
LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
152-
const BufferizationOptions &options) {
152+
const BufferizationOptions &options,
153+
BufferizationState &state) {
153154
OpBuilder::InsertionGuard g(rewriter);
154155
Location loc = getLoc();
155156

@@ -529,7 +530,8 @@ void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
529530
//===----------------------------------------------------------------------===//
530531

531532
LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
532-
const BufferizationOptions &options) {
533+
const BufferizationOptions &options,
534+
BufferizationState &state) {
533535
FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
534536
if (failed(buffer))
535537
return failure();
@@ -576,7 +578,8 @@ MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
576578

577579
LogicalResult
578580
MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
579-
const BufferizationOptions &options) {
581+
const BufferizationOptions &options,
582+
BufferizationState &state) {
580583
bool tensorDest = isa<TensorType>(getDest().getType());
581584
Value buffer;
582585
if (tensorDest) {
@@ -861,7 +864,8 @@ void ToBufferOp::getCanonicalizationPatterns(RewritePatternSet &results,
861864
}
862865

863866
LogicalResult ToBufferOp::bufferize(RewriterBase &rewriter,
864-
const BufferizationOptions &options) {
867+
const BufferizationOptions &options,
868+
BufferizationState &state) {
865869
// Fold to_buffer(to_tensor(x)) to x. Insert a cast if necessary.
866870
(void)foldToBufferToTensorPair(rewriter, *this, options);
867871
// Note: The return value of `bufferize` indicates whether there was an error

mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,17 +83,21 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
8383
}
8484

8585
auto payloadOps = state.getPayloadOps(getTarget());
86+
BufferizationState bufferizationState;
87+
8688
for (Operation *target : payloadOps) {
8789
if (!isa<ModuleOp, FunctionOpInterface>(target))
8890
return emitSilenceableError() << "expected module or function target";
8991
auto moduleOp = dyn_cast<ModuleOp>(target);
9092
if (options.bufferizeFunctionBoundaries) {
9193
if (!moduleOp)
9294
return emitSilenceableError() << "expected module target";
93-
if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
95+
if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options,
96+
bufferizationState)))
9497
return emitSilenceableError() << "bufferization failed";
9598
} else {
96-
if (failed(bufferization::runOneShotBufferize(target, options)))
99+
if (failed(bufferization::runOneShotBufferize(target, options,
100+
bufferizationState)))
97101
return emitSilenceableError() << "bufferization failed";
98102
}
99103
}
@@ -162,6 +166,7 @@ class BufferizationTransformDialectExtension
162166
registerTransformOps<
163167
#define GET_OP_LIST
164168
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
169+
165170
>();
166171
}
167172
};

mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,9 @@ BufferPlacementTransformationBase::BufferPlacementTransformationBase(
103103
//===----------------------------------------------------------------------===//
104104

105105
FailureOr<memref::GlobalOp>
106-
bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
107-
Attribute memorySpace) {
106+
bufferization::getGlobalFor(arith::ConstantOp constantOp,
107+
SymbolTableCollection &symbolTables,
108+
uint64_t alignment, Attribute memorySpace) {
108109
auto type = cast<RankedTensorType>(constantOp.getType());
109110
auto moduleOp = constantOp->getParentOfType<ModuleOp>();
110111
if (!moduleOp)
@@ -127,7 +128,7 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
127128
// Create a builder without an insertion point. We will insert using the
128129
// symbol table to guarantee unique names.
129130
OpBuilder globalBuilder(moduleOp.getContext());
130-
SymbolTable symbolTable(moduleOp);
131+
SymbolTable &symbolTable = symbolTables.getSymbolTable(moduleOp);
131132

132133
// Create a pretty name.
133134
SmallString<64> buf;
@@ -158,3 +159,19 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
158159
global->moveBefore(&moduleOp.front());
159160
return global;
160161
}
162+
163+
namespace mlir::bufferization {
164+
void removeSymbol(Operation *op, BufferizationState &state) {
165+
SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
166+
op->getParentWithTrait<OpTrait::SymbolTable>());
167+
168+
symbolTable.remove(op);
169+
}
170+
171+
void insertSymbol(Operation *op, BufferizationState &state) {
172+
SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
173+
op->getParentWithTrait<OpTrait::SymbolTable>());
174+
175+
symbolTable.insert(op);
176+
}
177+
} // namespace mlir::bufferization

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,13 @@ struct OneShotBufferizePass
161161
return signalPassFailure();
162162
}
163163

164+
BufferizationState state;
165+
164166
BufferizationStatistics statistics;
165167
ModuleOp moduleOp = getOperation();
166168
if (opt.bufferizeFunctionBoundaries) {
167-
if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) {
169+
if (failed(
170+
runOneShotModuleBufferize(moduleOp, opt, state, &statistics))) {
168171
signalPassFailure();
169172
return;
170173
}
@@ -175,7 +178,7 @@ struct OneShotBufferizePass
175178
"'bufferize-function-boundaries'");
176179
return signalPassFailure();
177180
}
178-
if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) {
181+
if (failed(runOneShotBufferize(moduleOp, opt, state, &statistics))) {
179182
signalPassFailure();
180183
return;
181184
}
@@ -275,6 +278,7 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
275278

276279
LogicalResult bufferization::bufferizeOp(Operation *op,
277280
const BufferizationOptions &options,
281+
BufferizationState &bufferizationState,
278282
BufferizationStatistics *statistics) {
279283
if (options.copyBeforeWrite) {
280284
AnalysisState state(options);
@@ -331,7 +335,8 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
331335
<< "//===-------------------------------------------===//\n"
332336
<< "IR after bufferizing: " << nextOp->getName() << "\n");
333337
rewriter.setInsertionPoint(nextOp);
334-
if (failed(bufferizableOp.bufferize(rewriter, options))) {
338+
if (failed(
339+
bufferizableOp.bufferize(rewriter, options, bufferizationState))) {
335340
LLVM_DEBUG(llvm::dbgs()
336341
<< "failed to bufferize\n"
337342
<< "//===-------------------------------------------===//\n");

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,8 @@ struct CallOpInterface
239239
/// All function arguments are writable. It is the responsibility of the
240240
/// CallOp to insert buffer copies where necessary.
241241
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
242-
const BufferizationOptions &options) const {
242+
const BufferizationOptions &options,
243+
BufferizationState &state) const {
243244
func::CallOp callOp = cast<func::CallOp>(op);
244245

245246
// 1. Compute the result types of the new CallOp.
@@ -349,7 +350,8 @@ struct ReturnOpInterface
349350
}
350351

351352
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
352-
const BufferizationOptions &options) const {
353+
const BufferizationOptions &options,
354+
BufferizationState &state) const {
353355
#ifndef NDEBUG
354356
auto returnOp = cast<func::ReturnOp>(op);
355357
assert(isa<FuncOp>(returnOp->getParentOp()) &&
@@ -418,7 +420,8 @@ struct FuncOpInterface
418420
/// All function bbArgs are writable unless they are explicitly marked as
419421
/// read-only. Callers must insert copies when needed.
420422
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
421-
const BufferizationOptions &options) const {
423+
const BufferizationOptions &options,
424+
BufferizationState &state) const {
422425
auto funcOp = cast<FuncOp>(op);
423426
FunctionType funcType = funcOp.getFunctionType();
424427

0 commit comments

Comments
 (0)