Skip to content

[MLIR] Use cached symbol tables in getFuncOpsOrderedByCalls #141967

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 3, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -310,21 +310,19 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
/// any func::CallOp.
static LogicalResult getFuncOpsOrderedByCalls(
ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap) {
SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
SymbolTableCollection &symbolTables) {
// For each FuncOp, the set of functions called by it (i.e. the union of
// symbols of all nested func::CallOp).
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
// For each FuncOp, the number of func::CallOp it contains.
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;

// TODO Avoid recomputing the symbol tables every time.
mlir::SymbolTableCollection symbolTable;

for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
// Collect function calls and populate the caller map.
numberCallOpsContainedInFuncOp[funcOp] = 0;
WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
func::FuncOp calledFunction = getCalledFunction(callOp, symbolTable);
func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
assert(calledFunction && "could not retrieved called func::FuncOp");
// If the called function does not have any tensors in its signature, then
// it is not necessary to bufferize the callee before the caller.
Expand Down Expand Up @@ -458,7 +456,8 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
FuncCallerMap callerMap;

if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
remainingFuncOps, callerMap)))
remainingFuncOps, callerMap,
funcState.symbolTables)))
return failure();

// Analyze functions in order. Starting with functions that are not calling
Expand Down Expand Up @@ -534,7 +533,8 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
// each other recursively are bufferized in an unspecified order at the end.
// We may use unnecessarily "complex" (in terms of layout map) buffer types.
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
remainingFuncOps, callerMap)))
remainingFuncOps, callerMap,
state.getSymbolTables())))
return failure();
llvm::append_range(orderedFuncOps, remainingFuncOps);

Expand Down