Skip to content

Commit 444abb3

Browse files
authored
[mlir][gpu] Add a symbol table field to TargetOptions and adjust GpuModuleToBinary (#65797)
This patch adds the option of building an optional symbol table for the top operation in the `gpu-module-to-binary` pass. The table is not created by default as most targets don't need it; instead, it is lazily built. The table is passed through a callback in `TargetOptions`. This patch is required to integrate #65539 .
1 parent 18b6724 commit 444abb3

File tree

4 files changed

+56
-18
lines changed

4 files changed

+56
-18
lines changed

mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class IRBuilderBase;
2020
}
2121

2222
namespace mlir {
23+
class SymbolTable;
2324
namespace LLVM {
2425
class ModuleTranslation;
2526
}
@@ -55,11 +56,13 @@ class TargetOptions {
5556
} CompilationTarget;
5657

5758
/// Constructor initializing the toolkit path, the list of files to link to,
58-
/// extra command line options & the compilation target. The default
59-
/// compilation target is `binary`.
59+
/// extra command line options, the compilation target and a callback for
60+
/// obtaining the parent symbol table. The default compilation target is
61+
/// `binOrFatbin`.
6062
TargetOptions(StringRef toolkitPath = {},
6163
ArrayRef<std::string> linkFiles = {}, StringRef cmdOptions = {},
62-
CompilationTarget compilationTarget = binOrFatbin);
64+
CompilationTarget compilationTarget = binOrFatbin,
65+
function_ref<SymbolTable *()> getSymbolTableCallback = {});
6366

6467
/// Returns the typeID.
6568
TypeID getTypeID() const;
@@ -80,12 +83,20 @@ class TargetOptions {
8083
/// Returns the compilation target.
8184
CompilationTarget getCompilationTarget() const;
8285

86+
/// Returns the result of the `getSymbolTableCallback` callback or a nullptr
87+
/// if no callback was provided.
88+
/// Note: The callback itself can return nullptr. It is up to the target how
89+
/// to react to getting a nullptr, e.g., emitting an error or constructing the
90+
/// table.
91+
SymbolTable *getSymbolTable() const;
92+
8393
protected:
8494
/// Derived classes must use this constructor to initialize `typeID` to the
8595
/// appropiate value: ie. `TargetOptions(TypeID::get<DerivedClass>())`.
8696
TargetOptions(TypeID typeID, StringRef toolkitPath = {},
8797
ArrayRef<std::string> linkFiles = {}, StringRef cmdOptions = {},
88-
CompilationTarget compilationTarget = binOrFatbin);
98+
CompilationTarget compilationTarget = binOrFatbin,
99+
function_ref<SymbolTable *()> getSymbolTableCallback = {});
89100

90101
/// Path to the target toolkit.
91102
std::string toolkitPath;
@@ -100,6 +111,10 @@ class TargetOptions {
100111
/// Compilation process target representation.
101112
CompilationTarget compilationTarget;
102113

114+
/// Callback for obtaining the parent symbol table of all the GPU modules
115+
/// being serialized.
116+
function_ref<SymbolTable *()> getSymbolTableCallback;
117+
103118
private:
104119
TypeID typeID;
105120
};

mlir/include/mlir/Dialect/GPU/Transforms/Passes.td

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,11 @@ def GpuModuleToBinaryPass
6464
with an object for every target.
6565

6666
The `format` argument can have the following values:
67-
1. `offloading`, `llvm`: producing an offloading representation.
68-
2. `assembly`, `isa`: producing assembly code.
69-
3. `binary`, `bin`: producing binaries.
67+
1. `offloading`, `llvm`: produces an offloading representation.
68+
2. `assembly`, `isa`: produces assembly code.
69+
3. `binary`, `bin`: produces binaries.
70+
4. `fatbinary`, `fatbin`: produces fatbinaries.
71+
5. `binOrFatbin`: produces bins or fatbins, the target decides which.
7072
}];
7173
let options = [
7274
Option<"offloadingHandler", "handler", "Attribute", "nullptr",

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1993,20 +1993,20 @@ gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
19931993
// GPU target options
19941994
//===----------------------------------------------------------------------===//
19951995

1996-
TargetOptions::TargetOptions(StringRef toolkitPath,
1997-
ArrayRef<std::string> linkFiles,
1998-
StringRef cmdOptions,
1999-
CompilationTarget compilationTarget)
1996+
TargetOptions::TargetOptions(
1997+
StringRef toolkitPath, ArrayRef<std::string> linkFiles,
1998+
StringRef cmdOptions, CompilationTarget compilationTarget,
1999+
function_ref<SymbolTable *()> getSymbolTableCallback)
20002000
: TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, linkFiles,
2001-
cmdOptions, compilationTarget) {}
2001+
cmdOptions, compilationTarget, getSymbolTableCallback) {}
20022002

2003-
TargetOptions::TargetOptions(TypeID typeID, StringRef toolkitPath,
2004-
ArrayRef<std::string> linkFiles,
2005-
StringRef cmdOptions,
2006-
CompilationTarget compilationTarget)
2003+
TargetOptions::TargetOptions(
2004+
TypeID typeID, StringRef toolkitPath, ArrayRef<std::string> linkFiles,
2005+
StringRef cmdOptions, CompilationTarget compilationTarget,
2006+
function_ref<SymbolTable *()> getSymbolTableCallback)
20072007
: toolkitPath(toolkitPath.str()), linkFiles(linkFiles),
20082008
cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget),
2009-
typeID(typeID) {}
2009+
getSymbolTableCallback(getSymbolTableCallback), typeID(typeID) {}
20102010

20112011
TypeID TargetOptions::getTypeID() const { return typeID; }
20122012

@@ -2016,6 +2016,10 @@ ArrayRef<std::string> TargetOptions::getLinkFiles() const { return linkFiles; }
20162016

20172017
StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
20182018

2019+
SymbolTable *TargetOptions::getSymbolTable() const {
2020+
return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
2021+
}
2022+
20192023
std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
20202024
TargetOptions::tokenizeCmdOptions() const {
20212025
std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;

mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,26 @@ void GpuModuleToBinaryPass::runOnOperation() {
6666
.Default(-1);
6767
if (targetFormat == -1)
6868
getOperation()->emitError() << "Invalid format specified.";
69+
70+
// Lazy symbol table builder callback.
71+
std::optional<SymbolTable> parentTable;
72+
auto lazyTableBuilder = [&]() -> SymbolTable * {
73+
// Build the table if it has not been built.
74+
if (!parentTable) {
75+
Operation *table = SymbolTable::getNearestSymbolTable(getOperation());
76+
// It's up to the target attribute to determine if failing to find a
77+
// symbol table is an error.
78+
if (!table)
79+
return nullptr;
80+
parentTable = SymbolTable(table);
81+
}
82+
return &parentTable.value();
83+
};
84+
6985
TargetOptions targetOptions(
7086
toolkitPath, linkFiles, cmdOptions,
71-
static_cast<TargetOptions::CompilationTarget>(targetFormat));
87+
static_cast<TargetOptions::CompilationTarget>(targetFormat),
88+
lazyTableBuilder);
7289
if (failed(transformGpuModulesToBinaries(
7390
getOperation(),
7491
offloadingHandler ? dyn_cast<OffloadingLLVMTranslationAttrInterface>(

0 commit comments

Comments
 (0)