Skip to content

Commit ee374ee

Browse files
committed
Fix Parsing for DeviceParser
1 parent 35fb15b commit ee374ee

File tree

3 files changed

+113
-21
lines changed

3 files changed

+113
-21
lines changed

clang/lib/Interpreter/DeviceOffload.cpp

Lines changed: 66 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ IncrementalCUDADeviceParser::IncrementalCUDADeviceParser(
2828
std::unique_ptr<CompilerInstance> DeviceInstance,
2929
CompilerInstance &HostInstance,
3030
llvm::IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> FS,
31-
llvm::Error &Err, const std::list<PartialTranslationUnit> &PTUs)
31+
llvm::Error &Err, std::list<PartialTranslationUnit> &PTUs)
3232
: IncrementalParser(*DeviceInstance, Err), PTUs(PTUs), VFS(FS),
33-
CodeGenOpts(HostInstance.getCodeGenOpts()),
34-
TargetOpts(HostInstance.getTargetOpts()) {
33+
CodeGenOpts(DeviceInstance->getCodeGenOpts()),
34+
TargetOpts(DeviceInstance->getTargetOpts()) {
3535
if (Err)
3636
return;
3737
StringRef Arch = TargetOpts.CPU;
@@ -51,37 +51,61 @@ IncrementalCUDADeviceParser::Parse(llvm::StringRef Input) {
5151
if (!PTU)
5252
return PTU.takeError();
5353

54-
auto PTX = GeneratePTX();
55-
if (!PTX)
56-
return PTX.takeError();
54+
// auto PTX = GeneratePTX();
55+
// if (!PTX)
56+
// return PTX.takeError();
5757

58-
auto Err = GenerateFatbinary();
59-
if (Err)
60-
return std::move(Err);
58+
// auto Err = GenerateFatbinary();
59+
// if (Err)
60+
// return std::move(Err);
6161

62-
std::string FatbinFileName =
63-
"/incr_module_" + std::to_string(PTUs.size()) + ".fatbin";
64-
VFS->addFile(FatbinFileName, 0,
65-
llvm::MemoryBuffer::getMemBuffer(
66-
llvm::StringRef(FatbinContent.data(), FatbinContent.size()),
67-
"", false));
62+
// std::string FatbinFileName =
63+
// "/incr_module_" + std::to_string(PTUs.size()) + ".fatbin";
64+
// VFS->addFile(FatbinFileName, 0,
65+
// llvm::MemoryBuffer::getMemBuffer(
66+
// llvm::StringRef(FatbinContent.data(), FatbinContent.size()),
67+
// "", false));
6868

69-
CodeGenOpts.CudaGpuBinaryFileName = FatbinFileName;
69+
// CodeGenOpts.CudaGpuBinaryFileName = FatbinFileName;
7070

71-
FatbinContent.clear();
71+
// FatbinContent.clear();
7272

7373
return PTU;
7474
}
7575

76+
PartialTranslationUnit &
77+
IncrementalCUDADeviceParser::RegisterPTU(TranslationUnitDecl *TU) {
78+
llvm::errs() << "[CUDA] RegisterPTU called. TU = " << TU << "\n";
79+
PTUs.push_back(PartialTranslationUnit());
80+
llvm::errs() << "[CUDA] PTUs size after push: " << PTUs.size() << "\n";
81+
PartialTranslationUnit &LastPTU = PTUs.back();
82+
LastPTU.TUPart = TU;
83+
return LastPTU;
84+
}
85+
7686
llvm::Expected<llvm::StringRef> IncrementalCUDADeviceParser::GeneratePTX() {
87+
llvm::errs() << "[CUDA] Generating PTX. PTUs size: " << PTUs.size() << "\n";
88+
assert(!PTUs.empty() && "PTUs list is empty during PTX generation!");
7789
auto &PTU = PTUs.back();
7890
std::string Error;
7991

92+
if (!PTU.TheModule) {
93+
llvm::errs() << "[CUDA] Error: PTU has no associated Module!\n";
94+
} else {
95+
llvm::errs() << "[CUDA] Module Triple: " << PTU.TheModule->getTargetTriple().str() << "\n";
96+
}
97+
98+
llvm::errs() << ">>> PTU Module Target Triple: " << PTU.TheModule->getTargetTriple().str() << "\n";
99+
llvm::errs() << ">>> Using CPU: " << TargetOpts.CPU << "\n";
100+
80101
const llvm::Target *Target = llvm::TargetRegistry::lookupTarget(
81102
PTU.TheModule->getTargetTriple(), Error);
82-
if (!Target)
103+
if (!Target) {
104+
llvm::errs() << ">>> Failed to lookup target: " << Error << "\n";
83105
return llvm::make_error<llvm::StringError>(std::move(Error),
84106
std::error_code());
107+
}
108+
85109
llvm::TargetOptions TO = llvm::TargetOptions();
86110
llvm::TargetMachine *TargetMachine = Target->createTargetMachine(
87111
PTU.TheModule->getTargetTriple(), TargetOpts.CPU, "", TO,
@@ -173,9 +197,33 @@ llvm::Error IncrementalCUDADeviceParser::GenerateFatbinary() {
173197

174198
FatbinContent.append(PTXCode.begin(), PTXCode.end());
175199

200+
std::string FatbinFileName =
201+
"/incr_module_" + std::to_string(PTUs.size()) + ".fatbin";
202+
203+
VFS->addFile(FatbinFileName, 0,
204+
llvm::MemoryBuffer::getMemBuffer(
205+
llvm::StringRef(FatbinContent.data(), FatbinContent.size()),
206+
"", false));
207+
208+
CodeGenOpts.CudaGpuBinaryFileName = FatbinFileName;
209+
210+
FatbinContent.clear();
211+
176212
return llvm::Error::success();
177213
}
178214

215+
// void IncrementalCUDADeviceParser::EmitFatbinaryToVFS(std::string &FatbinFileName) {
216+
// std::string FatbinFileName = "/incr_module_" + std::to_string(PTUs.size()) + ".fatbin";
217+
218+
// VFS->addFile(FatbinFileName, 0,
219+
// llvm::MemoryBuffer::getMemBuffer(
220+
// llvm::StringRef(FatbinContent.data(), FatbinContent.size()),
221+
// "", false));
222+
223+
// CodeGenOpts.CudaGpuBinaryFileName = FatbinFileName;
224+
// FatbinContent.clear();
225+
// }
226+
179227
IncrementalCUDADeviceParser::~IncrementalCUDADeviceParser() {}
180228

181229
} // namespace clang

clang/lib/Interpreter/DeviceOffload.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@ class CodeGenOptions;
2424
class TargetOptions;
2525

2626
class IncrementalCUDADeviceParser : public IncrementalParser {
27-
const std::list<PartialTranslationUnit> &PTUs;
27+
std::list<PartialTranslationUnit> &PTUs;
2828

2929
public:
3030
IncrementalCUDADeviceParser(
3131
std::unique_ptr<CompilerInstance> DeviceInstance,
3232
CompilerInstance &HostInstance,
3333
llvm::IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> VFS,
34-
llvm::Error &Err, const std::list<PartialTranslationUnit> &PTUs);
34+
llvm::Error &Err, std::list<PartialTranslationUnit> &PTUs);
3535

3636
llvm::Expected<TranslationUnitDecl *> Parse(llvm::StringRef Input) override;
3737

@@ -41,6 +41,9 @@ class IncrementalCUDADeviceParser : public IncrementalParser {
4141
// Generate fatbinary contents in memory
4242
llvm::Error GenerateFatbinary();
4343

44+
PartialTranslationUnit &RegisterPTU(TranslationUnitDecl *TU);
45+
// llvm::Expected<TranslationUnitDecl *> Parse(llvm::StringRef Input) override;
46+
4447
~IncrementalCUDADeviceParser();
4548

4649
protected:

clang/lib/Interpreter/Interpreter.cpp

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,9 +561,50 @@ Interpreter::Parse(llvm::StringRef Code) {
561561
// If we have a device parser, parse it first. The generated code will be
562562
// included in the host compilation
563563
if (DeviceParser) {
564+
llvm::errs() << "[CUDA] Parsing device code...\n";
564565
llvm::Expected<TranslationUnitDecl *> DeviceTU = DeviceParser->Parse(Code);
565-
if (auto E = DeviceTU.takeError())
566+
if (auto E = DeviceTU.takeError()) {
567+
llvm::errs() << "[CUDA] Device Parse failed!\n";
566568
return std::move(E);
569+
}
570+
llvm::errs() << "[CUDA] Device parse successful.\n";
571+
572+
auto *CudaParser = llvm::cast<IncrementalCUDADeviceParser>(DeviceParser.get());
573+
llvm::errs() << "[CUDA] Registering device PTU...\n";
574+
575+
PartialTranslationUnit &DevicePTU = CudaParser->RegisterPTU(*DeviceTU);
576+
FrontendAction *WrappedAct = Act->getWrapped();
577+
if (!WrappedAct->hasIRSupport()) {
578+
llvm::errs() << "[CUDA] Error: WrappedAct has no IR support!\n";
579+
return llvm::make_error<llvm::StringError>(
580+
"Device action has no IR support", llvm::inconvertibleErrorCode());
581+
}
582+
583+
CodeGenerator *CG = static_cast<CodeGenAction *>(WrappedAct)->getCodeGenerator();
584+
if (!CG) {
585+
llvm::errs() << "[CUDA] Error: CodeGen is null!\n";
586+
return llvm::make_error<llvm::StringError>(
587+
"Device CodeGen is null", llvm::inconvertibleErrorCode());
588+
}
589+
std::unique_ptr<llvm::Module> M(CG->ReleaseModule());
590+
if (!M) {
591+
llvm::errs() << "[CUDA] Error: Released module is null!\n";
592+
return llvm::make_error<llvm::StringError>(
593+
"Device LLVM module is null", llvm::inconvertibleErrorCode());
594+
}
595+
static unsigned ID = 0;
596+
CG->StartModule("incr_module_" + std::to_string(ID++), M->getContext());
597+
DevicePTU.TheModule = std::move(M);
598+
llvm::errs() << "[CUDA] Assigned LLVM module to DevicePTU\n";
599+
llvm::errs() << "[CUDA] Registered device PTU. TUPart=" << DevicePTU.TUPart << "\n";
600+
llvm::errs() << "[CUDA] Generating PTX...\n";
601+
llvm::Expected<llvm::StringRef> PTX = CudaParser->GeneratePTX();
602+
if (!PTX)
603+
return PTX.takeError();
604+
605+
llvm::Error Err = CudaParser->GenerateFatbinary();
606+
if (Err)
607+
return std::move(Err);
567608
}
568609

569610
// Tell the interpreter sliently ignore unused expressions since value

0 commit comments

Comments
 (0)