Skip to content

Commit 72a8893

Browse files
authored
Revert "[MLIR] Add bufferization state class to OneShotBufferization pass" (#141012)
Reverts #138143 The PR for the BufferizationState is temporarily reverted due to API incompatibilities that have been initially missed during the update and were not catched by PR checks.
1 parent 11953c6 commit 72a8893

27 files changed

+86
-214
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -578,20 +578,6 @@ class AnalysisState {
578578
insideMutuallyExclusiveRegionsCache;
579579
};
580580

581-
/// BufferizationState provides information about the state of the IR during the
582-
/// bufferization process.
583-
class BufferizationState {
584-
public:
585-
/// Get a reference to the collection of cached symbol tables.
586-
SymbolTableCollection &getSymbolTables();
587-
588-
private:
589-
/// The cached symbol tables.
590-
/// The user is expected to update / invalidate the cached symbol tables if
591-
/// the bufferized operation has the Symbol or SymbolTable traits.
592-
SymbolTableCollection symbolTables;
593-
};
594-
595581
/// Create an AllocTensorOp for the given shaped value (memref or tensor).
596582
/// If `copy` is set, the shaped value is copied. Otherwise, a tensor with
597583
/// undefined contents is allocated.

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,8 +426,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
426426
/*retType=*/"::llvm::LogicalResult",
427427
/*methodName=*/"bufferize",
428428
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
429-
"const ::mlir::bufferization::BufferizationOptions &":$options,
430-
"::mlir::bufferization::BufferizationState &":$state),
429+
"const ::mlir::bufferization::BufferizationOptions &":$options),
431430
/*methodBody=*/"",
432431
/*defaultImplementation=*/[{
433432
llvm_unreachable("bufferize not implemented");

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
9393

9494
let extraClassDeclaration = [{
9595
LogicalResult bufferize(RewriterBase &rewriter,
96-
const BufferizationOptions &options,
97-
BufferizationState &state);
96+
const BufferizationOptions &options);
9897

9998
bool resultBufferizesToMemoryWrite(OpResult opResult,
10099
const AnalysisState &state);
@@ -283,8 +282,7 @@ def Bufferization_MaterializeInDestinationOp
283282

284283
let extraClassDeclaration = [{
285284
LogicalResult bufferize(RewriterBase &rewriter,
286-
const BufferizationOptions &options,
287-
BufferizationState &state);
285+
const BufferizationOptions &options);
288286

289287
bool bufferizesToMemoryRead(OpOperand &opOperand,
290288
const AnalysisState &state);
@@ -377,8 +375,7 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
377375
}
378376

379377
LogicalResult bufferize(RewriterBase &rewriter,
380-
const BufferizationOptions &options,
381-
BufferizationState &state);
378+
const BufferizationOptions &options);
382379
}];
383380
}
384381

@@ -461,8 +458,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
461458
//===------------------------------------------------------------------===//
462459

463460
LogicalResult bufferize(RewriterBase &rewriter,
464-
const BufferizationOptions &options,
465-
BufferizationState &state) const {
461+
const BufferizationOptions &options) const {
466462
// to_tensor/to_buffer pairs fold away after bufferization.
467463
return success();
468464
}
@@ -554,8 +550,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
554550
}
555551

556552
LogicalResult bufferize(RewriterBase &rewriter,
557-
const BufferizationOptions &options,
558-
BufferizationState &state);
553+
const BufferizationOptions &options);
559554
}];
560555

561556
let assemblyFormat = [{

mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ class GlobalOp;
2929
} // namespace memref
3030

3131
namespace bufferization {
32-
class BufferizationState;
3332

3433
/// A simple analysis that detects allocation operations.
3534
class BufferPlacementAllocs {
@@ -123,14 +122,9 @@ class BufferPlacementTransformationBase {
123122
// Globals are created lazily at the top of the enclosing ModuleOp with pretty
124123
// names. Duplicates are avoided.
125124
FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp constantOp,
126-
SymbolTableCollection &symbolTables,
127125
uint64_t alignment,
128126
Attribute memorySpace = {});
129127

130-
void removeSymbol(Operation *op, BufferizationState &state);
131-
132-
void insertSymbol(Operation *op, BufferizationState &state);
133-
134128
} // namespace bufferization
135129
} // namespace mlir
136130

mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ struct BufferizationStatistics {
4545
/// additional buffer copies or set "options.copyBeforeWrite = true". The
4646
/// general bufferization entry point is `runOneShotBufferize`.
4747
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
48-
BufferizationState &bufferizationState,
4948
BufferizationStatistics *statistics = nullptr);
5049

5150
/// Bufferize the signature of `block` and its callers (i.e., ops that have the

mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,6 @@ LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state,
270270
/// Run One-Shot Bufferize on the given op: Analysis + Bufferization
271271
LogicalResult
272272
runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options,
273-
BufferizationState &state,
274273
BufferizationStatistics *statistics = nullptr);
275274

276275
} // namespace bufferization

mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ namespace bufferization {
2020
struct BufferizationStatistics;
2121
class OneShotAnalysisState;
2222
struct OneShotBufferizationOptions;
23-
class BufferizationState;
2423

2524
/// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in
2625
/// `state`.
@@ -39,7 +38,6 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
3938
/// will be inserted only to these FuncOps.
4039
llvm::LogicalResult
4140
bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
42-
BufferizationState &state,
4341
BufferizationStatistics *statistics = nullptr);
4442

4543
/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
@@ -52,7 +50,7 @@ void removeBufferizationAttributesInModule(ModuleOp moduleOp);
5250
llvm::LogicalResult runOneShotModuleBufferize(
5351
ModuleOp moduleOp,
5452
const bufferization::OneShotBufferizationOptions &options,
55-
BufferizationState &state, BufferizationStatistics *statistics = nullptr);
53+
BufferizationStatistics *statistics = nullptr);
5654

5755
} // namespace bufferization
5856
} // namespace mlir

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ namespace mlir {
3030
namespace bufferization {
3131
class AllocTensorOp;
3232
class OneShotAnalysisState;
33-
class BufferizationState;
3433
} // namespace bufferization
3534

3635
namespace linalg {

mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ struct ConstantOpInterface
2424
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
2525
arith::ConstantOp> {
2626
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
27-
const BufferizationOptions &options,
28-
BufferizationState &state) const {
27+
const BufferizationOptions &options) const {
2928
auto constantOp = cast<arith::ConstantOp>(op);
3029
auto type = dyn_cast<RankedTensorType>(constantOp.getType());
3130

@@ -47,8 +46,7 @@ struct ConstantOpInterface
4746
// Create global memory segment and replace tensor with memref pointing to
4847
// that memory segment.
4948
FailureOr<memref::GlobalOp> globalOp =
50-
getGlobalFor(constantOp, state.getSymbolTables(),
51-
options.bufferAlignment, memorySpace);
49+
getGlobalFor(constantOp, options.bufferAlignment, memorySpace);
5250
if (failed(globalOp))
5351
return failure();
5452
memref::GlobalOp globalMemref = *globalOp;
@@ -85,8 +83,7 @@ struct IndexCastOpInterface
8583
}
8684

8785
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
88-
const BufferizationOptions &options,
89-
BufferizationState &state) const {
86+
const BufferizationOptions &options) const {
9087
auto castOp = cast<arith::IndexCastOp>(op);
9188
auto resultTensorType = cast<TensorType>(castOp.getType());
9289

@@ -134,8 +131,7 @@ struct SelectOpInterface
134131
}
135132

136133
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
137-
const BufferizationOptions &options,
138-
BufferizationState &state) const {
134+
const BufferizationOptions &options) const {
139135
auto selectOp = cast<arith::SelectOp>(op);
140136
Location loc = selectOp.getLoc();
141137

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,6 @@ void AnalysisState::resetCache() {
125125
insideMutuallyExclusiveRegionsCache.clear();
126126
}
127127

128-
SymbolTableCollection &BufferizationState::getSymbolTables() {
129-
return symbolTables;
130-
}
131-
132128
Region *bufferization::getNextEnclosingRepetitiveRegion(
133129
Region *region, const BufferizationOptions &options) {
134130
assert(isRepetitiveRegion(region, options) && "expected repetitive region");

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,7 @@ void mlir::bufferization::populateDynamicDimSizes(
149149
//===----------------------------------------------------------------------===//
150150

151151
LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
152-
const BufferizationOptions &options,
153-
BufferizationState &state) {
152+
const BufferizationOptions &options) {
154153
OpBuilder::InsertionGuard g(rewriter);
155154
Location loc = getLoc();
156155

@@ -530,8 +529,7 @@ void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
530529
//===----------------------------------------------------------------------===//
531530

532531
LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
533-
const BufferizationOptions &options,
534-
BufferizationState &state) {
532+
const BufferizationOptions &options) {
535533
FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
536534
if (failed(buffer))
537535
return failure();
@@ -578,8 +576,7 @@ MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
578576

579577
LogicalResult
580578
MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
581-
const BufferizationOptions &options,
582-
BufferizationState &state) {
579+
const BufferizationOptions &options) {
583580
bool tensorDest = isa<TensorType>(getDest().getType());
584581
Value buffer;
585582
if (tensorDest) {
@@ -864,8 +861,7 @@ void ToBufferOp::getCanonicalizationPatterns(RewritePatternSet &results,
864861
}
865862

866863
LogicalResult ToBufferOp::bufferize(RewriterBase &rewriter,
867-
const BufferizationOptions &options,
868-
BufferizationState &state) {
864+
const BufferizationOptions &options) {
869865
// Fold to_buffer(to_tensor(x)) to x. Insert a cast if necessary.
870866
(void)foldToBufferToTensorPair(rewriter, *this, options);
871867
// Note: The return value of `bufferize` indicates whether there was an error

mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,21 +83,17 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
8383
}
8484

8585
auto payloadOps = state.getPayloadOps(getTarget());
86-
BufferizationState bufferizationState;
87-
8886
for (Operation *target : payloadOps) {
8987
if (!isa<ModuleOp, FunctionOpInterface>(target))
9088
return emitSilenceableError() << "expected module or function target";
9189
auto moduleOp = dyn_cast<ModuleOp>(target);
9290
if (options.bufferizeFunctionBoundaries) {
9391
if (!moduleOp)
9492
return emitSilenceableError() << "expected module target";
95-
if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options,
96-
bufferizationState)))
93+
if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
9794
return emitSilenceableError() << "bufferization failed";
9895
} else {
99-
if (failed(bufferization::runOneShotBufferize(target, options,
100-
bufferizationState)))
96+
if (failed(bufferization::runOneShotBufferize(target, options)))
10197
return emitSilenceableError() << "bufferization failed";
10298
}
10399
}
@@ -166,7 +162,6 @@ class BufferizationTransformDialectExtension
166162
registerTransformOps<
167163
#define GET_OP_LIST
168164
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
169-
170165
>();
171166
}
172167
};

mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,8 @@ BufferPlacementTransformationBase::BufferPlacementTransformationBase(
103103
//===----------------------------------------------------------------------===//
104104

105105
FailureOr<memref::GlobalOp>
106-
bufferization::getGlobalFor(arith::ConstantOp constantOp,
107-
SymbolTableCollection &symbolTables,
108-
uint64_t alignment, Attribute memorySpace) {
106+
bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
107+
Attribute memorySpace) {
109108
auto type = cast<RankedTensorType>(constantOp.getType());
110109
auto moduleOp = constantOp->getParentOfType<ModuleOp>();
111110
if (!moduleOp)
@@ -128,7 +127,7 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp,
128127
// Create a builder without an insertion point. We will insert using the
129128
// symbol table to guarantee unique names.
130129
OpBuilder globalBuilder(moduleOp.getContext());
131-
SymbolTable &symbolTable = symbolTables.getSymbolTable(moduleOp);
130+
SymbolTable symbolTable(moduleOp);
132131

133132
// Create a pretty name.
134133
SmallString<64> buf;
@@ -159,19 +158,3 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp,
159158
global->moveBefore(&moduleOp.front());
160159
return global;
161160
}
162-
163-
namespace mlir::bufferization {
164-
void removeSymbol(Operation *op, BufferizationState &state) {
165-
SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
166-
op->getParentWithTrait<OpTrait::SymbolTable>());
167-
168-
symbolTable.remove(op);
169-
}
170-
171-
void insertSymbol(Operation *op, BufferizationState &state) {
172-
SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
173-
op->getParentWithTrait<OpTrait::SymbolTable>());
174-
175-
symbolTable.insert(op);
176-
}
177-
} // namespace mlir::bufferization

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -161,13 +161,10 @@ struct OneShotBufferizePass
161161
return signalPassFailure();
162162
}
163163

164-
BufferizationState state;
165-
166164
BufferizationStatistics statistics;
167165
ModuleOp moduleOp = getOperation();
168166
if (opt.bufferizeFunctionBoundaries) {
169-
if (failed(
170-
runOneShotModuleBufferize(moduleOp, opt, state, &statistics))) {
167+
if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) {
171168
signalPassFailure();
172169
return;
173170
}
@@ -178,7 +175,7 @@ struct OneShotBufferizePass
178175
"'bufferize-function-boundaries'");
179176
return signalPassFailure();
180177
}
181-
if (failed(runOneShotBufferize(moduleOp, opt, state, &statistics))) {
178+
if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) {
182179
signalPassFailure();
183180
return;
184181
}
@@ -278,7 +275,6 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
278275

279276
LogicalResult bufferization::bufferizeOp(Operation *op,
280277
const BufferizationOptions &options,
281-
BufferizationState &bufferizationState,
282278
BufferizationStatistics *statistics) {
283279
if (options.copyBeforeWrite) {
284280
AnalysisState state(options);
@@ -335,8 +331,7 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
335331
<< "//===-------------------------------------------===//\n"
336332
<< "IR after bufferizing: " << nextOp->getName() << "\n");
337333
rewriter.setInsertionPoint(nextOp);
338-
if (failed(
339-
bufferizableOp.bufferize(rewriter, options, bufferizationState))) {
334+
if (failed(bufferizableOp.bufferize(rewriter, options))) {
340335
LLVM_DEBUG(llvm::dbgs()
341336
<< "failed to bufferize\n"
342337
<< "//===-------------------------------------------===//\n");

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,7 @@ struct CallOpInterface
239239
/// All function arguments are writable. It is the responsibility of the
240240
/// CallOp to insert buffer copies where necessary.
241241
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
242-
const BufferizationOptions &options,
243-
BufferizationState &state) const {
242+
const BufferizationOptions &options) const {
244243
func::CallOp callOp = cast<func::CallOp>(op);
245244

246245
// 1. Compute the result types of the new CallOp.
@@ -350,8 +349,7 @@ struct ReturnOpInterface
350349
}
351350

352351
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
353-
const BufferizationOptions &options,
354-
BufferizationState &state) const {
352+
const BufferizationOptions &options) const {
355353
#ifndef NDEBUG
356354
auto returnOp = cast<func::ReturnOp>(op);
357355
assert(isa<FuncOp>(returnOp->getParentOp()) &&
@@ -420,8 +418,7 @@ struct FuncOpInterface
420418
/// All function bbArgs are writable unless they are explicitly marked as
421419
/// read-only. Callers must insert copies when needed.
422420
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
423-
const BufferizationOptions &options,
424-
BufferizationState &state) const {
421+
const BufferizationOptions &options) const {
425422
auto funcOp = cast<FuncOp>(op);
426423
FunctionType funcType = funcOp.getFunctionType();
427424

0 commit comments

Comments
 (0)