Skip to content

Commit 9d34c05

Browse files
[mlir][bufferization][NFC] Simplify bufferizeOp function signature (#68625)
Remove the `opFilter` and `copyBeforeWrite` function arguments. These options can already be configured in the `options` object.
1 parent 3d0ca2c commit 9d34c05

File tree

5 files changed

+37
-42
lines changed

5 files changed

+37
-42
lines changed

mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,19 +63,12 @@ void populateEliminateBufferizeMaterializationsPatterns(
6363
BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns);
6464

6565
/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
66-
/// If `copyBeforeWrite`, buffers are duplicated and copied before any tensor
67-
/// use that bufferizes to a memory write.
6866
///
69-
/// Note: In the general case, it unsafe to run with `copyBeforeWrite = false`
70-
/// because read-after-write conflicts may materialize during bufferization.
71-
/// `copyBeforeWrite = false` is safe only if the input IR is guaranteed to
72-
/// *not* require any out-of-place bufferization.
73-
///
74-
/// Note: This function bufferizes ops without utilizing analysis results. It
75-
/// can be used to implement partial bufferization passes.
67+
/// Note: This function does not resolve read-after-write conflicts. Use this
68+
/// function only if it is guaranteed that the input IR can bufferize without
69+
/// additional buffer copies or set "options.copyBeforeWrite = true". The
70+
/// general bufferization entry point is `runOneShotBufferize`.
7671
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
77-
bool copyBeforeWrite = true,
78-
const OpFilter *opFilter = nullptr,
7972
BufferizationStatistics *statistics = nullptr);
8073

8174
/// Bufferize the signature of `block` and its callers (i.e., ops that have the
@@ -94,6 +87,9 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
9487
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
9588
const BufferizationOptions &options);
9689

90+
/// Return `BufferizationOptions` such that the `bufferizeOp` behaves like the
91+
/// old (deprecated) partial, dialect conversion-based bufferization passes. A
92+
/// copy will be inserted before every buffer write.
9793
BufferizationOptions getPartialBufferizationOptions();
9894

9995
} // namespace bufferization

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -383,11 +383,9 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
383383
DenseSet<Operation *> &toMemrefOps,
384384
SmallVector<Operation *> &worklist,
385385
const BufferizationOptions &options,
386-
const OpFilter *opFilter,
387386
BufferizationStatistics *statistics)
388387
: IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
389-
worklist(worklist), analysisState(options), opFilter(opFilter),
390-
statistics(statistics) {
388+
worklist(worklist), analysisState(options), statistics(statistics) {
391389
setListener(this);
392390
}
393391

@@ -424,7 +422,7 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
424422

425423
// Skip ops that are not allowed to be bufferized.
426424
auto const &options = analysisState.getOptions();
427-
if (!options.isOpAllowed(op) || (opFilter && !opFilter->isOpAllowed(op)))
425+
if (!options.isOpAllowed(op))
428426
return;
429427

430428
// Add op to worklist.
@@ -445,20 +443,15 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
445443
/// bufferization options.
446444
const AnalysisState analysisState;
447445

448-
/// An extra op filter for bufferization.
449-
const OpFilter *opFilter;
450-
451446
/// Bufferization statistics for debugging.
452447
BufferizationStatistics *statistics;
453448
};
454449
} // namespace
455450

456451
LogicalResult bufferization::bufferizeOp(Operation *op,
457452
const BufferizationOptions &options,
458-
bool copyBeforeWrite,
459-
const OpFilter *opFilter,
460453
BufferizationStatistics *statistics) {
461-
if (copyBeforeWrite) {
454+
if (options.copyBeforeWrite) {
462455
AnalysisState state(options);
463456
if (failed(insertTensorCopies(op, state)))
464457
return failure();
@@ -486,7 +479,7 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
486479

487480
// Bufferize all ops.
488481
BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps,
489-
worklist, options, opFilter, statistics);
482+
worklist, options, statistics);
490483
for (unsigned i = 0; i < worklist.size(); ++i) {
491484
Operation *nextOp = worklist[i];
492485
// Skip ops that were erased.
@@ -496,7 +489,7 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
496489
auto bufferizableOp = options.dynCastBufferizableOp(nextOp);
497490
if (!bufferizableOp)
498491
continue;
499-
if (opFilter && !opFilter->isOpAllowed(nextOp))
492+
if (!options.isOpAllowed(nextOp))
500493
continue;
501494
// Skip ops that no longer have tensor semantics.
502495
if (!hasTensorSemantics(nextOp))
@@ -558,8 +551,6 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
558551
// Continue ops that are not allowed.
559552
if (!options.isOpAllowed(op))
560553
continue;
561-
if (opFilter && !opFilter->isOpAllowed(op))
562-
continue;
563554
// Ops without any uses and no side effects will fold away.
564555
if (op->getUses().empty() && isMemoryEffectFree(op))
565556
continue;
@@ -662,6 +653,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
662653
BufferizationOptions bufferization::getPartialBufferizationOptions() {
663654
BufferizationOptions options;
664655
options.allowUnknownOps = true;
656+
options.copyBeforeWrite = true;
665657
options.enforceAliasingInvariants = false;
666658
options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
667659
const BufferizationOptions &options) {

mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,6 +1334,5 @@ bufferization::runOneShotBufferize(Operation *op,
13341334
}
13351335
if (options.testAnalysisOnly)
13361336
return success();
1337-
return bufferizeOp(op, options, /*copyBeforeWrite=*/options.copyBeforeWrite,
1338-
/*opFilter=*/nullptr, statistics);
1337+
return bufferizeOp(op, options, statistics);
13391338
}

mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,8 @@ static void removeBufferizationAttributes(BlockArgument bbArg) {
238238

239239
/// Return the func::FuncOp called by `callOp`.
240240
static func::FuncOp getCalledFunction(func::CallOp callOp) {
241-
SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
241+
SymbolRefAttr sym =
242+
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
242243
if (!sym)
243244
return nullptr;
244245
return dyn_cast_or_null<func::FuncOp>(
@@ -439,12 +440,19 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
439440
for (func::FuncOp funcOp : orderedFuncOps) {
440441
// Note: It would be good to apply cleanups here but we cannot as aliasInfo
441442
// would be invalidated.
442-
bool copyBeforeWrite =
443-
options.copyBeforeWrite ||
444-
llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName());
445-
if (failed(bufferizeOp(funcOp, options, copyBeforeWrite,
446-
/*opFilter=*/nullptr, statistics)))
447-
return failure();
443+
444+
if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) {
445+
// This function was not analyzed and RaW conflicts were not resolved.
446+
// Buffer copies must be inserted before every write.
447+
OneShotBufferizationOptions updatedOptions = options;
448+
updatedOptions.copyBeforeWrite = true;
449+
if (failed(bufferizeOp(funcOp, updatedOptions, statistics)))
450+
return failure();
451+
} else {
452+
if (failed(bufferizeOp(funcOp, options, statistics)))
453+
return failure();
454+
}
455+
448456
// Change buffer return types to more precise layout maps.
449457
if (options.inferFunctionResultLayout)
450458
foldMemRefCasts(funcOp);

mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,23 +81,23 @@ class SparsificationAndBufferizationPass
8181
/// and that all required buffer copies were already inserted by
8282
/// `insertTensorCopies` in the form of `bufferization.alloc_tensor` ops.
8383
LogicalResult runDenseBufferization() {
84-
bufferization::OpFilter denseOpFilter;
85-
denseOpFilter.allowOperation([&](Operation *op) {
84+
bufferization::OneShotBufferizationOptions updatedOptions =
85+
bufferizationOptions;
86+
// Skip all sparse ops.
87+
updatedOptions.opFilter.denyOperation([&](Operation *op) {
8688
if (containsSparseTensor(TypeRange(op->getResults())) ||
8789
containsSparseTensor(TypeRange(op->getOperands())))
88-
return false;
90+
return true;
8991
if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
9092
FunctionType funcType = funcOp.getFunctionType();
9193
if (containsSparseTensor(funcType.getInputs()) ||
9294
containsSparseTensor(funcType.getResults()))
93-
return false;
95+
return true;
9496
}
95-
return true;
97+
return false;
9698
});
9799

98-
if (failed(bufferization::bufferizeOp(getOperation(), bufferizationOptions,
99-
/*copyBeforeWrite=*/false,
100-
&denseOpFilter)))
100+
if (failed(bufferization::bufferizeOp(getOperation(), updatedOptions)))
101101
return failure();
102102

103103
bufferization::removeBufferizationAttributesInModule(getOperation());

0 commit comments

Comments
 (0)