Skip to content

Commit 0db13b8

Browse files
committed
[mlir][llvm] support new-struct-path-tbaa
1 parent 1e25c92 commit 0db13b8

File tree

5 files changed

+125
-6
lines changed

5 files changed

+125
-6
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td

+77-1
Original file line numberDiff line numberDiff line change
@@ -1080,8 +1080,84 @@ def LLVM_TBAATagAttr : LLVM_Attr<"TBAATag", "tbaa_tag"> {
10801080
let assemblyFormat = "`<` struct(params) `>`";
10811081
}
10821082

1083+
def LLVM_TBAAStructFieldAttr : LLVM_Attr<"TBAAStructField", "tbaa_struct_field"> {
1084+
let parameters = (ins
1085+
"TBAANodeAttr":$typeDesc,
1086+
"int64_t":$offset,
1087+
"int64_t":$size
1088+
);
1089+
let assemblyFormat = "`<` struct(params) `>`";
1090+
}
1091+
1092+
1093+
def LLVM_TBAAStructFieldAttrArray : ArrayRefParameter<"TBAAStructFieldAttr"> {
1094+
let printer = [{
1095+
$_printer << '{';
1096+
llvm::interleaveComma($_self, $_printer, [&](TBAAStructFieldAttr attr) {
1097+
$_printer.printStrippedAttrOrType(attr);
1098+
});
1099+
$_printer << '}';
1100+
}];
1101+
1102+
let parser = [{
1103+
[&]() -> FailureOr<SmallVector<TBAAStructFieldAttr>> {
1104+
using Result = SmallVector<TBAAStructFieldAttr>;
1105+
if ($_parser.parseLBrace())
1106+
return failure();
1107+
FailureOr<Result> result = FieldParser<Result>::parse($_parser);
1108+
if (failed(result))
1109+
return failure();
1110+
if ($_parser.parseRBrace())
1111+
return failure();
1112+
return result;
1113+
}()
1114+
}];
1115+
}
1116+
1117+
def LLVM_TBAATypeNodeAttr : LLVM_Attr<"TBAATypeNode", "tbaa_type_node", [], "TBAANodeAttr"> {
1118+
let parameters = (ins
1119+
"TBAANodeAttr":$parent,
1120+
"int64_t":$size,
1121+
StringRefParameter<>:$id,
1122+
LLVM_TBAAStructFieldAttrArray:$fields
1123+
);
1124+
let assemblyFormat = "`<` struct(params) `>`";
1125+
}
1126+
1127+
def LLVM_TBAAAccessTagAttr : LLVM_Attr<"TBAAAccessTag", "tbaa_access_tag"> {
1128+
let parameters = (ins
1129+
"TBAATypeNodeAttr":$base_type,
1130+
"TBAATypeNodeAttr":$access_type,
1131+
"int64_t":$offset,
1132+
"int64_t":$size
1133+
);
1134+
let builders = [
1135+
AttrBuilderWithInferredContext<(ins "TBAATypeNodeAttr":$baseType,
1136+
"TBAATypeNodeAttr":$accessType,
1137+
"int64_t":$offset,
1138+
"int64_t":$size), [{
1139+
return $_get(baseType.getContext(), baseType, accessType, offset, size);
1140+
}]>
1141+
];
1142+
let assemblyFormat = "`<` struct(params) `>`";
1143+
}
1144+
1145+
def LLVM_TBAAAccessTagArrayAttr
1146+
: TypedArrayAttrBase<LLVM_TBAAAccessTagAttr,
1147+
LLVM_TBAAAccessTagAttr.summary # " array"> {
1148+
let constBuilderCall = ?;
1149+
}
1150+
1151+
// def LLVM_TBAATagAttr2 : AnyAttrOf<[
1152+
// LLVM_TBAATagAttr,
1153+
// LLVM_TBAAAccessTagAttr
1154+
// ]>;
1155+
10831156
def LLVM_TBAATagArrayAttr
1084-
: TypedArrayAttrBase<LLVM_TBAATagAttr,
1157+
: TypedArrayAttrBase<AnyAttrOf<[
1158+
LLVM_TBAATagAttr,
1159+
LLVM_TBAAAccessTagAttr
1160+
]>,
10851161
LLVM_TBAATagAttr.summary # " array"> {
10861162
let constBuilderCall = ?;
10871163
}

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ class ModuleTranslation {
323323

324324
/// Returns the LLVM metadata corresponding to the given mlir LLVM dialect
325325
/// TBAATagAttr.
326-
llvm::MDNode *getTBAANode(TBAATagAttr tbaaAttr) const;
326+
llvm::MDNode *getTBAANode(Attribute tbaaAttr) const;
327327

328328
/// Process tbaa LLVM Metadata operations and create LLVM
329329
/// metadata nodes for them.

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -3401,7 +3401,8 @@ struct LLVMOpAsmDialectInterface : public OpAsmDialectInterface {
34013401
LoopVectorizeAttr, LoopInterleaveAttr, LoopUnrollAttr,
34023402
LoopUnrollAndJamAttr, LoopLICMAttr, LoopDistributeAttr,
34033403
LoopPipelineAttr, LoopPeeledAttr, LoopUnswitchAttr, TBAARootAttr,
3404-
TBAATagAttr, TBAATypeDescriptorAttr>([&](auto attr) {
3404+
TBAATagAttr, TBAATypeDescriptorAttr, TBAAAccessTagAttr,
3405+
TBAATypeNodeAttr>([&](auto attr) {
34053406
os << decltype(attr)::getMnemonic();
34063407
return AliasResult::OverridableAlias;
34073408
})

mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp

+9-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,15 @@ mlir::LLVM::detail::verifyAliasAnalysisOpInterface(Operation *op) {
5858
ArrayAttr tags = iface.getTBAATagsOrNull();
5959
if (!tags)
6060
return success();
61-
61+
if (tags.size() > 0) {
62+
if (mlir::isa<TBAATagAttr>(tags[0])) {
63+
return isArrayOf<TBAATagAttr>(op, tags);
64+
}
65+
66+
if (mlir::isa<TBAAAccessTagAttr>(tags[0])) {
67+
return isArrayOf<TBAAAccessTagAttr>(op, tags);
68+
}
69+
}
6270
return isArrayOf<TBAATagAttr>(op, tags);
6371
}
6472

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

+36-2
Original file line numberDiff line numberDiff line change
@@ -1766,7 +1766,8 @@ void ModuleTranslation::setAliasScopeMetadata(AliasAnalysisOpInterface op,
17661766
llvm::LLVMContext::MD_noalias);
17671767
}
17681768

1769-
llvm::MDNode *ModuleTranslation::getTBAANode(TBAATagAttr tbaaAttr) const {
1769+
// llvm::MDNode *ModuleTranslation::getTBAANode(TBAATagAttr tbaaAttr) const {
1770+
llvm::MDNode *ModuleTranslation::getTBAANode(Attribute tbaaAttr) const {
17701771
return tbaaMetadataMapping.lookup(tbaaAttr);
17711772
}
17721773

@@ -1786,7 +1787,8 @@ void ModuleTranslation::setTBAAMetadata(AliasAnalysisOpInterface op,
17861787
return;
17871788
}
17881789

1789-
llvm::MDNode *node = getTBAANode(cast<TBAATagAttr>(tagRefs[0]));
1790+
// llvm::MDNode *node = getTBAANode(cast<TBAATagAttr>(tagRefs[0]));
1791+
llvm::MDNode *node = getTBAANode(tagRefs[0]);
17901792
inst->setMetadata(llvm::LLVMContext::MD_tbaa, node);
17911793
}
17921794

@@ -1806,6 +1808,7 @@ void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) {
18061808
LogicalResult ModuleTranslation::createTBAAMetadata() {
18071809
llvm::LLVMContext &ctx = llvmModule->getContext();
18081810
llvm::IntegerType *offsetTy = llvm::IntegerType::get(ctx, 64);
1811+
llvm::IntegerType *sizeTy = llvm::IntegerType::get(ctx, 64);
18091812

18101813
// Walk the entire module and create all metadata nodes for the TBAA
18111814
// attributes. The code below relies on two invariants of the
@@ -1833,6 +1836,23 @@ LogicalResult ModuleTranslation::createTBAAMetadata() {
18331836
tbaaMetadataMapping.insert({descriptor, llvm::MDNode::get(ctx, operands)});
18341837
});
18351838

1839+
walker.addWalk([&](TBAATypeNodeAttr descriptor) {
1840+
SmallVector<llvm::Metadata *> operands;
1841+
operands.push_back(tbaaMetadataMapping.lookup(descriptor.getParent()));
1842+
operands.push_back(llvm::ConstantAsMetadata::get(
1843+
llvm::ConstantInt::get(sizeTy, descriptor.getSize())));
1844+
operands.push_back(llvm::MDString::get(ctx, descriptor.getId()));
1845+
for (auto field : descriptor.getFields()) {
1846+
operands.push_back(tbaaMetadataMapping.lookup(field.getTypeDesc()));
1847+
operands.push_back(llvm::ConstantAsMetadata::get(
1848+
llvm::ConstantInt::get(offsetTy, field.getOffset())));
1849+
operands.push_back(llvm::ConstantAsMetadata::get(
1850+
llvm::ConstantInt::get(sizeTy, field.getSize())));
1851+
}
1852+
1853+
tbaaMetadataMapping.insert({descriptor, llvm::MDNode::get(ctx, operands)});
1854+
});
1855+
18361856
walker.addWalk([&](TBAATagAttr tag) {
18371857
SmallVector<llvm::Metadata *> operands;
18381858

@@ -1848,6 +1868,20 @@ LogicalResult ModuleTranslation::createTBAAMetadata() {
18481868
tbaaMetadataMapping.insert({tag, llvm::MDNode::get(ctx, operands)});
18491869
});
18501870

1871+
walker.addWalk([&](TBAAAccessTagAttr tag) {
1872+
SmallVector<llvm::Metadata *> operands;
1873+
1874+
operands.push_back(tbaaMetadataMapping.lookup(tag.getBaseType()));
1875+
operands.push_back(tbaaMetadataMapping.lookup(tag.getAccessType()));
1876+
1877+
operands.push_back(llvm::ConstantAsMetadata::get(
1878+
llvm::ConstantInt::get(offsetTy, tag.getOffset())));
1879+
operands.push_back(llvm::ConstantAsMetadata::get(
1880+
llvm::ConstantInt::get(sizeTy, tag.getSize())));
1881+
1882+
tbaaMetadataMapping.insert({tag, llvm::MDNode::get(ctx, operands)});
1883+
});
1884+
18511885
mlirModule->walk([&](AliasAnalysisOpInterface analysisOpInterface) {
18521886
if (auto attr = analysisOpInterface.getTBAATagsOrNull())
18531887
walker.walk(attr);

0 commit comments

Comments
 (0)