Skip to content

Commit f586b1e

Browse files
Tzung-Han Juangerick-xanadu
Tzung-Han Juang
andauthored
[MLIR] Make OneShotModuleBufferize use OpInterface (#107295)
**Description:** `OneShotModuleBufferize` deals with the bufferization of `FuncOp`, `CallOp` and `ReturnOp` but they are hard-coded. Any custom function-like operations will not be handled. The PR replaces a part of `FuncOp` and `CallOp` with `FunctionOpInterface` and `CallOpInterface` in `OneShotModuleBufferize` so that custom function ops and call ops can be bufferized. **Related Discord Discussion:** [Link](https://discord.com/channels/636084430946959380/642426447167881246/1280556809911799900) --------- Co-authored-by: erick-xanadu <[email protected]>
1 parent 614aeda commit f586b1e

File tree

11 files changed

+315
-281
lines changed

11 files changed

+315
-281
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "mlir/IR/Operation.h"
1313
#include "mlir/IR/PatternMatch.h"
14+
#include "mlir/Interfaces/FunctionInterfaces.h"
1415
#include "mlir/Support/LLVM.h"
1516
#include "llvm/ADT/DenseMapInfoVariant.h"
1617
#include "llvm/ADT/SetVector.h"
@@ -260,9 +261,9 @@ struct BufferizationOptions {
260261
using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
261262
/// Tensor -> MemRef type converter.
262263
/// Parameters: Value, memory space, func op, bufferization options
263-
using FunctionArgTypeConverterFn =
264-
std::function<BaseMemRefType(TensorType, Attribute memorySpace,
265-
func::FuncOp, const BufferizationOptions &)>;
264+
using FunctionArgTypeConverterFn = std::function<BaseMemRefType(
265+
TensorType, Attribute memorySpace, FunctionOpInterface,
266+
const BufferizationOptions &)>;
266267
/// Tensor -> MemRef type converter.
267268
/// Parameters: Value, memory space, bufferization options
268269
using UnknownTypeConverterFn = std::function<BaseMemRefType(

mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,24 +50,24 @@ struct FuncAnalysisState : public OneShotAnalysisState::Extension {
5050

5151
/// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg
5252
/// indices.
53-
DenseMap<FuncOp, IndexMapping> equivalentFuncArgs;
53+
DenseMap<FunctionOpInterface, IndexMapping> equivalentFuncArgs;
5454

5555
/// A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
56-
DenseMap<FuncOp, IndexToIndexListMapping> aliasingReturnVals;
56+
DenseMap<FunctionOpInterface, IndexToIndexListMapping> aliasingReturnVals;
5757

5858
/// A set of all read BlockArguments of FuncOps.
59-
DenseMap<FuncOp, BbArgIndexSet> readBbArgs;
59+
DenseMap<FunctionOpInterface, BbArgIndexSet> readBbArgs;
6060

6161
/// A set of all written-to BlockArguments of FuncOps.
62-
DenseMap<FuncOp, BbArgIndexSet> writtenBbArgs;
62+
DenseMap<FunctionOpInterface, BbArgIndexSet> writtenBbArgs;
6363

6464
/// Keep track of which FuncOps are fully analyzed or currently being
6565
/// analyzed.
66-
DenseMap<FuncOp, FuncOpAnalysisState> analyzedFuncOps;
66+
DenseMap<FunctionOpInterface, FuncOpAnalysisState> analyzedFuncOps;
6767

6868
/// This function is called right before analyzing the given FuncOp. It
6969
/// initializes the data structures for the FuncOp in this state object.
70-
void startFunctionAnalysis(FuncOp funcOp);
70+
void startFunctionAnalysis(FunctionOpInterface funcOp);
7171
};
7272

7373
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/IR/TypeUtilities.h"
1919
#include "mlir/IR/Value.h"
2020
#include "mlir/Interfaces/ControlFlowInterfaces.h"
21+
#include "mlir/Interfaces/FunctionInterfaces.h"
2122
#include "llvm/ADT/ScopeExit.h"
2223
#include "llvm/Support/Debug.h"
2324

@@ -314,7 +315,7 @@ namespace {
314315
/// Default function arg type converter: Use a fully dynamic layout map.
315316
BaseMemRefType
316317
defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
317-
func::FuncOp funcOp,
318+
FunctionOpInterface funcOp,
318319
const BufferizationOptions &options) {
319320
return getMemRefTypeWithFullyDynamicLayout(type, memorySpace);
320321
}
@@ -361,7 +362,7 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const {
361362
void BufferizationOptions::setFunctionBoundaryTypeConversion(
362363
LayoutMapOption layoutMapOption) {
363364
functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
364-
func::FuncOp funcOp,
365+
FunctionOpInterface funcOp,
365366
const BufferizationOptions &options) {
366367
if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
367368
return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace mlir {
2222
namespace bufferization {
2323
namespace func_ext {
2424

25-
void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
25+
void FuncAnalysisState::startFunctionAnalysis(FunctionOpInterface funcOp) {
2626
analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
2727
auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping());
2828
auto createdAliasingResults =

mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

Lines changed: 55 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ using namespace mlir::bufferization;
7575
using namespace mlir::bufferization::func_ext;
7676

7777
/// A mapping of FuncOps to their callers.
78-
using FuncCallerMap = DenseMap<func::FuncOp, DenseSet<Operation *>>;
78+
using FuncCallerMap = DenseMap<FunctionOpInterface, DenseSet<Operation *>>;
7979

8080
/// Get or create FuncAnalysisState.
8181
static FuncAnalysisState &
@@ -88,10 +88,11 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
8888

8989
/// Return the unique ReturnOp that terminates `funcOp`.
9090
/// Return nullptr if there is no such unique ReturnOp.
91-
static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
92-
func::ReturnOp returnOp;
93-
for (Block &b : funcOp.getBody()) {
94-
if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
91+
static Operation *getAssumedUniqueReturnOp(FunctionOpInterface funcOp) {
92+
Operation *returnOp = nullptr;
93+
for (Block &b : funcOp.getFunctionBody()) {
94+
auto candidateOp = b.getTerminator();
95+
if (candidateOp && candidateOp->hasTrait<OpTrait::ReturnLike>()) {
9596
if (returnOp)
9697
return nullptr;
9798
returnOp = candidateOp;
@@ -126,16 +127,16 @@ static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
126127
/// Store function BlockArguments that are equivalent to/aliasing a returned
127128
/// value in FuncAnalysisState.
128129
static LogicalResult
129-
aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
130+
aliasingFuncOpBBArgsAnalysis(FunctionOpInterface funcOp,
131+
OneShotAnalysisState &state,
130132
FuncAnalysisState &funcState) {
131-
if (funcOp.getBody().empty()) {
133+
if (funcOp.getFunctionBody().empty()) {
132134
// No function body available. Conservatively assume that every tensor
133135
// return value may alias with any tensor bbArg.
134-
FunctionType type = funcOp.getFunctionType();
135-
for (const auto &inputIt : llvm::enumerate(type.getInputs())) {
136+
for (const auto &inputIt : llvm::enumerate(funcOp.getArgumentTypes())) {
136137
if (!isa<TensorType>(inputIt.value()))
137138
continue;
138-
for (const auto &resultIt : llvm::enumerate(type.getResults())) {
139+
for (const auto &resultIt : llvm::enumerate(funcOp.getResultTypes())) {
139140
if (!isa<TensorType>(resultIt.value()))
140141
continue;
141142
int64_t returnIdx = resultIt.index();
@@ -147,7 +148,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
147148
}
148149

149150
// Support only single return-terminated block in the function.
150-
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
151+
Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
151152
assert(returnOp && "expected func with single return op");
152153

153154
for (OpOperand &returnVal : returnOp->getOpOperands())
@@ -168,8 +169,8 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
168169
return success();
169170
}
170171

171-
static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
172-
bool isWritten) {
172+
static void annotateFuncArgAccess(FunctionOpInterface funcOp, int64_t idx,
173+
bool isRead, bool isWritten) {
173174
OpBuilder b(funcOp.getContext());
174175
Attribute accessType;
175176
if (isRead && isWritten) {
@@ -189,12 +190,12 @@ static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
189190
/// function with unknown ops, we conservatively assume that such ops bufferize
190191
/// to a read + write.
191192
static LogicalResult
192-
funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
193+
funcOpBbArgReadWriteAnalysis(FunctionOpInterface funcOp,
194+
OneShotAnalysisState &state,
193195
FuncAnalysisState &funcState) {
194-
for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
195-
++idx) {
196+
for (int64_t idx = 0, e = funcOp.getNumArguments(); idx < e; ++idx) {
196197
// Skip non-tensor arguments.
197-
if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx)))
198+
if (!isa<TensorType>(funcOp.getArgumentTypes()[idx]))
198199
continue;
199200
bool isRead;
200201
bool isWritten;
@@ -204,7 +205,7 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
204205
StringRef str = accessAttr.getValue();
205206
isRead = str == "read" || str == "read-write";
206207
isWritten = str == "write" || str == "read-write";
207-
} else if (funcOp.getBody().empty()) {
208+
} else if (funcOp.getFunctionBody().empty()) {
208209
// If the function has no body, conservatively assume that all args are
209210
// read + written.
210211
isRead = true;
@@ -230,33 +231,32 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
230231

231232
/// Remove bufferization attributes on FuncOp arguments.
232233
static void removeBufferizationAttributes(BlockArgument bbArg) {
233-
auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp());
234+
auto funcOp = cast<FunctionOpInterface>(bbArg.getOwner()->getParentOp());
234235
funcOp.removeArgAttr(bbArg.getArgNumber(),
235236
BufferizationDialect::kBufferLayoutAttrName);
236237
funcOp.removeArgAttr(bbArg.getArgNumber(),
237238
BufferizationDialect::kWritableAttrName);
238239
}
239240

240-
/// Return the func::FuncOp called by `callOp`.
241-
static func::FuncOp getCalledFunction(func::CallOp callOp) {
241+
static FunctionOpInterface getCalledFunction(CallOpInterface callOp) {
242242
SymbolRefAttr sym =
243243
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
244244
if (!sym)
245245
return nullptr;
246-
return dyn_cast_or_null<func::FuncOp>(
246+
return dyn_cast_or_null<FunctionOpInterface>(
247247
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
248248
}
249249

250250
/// Gather equivalence info of CallOps.
251251
/// Note: This only adds new equivalence info if the called function was already
252252
/// analyzed.
253253
// TODO: This does not handle cyclic function call graphs etc.
254-
static void equivalenceAnalysis(func::FuncOp funcOp,
254+
static void equivalenceAnalysis(FunctionOpInterface funcOp,
255255
OneShotAnalysisState &state,
256256
FuncAnalysisState &funcState) {
257-
funcOp->walk([&](func::CallOp callOp) {
258-
func::FuncOp calledFunction = getCalledFunction(callOp);
259-
assert(calledFunction && "could not retrieved called func::FuncOp");
257+
funcOp->walk([&](CallOpInterface callOp) {
258+
FunctionOpInterface calledFunction = getCalledFunction(callOp);
259+
assert(calledFunction && "could not retrieved called FunctionOpInterface");
260260

261261
// No equivalence info available for the called function.
262262
if (!funcState.equivalentFuncArgs.count(calledFunction))
@@ -267,7 +267,7 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
267267
int64_t bbargIdx = it.second;
268268
if (!state.isInPlace(callOp->getOpOperand(bbargIdx)))
269269
continue;
270-
Value returnVal = callOp.getResult(returnIdx);
270+
Value returnVal = callOp->getResult(returnIdx);
271271
Value argVal = callOp->getOperand(bbargIdx);
272272
state.unionEquivalenceClasses(returnVal, argVal);
273273
}
@@ -277,11 +277,9 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
277277
}
278278

279279
/// Return "true" if the given function signature has tensor semantics.
280-
static bool hasTensorSignature(func::FuncOp funcOp) {
281-
return llvm::any_of(funcOp.getFunctionType().getInputs(),
282-
llvm::IsaPred<TensorType>) ||
283-
llvm::any_of(funcOp.getFunctionType().getResults(),
284-
llvm::IsaPred<TensorType>);
280+
static bool hasTensorSignature(FunctionOpInterface funcOp) {
281+
return llvm::any_of(funcOp.getArgumentTypes(), llvm::IsaPred<TensorType>) ||
282+
llvm::any_of(funcOp.getResultTypes(), llvm::IsaPred<TensorType>);
285283
}
286284

287285
/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
@@ -291,16 +289,16 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
291289
/// retrieve the called FuncOp from any func::CallOp.
292290
static LogicalResult
293291
getFuncOpsOrderedByCalls(ModuleOp moduleOp,
294-
SmallVectorImpl<func::FuncOp> &orderedFuncOps,
292+
SmallVectorImpl<FunctionOpInterface> &orderedFuncOps,
295293
FuncCallerMap &callerMap) {
296294
// For each FuncOp, the set of functions called by it (i.e. the union of
297295
// symbols of all nested func::CallOp).
298-
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
296+
DenseMap<FunctionOpInterface, DenseSet<FunctionOpInterface>> calledBy;
299297
// For each FuncOp, the number of func::CallOp it contains.
300-
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
301-
WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
302-
if (!funcOp.getBody().empty()) {
303-
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
298+
DenseMap<FunctionOpInterface, unsigned> numberCallOpsContainedInFuncOp;
299+
WalkResult res = moduleOp.walk([&](FunctionOpInterface funcOp) -> WalkResult {
300+
if (!funcOp.getFunctionBody().empty()) {
301+
Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
304302
if (!returnOp)
305303
return funcOp->emitError()
306304
<< "cannot bufferize a FuncOp with tensors and "
@@ -309,9 +307,10 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
309307

310308
// Collect function calls and populate the caller map.
311309
numberCallOpsContainedInFuncOp[funcOp] = 0;
312-
return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
313-
func::FuncOp calledFunction = getCalledFunction(callOp);
314-
assert(calledFunction && "could not retrieved called func::FuncOp");
310+
return funcOp.walk([&](CallOpInterface callOp) -> WalkResult {
311+
FunctionOpInterface calledFunction = getCalledFunction(callOp);
312+
assert(calledFunction &&
313+
"could not retrieved called FunctionOpInterface");
315314
// If the called function does not have any tensors in its signature, then
316315
// it is not necessary to bufferize the callee before the caller.
317316
if (!hasTensorSignature(calledFunction))
@@ -349,11 +348,11 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
349348
/// most generic layout map as function return types. After bufferizing the
350349
/// entire function body, a more concise memref type can potentially be used for
351350
/// the return type of the function.
352-
static void foldMemRefCasts(func::FuncOp funcOp) {
353-
if (funcOp.getBody().empty())
351+
static void foldMemRefCasts(FunctionOpInterface funcOp) {
352+
if (funcOp.getFunctionBody().empty())
354353
return;
355354

356-
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
355+
Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
357356
SmallVector<Type> resultTypes;
358357

359358
for (OpOperand &operand : returnOp->getOpOperands()) {
@@ -365,8 +364,8 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
365364
}
366365
}
367366

368-
auto newFuncType = FunctionType::get(
369-
funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
367+
auto newFuncType = FunctionType::get(funcOp.getContext(),
368+
funcOp.getArgumentTypes(), resultTypes);
370369
funcOp.setType(newFuncType);
371370
}
372371

@@ -379,7 +378,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
379378
FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
380379

381380
// A list of functions in the order in which they are analyzed + bufferized.
382-
SmallVector<func::FuncOp> orderedFuncOps;
381+
SmallVector<FunctionOpInterface> orderedFuncOps;
383382

384383
// A mapping of FuncOps to their callers.
385384
FuncCallerMap callerMap;
@@ -388,7 +387,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
388387
return failure();
389388

390389
// Analyze ops.
391-
for (func::FuncOp funcOp : orderedFuncOps) {
390+
for (FunctionOpInterface funcOp : orderedFuncOps) {
392391
if (!state.getOptions().isOpAllowed(funcOp))
393392
continue;
394393

@@ -416,7 +415,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
416415

417416
void mlir::bufferization::removeBufferizationAttributesInModule(
418417
ModuleOp moduleOp) {
419-
moduleOp.walk([&](func::FuncOp op) {
418+
moduleOp.walk([&](FunctionOpInterface op) {
420419
for (BlockArgument bbArg : op.getArguments())
421420
removeBufferizationAttributes(bbArg);
422421
});
@@ -430,7 +429,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
430429
IRRewriter rewriter(moduleOp.getContext());
431430

432431
// A list of functions in the order in which they are analyzed + bufferized.
433-
SmallVector<func::FuncOp> orderedFuncOps;
432+
SmallVector<FunctionOpInterface> orderedFuncOps;
434433

435434
// A mapping of FuncOps to their callers.
436435
FuncCallerMap callerMap;
@@ -439,11 +438,11 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
439438
return failure();
440439

441440
// Bufferize functions.
442-
for (func::FuncOp funcOp : orderedFuncOps) {
441+
for (FunctionOpInterface funcOp : orderedFuncOps) {
443442
// Note: It would be good to apply cleanups here but we cannot as aliasInfo
444443
// would be invalidated.
445444

446-
if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) {
445+
if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getName())) {
447446
// This function was not analyzed and RaW conflicts were not resolved.
448447
// Buffer copies must be inserted before every write.
449448
OneShotBufferizationOptions updatedOptions = options;
@@ -463,7 +462,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
463462
// Bufferize all other ops.
464463
for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
465464
// Functions were already bufferized.
466-
if (isa<func::FuncOp>(&op))
465+
if (isa<FunctionOpInterface>(&op))
467466
continue;
468467
if (failed(bufferizeOp(&op, options, statistics)))
469468
return failure();
@@ -490,12 +489,12 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
490489
// FuncOps whose names are specified in options.noAnalysisFuncFilter will
491490
// not be analyzed. Ops in these FuncOps will not be analyzed as well.
492491
OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) {
493-
auto func = dyn_cast<func::FuncOp>(op);
492+
auto func = dyn_cast<FunctionOpInterface>(op);
494493
if (!func)
495-
func = op->getParentOfType<func::FuncOp>();
494+
func = op->getParentOfType<FunctionOpInterface>();
496495
if (func)
497496
return llvm::is_contained(options.noAnalysisFuncFilter,
498-
func.getSymName());
497+
func.getName());
499498
return false;
500499
};
501500
OneShotBufferizationOptions updatedOptions(options);

0 commit comments

Comments
 (0)