Skip to content

Commit 4206f58

Browse files
[mlir] Integrate OpAsmTypeInterface with AsmPrinter
1 parent 3a439e2 commit 4206f58

File tree

3 files changed

+95
-0
lines changed

3 files changed

+95
-0
lines changed

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1536,10 +1536,13 @@ StringRef maybeGetValueNameFromLoc(Value value, StringRef name) {
15361536
} // namespace
15371537

15381538
void SSANameState::numberValuesInRegion(Region &region) {
1539+
// indicate whether OpAsmOpInterface set a name
1540+
bool opAsmOpInterfaceUsed = false;
15391541
auto setBlockArgNameFn = [&](Value arg, StringRef name) {
15401542
assert(!valueIDs.count(arg) && "arg numbered multiple times");
15411543
assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == &region &&
15421544
"arg not defined in current region");
1545+
opAsmOpInterfaceUsed = true;
15431546
if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
15441547
name = maybeGetValueNameFromLoc(arg, name);
15451548
setValueName(arg, name);
@@ -1549,6 +1552,23 @@ void SSANameState::numberValuesInRegion(Region &region) {
15491552
if (Operation *op = region.getParentOp()) {
15501553
if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
15511554
asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
1555+
if (!opAsmOpInterfaceUsed) {
1556+
// If the OpAsmOpInterface didn't set a name, and when
1557+
// all arguments have OpAsmTypeInterface, get names from the type
1558+
bool allHaveOpAsmTypeInterface =
1559+
llvm::all_of(region.getArguments(), [&](Value arg) {
1560+
return mlir::isa<OpAsmTypeInterface>(arg.getType());
1561+
});
1562+
if (allHaveOpAsmTypeInterface) {
1563+
for (auto arg : region.getArguments()) {
1564+
auto typeInterface = mlir::cast<OpAsmTypeInterface>(arg.getType());
1565+
auto setNameFn = [&](StringRef name) {
1566+
setBlockArgNameFn(arg, name);
1567+
};
1568+
typeInterface.getAsmName(setNameFn);
1569+
}
1570+
}
1571+
}
15521572
}
15531573
}
15541574

@@ -1598,9 +1618,12 @@ void SSANameState::numberValuesInBlock(Block &block) {
15981618
void SSANameState::numberValuesInOp(Operation &op) {
15991619
// Function used to set the special result names for the operation.
16001620
SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
1621+
// indicating whether OpAsmOpInterface set a name
1622+
bool opAsmOpInterfaceUsed = false;
16011623
auto setResultNameFn = [&](Value result, StringRef name) {
16021624
assert(!valueIDs.count(result) && "result numbered multiple times");
16031625
assert(result.getDefiningOp() == &op && "result not defined by 'op'");
1626+
opAsmOpInterfaceUsed = true;
16041627
if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
16051628
name = maybeGetValueNameFromLoc(result, name);
16061629
setValueName(result, name);
@@ -1630,6 +1653,23 @@ void SSANameState::numberValuesInOp(Operation &op) {
16301653
asmInterface.getAsmBlockNames(setBlockNameFn);
16311654
asmInterface.getAsmResultNames(setResultNameFn);
16321655
}
1656+
if (!opAsmOpInterfaceUsed) {
1657+
// If the OpAsmOpInterface didn't set a name, and when
1658+
// all results have OpAsmTypeInterface, get names from the type
1659+
bool allHaveOpAsmTypeInterface =
1660+
llvm::all_of(op.getResults(), [&](Value result) {
1661+
return mlir::isa<OpAsmTypeInterface>(result.getType());
1662+
});
1663+
if (allHaveOpAsmTypeInterface) {
1664+
for (auto result : op.getResults()) {
1665+
auto typeInterface = mlir::cast<OpAsmTypeInterface>(result.getType());
1666+
auto setNameFn = [&](StringRef name) {
1667+
setResultNameFn(result, name);
1668+
};
1669+
typeInterface.getAsmName(setNameFn);
1670+
}
1671+
}
1672+
}
16331673
}
16341674

16351675
unsigned numResults = op.getNumResults();

mlir/test/IR/op-asm-interface.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,39 @@ func.func @block_argument_name_from_op_asm_type_interface() {
2222
}
2323
return
2424
}
25+
26+
// -----
27+
28+
//===----------------------------------------------------------------------===//
29+
// Test OpAsmTypeInterface
30+
//===----------------------------------------------------------------------===//
31+
32+
func.func @result_name_from_op_asm_type_interface_asmprinter() {
33+
// CHECK-LABEL: @result_name_from_op_asm_type_interface_asmprinter
34+
// CHECK: %op_asm_type_interface
35+
%0 = "test.result_name_from_type_interface"() : () -> !test.op_asm_type_interface
36+
return
37+
}
38+
39+
// -----
40+
41+
// i1 does not have OpAsmTypeInterface, should not get named.
42+
func.func @result_name_from_op_asm_type_interface_not_all() {
43+
// CHECK-LABEL: @result_name_from_op_asm_type_interface_not_all
44+
// CHECK-NOT: %op_asm_type_interface
45+
// CHECK: %0:2
46+
%0:2 = "test.result_name_from_type_interface"() : () -> (!test.op_asm_type_interface, i1)
47+
return
48+
}
49+
50+
// -----
51+
52+
func.func @block_argument_name_from_op_asm_type_interface_asmprinter() {
53+
// CHECK-LABEL: @block_argument_name_from_op_asm_type_interface_asmprinter
54+
// CHECK: ^bb0(%op_asm_type_interface
55+
test.block_argument_name_from_type_interface {
56+
^bb0(%arg0: !test.op_asm_type_interface):
57+
"test.terminator"() : ()->()
58+
}
59+
return
60+
}

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -939,6 +939,25 @@ def BlockArgumentNameFromTypeOp
939939
let assemblyFormat = "regions attr-dict-with-keyword";
940940
}
941941

942+
// This is used to test OpAsmTypeInterface::getAsmName's integration with AsmPrinter
943+
// for op result name when OpAsmOpInterface::getAsmResultNames is the default implementation
944+
// i.e. does nothing
945+
def ResultNameFromTypeInterfaceOp
946+
: TEST_Op<"result_name_from_type_interface",
947+
[OpAsmOpInterface]> {
948+
let results = (outs Variadic<AnyType>:$r);
949+
}
950+
951+
// This is used to test OpAsmTypeInterface::getAsmName's integration with AsmPrinter
952+
// for block argument name when OpAsmOpInterface::getAsmBlockArgumentNames is the default implementation
953+
// i.e. does nothing
954+
def BlockArgumentNameFromTypeInterfaceOp
955+
: TEST_Op<"block_argument_name_from_type_interface",
956+
[OpAsmOpInterface]> {
957+
let regions = (region AnyRegion:$body);
958+
let assemblyFormat = "regions attr-dict-with-keyword";
959+
}
960+
942961
// This is used to test the OpAsmOpInterface::getDefaultDialect() feature:
943962
// operations nested in a region under this op will drop the "test." dialect
944963
// prefix.

0 commit comments

Comments
 (0)