Skip to content

Commit fc8b2bf

Browse files
[MLIR][LLVM] Import dereferenceable metadata from LLVM IR (#130974)
Add support for importing `dereferenceable` and `dereferenceable_or_null` metadata into LLVM dialect. Add a new attribute which models these two metadata nodes and a new OpInterface.
1 parent bddf24d commit fc8b2bf

File tree

15 files changed

+261
-3
lines changed

15 files changed

+261
-3
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1267,4 +1267,28 @@ def WorkgroupAttributionAttr
12671267
let assemblyFormat = "`<` $num_elements `,` $element_type `>`";
12681268
}
12691269

1270+
//===----------------------------------------------------------------------===//
1271+
// DereferenceableAttr
1272+
//===----------------------------------------------------------------------===//
1273+
1274+
def LLVM_DereferenceableAttr : LLVM_Attr<"Dereferenceable", "dereferenceable"> {
1275+
let summary = "LLVM dereferenceable attribute";
1276+
let description = [{
1277+
Defines `dereferenceable` or `dereferenceable_or_null` metadata that can
1278+
be set via the `DereferenceableOpInterface` on an `inttoptr` operation or
1279+
on a `load` operation which loads a pointer. The attribute is used to
1280+
denote that the result of these operations is dereferenceable up to a
1281+
certain number of bytes, represented by `$bytes`. The optional `$mayBeNull`
1282+
parameter is set to true if the attribute defines `dereferenceable_or_null`
1283+
metadata.
1284+
1285+
See the following links for more details:
1286+
https://llvm.org/docs/LangRef.html#dereferenceable-metadata
1287+
https://llvm.org/docs/LangRef.html#dereferenceable-or-null-metadata
1288+
}];
1289+
let parameters = (ins "uint64_t":$bytes,
1290+
DefaultValuedParameter<"bool", "false">:$mayBeNull);
1291+
let assemblyFormat = "`<` struct(params) `>`";
1292+
}
1293+
12701294
#endif // LLVMIR_ATTRDEFS

mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ LogicalResult verifyAccessGroupOpInterface(Operation *op);
2727
/// the alias analysis interface.
2828
LogicalResult verifyAliasAnalysisOpInterface(Operation *op);
2929

30+
/// Verifies that the operation implementing the dereferenceable interface has
31+
/// exactly one result of LLVM pointer type.
32+
LogicalResult verifyDereferenceableOpInterface(Operation *op);
33+
3034
} // namespace detail
3135
} // namespace LLVM
3236
} // namespace mlir

mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,43 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> {
330330
];
331331
}
332332

333+
def DereferenceableOpInterface : OpInterface<"DereferenceableOpInterface"> {
334+
let description = [{
335+
An interface for memory operations that can carry dereferenceable metadata.
336+
It provides setters and getters for the operation's dereferenceable
337+
attributes. The default implementations of the interface methods expect
338+
the operation to have an attribute of type DereferenceableAttr.
339+
}];
340+
341+
let cppNamespace = "::mlir::LLVM";
342+
let verify = [{ return detail::verifyDereferenceableOpInterface($_op); }];
343+
344+
let methods = [
345+
InterfaceMethod<
346+
/*desc=*/ "Returns the dereferenceable attribute or nullptr",
347+
/*returnType=*/ "::mlir::LLVM::DereferenceableAttr",
348+
/*methodName=*/ "getDereferenceableOrNull",
349+
/*args=*/ (ins),
350+
/*methodBody=*/ [{}],
351+
/*defaultImpl=*/ [{
352+
auto op = cast<ConcreteOp>(this->getOperation());
353+
return op.getDereferenceableAttr();
354+
}]
355+
>,
356+
InterfaceMethod<
357+
/*desc=*/ "Sets the dereferenceable attribute",
358+
/*returnType=*/ "void",
359+
/*methodName=*/ "setDereferenceable",
360+
/*args=*/ (ins "::mlir::LLVM::DereferenceableAttr":$attr),
361+
/*methodBody=*/ [{}],
362+
/*defaultImpl=*/ [{
363+
auto op = cast<ConcreteOp>(this->getOperation());
364+
op.setDereferenceableAttr(attr);
365+
}]
366+
>
367+
];
368+
}
369+
333370
def FPExceptionBehaviorOpInterface : OpInterface<"FPExceptionBehaviorOpInterface"> {
334371
let description = [{
335372
An interface for operations receiving an exception behavior attribute

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,8 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
364364
[DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
365365
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
366366
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
367-
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>]> {
367+
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>,
368+
DeclareOpInterfaceMethods<DereferenceableOpInterface>]> {
368369
dag args = (ins LLVM_AnyPointer:$addr,
369370
OptionalAttr<I64Attr>:$alignment,
370371
UnitAttr:$volatile_,
@@ -373,7 +374,8 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
373374
UnitAttr:$invariantGroup,
374375
DefaultValuedAttr<
375376
AtomicOrdering, "AtomicOrdering::not_atomic">:$ordering,
376-
OptionalAttr<StrAttr>:$syncscope);
377+
OptionalAttr<StrAttr>:$syncscope,
378+
OptionalAttr<LLVM_DereferenceableAttr>:$dereferenceable);
377379
// Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
378380
let arguments = !con(args, aliasAttrs);
379381
let results = (outs LLVM_LoadableType:$res);
@@ -407,6 +409,7 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
407409
(`atomic` (`syncscope` `(` $syncscope^ `)`)? $ordering^)?
408410
(`invariant` $invariant^)?
409411
(`invariant_group` $invariantGroup^)?
412+
(`dereferenceable` `` $dereferenceable^)?
410413
attr-dict `:` qualified(type($addr)) `->` type($res)
411414
}];
412415
string llvmBuilder = [{
@@ -416,6 +419,8 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
416419
llvm::MDNode *metadata = llvm::MDNode::get(inst->getContext(), std::nullopt);
417420
inst->setMetadata(llvm::LLVMContext::MD_invariant_load, metadata);
418421
}
422+
if ($dereferenceable)
423+
moduleTranslation.setDereferenceableMetadata(op, inst);
419424
}] # setOrderingCode
420425
# setSyncScopeCode
421426
# setAlignmentCode
@@ -571,6 +576,29 @@ class LLVM_CastOpWithOverflowFlag<string mnemonic, string instName, Type type,
571576
}];
572577
}
573578

579+
class LLVM_DereferenceableCastOp<string mnemonic, string instName, Type type,
580+
Type resultType, list<Trait> traits = []> :
581+
LLVM_Op<mnemonic, !listconcat([Pure], [DeclareOpInterfaceMethods<DereferenceableOpInterface>], traits)> {
582+
let arguments = (ins type:$arg, OptionalAttr<LLVM_DereferenceableAttr>:$dereferenceable);
583+
let results = (outs resultType:$res);
584+
let builders = [LLVM_OneResultOpBuilder];
585+
let assemblyFormat = "$arg (`dereferenceable` `` $dereferenceable^)? attr-dict `:` type($arg) `to` type($res)";
586+
string llvmInstName = instName;
587+
string llvmBuilder = [{
588+
auto *val = builder.Create}] # instName # [{($arg, $_resultType);
589+
$res = val;
590+
if ($dereferenceable) {
591+
llvm::Instruction *inst = dyn_cast<llvm::Instruction>(val);
592+
moduleTranslation.setDereferenceableMetadata(op, inst);
593+
}
594+
}];
595+
string mlirBuilder = [{
596+
auto op = $_builder.create<$_qualCppClassName>(
597+
$_location, $_resultType, $arg);
598+
$res = op;
599+
}];
600+
}
601+
574602
def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "BitCast", LLVM_AnyNonAggregate,
575603
LLVM_AnyNonAggregate, [DeclareOpInterfaceMethods<PromotableOpInterface>]> {
576604
let hasFolder = 1;
@@ -583,7 +611,7 @@ def LLVM_AddrSpaceCastOp : LLVM_CastOp<"addrspacecast", "AddrSpaceCast",
583611
DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
584612
let hasFolder = 1;
585613
}
586-
def LLVM_IntToPtrOp : LLVM_CastOp<"inttoptr", "IntToPtr",
614+
def LLVM_IntToPtrOp : LLVM_DereferenceableCastOp<"inttoptr", "IntToPtr",
587615
LLVM_ScalarOrVectorOf<AnySignlessInteger>,
588616
LLVM_ScalarOrVectorOf<LLVM_AnyPointer>>;
589617
def LLVM_PtrToIntOp : LLVM_CastOp<"ptrtoint", "PtrToInt",

mlir/include/mlir/Target/LLVMIR/ModuleImport.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,13 @@ class ModuleImport {
248248
LoopAnnotationAttr translateLoopAnnotationAttr(const llvm::MDNode *node,
249249
Location loc) const;
250250

251+
/// Returns the dereferenceable attribute that corresponds to the given LLVM
252+
/// dereferenceable or dereferenceable_or_null metadata `node`. `kindID`
253+
/// specifies the kind of the metadata node (dereferenceable or
254+
/// dereferenceable_or_null).
255+
FailureOr<DereferenceableAttr>
256+
translateDereferenceableAttr(const llvm::MDNode *node, unsigned kindID);
257+
251258
/// Returns the alias scope attributes that map to the alias scope nodes
252259
/// starting from the metadata `node`. Returns failure, if any of the
253260
/// attributes cannot be found.

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,11 @@ class ModuleTranslation {
161161
/// Sets LLVM TBAA metadata for memory operations that have TBAA attributes.
162162
void setTBAAMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst);
163163

164+
/// Sets LLVM dereferenceable metadata for operations that have
165+
/// dereferenceable attributes.
166+
void setDereferenceableMetadata(DereferenceableOpInterface op,
167+
llvm::Instruction *inst);
168+
164169
/// Sets LLVM profiling metadata for operations that have branch weights.
165170
void setBranchWeightsMetadata(BranchWeightOpInterface op);
166171

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -940,6 +940,7 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
940940
alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
941941
isNonTemporal, isInvariant, isInvariantGroup, ordering,
942942
syncscope.empty() ? nullptr : builder.getStringAttr(syncscope),
943+
/*dereferenceable=*/nullptr,
943944
/*access_groups=*/nullptr,
944945
/*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr,
945946
/*tbaa=*/nullptr);

mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,23 @@ mlir::LLVM::detail::verifyAliasAnalysisOpInterface(Operation *op) {
6262
return isArrayOf<TBAATagAttr>(op, tags);
6363
}
6464

65+
//===----------------------------------------------------------------------===//
66+
// DereferenceableOpInterface
67+
//===----------------------------------------------------------------------===//
68+
69+
LogicalResult
70+
mlir::LLVM::detail::verifyDereferenceableOpInterface(Operation *op) {
71+
auto iface = cast<DereferenceableOpInterface>(op);
72+
73+
if (auto derefAttr = iface.getDereferenceableOrNull())
74+
if (op->getNumResults() != 1 ||
75+
!mlir::isa<LLVMPointerType>(op->getResult(0).getType()))
76+
return op->emitOpError(
77+
"expected op to return a single LLVM pointer type");
78+
79+
return success();
80+
}
81+
6582
SmallVector<Value> mlir::LLVM::AtomicCmpXchgOp::getAccessedOperands() {
6683
return {getPtr()};
6784
}

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ static ArrayRef<unsigned> getSupportedMetadataImpl(llvm::LLVMContext &context) {
9090
llvm::LLVMContext::MD_loop,
9191
llvm::LLVMContext::MD_noalias,
9292
llvm::LLVMContext::MD_alias_scope,
93+
llvm::LLVMContext::MD_dereferenceable,
94+
llvm::LLVMContext::MD_dereferenceable_or_null,
9395
context.getMDKindID(vecTypeHintMDName),
9496
context.getMDKindID(workGroupSizeHintMDName),
9597
context.getMDKindID(reqdWorkGroupSizeMDName),
@@ -188,6 +190,25 @@ static LogicalResult setAccessGroupsAttr(const llvm::MDNode *node,
188190
return success();
189191
}
190192

193+
/// Converts the given dereferenceable metadata node to a dereferenceable
194+
/// attribute, and attaches it to the imported operation if the translation
195+
/// succeeds. Returns failure if the LLVM IR metadata node is ill-formed.
196+
static LogicalResult setDereferenceableAttr(const llvm::MDNode *node,
197+
unsigned kindID, Operation *op,
198+
LLVM::ModuleImport &moduleImport) {
199+
auto dereferenceable =
200+
moduleImport.translateDereferenceableAttr(node, kindID);
201+
if (failed(dereferenceable))
202+
return failure();
203+
204+
auto iface = dyn_cast<DereferenceableOpInterface>(op);
205+
if (!iface)
206+
return failure();
207+
208+
iface.setDereferenceable(*dereferenceable);
209+
return success();
210+
}
211+
191212
/// Converts the given loop metadata node to an MLIR loop annotation attribute
192213
/// and attaches it to the imported operation if the translation succeeds.
193214
/// Returns failure otherwise.
@@ -401,6 +422,13 @@ class LLVMDialectLLVMIRImportInterface : public LLVMImportDialectInterface {
401422
return setAliasScopesAttr(node, op, moduleImport);
402423
if (kind == llvm::LLVMContext::MD_noalias)
403424
return setNoaliasScopesAttr(node, op, moduleImport);
425+
if (kind == llvm::LLVMContext::MD_dereferenceable)
426+
return setDereferenceableAttr(node, llvm::LLVMContext::MD_dereferenceable,
427+
op, moduleImport);
428+
if (kind == llvm::LLVMContext::MD_dereferenceable_or_null)
429+
return setDereferenceableAttr(
430+
node, llvm::LLVMContext::MD_dereferenceable_or_null, op,
431+
moduleImport);
404432

405433
llvm::LLVMContext &context = node->getContext();
406434
if (kind == context.getMDKindID(vecTypeHintMDName))

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2527,6 +2527,31 @@ ModuleImport::translateLoopAnnotationAttr(const llvm::MDNode *node,
25272527
return loopAnnotationImporter->translateLoopAnnotation(node, loc);
25282528
}
25292529

2530+
FailureOr<DereferenceableAttr>
2531+
ModuleImport::translateDereferenceableAttr(const llvm::MDNode *node,
2532+
unsigned kindID) {
2533+
Location loc = mlirModule.getLoc();
2534+
2535+
// The only operand should be a constant integer representing the number of
2536+
// dereferenceable bytes.
2537+
if (node->getNumOperands() != 1)
2538+
return emitError(loc) << "dereferenceable metadata must have one operand: "
2539+
<< diagMD(node, llvmModule.get());
2540+
2541+
auto *numBytesMD = dyn_cast<llvm::ConstantAsMetadata>(node->getOperand(0));
2542+
auto *numBytesCst = dyn_cast<llvm::ConstantInt>(numBytesMD->getValue());
2543+
if (!numBytesCst || !numBytesCst->getValue().isNonNegative())
2544+
return emitError(loc) << "dereferenceable metadata operand must be a "
2545+
"non-negative constant integer: "
2546+
<< diagMD(node, llvmModule.get());
2547+
2548+
bool mayBeNull = kindID == llvm::LLVMContext::MD_dereferenceable_or_null;
2549+
auto derefAttr = builder.getAttr<DereferenceableAttr>(
2550+
numBytesCst->getZExtValue(), mayBeNull);
2551+
2552+
return derefAttr;
2553+
}
2554+
25302555
OwningOpRef<ModuleOp>
25312556
mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
25322557
MLIRContext *context, bool emitExpensiveWarnings,

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1925,6 +1925,22 @@ void ModuleTranslation::setTBAAMetadata(AliasAnalysisOpInterface op,
19251925
inst->setMetadata(llvm::LLVMContext::MD_tbaa, node);
19261926
}
19271927

1928+
void ModuleTranslation::setDereferenceableMetadata(
1929+
DereferenceableOpInterface op, llvm::Instruction *inst) {
1930+
DereferenceableAttr derefAttr = op.getDereferenceableOrNull();
1931+
if (!derefAttr)
1932+
return;
1933+
1934+
llvm::MDNode *derefSizeNode = llvm::MDNode::get(
1935+
getLLVMContext(),
1936+
llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(
1937+
llvm::IntegerType::get(getLLVMContext(), 64), derefAttr.getBytes())));
1938+
unsigned kindId = derefAttr.getMayBeNull()
1939+
? llvm::LLVMContext::MD_dereferenceable_or_null
1940+
: llvm::LLVMContext::MD_dereferenceable;
1941+
inst->setMetadata(kindId, derefSizeNode);
1942+
}
1943+
19281944
void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) {
19291945
DenseI32ArrayAttr weightsAttr = op.getBranchWeightsOrNull();
19301946
if (!weightsAttr)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
// RUN: mlir-opt --allow-unregistered-dialect -split-input-file -verify-diagnostics %s
2+
3+
llvm.func @deref(%arg0: !llvm.ptr) {
4+
// expected-error @below {{op expected op to return a single LLVM pointer type}}
5+
%0 = llvm.load %arg0 dereferenceable<bytes = 8> {alignment = 8 : i64} : !llvm.ptr -> i64
6+
llvm.return
7+
}

mlir/test/Target/LLVMIR/Import/import-failure.ll

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,17 @@ declare void @llvm.experimental.noalias.scope.decl(metadata)
338338

339339
; // -----
340340

341+
; CHECK: import-failure.ll
342+
; CHECK-SAME: dereferenceable metadata operand must be a non-negative constant integer
343+
define void @deref(i64 %0) {
344+
%2 = inttoptr i64 %0 to ptr, !dereferenceable !0
345+
ret void
346+
}
347+
348+
!0 = !{i64 -4}
349+
350+
; // -----
351+
341352
; CHECK: import-failure.ll
342353
; CHECK-SAME: warning: unhandled data layout token: ni:42
343354
target datalayout = "e-ni:42-i64:64"
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
2+
3+
define void @deref(i64 %0, ptr %1) {
4+
; CHECK: llvm.inttoptr
5+
; CHECK-SAME: dereferenceable<bytes = 4>
6+
%3 = inttoptr i64 %0 to ptr, !dereferenceable !0
7+
; CHECK: llvm.load
8+
; CHECK-SAME: dereferenceable<bytes = 8>
9+
%4 = load ptr, ptr %1, align 8, !dereferenceable !1
10+
ret void
11+
}
12+
13+
define void @deref_or_null(i64 %0, ptr %1) {
14+
; CHECK: llvm.inttoptr
15+
; CHECK-SAME: dereferenceable<bytes = 4, mayBeNull = true>
16+
%3 = inttoptr i64 %0 to ptr, !dereferenceable_or_null !0
17+
; CHECK: llvm.load
18+
; CHECK-SAME: dereferenceable<bytes = 8, mayBeNull = true>
19+
%4 = load ptr, ptr %1, align 8, !dereferenceable_or_null !1
20+
ret void
21+
}
22+
23+
!0 = !{i64 4}
24+
!1 = !{i64 8}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
2+
3+
llvm.func @deref(%arg0: i64, %arg1: !llvm.ptr) {
4+
// CHECK: inttoptr {{.*}} !dereferenceable [[D0:![0-9]+]]
5+
%0 = llvm.inttoptr %arg0 dereferenceable<bytes = 4> : i64 to !llvm.ptr
6+
%1 = llvm.load %0 {alignment = 4 : i64} : !llvm.ptr -> i32
7+
// CHECK: load {{.*}} !dereferenceable [[D1:![0-9]+]]
8+
%2 = llvm.load %arg1 dereferenceable<bytes = 8> {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
9+
llvm.store %1, %2 {alignment = 4 : i64} : i32, !llvm.ptr
10+
llvm.return
11+
}
12+
13+
llvm.func @deref_or_null(%arg0: i64, %arg1: !llvm.ptr) {
14+
// CHECK: inttoptr {{.*}} !dereferenceable_or_null [[D0]]
15+
%0 = llvm.inttoptr %arg0 dereferenceable<bytes = 4, mayBeNull = true> : i64 to !llvm.ptr
16+
%1 = llvm.load %0 {alignment = 4 : i64} : !llvm.ptr -> i32
17+
// CHECK: load {{.*}} !dereferenceable_or_null [[D1]]
18+
%2 = llvm.load %arg1 dereferenceable<bytes = 8, mayBeNull = true> {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
19+
llvm.store %1, %2 {alignment = 4 : i64} : i32, !llvm.ptr
20+
llvm.return
21+
}
22+
23+
// CHECK: [[D0]] = !{i64 4}
24+
// CHECK: [[D1]] = !{i64 8}

0 commit comments

Comments
 (0)