Skip to content

Commit 0fa78e4

Browse files
[mlir] Add OpAsmTypeInterface for pretty-print
1 parent f51db95 commit 0fa78e4

File tree

8 files changed

+148
-7
lines changed

8 files changed

+148
-7
lines changed

mlir/include/mlir/IR/CMakeLists.txt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
1-
add_mlir_interface(OpAsmInterface)
21
add_mlir_interface(SymbolInterfaces)
32
add_mlir_interface(RegionKindInterface)
43

4+
set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
5+
mlir_tablegen(OpAsmOpInterface.h.inc -gen-op-interface-decls)
6+
mlir_tablegen(OpAsmOpInterface.cpp.inc -gen-op-interface-defs)
7+
mlir_tablegen(OpAsmTypeInterface.h.inc -gen-type-interface-decls)
8+
mlir_tablegen(OpAsmTypeInterface.cpp.inc -gen-type-interface-defs)
9+
add_public_tablegen_target(MLIROpAsmInterfaceIncGen)
10+
add_dependencies(mlir-generic-headers MLIROpAsmInterfaceIncGen)
11+
512
set(LLVM_TARGET_DEFINITIONS BuiltinAttributes.td)
613
mlir_tablegen(BuiltinAttributes.h.inc -gen-attrdef-decls)
714
mlir_tablegen(BuiltinAttributes.cpp.inc -gen-attrdef-defs)

mlir/include/mlir/IR/OpAsmInterface.td

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,27 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
5050
```mlir
5151
%first_result, %middle_results:2, %0 = "my.op" ...
5252
```
53+
54+
The default implementation uses `OpAsmTypeInterface` to get the name for
55+
each result from its type.
56+
57+
If not all of the result types have `OpAsmTypeInterface`, the default implementation
58+
does nothing, as the packing behavior should be decided by the operation itself.
5359
}],
5460
"void", "getAsmResultNames",
5561
(ins "::mlir::OpAsmSetValueNameFn":$setNameFn),
56-
"", "return;"
62+
"", [{
63+
bool hasOpAsmTypeInterface = llvm::all_of($_op->getResults(), [&](Value result) {
64+
return ::mlir::isa<::mlir::OpAsmTypeInterface>(result.getType());
65+
});
66+
if (!hasOpAsmTypeInterface)
67+
return;
68+
for (auto result : $_op->getResults()) {
69+
auto setResultNameFn = [&](StringRef name) { setNameFn(result, name); };
70+
auto opAsmTypeInterface = ::mlir::cast<::mlir::OpAsmTypeInterface>(result.getType());
71+
opAsmTypeInterface.getAsmName(setResultNameFn);
72+
}
73+
}]
5774
>,
5875
InterfaceMethod<[{
5976
Get a special name to use when printing the block arguments for a region
@@ -64,7 +81,16 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
6481
"::mlir::Region&":$region,
6582
"::mlir::OpAsmSetValueNameFn":$setNameFn
6683
),
67-
"", "return;"
84+
"", [{
85+
for (auto &block : region) {
86+
for (auto arg : block.getArguments()) {
87+
if (auto opAsmTypeInterface = ::mlir::dyn_cast<::mlir::OpAsmTypeInterface>(arg.getType())) {
88+
auto setArgNameFn = [&](StringRef name) { setNameFn(arg, name); };
89+
opAsmTypeInterface.getAsmName(setArgNameFn);
90+
}
91+
}
92+
}
93+
}]
6894
>,
6995
InterfaceMethod<[{
7096
Get the name to use for a given block inside a region attached to this
@@ -109,6 +135,31 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
109135
];
110136
}
111137

138+
//===----------------------------------------------------------------------===//
139+
// OpAsmTypeInterface
140+
//===----------------------------------------------------------------------===//
141+
142+
def OpAsmTypeInterface : TypeInterface<"OpAsmTypeInterface"> {
143+
let description = [{
144+
This interface provides hooks to interact with the AsmPrinter and AsmParser
145+
classes.
146+
}];
147+
let cppNamespace = "::mlir";
148+
149+
let methods = [
150+
InterfaceMethod<[{
151+
Get a special name to use when printing value of this type.
152+
153+
For example, the default implementation of OpAsmOpInterface
154+
will respect this method when printing the results of an operation
155+
and/or block argument of it.
156+
}],
157+
"void", "getAsmName",
158+
(ins "::mlir::OpAsmSetNameFn":$setNameFn), "", ";"
159+
>,
160+
];
161+
}
162+
112163
//===----------------------------------------------------------------------===//
113164
// ResourceHandleParameter
114165
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/OpImplementation.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,7 @@ class AsmParser {
734734
virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
735735
virtual OptionalParseResult parseOptionalDecimalInteger(APInt &result) = 0;
736736

737-
private:
737+
private:
738738
template <typename IntT, typename ParseFn>
739739
OptionalParseResult parseOptionalIntegerAndCheck(IntT &result,
740740
ParseFn &&parseFn) {
@@ -756,7 +756,7 @@ class AsmParser {
756756
return success();
757757
}
758758

759-
public:
759+
public:
760760
template <typename IntT>
761761
OptionalParseResult parseOptionalInteger(IntT &result) {
762762
return parseOptionalIntegerAndCheck(
@@ -1727,6 +1727,10 @@ class OpAsmParser : public AsmParser {
17271727
// Dialect OpAsm interface.
17281728
//===--------------------------------------------------------------------===//
17291729

1730+
/// A functor used to set the name of the result. See 'getAsmResultNames' below
1731+
/// for more details.
1732+
using OpAsmSetNameFn = function_ref<void(StringRef)>;
1733+
17301734
/// A functor used to set the name of the start of a result group of an
17311735
/// operation. See 'getAsmResultNames' below for more details.
17321736
using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
@@ -1820,7 +1824,9 @@ ParseResult parseDimensionList(OpAsmParser &parser,
18201824
//===--------------------------------------------------------------------===//
18211825

18221826
/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
1823-
#include "mlir/IR/OpAsmInterface.h.inc"
1827+
#include "mlir/IR/OpAsmTypeInterface.h.inc"
1828+
// put Attr/Type before Op
1829+
#include "mlir/IR/OpAsmOpInterface.h.inc"
18241830

18251831
namespace llvm {
18261832
template <>

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ void OpAsmPrinter::printFunctionalType(Operation *op) {
125125
//===----------------------------------------------------------------------===//
126126

127127
/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
128-
#include "mlir/IR/OpAsmInterface.cpp.inc"
128+
#include "mlir/IR/OpAsmOpInterface.cpp.inc"
129+
#include "mlir/IR/OpAsmTypeInterface.cpp.inc"
129130

130131
LogicalResult
131132
OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const {

mlir/test/IR/op-asm-interface.mlir

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s
2+
3+
//===----------------------------------------------------------------------===//
4+
// Test OpAsmOpInterface
5+
//===----------------------------------------------------------------------===//
6+
7+
func.func @result_name_from_op_asm_type_interface() {
8+
// CHECK-LABEL: @result_name_from_op_asm_type_interface
9+
// CHECK: %op_asm_type_interface
10+
%0 = "test.default_result_name"() : () -> !test.op_asm_type_interface
11+
return
12+
}
13+
14+
// -----
15+
16+
func.func @result_name_pack_from_op_asm_type_interface() {
17+
// CHECK-LABEL: @result_name_pack_from_op_asm_type_interface
18+
// CHECK: %op_asm_type_interface{{.*}}, %op_asm_type_interface{{.*}}
19+
// CHECK-NOT: :2
20+
%0:2 = "test.default_result_name_packing"() : () -> (!test.op_asm_type_interface, !test.op_asm_type_interface)
21+
return
22+
}
23+
24+
// -----
25+
26+
func.func @result_name_pack_do_nothing() {
27+
// CHECK-LABEL: @result_name_pack_do_nothing
28+
// CHECK: %0:2
29+
%0:2 = "test.default_result_name_packing"() : () -> (i32, !test.op_asm_type_interface)
30+
return
31+
}
32+
33+
// -----
34+
35+
func.func @block_argument_name_from_op_asm_type_interface() {
36+
// CHECK-LABEL: @block_argument_name_from_op_asm_type_interface
37+
// CHECK: ^bb0(%op_asm_type_interface
38+
test.default_block_argument_name {
39+
^bb0(%arg0: !test.op_asm_type_interface):
40+
"test.terminator"() : ()->()
41+
}
42+
return
43+
}

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -924,6 +924,30 @@ def CustomResultsNameOp
924924
let results = (outs Variadic<AnyInteger>:$r);
925925
}
926926

927+
// This is used to test default implementation of OpAsmOpInterface::getAsmResultNames,
928+
// which uses OpAsmTypeInterface if available.
929+
def DefaultResultsNameOp
930+
: TEST_Op<"default_result_name",
931+
[OpAsmOpInterface]> {
932+
let results = (outs AnyType:$r);
933+
}
934+
935+
// This is used to test default implementation of OpAsmOpInterface::getAsmResultNames,
936+
// when there are multiple results, and not all of their type has OpAsmTypeInterface,
937+
// it should not set result name from OpAsmTypeInterface.
938+
def DefaultResultsNamePackingOp
939+
: TEST_Op<"default_result_name_packing",
940+
[OpAsmOpInterface]> {
941+
let results = (outs AnyType:$r, AnyType:$s);
942+
}
943+
944+
// This is used to test default implementation of OpAsmOpInterface::getAsmBlockArgumentNames,
945+
def DefaultBlockArgumentNameOp : TEST_Op<"default_block_argument_name",
946+
[OpAsmOpInterface]> {
947+
let regions = (region AnyRegion:$body);
948+
let assemblyFormat = "regions attr-dict-with-keyword";
949+
}
950+
927951
// This is used to test the OpAsmOpInterface::getDefaultDialect() feature:
928952
// operations nested in a region under this op will drop the "test." dialect
929953
// prefix.

mlir/test/lib/Dialect/Test/TestTypeDefs.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,4 +398,8 @@ def TestTypeVerification : Test_Type<"TestTypeVerification"> {
398398
let assemblyFormat = "`<` $param `>`";
399399
}
400400

401+
def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface", [DeclareTypeInterfaceMethods<OpAsmTypeInterface, ["getAsmName"]>]> {
402+
let mnemonic = "op_asm_type_interface";
403+
}
404+
401405
#endif // TEST_TYPEDEFS

mlir/test/lib/Dialect/Test/TestTypes.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,3 +531,8 @@ void TestRecursiveAliasType::print(AsmPrinter &printer) const {
531531
}
532532
printer << ">";
533533
}
534+
535+
void TestTypeOpAsmTypeInterfaceType::getAsmName(
536+
OpAsmSetNameFn setNameFn) const {
537+
setNameFn("op_asm_type_interface");
538+
}

0 commit comments

Comments
 (0)