Skip to content

[MLIR][LLVM] Block address support #134335

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -1224,6 +1224,25 @@ def LLVM_DSOLocalEquivalentAttr : LLVM_Attr<"DSOLocalEquivalent",
let assemblyFormat = "$sym";
}

//===----------------------------------------------------------------------===//
// BlockAddressAttr
//===----------------------------------------------------------------------===//

def LLVM_BlockTagAttr : LLVM_Attr<"BlockTag", "blocktag"> {
let parameters = (ins "uint32_t":$id);
let assemblyFormat = "`<` struct(params) `>`";
}

/// Folded into from LLVM_BlockAddressAttr.
def LLVM_BlockAddressAttr : LLVM_Attr<"BlockAddress", "blockaddress"> {
let description = [{
Describes a block address identified by a pair of `$function` and `$tag`.
}];
let parameters = (ins "FlatSymbolRefAttr":$function,
"BlockTagAttr":$tag);
let assemblyFormat = "`<` struct(params) `>`";
}

//===----------------------------------------------------------------------===//
// VecTypeHintAttr
//===----------------------------------------------------------------------===//
Expand Down
78 changes: 78 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1625,6 +1625,84 @@ def LLVM_DSOLocalEquivalentOp : LLVM_Op<"dso_local_equivalent",
let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
// BlockAddressOp & BlockTagOp
//===----------------------------------------------------------------------===//

def LLVM_BlockAddressOp : LLVM_Op<"blockaddress",
[Pure, ConstantLike, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let arguments = (ins LLVM_BlockAddressAttr:$block_addr);
let results = (outs LLVM_AnyPointer:$res);

let summary = "Creates a LLVM blockaddress ptr";

let description = [{
Creates an SSA value containing a pointer to a basic block. The block
address information (function and block) is given by the `BlockAddressAttr`
attribute. This operation assumes an existing `llvm.blocktag` operation
identifying an existing MLIR block within a function. Example:

```mlir
llvm.mlir.global private @g() : !llvm.ptr {
%0 = llvm.blockaddress <function = @fn, tag = <id = 0>> : !llvm.ptr
llvm.return %0 : !llvm.ptr
}

llvm.func @fn() {
llvm.br ^bb1
^bb1: // pred: ^bb0
llvm.blocktag <id = 0>
llvm.return
}
```
}];

let assemblyFormat = [{
$block_addr
attr-dict `:` qualified(type($res))
}];

let extraClassDeclaration = [{
/// Return the llvm.func operation that is referenced here.
LLVMFuncOp getFunction(SymbolTableCollection &symbolTable);

/// Search for the matching `llvm.blocktag` operation. This is performed
/// by walking the function in `block_addr`.
BlockTagOp getBlockTagOp();
}];

let hasVerifier = 1;
let hasFolder = 1;
}

def LLVM_BlockTagOp : LLVM_Op<"blocktag"> {
let description = [{
This operation uses a `tag` to uniquely identify an MLIR block in a
function. The same tag is used by `llvm.blockaddress` in order to compute
the target address.

A given function should have at most one `llvm.blocktag` operation with a
given `tag`. This operation cannot be used as a terminator but can be
placed everywhere else in a block.

Example:

```mlir
llvm.func @f() -> !llvm.ptr {
%addr = llvm.blockaddress <function = @f, tag = <id = 1>> : !llvm.ptr
llvm.br ^bb1
^bb1:
llvm.blocktag <id = 1>
llvm.return %addr : !llvm.ptr
}
```
}];
let arguments = (ins LLVM_BlockTagAttr:$tag);
let assemblyFormat = [{ $tag attr-dict }];
// Covered as part of LLVMFuncOp verifier.
let hasVerifier = 0;
}

def LLVM_ComdatSelectorOp : LLVM_Op<"comdat_selector", [Symbol]> {
let arguments = (ins
SymbolNameAttr:$sym_name,
Expand Down
35 changes: 35 additions & 0 deletions mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,29 @@ class ModuleTranslation {
return callMapping.lookup(op);
}

/// Maps a blockaddress operation to its corresponding placeholder LLVM
/// value.
void mapUnresolvedBlockAddress(BlockAddressOp op, llvm::Value *cst) {
auto result = unresolvedBlockAddressMapping.try_emplace(op, cst);
(void)result;
assert(result.second &&
"attempting to map a blockaddress that is already mapped");
}

/// Maps a blockaddress operation to its corresponding placeholder LLVM
/// value.
void mapBlockTag(BlockAddressAttr attr, BlockTagOp blockTag) {
// Attempts to map already mapped block labels which is fine if the given
// labels are verified to be unique.
blockTagMapping[attr] = blockTag;
}

/// Finds an MLIR block that corresponds to the given MLIR call
/// operation.
BlockTagOp lookupBlockTag(BlockAddressAttr attr) const {
return blockTagMapping.lookup(attr);
}

/// Removes the mapping for blocks contained in the region and values defined
/// in these blocks.
void forgetMapping(Region &region);
Expand Down Expand Up @@ -338,6 +361,8 @@ class ModuleTranslation {
LogicalResult convertFunctions();
LogicalResult convertComdats();

LogicalResult convertUnresolvedBlockAddress();

/// Handle conversion for both globals and global aliases.
///
/// - Create named global variables that correspond to llvm.mlir.global
Expand Down Expand Up @@ -433,6 +458,16 @@ class ModuleTranslation {
/// This map is populated on module entry.
DenseMap<ComdatSelectorOp, llvm::Comdat *> comdatMapping;

/// Mapping from llvm.blockaddress operations to their corresponding LLVM
/// constant placeholders. After all basic blocks are translated, this
/// mapping is used to replace the placeholders with the LLVM block addresses.
DenseMap<BlockAddressOp, llvm::Value *> unresolvedBlockAddressMapping;

/// Mapping from a BlockAddressAttr attribute to a matching BlockTagOp. This
/// is used to cache BlockTagOp locations instead of walking a LLVMFuncOp in
/// search for those.
DenseMap<BlockAddressAttr, BlockTagOp> blockTagMapping;

/// Stack of user-specified state elements, useful when translating operations
/// with regions.
SmallVector<std::unique_ptr<StackFrame>> stack;
Expand Down
75 changes: 75 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2305,6 +2305,28 @@ static LogicalResult verifyComdat(Operation *op,
return success();
}

static LogicalResult verifyBlockTags(LLVMFuncOp funcOp) {
llvm::DenseSet<BlockTagAttr> blockTags;
BlockTagOp badBlockTagOp;
if (funcOp
.walk([&](BlockTagOp blockTagOp) {
if (blockTags.contains(blockTagOp.getTag())) {
badBlockTagOp = blockTagOp;
return WalkResult::interrupt();
}
blockTags.insert(blockTagOp.getTag());
return WalkResult::advance();
})
.wasInterrupted()) {
badBlockTagOp.emitError()
<< "duplicate block tag '" << badBlockTagOp.getTag().getId()
<< "' in the same function: ";
return failure();
}

return success();
}

/// Parse common attributes that might show up in the same order in both
/// GlobalOp and AliasOp.
template <typename OpType>
Expand Down Expand Up @@ -3060,6 +3082,9 @@ LogicalResult LLVMFuncOp::verify() {
return emitError(diagnosticMessage);
}

if (failed(verifyBlockTags(*this)))
return failure();

return success();
}

Expand Down Expand Up @@ -3815,6 +3840,56 @@ void InlineAsmOp::getEffects(
}
}

//===----------------------------------------------------------------------===//
// BlockAddressOp
//===----------------------------------------------------------------------===//

LogicalResult
BlockAddressOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
Operation *symbol = symbolTable.lookupSymbolIn(parentLLVMModule(*this),
getBlockAddr().getFunction());
auto function = dyn_cast_or_null<LLVMFuncOp>(symbol);

if (!function)
return emitOpError("must reference a function defined by 'llvm.func'");

return success();
}

LLVMFuncOp BlockAddressOp::getFunction(SymbolTableCollection &symbolTable) {
return dyn_cast_or_null<LLVMFuncOp>(symbolTable.lookupSymbolIn(
parentLLVMModule(*this), getBlockAddr().getFunction()));
}

BlockTagOp BlockAddressOp::getBlockTagOp() {
auto funcOp = dyn_cast<LLVMFuncOp>(mlir::SymbolTable::lookupNearestSymbolFrom(
parentLLVMModule(*this), getBlockAddr().getFunction()));
if (!funcOp)
return nullptr;

BlockTagOp blockTagOp = nullptr;
funcOp.walk([&](LLVM::BlockTagOp labelOp) {
if (labelOp.getTag() == getBlockAddr().getTag()) {
blockTagOp = labelOp;
return WalkResult::interrupt();
}
return WalkResult::advance();
});
return blockTagOp;
}

LogicalResult BlockAddressOp::verify() {
if (!getBlockTagOp())
return emitOpError(
"expects an existing block label target in the referenced function");

return success();
}

/// Fold a blockaddress operation to a dedicated blockaddress
/// attribute.
OpFoldResult BlockAddressOp::fold(FoldAdaptor) { return getBlockAddr(); }

//===----------------------------------------------------------------------===//
// AssumeOp (intrinsic)
//===----------------------------------------------------------------------===//
Expand Down
6 changes: 4 additions & 2 deletions mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -731,8 +731,10 @@ struct LLVMInlinerInterface : public DialectInlinerInterface {
}

bool isLegalToInline(Operation *op, Region *, bool, IRMapping &) const final {
// The inliner cannot handle variadic function arguments.
return !isa<LLVM::VaStartOp>(op);
// The inliner cannot handle variadic function arguments and blocktag
// operations prevent inlining since they the blockaddress operations
// reference them via the callee symbol.
return !(isa<LLVM::VaStartOp>(op) || isa<LLVM::BlockTagOp>(op));
}

/// Handle the given inlined return by replacing it with a branch. This
Expand Down
53 changes: 53 additions & 0 deletions mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,59 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
return success();
}

// Emit blockaddress. We first need to find the LLVM block referenced by this
// operation and then create a LLVM block address for it.
if (auto blockAddressOp = dyn_cast<LLVM::BlockAddressOp>(opInst)) {
// getBlockTagOp() walks a function to search for block labels. Check
// whether it's in cache first.
BlockAddressAttr blockAddressAttr = blockAddressOp.getBlockAddr();
BlockTagOp blockTagOp = moduleTranslation.lookupBlockTag(blockAddressAttr);
if (!blockTagOp) {
blockTagOp = blockAddressOp.getBlockTagOp();
moduleTranslation.mapBlockTag(blockAddressAttr, blockTagOp);
}

llvm::Value *llvmValue = nullptr;
StringRef fnName = blockAddressAttr.getFunction().getValue();
if (llvm::BasicBlock *llvmBlock =
moduleTranslation.lookupBlock(blockTagOp->getBlock())) {
llvm::Function *llvmFn = moduleTranslation.lookupFunction(fnName);
llvmValue = llvm::BlockAddress::get(llvmFn, llvmBlock);
} else {
// The matching LLVM block is not yet emitted, a placeholder is created
// in its place. When the LLVM block is emitted later in translation,
// the llvmValue is replaced with the actual llvm::BlockAddress.
// A GlobalVariable is chosen as placeholder because in general LLVM
// constants are uniqued and are not proper for RAUW, since that could
// harm unrelated uses of the constant.
llvmValue = new llvm::GlobalVariable(
*moduleTranslation.getLLVMModule(),
llvm::PointerType::getUnqual(moduleTranslation.getLLVMContext()),
/*isConstant=*/true, llvm::GlobalValue::LinkageTypes::ExternalLinkage,
/*Initializer=*/nullptr,
Twine("__mlir_block_address_")
.concat(Twine(fnName))
.concat(Twine((uint64_t)blockAddressOp.getOperation())));
moduleTranslation.mapUnresolvedBlockAddress(blockAddressOp, llvmValue);
}

moduleTranslation.mapValue(blockAddressOp.getResult(), llvmValue);
return success();
}

// Emit block label. If this label is seen before BlockAddressOp is
// translated, go ahead and already map it.
if (auto blockTagOp = dyn_cast<LLVM::BlockTagOp>(opInst)) {
auto funcOp = blockTagOp->getParentOfType<LLVMFuncOp>();
BlockAddressAttr blockAddressAttr = BlockAddressAttr::get(
&moduleTranslation.getContext(),
FlatSymbolRefAttr::get(&moduleTranslation.getContext(),
funcOp.getName()),
blockTagOp.getTag());
moduleTranslation.mapBlockTag(blockAddressAttr, blockTagOp);
return success();
}

return failure();
}

Expand Down
27 changes: 24 additions & 3 deletions mlir/lib/Target/LLVMIR/ModuleImport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1381,9 +1381,18 @@ FailureOr<Value> ModuleImport::convertConstant(llvm::Constant *constant) {
return builder.create<LLVM::ZeroOp>(loc, targetExtType).getRes();
}

if (auto *blockAddr = dyn_cast<llvm::BlockAddress>(constant)) {
auto fnSym =
FlatSymbolRefAttr::get(context, blockAddr->getFunction()->getName());
auto blockTag =
BlockTagAttr::get(context, blockAddr->getBasicBlock()->getNumber());
return builder
.create<BlockAddressOp>(loc, convertType(blockAddr->getType()),
BlockAddressAttr::get(context, fnSym, blockTag))
.getRes();
}

StringRef error = "";
if (isa<llvm::BlockAddress>(constant))
error = " since blockaddress(...) is unsupported";

if (isa<llvm::ConstantPtrAuth>(constant))
error = " since ptrauth(...) is unsupported";
Expand Down Expand Up @@ -2448,8 +2457,13 @@ LogicalResult ModuleImport::processFunction(llvm::Function *func) {
SmallVector<llvm::BasicBlock *> reachableBasicBlocks;
for (llvm::BasicBlock &basicBlock : *func) {
// Skip unreachable blocks.
if (!reachable.contains(&basicBlock))
if (!reachable.contains(&basicBlock)) {
if (basicBlock.hasAddressTaken())
return emitError(funcOp.getLoc())
<< "unreachable block '" << basicBlock.getName()
<< "' with address taken";
continue;
}
Region &body = funcOp.getBody();
Block *block = builder.createBlock(&body, body.end());
mapBlock(&basicBlock, block);
Expand Down Expand Up @@ -2606,6 +2620,13 @@ LogicalResult ModuleImport::processBasicBlock(llvm::BasicBlock *bb,
}
}
}

if (bb->hasAddressTaken()) {
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(block);
builder.create<BlockTagOp>(block->getParentOp()->getLoc(),
BlockTagAttr::get(context, bb->getNumber()));
}
return success();
}

Expand Down
Loading