Skip to content

Commit 0b82313

Browse files
[mlir] Integrate OpAsmTypeInterface with AsmPrinter
1 parent 3a439e2 commit 0b82313

File tree

3 files changed

+90
-0
lines changed

3 files changed

+90
-0
lines changed

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1536,10 +1536,13 @@ StringRef maybeGetValueNameFromLoc(Value value, StringRef name) {
15361536
} // namespace
15371537

15381538
void SSANameState::numberValuesInRegion(Region &region) {
1539+
// indicate whether OpAsmOpInterface set a name.
1540+
bool opAsmOpInterfaceUsed = false;
15391541
auto setBlockArgNameFn = [&](Value arg, StringRef name) {
15401542
assert(!valueIDs.count(arg) && "arg numbered multiple times");
15411543
assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == &region &&
15421544
"arg not defined in current region");
1545+
opAsmOpInterfaceUsed = true;
15431546
if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
15441547
name = maybeGetValueNameFromLoc(arg, name);
15451548
setValueName(arg, name);
@@ -1549,6 +1552,18 @@ void SSANameState::numberValuesInRegion(Region &region) {
15491552
if (Operation *op = region.getParentOp()) {
15501553
if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
15511554
asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
1555+
if (!opAsmOpInterfaceUsed) {
1556+
// If the OpAsmOpInterface didn't set a name, get name from the type.
1557+
for (auto arg : region.getArguments()) {
1558+
if (auto typeInterface =
1559+
mlir::dyn_cast<OpAsmTypeInterface>(arg.getType())) {
1560+
auto setNameFn = [&](StringRef name) {
1561+
setBlockArgNameFn(arg, name);
1562+
};
1563+
typeInterface.getAsmName(setNameFn);
1564+
}
1565+
}
1566+
}
15521567
}
15531568
}
15541569

@@ -1598,9 +1613,12 @@ void SSANameState::numberValuesInBlock(Block &block) {
15981613
void SSANameState::numberValuesInOp(Operation &op) {
15991614
// Function used to set the special result names for the operation.
16001615
SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
1616+
// indicating whether OpAsmOpInterface set a name.
1617+
bool opAsmOpInterfaceUsed = false;
16011618
auto setResultNameFn = [&](Value result, StringRef name) {
16021619
assert(!valueIDs.count(result) && "result numbered multiple times");
16031620
assert(result.getDefiningOp() == &op && "result not defined by 'op'");
1621+
opAsmOpInterfaceUsed = true;
16041622
if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
16051623
name = maybeGetValueNameFromLoc(result, name);
16061624
setValueName(result, name);
@@ -1630,6 +1648,23 @@ void SSANameState::numberValuesInOp(Operation &op) {
16301648
asmInterface.getAsmBlockNames(setBlockNameFn);
16311649
asmInterface.getAsmResultNames(setResultNameFn);
16321650
}
1651+
if (!opAsmOpInterfaceUsed) {
1652+
// If the OpAsmOpInterface didn't set a name, and
1653+
// all results have OpAsmTypeInterface, get names from types.
1654+
bool allHaveOpAsmTypeInterface =
1655+
llvm::all_of(op.getResultTypes(), [&](Type type) {
1656+
return mlir::isa<OpAsmTypeInterface>(type);
1657+
});
1658+
if (allHaveOpAsmTypeInterface) {
1659+
for (auto result : op.getResults()) {
1660+
auto typeInterface = mlir::cast<OpAsmTypeInterface>(result.getType());
1661+
auto setNameFn = [&](StringRef name) {
1662+
setResultNameFn(result, name);
1663+
};
1664+
typeInterface.getAsmName(setNameFn);
1665+
}
1666+
}
1667+
}
16331668
}
16341669

16351670
unsigned numResults = op.getNumResults();

mlir/test/IR/op-asm-interface.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,39 @@ func.func @block_argument_name_from_op_asm_type_interface() {
2222
}
2323
return
2424
}
25+
26+
// -----
27+
28+
//===----------------------------------------------------------------------===//
29+
// Test OpAsmTypeInterface
30+
//===----------------------------------------------------------------------===//
31+
32+
func.func @result_name_from_op_asm_type_interface_asmprinter() {
33+
// CHECK-LABEL: @result_name_from_op_asm_type_interface_asmprinter
34+
// CHECK: %op_asm_type_interface
35+
%0 = "test.result_name_from_type_interface"() : () -> !test.op_asm_type_interface
36+
return
37+
}
38+
39+
// -----
40+
41+
// i1 does not have OpAsmTypeInterface, should not get named.
42+
func.func @result_name_from_op_asm_type_interface_not_all() {
43+
// CHECK-LABEL: @result_name_from_op_asm_type_interface_not_all
44+
// CHECK-NOT: %op_asm_type_interface
45+
// CHECK: %0:2
46+
%0:2 = "test.result_name_from_type_interface"() : () -> (!test.op_asm_type_interface, i1)
47+
return
48+
}
49+
50+
// -----
51+
52+
func.func @block_argument_name_from_op_asm_type_interface_asmprinter() {
53+
// CHECK-LABEL: @block_argument_name_from_op_asm_type_interface_asmprinter
54+
// CHECK: ^bb0(%op_asm_type_interface
55+
test.block_argument_name_from_type_interface {
56+
^bb0(%arg0: !test.op_asm_type_interface):
57+
"test.terminator"() : ()->()
58+
}
59+
return
60+
}

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -939,6 +939,25 @@ def BlockArgumentNameFromTypeOp
939939
let assemblyFormat = "regions attr-dict-with-keyword";
940940
}
941941

942+
// This is used to test OpAsmTypeInterface::getAsmName's integration with AsmPrinter
943+
// for op result name when OpAsmOpInterface::getAsmResultNames is the default implementation
944+
// i.e. does nothing.
945+
def ResultNameFromTypeInterfaceOp
946+
: TEST_Op<"result_name_from_type_interface",
947+
[OpAsmOpInterface]> {
948+
let results = (outs Variadic<AnyType>:$r);
949+
}
950+
951+
// This is used to test OpAsmTypeInterface::getAsmName's integration with AsmPrinter
952+
// for block argument name when OpAsmOpInterface::getAsmBlockArgumentNames is the default implementation
953+
// i.e. does nothing.
954+
def BlockArgumentNameFromTypeInterfaceOp
955+
: TEST_Op<"block_argument_name_from_type_interface",
956+
[OpAsmOpInterface]> {
957+
let regions = (region AnyRegion:$body);
958+
let assemblyFormat = "regions attr-dict-with-keyword";
959+
}
960+
942961
// This is used to test the OpAsmOpInterface::getDefaultDialect() feature:
943962
// operations nested in a region under this op will drop the "test." dialect
944963
// prefix.

0 commit comments

Comments
 (0)