|
| 1 | +//===- Offload.cpp - LLVM Target Offload ------------------------*- C++ -*-===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// |
| 9 | +// This file defines LLVM target offload utility classes. |
| 10 | +// |
| 11 | +//===----------------------------------------------------------------------===// |
| 12 | + |
| 13 | +#include "mlir/Target/LLVM/Offload.h" |
| 14 | +#include "llvm/Frontend/Offloading/Utility.h" |
| 15 | +#include "llvm/IR/Constants.h" |
| 16 | +#include "llvm/IR/Module.h" |
| 17 | + |
| 18 | +using namespace mlir; |
| 19 | +using namespace mlir::LLVM; |
| 20 | + |
| 21 | +std::string OffloadHandler::getBeginSymbol(StringRef suffix) { |
| 22 | + return ("__begin_offload_" + suffix).str(); |
| 23 | +} |
| 24 | + |
| 25 | +std::string OffloadHandler::getEndSymbol(StringRef suffix) { |
| 26 | + return ("__end_offload_" + suffix).str(); |
| 27 | +} |
| 28 | + |
| 29 | +namespace { |
| 30 | +/// Returns the type of the entry array. |
| 31 | +llvm::ArrayType *getEntryArrayType(llvm::Module &module, size_t numElems) { |
| 32 | + return llvm::ArrayType::get(llvm::offloading::getEntryTy(module), numElems); |
| 33 | +} |
| 34 | + |
| 35 | +/// Creates the initializer of the entry array. |
| 36 | +llvm::Constant *getEntryArrayBegin(llvm::Module &module, |
| 37 | + ArrayRef<llvm::Constant *> entries) { |
| 38 | + // If there are no entries return a constant zero initializer. |
| 39 | + llvm::ArrayType *arrayTy = getEntryArrayType(module, entries.size()); |
| 40 | + return entries.empty() ? llvm::ConstantAggregateZero::get(arrayTy) |
| 41 | + : llvm::ConstantArray::get(arrayTy, entries); |
| 42 | +} |
| 43 | + |
| 44 | +/// Computes the end position of the entry array. |
| 45 | +llvm::Constant *getEntryArrayEnd(llvm::Module &module, |
| 46 | + llvm::GlobalVariable *begin, size_t numElems) { |
| 47 | + llvm::Type *intTy = module.getDataLayout().getIntPtrType(module.getContext()); |
| 48 | + return llvm::ConstantExpr::getGetElementPtr( |
| 49 | + llvm::offloading::getEntryTy(module), begin, |
| 50 | + ArrayRef<llvm::Constant *>({llvm::ConstantInt::get(intTy, numElems)}), |
| 51 | + true); |
| 52 | +} |
| 53 | +} // namespace |
| 54 | + |
| 55 | +OffloadHandler::OffloadEntryArray |
| 56 | +OffloadHandler::getEntryArray(StringRef suffix) { |
| 57 | + llvm::GlobalVariable *beginGV = |
| 58 | + module.getGlobalVariable(getBeginSymbol(suffix), true); |
| 59 | + llvm::GlobalVariable *endGV = |
| 60 | + module.getGlobalVariable(getEndSymbol(suffix), true); |
| 61 | + return {beginGV, endGV}; |
| 62 | +} |
| 63 | + |
| 64 | +OffloadHandler::OffloadEntryArray |
| 65 | +OffloadHandler::emitEmptyEntryArray(StringRef suffix) { |
| 66 | + llvm::ArrayType *arrayTy = getEntryArrayType(module, 0); |
| 67 | + auto *beginGV = new llvm::GlobalVariable( |
| 68 | + module, arrayTy, /*isConstant=*/true, llvm::GlobalValue::InternalLinkage, |
| 69 | + getEntryArrayBegin(module, {}), getBeginSymbol(suffix)); |
| 70 | + auto *endGV = new llvm::GlobalVariable( |
| 71 | + module, llvm::PointerType::get(module.getContext(), 0), |
| 72 | + /*isConstant=*/true, llvm::GlobalValue::InternalLinkage, |
| 73 | + getEntryArrayEnd(module, beginGV, 0), getEndSymbol(suffix)); |
| 74 | + return {beginGV, endGV}; |
| 75 | +} |
| 76 | + |
| 77 | +LogicalResult OffloadHandler::insertOffloadEntry(StringRef suffix, |
| 78 | + llvm::Constant *entry) { |
| 79 | + // Get the begin and end symbols to the entry array. |
| 80 | + std::string beginSymId = getBeginSymbol(suffix); |
| 81 | + llvm::GlobalVariable *beginGV = module.getGlobalVariable(beginSymId, true); |
| 82 | + llvm::GlobalVariable *endGV = |
| 83 | + module.getGlobalVariable(getEndSymbol(suffix), true); |
| 84 | + // Fail if the symbols are missing. |
| 85 | + if (!beginGV || !endGV) |
| 86 | + return failure(); |
| 87 | + // Create the entry initializer. |
| 88 | + assert(beginGV->getInitializer() && "entry array initializer is missing."); |
| 89 | + // Add existing entries into the new entry array. |
| 90 | + SmallVector<llvm::Constant *> entries; |
| 91 | + if (auto beginInit = dyn_cast_or_null<llvm::ConstantAggregate>( |
| 92 | + beginGV->getInitializer())) { |
| 93 | + for (unsigned i = 0; i < beginInit->getNumOperands(); ++i) |
| 94 | + entries.push_back(beginInit->getOperand(i)); |
| 95 | + } |
| 96 | + // Add the new entry. |
| 97 | + entries.push_back(entry); |
| 98 | + // Create a global holding the new updated set of entries. |
| 99 | + auto *arrayTy = llvm::ArrayType::get(llvm::offloading::getEntryTy(module), |
| 100 | + entries.size()); |
| 101 | + auto *entryArr = new llvm::GlobalVariable( |
| 102 | + module, arrayTy, /*isConstant=*/true, llvm::GlobalValue::InternalLinkage, |
| 103 | + getEntryArrayBegin(module, entries), beginSymId, endGV); |
| 104 | + // Replace the old entry array variable withe new one. |
| 105 | + beginGV->replaceAllUsesWith(entryArr); |
| 106 | + beginGV->eraseFromParent(); |
| 107 | + entryArr->setName(beginSymId); |
| 108 | + // Update the end symbol. |
| 109 | + endGV->setInitializer(getEntryArrayEnd(module, entryArr, entries.size())); |
| 110 | + return success(); |
| 111 | +} |
0 commit comments