Skip to content

Commit a1d71c3

Browse files
authored
[flang][cuda] Additional update to ExternalNameConversion (#119276)
1 parent 650e736 commit a1d71c3

File tree

2 files changed

+48
-27
lines changed

2 files changed

+48
-27
lines changed

flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -60,23 +60,30 @@ void ExternalNameConversionPass::runOnOperation() {
6060

6161
llvm::DenseMap<mlir::StringAttr, mlir::FlatSymbolRefAttr> remappings;
6262

63+
auto processFctOrGlobal = [&](mlir::Operation &funcOrGlobal) {
64+
auto symName = funcOrGlobal.getAttrOfType<mlir::StringAttr>(
65+
mlir::SymbolTable::getSymbolAttrName());
66+
auto deconstructedName = fir::NameUniquer::deconstruct(symName);
67+
if (fir::NameUniquer::isExternalFacingUniquedName(deconstructedName)) {
68+
auto newName = mangleExternalName(deconstructedName, appendUnderscoreOpt);
69+
auto newAttr = mlir::StringAttr::get(context, newName);
70+
mlir::SymbolTable::setSymbolName(&funcOrGlobal, newAttr);
71+
auto newSymRef = mlir::FlatSymbolRefAttr::get(newAttr);
72+
remappings.try_emplace(symName, newSymRef);
73+
if (llvm::isa<mlir::func::FuncOp>(funcOrGlobal))
74+
funcOrGlobal.setAttr(fir::getInternalFuncNameAttrName(), symName);
75+
}
76+
};
77+
6378
auto renameFuncOrGlobalInModule = [&](mlir::Operation *module) {
64-
for (auto &funcOrGlobal : module->getRegion(0).front()) {
65-
if (llvm::isa<mlir::func::FuncOp>(funcOrGlobal) ||
66-
llvm::isa<fir::GlobalOp>(funcOrGlobal)) {
67-
auto symName = funcOrGlobal.getAttrOfType<mlir::StringAttr>(
68-
mlir::SymbolTable::getSymbolAttrName());
69-
auto deconstructedName = fir::NameUniquer::deconstruct(symName);
70-
if (fir::NameUniquer::isExternalFacingUniquedName(deconstructedName)) {
71-
auto newName =
72-
mangleExternalName(deconstructedName, appendUnderscoreOpt);
73-
auto newAttr = mlir::StringAttr::get(context, newName);
74-
mlir::SymbolTable::setSymbolName(&funcOrGlobal, newAttr);
75-
auto newSymRef = mlir::FlatSymbolRefAttr::get(newAttr);
76-
remappings.try_emplace(symName, newSymRef);
77-
if (llvm::isa<mlir::func::FuncOp>(funcOrGlobal))
78-
funcOrGlobal.setAttr(fir::getInternalFuncNameAttrName(), symName);
79-
}
79+
for (auto &op : module->getRegion(0).front()) {
80+
if (mlir::isa<mlir::func::FuncOp, fir::GlobalOp>(op)) {
81+
processFctOrGlobal(op);
82+
} else if (auto gpuMod = mlir::dyn_cast<mlir::gpu::GPUModuleOp>(op)) {
83+
for (auto &gpuOp : gpuMod.getBodyRegion().front())
84+
if (mlir::isa<mlir::func::FuncOp, fir::GlobalOp,
85+
mlir::gpu::GPUFuncOp>(gpuOp))
86+
processFctOrGlobal(gpuOp);
8087
}
8188
}
8289
};
@@ -85,23 +92,25 @@ void ExternalNameConversionPass::runOnOperation() {
8592
// globals.
8693
renameFuncOrGlobalInModule(op);
8794

88-
// Do the same in GPU modules.
89-
if (auto mod = mlir::dyn_cast_or_null<mlir::ModuleOp>(*op))
90-
for (auto gpuMod : mod.getOps<mlir::gpu::GPUModuleOp>())
91-
renameFuncOrGlobalInModule(gpuMod);
92-
9395
if (remappings.empty())
9496
return;
9597

9698
// Update all uses of the functions and globals that have been renamed.
9799
op.walk([&remappings](mlir::Operation *nestedOp) {
98100
llvm::SmallVector<std::pair<mlir::StringAttr, mlir::SymbolRefAttr>> updates;
99101
for (const mlir::NamedAttribute &attr : nestedOp->getAttrDictionary())
100-
if (auto symRef = llvm::dyn_cast<mlir::SymbolRefAttr>(attr.getValue()))
101-
if (auto remap = remappings.find(symRef.getRootReference());
102-
remap != remappings.end())
102+
if (auto symRef = llvm::dyn_cast<mlir::SymbolRefAttr>(attr.getValue())) {
103+
if (auto remap = remappings.find(symRef.getLeafReference());
104+
remap != remappings.end()) {
105+
mlir::SymbolRefAttr symAttr = mlir::FlatSymbolRefAttr(remap->second);
106+
if (mlir::isa<mlir::gpu::LaunchFuncOp>(nestedOp))
107+
symAttr = mlir::SymbolRefAttr::get(
108+
symRef.getRootReference(),
109+
{mlir::FlatSymbolRefAttr(remap->second)});
103110
updates.emplace_back(std::pair<mlir::StringAttr, mlir::SymbolRefAttr>{
104-
attr.getName(), mlir::SymbolRefAttr(remap->second)});
111+
attr.getName(), symAttr});
112+
}
113+
}
105114
for (auto update : updates)
106115
nestedOp->setAttr(update.first, update.second);
107116
});
Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,25 @@
11
// RUN: fir-opt --split-input-file --external-name-interop %s | FileCheck %s
22

3+
module @mod attributes {gpu.container_module} {
4+
35
gpu.module @cuda_device_mod {
4-
gpu.func @_QPfoo() {
6+
gpu.func @_QPfoo() kernel {
57
fir.call @_QPthreadfence() fastmath<contract> : () -> ()
68
gpu.return
79
}
810
func.func private @_QPthreadfence() attributes {cuf.proc_attr = #cuf.cuda_proc<device>}
911
}
1012

11-
// CHECK-LABEL: gpu.func @_QPfoo
13+
func.func @test() -> () {
14+
%0 = llvm.mlir.constant(0 : i64) : i64
15+
%1 = llvm.mlir.constant(0 : i32) : i32
16+
gpu.launch_func @cuda_device_mod::@_QPfoo blocks in (%0, %0, %0) threads in (%0, %0, %0) : i64 dynamic_shared_memory_size %1
17+
return
18+
}
19+
20+
// CHECK-LABEL: gpu.func @foo_()
1221
// CHECK: fir.call @threadfence_()
1322
// CHECK: func.func private @threadfence_()
23+
// CHECK: gpu.launch_func @cuda_device_mod::@foo_
24+
25+
}

0 commit comments

Comments
 (0)