@@ -60,23 +60,30 @@ void ExternalNameConversionPass::runOnOperation() {
60
60
61
61
llvm::DenseMap<mlir::StringAttr, mlir::FlatSymbolRefAttr> remappings;
62
62
63
+ auto processFctOrGlobal = [&](mlir::Operation &funcOrGlobal) {
64
+ auto symName = funcOrGlobal.getAttrOfType <mlir::StringAttr>(
65
+ mlir::SymbolTable::getSymbolAttrName ());
66
+ auto deconstructedName = fir::NameUniquer::deconstruct (symName);
67
+ if (fir::NameUniquer::isExternalFacingUniquedName (deconstructedName)) {
68
+ auto newName = mangleExternalName (deconstructedName, appendUnderscoreOpt);
69
+ auto newAttr = mlir::StringAttr::get (context, newName);
70
+ mlir::SymbolTable::setSymbolName (&funcOrGlobal, newAttr);
71
+ auto newSymRef = mlir::FlatSymbolRefAttr::get (newAttr);
72
+ remappings.try_emplace (symName, newSymRef);
73
+ if (llvm::isa<mlir::func::FuncOp>(funcOrGlobal))
74
+ funcOrGlobal.setAttr (fir::getInternalFuncNameAttrName (), symName);
75
+ }
76
+ };
77
+
63
78
auto renameFuncOrGlobalInModule = [&](mlir::Operation *module) {
64
- for (auto &funcOrGlobal : module->getRegion (0 ).front ()) {
65
- if (llvm::isa<mlir::func::FuncOp>(funcOrGlobal) ||
66
- llvm::isa<fir::GlobalOp>(funcOrGlobal)) {
67
- auto symName = funcOrGlobal.getAttrOfType <mlir::StringAttr>(
68
- mlir::SymbolTable::getSymbolAttrName ());
69
- auto deconstructedName = fir::NameUniquer::deconstruct (symName);
70
- if (fir::NameUniquer::isExternalFacingUniquedName (deconstructedName)) {
71
- auto newName =
72
- mangleExternalName (deconstructedName, appendUnderscoreOpt);
73
- auto newAttr = mlir::StringAttr::get (context, newName);
74
- mlir::SymbolTable::setSymbolName (&funcOrGlobal, newAttr);
75
- auto newSymRef = mlir::FlatSymbolRefAttr::get (newAttr);
76
- remappings.try_emplace (symName, newSymRef);
77
- if (llvm::isa<mlir::func::FuncOp>(funcOrGlobal))
78
- funcOrGlobal.setAttr (fir::getInternalFuncNameAttrName (), symName);
79
- }
79
+ for (auto &op : module->getRegion (0 ).front ()) {
80
+ if (mlir::isa<mlir::func::FuncOp, fir::GlobalOp>(op)) {
81
+ processFctOrGlobal (op);
82
+ } else if (auto gpuMod = mlir::dyn_cast<mlir::gpu::GPUModuleOp>(op)) {
83
+ for (auto &gpuOp : gpuMod.getBodyRegion ().front ())
84
+ if (mlir::isa<mlir::func::FuncOp, fir::GlobalOp,
85
+ mlir::gpu::GPUFuncOp>(gpuOp))
86
+ processFctOrGlobal (gpuOp);
80
87
}
81
88
}
82
89
};
@@ -85,23 +92,25 @@ void ExternalNameConversionPass::runOnOperation() {
85
92
// globals.
86
93
renameFuncOrGlobalInModule (op);
87
94
88
- // Do the same in GPU modules.
89
- if (auto mod = mlir::dyn_cast_or_null<mlir::ModuleOp>(*op))
90
- for (auto gpuMod : mod.getOps <mlir::gpu::GPUModuleOp>())
91
- renameFuncOrGlobalInModule (gpuMod);
92
-
93
95
if (remappings.empty ())
94
96
return ;
95
97
96
98
// Update all uses of the functions and globals that have been renamed.
97
99
op.walk ([&remappings](mlir::Operation *nestedOp) {
98
100
llvm::SmallVector<std::pair<mlir::StringAttr, mlir::SymbolRefAttr>> updates;
99
101
for (const mlir::NamedAttribute &attr : nestedOp->getAttrDictionary ())
100
- if (auto symRef = llvm::dyn_cast<mlir::SymbolRefAttr>(attr.getValue ()))
101
- if (auto remap = remappings.find (symRef.getRootReference ());
102
- remap != remappings.end ())
102
+ if (auto symRef = llvm::dyn_cast<mlir::SymbolRefAttr>(attr.getValue ())) {
103
+ if (auto remap = remappings.find (symRef.getLeafReference ());
104
+ remap != remappings.end ()) {
105
+ mlir::SymbolRefAttr symAttr = mlir::FlatSymbolRefAttr (remap->second );
106
+ if (mlir::isa<mlir::gpu::LaunchFuncOp>(nestedOp))
107
+ symAttr = mlir::SymbolRefAttr::get (
108
+ symRef.getRootReference (),
109
+ {mlir::FlatSymbolRefAttr (remap->second )});
103
110
updates.emplace_back (std::pair<mlir::StringAttr, mlir::SymbolRefAttr>{
104
- attr.getName (), mlir::SymbolRefAttr (remap->second )});
111
+ attr.getName (), symAttr});
112
+ }
113
+ }
105
114
for (auto update : updates)
106
115
nestedOp->setAttr (update.first , update.second );
107
116
});
0 commit comments