-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[MLIR][LLVM] Block address support #134335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-llvm Author: Bruno Cardoso Lopes (bcardosolopes) ChangesAdd support for import and translate. MLIR does not support using basic block references outside a function (like LLVM does), This PR does not consider changes to MLIR to that respect. It instead introduces two new ops:
Value Patch is 24.22 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/134335.diff 14 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 41c30b81770bc..4648412a2c093 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1224,6 +1224,25 @@ def LLVM_DSOLocalEquivalentAttr : LLVM_Attr<"DSOLocalEquivalent",
let assemblyFormat = "$sym";
}
+//===----------------------------------------------------------------------===//
+// BlockAddressAttr
+//===----------------------------------------------------------------------===//
+
+def LLVM_BlockTagAttr : LLVM_Attr<"BlockTag", "blocktag"> {
+ let parameters = (ins "uint32_t":$id);
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
+/// Folded into from LLVM_BlockAddressAttr.
+def LLVM_BlockAddressAttr : LLVM_Attr<"BlockAddress", "blockaddress"> {
+ let description = [{
+ Describes a block address identified by a pair of `$function` and `$tag`.
+ }];
+ let parameters = (ins "FlatSymbolRefAttr":$function,
+ "BlockTagAttr":$tag);
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
//===----------------------------------------------------------------------===//
// VecTypeHintAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 423cf948b03e1..b107b64e55b46 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1625,6 +1625,84 @@ def LLVM_DSOLocalEquivalentOp : LLVM_Op<"dso_local_equivalent",
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// BlockAddressOp & BlockTagOp
+//===----------------------------------------------------------------------===//
+
+def LLVM_BlockAddressOp : LLVM_Op<"blockaddress",
+ [Pure, ConstantLike, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+ let arguments = (ins LLVM_BlockAddressAttr:$block_addr);
+ let results = (outs LLVM_AnyPointer:$res);
+
+ let summary = "Creates a LLVM blockaddress ptr";
+
+ let description = [{
+ Creates an SSA value containing a pointer to a basic block. The block
+ address information (function and block) is given by the `BlockAddressAttr`
+ attribute. This operation assumes an existing `llvm.blocktag` operation
+ identifying an existing MLIR block within a function. Example:
+
+ ```mlir
+ llvm.mlir.global private @g() : !llvm.ptr {
+ %0 = llvm.blockaddress <function = @fn, tag = <id = 0>> : !llvm.ptr
+ llvm.return %0 : !llvm.ptr
+ }
+
+ llvm.func @fn() {
+ llvm.br ^bb1
+ ^bb1: // pred: ^bb0
+ llvm.blocktag <id = 0>
+ llvm.return
+ }
+ ```
+ }];
+
+ let assemblyFormat = [{
+ $block_addr
+ attr-dict `:` qualified(type($res))
+ }];
+
+ let extraClassDeclaration = [{
+ /// Return the llvm.func operation that is referenced here.
+ LLVMFuncOp getFunction(SymbolTableCollection &symbolTable);
+
+ /// Search for the matching `llvm.blocktag` operation. This is performed
+ /// by walking the function in `block_addr`.
+ BlockTagOp getBlockTagOp();
+ }];
+
+ let hasVerifier = 1;
+ let hasFolder = 1;
+}
+
+def LLVM_BlockTagOp : LLVM_Op<"blocktag"> {
+ let description = [{
+ This operation uses a `tag` to uniquely identify an MLIR block in a
+ function. The same tag is used by `llvm.blockaddress` in order to compute
+ the target address.
+
+ A given function should have at most one `llvm.blocktag` operation with a
+ given `tag`. This operation cannot be used as a terminator but can be
+ placed everywhere else in a block.
+
+ Example:
+
+ ```mlir
+ llvm.func @f() -> !llvm.ptr {
+ %addr = llvm.blockaddress <function = @f, tag = <id = 1>> : !llvm.ptr
+ llvm.br ^bb1
+ ^bb1:
+ llvm.blocktag <id = 1>
+ llvm.return %addr : !llvm.ptr
+ }
+ ```
+ }];
+ let arguments = (ins LLVM_BlockTagAttr:$tag);
+ let assemblyFormat = [{ $tag attr-dict }];
+ // Covered as part of LLVMFuncOp verifier.
+ let hasVerifier = 0;
+}
+
def LLVM_ComdatSelectorOp : LLVM_Op<"comdat_selector", [Symbol]> {
let arguments = (ins
SymbolNameAttr:$sym_name,
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 01dda6238d8f3..8263bb0ac42d5 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -136,6 +136,29 @@ class ModuleTranslation {
return callMapping.lookup(op);
}
+ /// Maps a blockaddress operation to its corresponding placeholder LLVM
+ /// value.
+ void mapUnresolvedBlockAddress(BlockAddressOp op, llvm::Value *cst) {
+ auto result = unresolvedBlockAddressMapping.try_emplace(op, cst);
+ (void)result;
+ assert(result.second &&
+ "attempting to map a blockaddress that is already mapped");
+ }
+
+ /// Maps a blockaddress operation to its corresponding placeholder LLVM
+ /// value.
+ void mapBlockTag(BlockAddressAttr attr, BlockTagOp blockTag) {
+ // Attempts to map already mapped block labels are fine given labels are
+ // verified to be unique.
+ blockTagMapping[attr] = blockTag;
+ }
+
+ /// Finds an MLIR block that corresponds to the given MLIR call
+ /// operation.
+ BlockTagOp lookupBlockTag(BlockAddressAttr attr) const {
+ return blockTagMapping.lookup(attr);
+ }
+
/// Removes the mapping for blocks contained in the region and values defined
/// in these blocks.
void forgetMapping(Region ®ion);
@@ -338,6 +361,8 @@ class ModuleTranslation {
LogicalResult convertFunctions();
LogicalResult convertComdats();
+ LogicalResult convertUnresolvedBlockAddress();
+
/// Handle conversion for both globals and global aliases.
///
/// - Create named global variables that correspond to llvm.mlir.global
@@ -430,6 +455,16 @@ class ModuleTranslation {
/// This map is populated on module entry.
DenseMap<ComdatSelectorOp, llvm::Comdat *> comdatMapping;
+ /// Mapping from llvm.blockaddress operations to their corresponding LLVM
+ /// constant placeholders. After all basic blocks are translated, this
+ /// mapping is used to replace the placeholders with the LLVM block addresses.
+ DenseMap<BlockAddressOp, llvm::Value *> unresolvedBlockAddressMapping;
+
+ /// Mapping from a BlockAddressAttr attribute to a matching BlockTagOp. This
+ /// is used to cache BlockTagOp locations instead of walking a LLVMFuncOp in
+ /// search for those.
+ DenseMap<BlockAddressAttr, BlockTagOp> blockTagMapping;
+
/// Stack of user-specified state elements, useful when translating operations
/// with regions.
SmallVector<std::unique_ptr<StackFrame>> stack;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 252bdd1425d5e..0deb2900fdfe5 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2305,6 +2305,28 @@ static LogicalResult verifyComdat(Operation *op,
return success();
}
+static LogicalResult verifyBlockTags(LLVMFuncOp funcOp) {
+ llvm::DenseSet<BlockTagAttr> blockTags;
+ BlockTagOp badBlockTagOp;
+ if (funcOp
+ .walk([&](BlockTagOp blockTagOp) {
+ if (blockTags.count(blockTagOp.getTag())) {
+ badBlockTagOp = blockTagOp;
+ return WalkResult::interrupt();
+ }
+ blockTags.insert(blockTagOp.getTag());
+ return WalkResult::advance();
+ })
+ .wasInterrupted()) {
+ badBlockTagOp.emitError()
+ << "duplicate block tag '" << badBlockTagOp.getTag().getId()
+ << "' in the same function: ";
+ return failure();
+ }
+
+ return success();
+}
+
/// Parse common attributes that might show up in the same order in both
/// GlobalOp and AliasOp.
template <typename OpType>
@@ -3060,6 +3082,9 @@ LogicalResult LLVMFuncOp::verify() {
return emitError(diagnosticMessage);
}
+ if (failed(verifyBlockTags(*this)))
+ return failure();
+
return success();
}
@@ -3815,6 +3840,55 @@ void InlineAsmOp::getEffects(
}
}
+//===----------------------------------------------------------------------===//
+// BlockAddressOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+BlockAddressOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ Operation *symbol = symbolTable.lookupSymbolIn(parentLLVMModule(*this),
+ getBlockAddr().getFunction());
+ auto function = dyn_cast_or_null<LLVMFuncOp>(symbol);
+
+ if (!function)
+ return emitOpError("must reference a function defined by 'llvm.func'");
+
+ return success();
+}
+
+LLVMFuncOp BlockAddressOp::getFunction(SymbolTableCollection &symbolTable) {
+ return dyn_cast_or_null<LLVMFuncOp>(symbolTable.lookupSymbolIn(
+ parentLLVMModule(*this), getBlockAddr().getFunction()));
+}
+
+BlockTagOp BlockAddressOp::getBlockTagOp() {
+ auto m = (*this)->getParentOfType<ModuleOp>();
+ auto funcOp = cast<LLVMFuncOp>(mlir::SymbolTable::lookupNearestSymbolFrom(
+ m, getBlockAddr().getFunction()));
+
+ BlockTagOp blockTagOp = nullptr;
+ funcOp.walk([&](LLVM::BlockTagOp labelOp) {
+ if (labelOp.getTag() == getBlockAddr().getTag()) {
+ blockTagOp = labelOp;
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+ return blockTagOp;
+}
+
+LogicalResult BlockAddressOp::verify() {
+ if (!getBlockTagOp())
+ return emitOpError(
+ "expects an existing block label target in the referenced function");
+
+ return success();
+}
+
+/// Fold a blockaddress operation to a dedicated blockaddress
+/// attribute.
+OpFoldResult BlockAddressOp::fold(FoldAdaptor) { return getBlockAddr(); }
+
//===----------------------------------------------------------------------===//
// AssumeOp (intrinsic)
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 10b68a333bcbd..738f036bb376a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -555,6 +555,59 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
return success();
}
+ // Emit blockaddress. We first need to find the LLVM block referenced by this
+ // operation and then create a LLVM block address for it.
+ if (auto blockAddressOp = dyn_cast<LLVM::BlockAddressOp>(opInst)) {
+ // getBlockTagOp() walks a function to search for block labels. Check
+ // whether it's in cache first.
+ BlockAddressAttr blockAddressAttr = blockAddressOp.getBlockAddr();
+ BlockTagOp blockTagOp = moduleTranslation.lookupBlockTag(blockAddressAttr);
+ if (!blockTagOp) {
+ blockTagOp = blockAddressOp.getBlockTagOp();
+ moduleTranslation.mapBlockTag(blockAddressAttr, blockTagOp);
+ }
+
+ llvm::Value *llvmValue = nullptr;
+ StringRef fnName = blockAddressAttr.getFunction().getValue();
+ if (llvm::BasicBlock *llvmBlock =
+ moduleTranslation.lookupBlock(blockTagOp->getBlock())) {
+ llvm::Function *llvmFn = moduleTranslation.lookupFunction(fnName);
+ llvmValue = llvm::BlockAddress::get(llvmFn, llvmBlock);
+ } else {
+ // The matching LLVM block is not yet emitted, a placeholder is created
+ // in its place. When the LLVM block is emitted later in translation,
+ // the llvmValue is replaced with the actual llvm::BlockAddress.
+ // A GlobalVariable is chosen as placeholder because in general LLVM
+ // constants are uniqued and are not proper for RAUW, since that could
+ // harm unrelated uses of the constant.
+ llvmValue = new llvm::GlobalVariable(
+ *moduleTranslation.getLLVMModule(),
+ llvm::PointerType::getUnqual(moduleTranslation.getLLVMContext()),
+ /*isConstant=*/true, llvm::GlobalValue::LinkageTypes::ExternalLinkage,
+ /*Initializer=*/nullptr,
+ Twine("__mlir_block_address_")
+ .concat(Twine(fnName))
+ .concat(Twine((uint64_t)blockAddressOp.getOperation())));
+ moduleTranslation.mapUnresolvedBlockAddress(blockAddressOp, llvmValue);
+ }
+
+ moduleTranslation.mapValue(blockAddressOp.getResult(), llvmValue);
+ return success();
+ }
+
+ // Emit block label. If this label is seen before BlockAddressOp is
+ // translated, go ahead and already map it.
+ if (auto blockTagOp = dyn_cast<LLVM::BlockTagOp>(opInst)) {
+ auto funcOp = blockTagOp->getParentOfType<LLVMFuncOp>();
+ BlockAddressAttr blockAddressAttr = BlockAddressAttr::get(
+ &moduleTranslation.getContext(),
+ FlatSymbolRefAttr::get(&moduleTranslation.getContext(),
+ funcOp.getName()),
+ blockTagOp.getTag());
+ moduleTranslation.mapBlockTag(blockAddressAttr, blockTagOp);
+ return success();
+ }
+
return failure();
}
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index ea141d8b07284..7635b21d01d4b 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1362,9 +1362,19 @@ FailureOr<Value> ModuleImport::convertConstant(llvm::Constant *constant) {
return builder.create<LLVM::ZeroOp>(loc, targetExtType).getRes();
}
+ if (auto *blockAddr = dyn_cast<llvm::BlockAddress>(constant))
+ return builder
+ .create<BlockAddressOp>(
+ loc, convertType(blockAddr->getType()),
+ BlockAddressAttr::get(
+ context,
+ FlatSymbolRefAttr::get(context,
+ blockAddr->getFunction()->getName()),
+ BlockTagAttr::get(context,
+ blockAddr->getBasicBlock()->getNumber())))
+ .getRes();
+
StringRef error = "";
- if (isa<llvm::BlockAddress>(constant))
- error = " since blockaddress(...) is unsupported";
if (isa<llvm::ConstantPtrAuth>(constant))
error = " since ptrauth(...) is unsupported";
@@ -2429,8 +2439,13 @@ LogicalResult ModuleImport::processFunction(llvm::Function *func) {
SmallVector<llvm::BasicBlock *> reachableBasicBlocks;
for (llvm::BasicBlock &basicBlock : *func) {
// Skip unreachable blocks.
- if (!reachable.contains(&basicBlock))
+ if (!reachable.contains(&basicBlock)) {
+ if (basicBlock.hasAddressTaken())
+ emitWarning(funcOp.getLoc())
+ << "unreachable block '" << basicBlock.getName()
+ << "' with address taken";
continue;
+ }
Region &body = funcOp.getBody();
Block *block = builder.createBlock(&body, body.end());
mapBlock(&basicBlock, block);
@@ -2587,6 +2602,13 @@ LogicalResult ModuleImport::processBasicBlock(llvm::BasicBlock *bb,
}
}
}
+
+ if (bb->hasAddressTaken()) {
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(block);
+ builder.create<BlockTagOp>(block->getParentOp()->getLoc(),
+ BlockTagAttr::get(context, bb->getNumber()));
+ }
return success();
}
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 1e2f2c0468045..475d8b72c52e2 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1824,6 +1824,27 @@ LogicalResult ModuleTranslation::convertComdats() {
return success();
}
+LogicalResult ModuleTranslation::convertUnresolvedBlockAddress() {
+ for (auto &[blockAddressOp, llvmCst] : unresolvedBlockAddressMapping) {
+ BlockAddressAttr blockAddressAttr = blockAddressOp.getBlockAddr();
+ BlockTagOp blockTagOp = lookupBlockTag(blockAddressAttr);
+ assert(blockTagOp && "expected all block tags to be already seen");
+
+ llvm::BasicBlock *llvmBlock = lookupBlock(blockTagOp->getBlock());
+ assert(llvmBlock && "expected LLVM blocks to be already translated");
+
+ // Update mapping with new block address constant.
+ auto *llvmBlockAddr = llvm::BlockAddress::get(
+ lookupFunction(blockAddressAttr.getFunction().getValue()), llvmBlock);
+ llvmCst->replaceAllUsesWith(llvmBlockAddr);
+ mapValue(blockAddressOp.getResult(), llvmBlockAddr);
+ assert(llvmCst->use_empty() && "expected all uses to be replaced");
+ cast<llvm::GlobalVariable>(llvmCst)->eraseFromParent();
+ }
+ unresolvedBlockAddressMapping.clear();
+ return success();
+}
+
void ModuleTranslation::setAccessGroupsMetadata(AccessGroupOpInterface op,
llvm::Instruction *inst) {
if (llvm::MDNode *node = loopAnnotationTranslation->getAccessGroups(op))
@@ -2218,6 +2239,11 @@ mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext,
if (failed(translator.convertFunctions()))
return nullptr;
+ // Now that all MLIR blocks are resolved into LLVM ones, patch block address
+ // constants to point to the correct blocks.
+ if (failed(translator.convertUnresolvedBlockAddress()))
+ return nullptr;
+
// Once we've finished constructing elements in the module, we should convert
// it to use the debug info format desired by LLVM.
// See https://llvm.org/docs/RemoveDIsDebugInfo.html
diff --git a/mlir/test/Dialect/LLVMIR/blockaddress-canonicalize.mlir b/mlir/test/Dialect/LLVMIR/blockaddress-canonicalize.mlir
new file mode 100644
index 0000000000000..0f18cc6d8cd3e
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/blockaddress-canonicalize.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s -canonicalize | FileCheck %s
+
+// CHECK-LABEL: llvm.func @ba()
+llvm.func @ba() -> !llvm.ptr {
+ %0 = llvm.blockaddress <function = @ba, tag = <id = 1>> : !llvm.ptr
+ llvm.br ^bb1
+^bb1:
+ // CHECK: llvm.blocktag <id = 1>
+ llvm.blocktag <id = 1>
+ llvm.return %0 : !llvm.ptr
+}
diff --git a/mlir/test/Dialect/LLVMIR/constant-folding.mlir b/mlir/test/Dialect/LLVMIR/constant-folding.mlir
index 99f657f0aefec..0616f19b8fddb 100644
--- a/mlir/test/Dialect/LLVMIR/constant-folding.mlir
+++ b/mlir/test/Dialect/LLVMIR/constant-folding.mlir
@@ -196,3 +196,18 @@ llvm.func @dso_local_equivalent_select(%arg: i1) -> !llvm.ptr {
}
llvm.func @yay()
+
+// -----
+
+// CHECK-LABEL: llvm.func @blockaddress_select
+llvm.func @blockaddress_select(%arg: i1) -> !llvm.ptr {
+ // CHECK-NEXT: %[[ADDR:.+]] = llvm.blockaddress <function = @blockaddress_select, tag = <id = 1>>
+ %0 = llvm.blockaddress <function = @blockaddress_select, tag = <id = 1>> : !llvm.ptr
+ %1 = llvm.blockaddress <function = @blockaddress_select, tag = <id = 1>> : !llvm.ptr
+ %2 = arith.select %arg, %0, %1 : !llvm.ptr
+ // CHECK-NEXT: llvm.br ^bb1
+ llvm.br ^bb1
+^bb1:
+ llvm.blocktag <id = 1>
+ llvm.return %1 : !llvm.ptr
+}
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index fb9631d99b91a..e70e3185af236 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1780,3 +1780,25 @@ module {
// expected-error@+1 {{failed to parse ModuleFlagAttr parameter 'value' which is to be a `uint32_t`}}
llvm.module_flags [#llvm.mlir.module_flag<error, "wchar_size", "yolo">]
}
+
+// -----
+
+llvm.func @t0() -> !llvm.ptr {
+ %0 = llvm.blockaddress <function = @t0, tag = <id = 1>> : !llvm.ptr
+ llvm.blocktag <id = 1>
+ llvm.br ^bb1
+^bb1:
+ // expected-error@+1 {{duplicate block tag '1' in the same function}}
+ llvm.blocktag <id = 1>
+ llvm.return %0 : !llvm.ptr
+}
+
+// -----
+
+llvm.func @t1() -> !llvm.ptr {
+ // expected-error@+1 {{expects an existing block label target in the referenced function}}
+ %0 = llvm.blockaddress <function = @t1, tag = <id = 1>> : !llvm.ptr
+ llvm.br ^bb1
+^bb1:
+ llvm.return %0 : !llvm.ptr
+}
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index d0aa65d14a176..88460fe374d87 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -1002,3 +1002,24 @@ llvm.func @intrinsic_call_arg_attrs_bundles(%arg0: i32) -> i32 {
%0 = llvm.call_intrinsic "llvm.riscv.sha256s...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Bruno Cardoso Lopes (bcardosolopes) ChangesAdd support for import and translate. MLIR does not support using basic block references outside a function (like LLVM does), This PR does not consider changes to MLIR to that respect. It instead introduces two new ops:
Value Patch is 24.22 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/134335.diff 14 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 41c30b81770bc..4648412a2c093 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1224,6 +1224,25 @@ def LLVM_DSOLocalEquivalentAttr : LLVM_Attr<"DSOLocalEquivalent",
let assemblyFormat = "$sym";
}
+//===----------------------------------------------------------------------===//
+// BlockAddressAttr
+//===----------------------------------------------------------------------===//
+
+def LLVM_BlockTagAttr : LLVM_Attr<"BlockTag", "blocktag"> {
+ let parameters = (ins "uint32_t":$id);
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
+/// Folded into from LLVM_BlockAddressAttr.
+def LLVM_BlockAddressAttr : LLVM_Attr<"BlockAddress", "blockaddress"> {
+ let description = [{
+ Describes a block address identified by a pair of `$function` and `$tag`.
+ }];
+ let parameters = (ins "FlatSymbolRefAttr":$function,
+ "BlockTagAttr":$tag);
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
//===----------------------------------------------------------------------===//
// VecTypeHintAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 423cf948b03e1..b107b64e55b46 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1625,6 +1625,84 @@ def LLVM_DSOLocalEquivalentOp : LLVM_Op<"dso_local_equivalent",
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// BlockAddressOp & BlockTagOp
+//===----------------------------------------------------------------------===//
+
+def LLVM_BlockAddressOp : LLVM_Op<"blockaddress",
+ [Pure, ConstantLike, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+ let arguments = (ins LLVM_BlockAddressAttr:$block_addr);
+ let results = (outs LLVM_AnyPointer:$res);
+
+ let summary = "Creates a LLVM blockaddress ptr";
+
+ let description = [{
+ Creates an SSA value containing a pointer to a basic block. The block
+ address information (function and block) is given by the `BlockAddressAttr`
+ attribute. This operation assumes an existing `llvm.blocktag` operation
+ identifying an existing MLIR block within a function. Example:
+
+ ```mlir
+ llvm.mlir.global private @g() : !llvm.ptr {
+ %0 = llvm.blockaddress <function = @fn, tag = <id = 0>> : !llvm.ptr
+ llvm.return %0 : !llvm.ptr
+ }
+
+ llvm.func @fn() {
+ llvm.br ^bb1
+ ^bb1: // pred: ^bb0
+ llvm.blocktag <id = 0>
+ llvm.return
+ }
+ ```
+ }];
+
+ let assemblyFormat = [{
+ $block_addr
+ attr-dict `:` qualified(type($res))
+ }];
+
+ let extraClassDeclaration = [{
+ /// Return the llvm.func operation that is referenced here.
+ LLVMFuncOp getFunction(SymbolTableCollection &symbolTable);
+
+ /// Search for the matching `llvm.blocktag` operation. This is performed
+ /// by walking the function in `block_addr`.
+ BlockTagOp getBlockTagOp();
+ }];
+
+ let hasVerifier = 1;
+ let hasFolder = 1;
+}
+
+def LLVM_BlockTagOp : LLVM_Op<"blocktag"> {
+ let description = [{
+ This operation uses a `tag` to uniquely identify an MLIR block in a
+ function. The same tag is used by `llvm.blockaddress` in order to compute
+ the target address.
+
+ A given function should have at most one `llvm.blocktag` operation with a
+ given `tag`. This operation cannot be used as a terminator but can be
+ placed everywhere else in a block.
+
+ Example:
+
+ ```mlir
+ llvm.func @f() -> !llvm.ptr {
+ %addr = llvm.blockaddress <function = @f, tag = <id = 1>> : !llvm.ptr
+ llvm.br ^bb1
+ ^bb1:
+ llvm.blocktag <id = 1>
+ llvm.return %addr : !llvm.ptr
+ }
+ ```
+ }];
+ let arguments = (ins LLVM_BlockTagAttr:$tag);
+ let assemblyFormat = [{ $tag attr-dict }];
+ // Covered as part of LLVMFuncOp verifier.
+ let hasVerifier = 0;
+}
+
def LLVM_ComdatSelectorOp : LLVM_Op<"comdat_selector", [Symbol]> {
let arguments = (ins
SymbolNameAttr:$sym_name,
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 01dda6238d8f3..8263bb0ac42d5 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -136,6 +136,29 @@ class ModuleTranslation {
return callMapping.lookup(op);
}
+ /// Maps a blockaddress operation to its corresponding placeholder LLVM
+ /// value.
+ void mapUnresolvedBlockAddress(BlockAddressOp op, llvm::Value *cst) {
+ auto result = unresolvedBlockAddressMapping.try_emplace(op, cst);
+ (void)result;
+ assert(result.second &&
+ "attempting to map a blockaddress that is already mapped");
+ }
+
+ /// Maps a blockaddress operation to its corresponding placeholder LLVM
+ /// value.
+ void mapBlockTag(BlockAddressAttr attr, BlockTagOp blockTag) {
+ // Attempts to map already mapped block labels are fine given labels are
+ // verified to be unique.
+ blockTagMapping[attr] = blockTag;
+ }
+
+ /// Finds an MLIR block that corresponds to the given MLIR call
+ /// operation.
+ BlockTagOp lookupBlockTag(BlockAddressAttr attr) const {
+ return blockTagMapping.lookup(attr);
+ }
+
/// Removes the mapping for blocks contained in the region and values defined
/// in these blocks.
void forgetMapping(Region ®ion);
@@ -338,6 +361,8 @@ class ModuleTranslation {
LogicalResult convertFunctions();
LogicalResult convertComdats();
+ LogicalResult convertUnresolvedBlockAddress();
+
/// Handle conversion for both globals and global aliases.
///
/// - Create named global variables that correspond to llvm.mlir.global
@@ -430,6 +455,16 @@ class ModuleTranslation {
/// This map is populated on module entry.
DenseMap<ComdatSelectorOp, llvm::Comdat *> comdatMapping;
+ /// Mapping from llvm.blockaddress operations to their corresponding LLVM
+ /// constant placeholders. After all basic blocks are translated, this
+ /// mapping is used to replace the placeholders with the LLVM block addresses.
+ DenseMap<BlockAddressOp, llvm::Value *> unresolvedBlockAddressMapping;
+
+ /// Mapping from a BlockAddressAttr attribute to a matching BlockTagOp. This
+ /// is used to cache BlockTagOp locations instead of walking a LLVMFuncOp in
+ /// search for those.
+ DenseMap<BlockAddressAttr, BlockTagOp> blockTagMapping;
+
/// Stack of user-specified state elements, useful when translating operations
/// with regions.
SmallVector<std::unique_ptr<StackFrame>> stack;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 252bdd1425d5e..0deb2900fdfe5 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2305,6 +2305,28 @@ static LogicalResult verifyComdat(Operation *op,
return success();
}
+static LogicalResult verifyBlockTags(LLVMFuncOp funcOp) {
+ llvm::DenseSet<BlockTagAttr> blockTags;
+ BlockTagOp badBlockTagOp;
+ if (funcOp
+ .walk([&](BlockTagOp blockTagOp) {
+ if (blockTags.count(blockTagOp.getTag())) {
+ badBlockTagOp = blockTagOp;
+ return WalkResult::interrupt();
+ }
+ blockTags.insert(blockTagOp.getTag());
+ return WalkResult::advance();
+ })
+ .wasInterrupted()) {
+ badBlockTagOp.emitError()
+ << "duplicate block tag '" << badBlockTagOp.getTag().getId()
+ << "' in the same function: ";
+ return failure();
+ }
+
+ return success();
+}
+
/// Parse common attributes that might show up in the same order in both
/// GlobalOp and AliasOp.
template <typename OpType>
@@ -3060,6 +3082,9 @@ LogicalResult LLVMFuncOp::verify() {
return emitError(diagnosticMessage);
}
+ if (failed(verifyBlockTags(*this)))
+ return failure();
+
return success();
}
@@ -3815,6 +3840,55 @@ void InlineAsmOp::getEffects(
}
}
+//===----------------------------------------------------------------------===//
+// BlockAddressOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+BlockAddressOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ Operation *symbol = symbolTable.lookupSymbolIn(parentLLVMModule(*this),
+ getBlockAddr().getFunction());
+ auto function = dyn_cast_or_null<LLVMFuncOp>(symbol);
+
+ if (!function)
+ return emitOpError("must reference a function defined by 'llvm.func'");
+
+ return success();
+}
+
+LLVMFuncOp BlockAddressOp::getFunction(SymbolTableCollection &symbolTable) {
+ return dyn_cast_or_null<LLVMFuncOp>(symbolTable.lookupSymbolIn(
+ parentLLVMModule(*this), getBlockAddr().getFunction()));
+}
+
+BlockTagOp BlockAddressOp::getBlockTagOp() {
+ auto m = (*this)->getParentOfType<ModuleOp>();
+ auto funcOp = cast<LLVMFuncOp>(mlir::SymbolTable::lookupNearestSymbolFrom(
+ m, getBlockAddr().getFunction()));
+
+ BlockTagOp blockTagOp = nullptr;
+ funcOp.walk([&](LLVM::BlockTagOp labelOp) {
+ if (labelOp.getTag() == getBlockAddr().getTag()) {
+ blockTagOp = labelOp;
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+ return blockTagOp;
+}
+
+LogicalResult BlockAddressOp::verify() {
+ if (!getBlockTagOp())
+ return emitOpError(
+ "expects an existing block label target in the referenced function");
+
+ return success();
+}
+
+/// Fold a blockaddress operation to a dedicated blockaddress
+/// attribute.
+OpFoldResult BlockAddressOp::fold(FoldAdaptor) { return getBlockAddr(); }
+
//===----------------------------------------------------------------------===//
// AssumeOp (intrinsic)
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 10b68a333bcbd..738f036bb376a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -555,6 +555,59 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
return success();
}
+ // Emit blockaddress. We first need to find the LLVM block referenced by this
+ // operation and then create a LLVM block address for it.
+ if (auto blockAddressOp = dyn_cast<LLVM::BlockAddressOp>(opInst)) {
+ // getBlockTagOp() walks a function to search for block labels. Check
+ // whether it's in cache first.
+ BlockAddressAttr blockAddressAttr = blockAddressOp.getBlockAddr();
+ BlockTagOp blockTagOp = moduleTranslation.lookupBlockTag(blockAddressAttr);
+ if (!blockTagOp) {
+ blockTagOp = blockAddressOp.getBlockTagOp();
+ moduleTranslation.mapBlockTag(blockAddressAttr, blockTagOp);
+ }
+
+ llvm::Value *llvmValue = nullptr;
+ StringRef fnName = blockAddressAttr.getFunction().getValue();
+ if (llvm::BasicBlock *llvmBlock =
+ moduleTranslation.lookupBlock(blockTagOp->getBlock())) {
+ llvm::Function *llvmFn = moduleTranslation.lookupFunction(fnName);
+ llvmValue = llvm::BlockAddress::get(llvmFn, llvmBlock);
+ } else {
+ // The matching LLVM block is not yet emitted, a placeholder is created
+ // in its place. When the LLVM block is emitted later in translation,
+ // the llvmValue is replaced with the actual llvm::BlockAddress.
+ // A GlobalVariable is chosen as placeholder because in general LLVM
+ // constants are uniqued and are not proper for RAUW, since that could
+ // harm unrelated uses of the constant.
+ llvmValue = new llvm::GlobalVariable(
+ *moduleTranslation.getLLVMModule(),
+ llvm::PointerType::getUnqual(moduleTranslation.getLLVMContext()),
+ /*isConstant=*/true, llvm::GlobalValue::LinkageTypes::ExternalLinkage,
+ /*Initializer=*/nullptr,
+ Twine("__mlir_block_address_")
+ .concat(Twine(fnName))
+ .concat(Twine((uint64_t)blockAddressOp.getOperation())));
+ moduleTranslation.mapUnresolvedBlockAddress(blockAddressOp, llvmValue);
+ }
+
+ moduleTranslation.mapValue(blockAddressOp.getResult(), llvmValue);
+ return success();
+ }
+
+ // Emit block label. If this label is seen before BlockAddressOp is
+ // translated, go ahead and already map it.
+ if (auto blockTagOp = dyn_cast<LLVM::BlockTagOp>(opInst)) {
+ auto funcOp = blockTagOp->getParentOfType<LLVMFuncOp>();
+ BlockAddressAttr blockAddressAttr = BlockAddressAttr::get(
+ &moduleTranslation.getContext(),
+ FlatSymbolRefAttr::get(&moduleTranslation.getContext(),
+ funcOp.getName()),
+ blockTagOp.getTag());
+ moduleTranslation.mapBlockTag(blockAddressAttr, blockTagOp);
+ return success();
+ }
+
return failure();
}
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index ea141d8b07284..7635b21d01d4b 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1362,9 +1362,19 @@ FailureOr<Value> ModuleImport::convertConstant(llvm::Constant *constant) {
return builder.create<LLVM::ZeroOp>(loc, targetExtType).getRes();
}
+ if (auto *blockAddr = dyn_cast<llvm::BlockAddress>(constant))
+ return builder
+ .create<BlockAddressOp>(
+ loc, convertType(blockAddr->getType()),
+ BlockAddressAttr::get(
+ context,
+ FlatSymbolRefAttr::get(context,
+ blockAddr->getFunction()->getName()),
+ BlockTagAttr::get(context,
+ blockAddr->getBasicBlock()->getNumber())))
+ .getRes();
+
StringRef error = "";
- if (isa<llvm::BlockAddress>(constant))
- error = " since blockaddress(...) is unsupported";
if (isa<llvm::ConstantPtrAuth>(constant))
error = " since ptrauth(...) is unsupported";
@@ -2429,8 +2439,13 @@ LogicalResult ModuleImport::processFunction(llvm::Function *func) {
SmallVector<llvm::BasicBlock *> reachableBasicBlocks;
for (llvm::BasicBlock &basicBlock : *func) {
// Skip unreachable blocks.
- if (!reachable.contains(&basicBlock))
+ if (!reachable.contains(&basicBlock)) {
+ if (basicBlock.hasAddressTaken())
+ emitWarning(funcOp.getLoc())
+ << "unreachable block '" << basicBlock.getName()
+ << "' with address taken";
continue;
+ }
Region &body = funcOp.getBody();
Block *block = builder.createBlock(&body, body.end());
mapBlock(&basicBlock, block);
@@ -2587,6 +2602,13 @@ LogicalResult ModuleImport::processBasicBlock(llvm::BasicBlock *bb,
}
}
}
+
+ if (bb->hasAddressTaken()) {
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(block);
+ builder.create<BlockTagOp>(block->getParentOp()->getLoc(),
+ BlockTagAttr::get(context, bb->getNumber()));
+ }
return success();
}
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 1e2f2c0468045..475d8b72c52e2 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1824,6 +1824,27 @@ LogicalResult ModuleTranslation::convertComdats() {
return success();
}
+LogicalResult ModuleTranslation::convertUnresolvedBlockAddress() {
+ for (auto &[blockAddressOp, llvmCst] : unresolvedBlockAddressMapping) {
+ BlockAddressAttr blockAddressAttr = blockAddressOp.getBlockAddr();
+ BlockTagOp blockTagOp = lookupBlockTag(blockAddressAttr);
+ assert(blockTagOp && "expected all block tags to be already seen");
+
+ llvm::BasicBlock *llvmBlock = lookupBlock(blockTagOp->getBlock());
+ assert(llvmBlock && "expected LLVM blocks to be already translated");
+
+ // Update mapping with new block address constant.
+ auto *llvmBlockAddr = llvm::BlockAddress::get(
+ lookupFunction(blockAddressAttr.getFunction().getValue()), llvmBlock);
+ llvmCst->replaceAllUsesWith(llvmBlockAddr);
+ mapValue(blockAddressOp.getResult(), llvmBlockAddr);
+ assert(llvmCst->use_empty() && "expected all uses to be replaced");
+ cast<llvm::GlobalVariable>(llvmCst)->eraseFromParent();
+ }
+ unresolvedBlockAddressMapping.clear();
+ return success();
+}
+
void ModuleTranslation::setAccessGroupsMetadata(AccessGroupOpInterface op,
llvm::Instruction *inst) {
if (llvm::MDNode *node = loopAnnotationTranslation->getAccessGroups(op))
@@ -2218,6 +2239,11 @@ mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext,
if (failed(translator.convertFunctions()))
return nullptr;
+ // Now that all MLIR blocks are resolved into LLVM ones, patch block address
+ // constants to point to the correct blocks.
+ if (failed(translator.convertUnresolvedBlockAddress()))
+ return nullptr;
+
// Once we've finished constructing elements in the module, we should convert
// it to use the debug info format desired by LLVM.
// See https://llvm.org/docs/RemoveDIsDebugInfo.html
diff --git a/mlir/test/Dialect/LLVMIR/blockaddress-canonicalize.mlir b/mlir/test/Dialect/LLVMIR/blockaddress-canonicalize.mlir
new file mode 100644
index 0000000000000..0f18cc6d8cd3e
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/blockaddress-canonicalize.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s -canonicalize | FileCheck %s
+
+// CHECK-LABEL: llvm.func @ba()
+llvm.func @ba() -> !llvm.ptr {
+ %0 = llvm.blockaddress <function = @ba, tag = <id = 1>> : !llvm.ptr
+ llvm.br ^bb1
+^bb1:
+ // CHECK: llvm.blocktag <id = 1>
+ llvm.blocktag <id = 1>
+ llvm.return %0 : !llvm.ptr
+}
diff --git a/mlir/test/Dialect/LLVMIR/constant-folding.mlir b/mlir/test/Dialect/LLVMIR/constant-folding.mlir
index 99f657f0aefec..0616f19b8fddb 100644
--- a/mlir/test/Dialect/LLVMIR/constant-folding.mlir
+++ b/mlir/test/Dialect/LLVMIR/constant-folding.mlir
@@ -196,3 +196,18 @@ llvm.func @dso_local_equivalent_select(%arg: i1) -> !llvm.ptr {
}
llvm.func @yay()
+
+// -----
+
+// CHECK-LABEL: llvm.func @blockaddress_select
+llvm.func @blockaddress_select(%arg: i1) -> !llvm.ptr {
+ // CHECK-NEXT: %[[ADDR:.+]] = llvm.blockaddress <function = @blockaddress_select, tag = <id = 1>>
+ %0 = llvm.blockaddress <function = @blockaddress_select, tag = <id = 1>> : !llvm.ptr
+ %1 = llvm.blockaddress <function = @blockaddress_select, tag = <id = 1>> : !llvm.ptr
+ %2 = arith.select %arg, %0, %1 : !llvm.ptr
+ // CHECK-NEXT: llvm.br ^bb1
+ llvm.br ^bb1
+^bb1:
+ llvm.blocktag <id = 1>
+ llvm.return %1 : !llvm.ptr
+}
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index fb9631d99b91a..e70e3185af236 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1780,3 +1780,25 @@ module {
// expected-error@+1 {{failed to parse ModuleFlagAttr parameter 'value' which is to be a `uint32_t`}}
llvm.module_flags [#llvm.mlir.module_flag<error, "wchar_size", "yolo">]
}
+
+// -----
+
+llvm.func @t0() -> !llvm.ptr {
+ %0 = llvm.blockaddress <function = @t0, tag = <id = 1>> : !llvm.ptr
+ llvm.blocktag <id = 1>
+ llvm.br ^bb1
+^bb1:
+ // expected-error@+1 {{duplicate block tag '1' in the same function}}
+ llvm.blocktag <id = 1>
+ llvm.return %0 : !llvm.ptr
+}
+
+// -----
+
+llvm.func @t1() -> !llvm.ptr {
+ // expected-error@+1 {{expects an existing block label target in the referenced function}}
+ %0 = llvm.blockaddress <function = @t1, tag = <id = 1>> : !llvm.ptr
+ llvm.br ^bb1
+^bb1:
+ llvm.return %0 : !llvm.ptr
+}
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index d0aa65d14a176..88460fe374d87 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -1002,3 +1002,24 @@ llvm.func @intrinsic_call_arg_attrs_bundles(%arg0: i32) -> i32 {
%0 = llvm.call_intrinsic "llvm.riscv.sha256s...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yet another try to comment on the correct pr...
As probably have realized the block address semantics is quite tricky to replicate in MLIR and we probably have to make sure certain optimizations don't apply:
- Inlining should not happen if a function contains a tag (I commented hopefully on the right PR who to prevent inlinig)
- The unreachable block issue should probably result in an hard error since there is not much we can do if there is a tag in an unreachable block?
- What happens if MLIR merges two block (e.g. when running region simplify) with different tags? I guess this could still be exported but the LLVM IR before and after MLIR would be semantically different since two block addresses would be merged into one block address. Does LLVM proper also merge block addresses if two blocks can be merged?
Given the snippet:
Region simplify works nicely:
If I introduce the blocktags, the opt is blocked:
Seems like the optimization is already conservative enough? LLVM proper is slightly more smart but it also "blocks" the full optimization in face of blockaddresses: https://godbolt.org/z/TP4K35zo8 (comment the globals to double check) I would like to also write a test for even more simple block merging (something like https://godbolt.org/z/c5o78o6Y9), but canonicalize can't seem to do the job, how do call
into
? |
Left a question for discussion above, but updated the PR to address all other concerns. |
It seems only the control flow dialect branch operations support block merging:
I had in mind that region simplify does merge consecutive blocks if there is only one control flow edge between them. However, this actually requires the operation to implement a canonicalization pattern which is not the case in LLVM dialect. In that sense your PR is fine then since whoever would like to implement a canonicalization for llvm.br would have to check if there is a tag operation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding the extra guardrails.
LGTM modulo last comments.
It looks like block merging does not happen since the LLVM branch ops do not have the necessary canonicalization patterns. Together with the inlining change we should be on the safe side.
Add support for import and translate. MLIR does not support using basic block references outside a function (like LLVM does), This PR does not consider changes to MLIR to that respect. It instead introduces two new ops: `llvm.blockaddress` and `llvm.blocktag`. Here's an example: ``` llvm.func @ba() -> !llvm.ptr { %0 = llvm.blockaddress <function = @ba, tag = <id = 1>> : !llvm.ptr llvm.br ^bb1 ^bb1: // pred: ^bb0 llvm.blocktag <id = 1> llvm.return %0 : !llvm.ptr } ``` Value `%0` hold the address of block tagged as `id = 1` in function `@ba`. Block tags need to be unique within a function and use of `llvm.blockaddress` requires a matching tag in a `llvm.blocktag`.
55a5bfa
to
a523b7d
Compare
Thanks @gysit, addressed all review comments and rebased against main to get rid of the unrelated Linux failure. |
Add support for import and translate.
MLIR does not support using basic block references outside a function (like LLVM does), This PR does not consider changes to MLIR to that respect. It instead introduces two new ops:
llvm.blockaddress
andllvm.blocktag
. Here's an example:Value
%0
hold the address of block tagged asid = 1
in function@ba
. Block tags need to be unique within a function and use ofllvm.blockaddress
requires a matching tag in allvm.blocktag
.