Skip to content

[mlir] Add OpAsmTypeInterface for pretty-print #121187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion mlir/include/mlir/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
add_mlir_interface(OpAsmInterface)
add_mlir_interface(SymbolInterfaces)
add_mlir_interface(RegionKindInterface)

set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
mlir_tablegen(OpAsmOpInterface.h.inc -gen-op-interface-decls)
mlir_tablegen(OpAsmOpInterface.cpp.inc -gen-op-interface-defs)
mlir_tablegen(OpAsmTypeInterface.h.inc -gen-type-interface-decls)
mlir_tablegen(OpAsmTypeInterface.cpp.inc -gen-type-interface-defs)
add_public_tablegen_target(MLIROpAsmInterfaceIncGen)
add_dependencies(mlir-generic-headers MLIROpAsmInterfaceIncGen)

set(LLVM_TARGET_DEFINITIONS BuiltinAttributes.td)
mlir_tablegen(BuiltinAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(BuiltinAttributes.cpp.inc -gen-attrdef-defs)
Expand Down
21 changes: 21 additions & 0 deletions mlir/include/mlir/IR/OpAsmInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,27 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
];
}

//===----------------------------------------------------------------------===//
// OpAsmTypeInterface
//===----------------------------------------------------------------------===//

def OpAsmTypeInterface : TypeInterface<"OpAsmTypeInterface"> {
let description = [{
This interface provides hooks to interact with the AsmPrinter and AsmParser
classes.
}];
let cppNamespace = "::mlir";

let methods = [
InterfaceMethod<[{
Get a name to use when printing a value of this type.
}],
"void", "getAsmName",
(ins "::mlir::OpAsmSetNameFn":$setNameFn), "", ";"
Comment on lines +127 to +128
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you instead follow the same API pattern as the getAlias function on OpAsmDialectInterface?

>,
];
}

//===----------------------------------------------------------------------===//
// ResourceHandleParameter
//===----------------------------------------------------------------------===//
Expand Down
11 changes: 8 additions & 3 deletions mlir/include/mlir/IR/OpImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ class AsmParser {
virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
virtual OptionalParseResult parseOptionalDecimalInteger(APInt &result) = 0;

private:
private:
template <typename IntT, typename ParseFn>
OptionalParseResult parseOptionalIntegerAndCheck(IntT &result,
ParseFn &&parseFn) {
Expand All @@ -756,7 +756,7 @@ class AsmParser {
return success();
}

public:
public:
template <typename IntT>
OptionalParseResult parseOptionalInteger(IntT &result) {
return parseOptionalIntegerAndCheck(
Expand Down Expand Up @@ -1727,6 +1727,10 @@ class OpAsmParser : public AsmParser {
// Dialect OpAsm interface.
//===--------------------------------------------------------------------===//

/// A functor used to set the name of the result. See 'getAsmResultNames' below
/// for more details.
using OpAsmSetNameFn = function_ref<void(StringRef)>;

/// A functor used to set the name of the start of a result group of an
/// operation. See 'getAsmResultNames' below for more details.
using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
Expand Down Expand Up @@ -1820,7 +1824,8 @@ ParseResult parseDimensionList(OpAsmParser &parser,
//===--------------------------------------------------------------------===//

/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
#include "mlir/IR/OpAsmInterface.h.inc"
#include "mlir/IR/OpAsmOpInterface.h.inc"
#include "mlir/IR/OpAsmTypeInterface.h.inc"

namespace llvm {
template <>
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ void OpAsmPrinter::printFunctionalType(Operation *op) {
//===----------------------------------------------------------------------===//

/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
#include "mlir/IR/OpAsmInterface.cpp.inc"
#include "mlir/IR/OpAsmOpInterface.cpp.inc"
#include "mlir/IR/OpAsmTypeInterface.cpp.inc"

LogicalResult
OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const {
Expand Down
24 changes: 24 additions & 0 deletions mlir/test/IR/op-asm-interface.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s

//===----------------------------------------------------------------------===//
// Test OpAsmOpInterface
//===----------------------------------------------------------------------===//

func.func @result_name_from_op_asm_type_interface() {
// CHECK-LABEL: @result_name_from_op_asm_type_interface
// CHECK: %op_asm_type_interface
%0 = "test.result_name_from_type"() : () -> !test.op_asm_type_interface
return
}

// -----

func.func @block_argument_name_from_op_asm_type_interface() {
// CHECK-LABEL: @block_argument_name_from_op_asm_type_interface
// CHECK: ^bb0(%op_asm_type_interface
test.block_argument_name_from_type {
^bb0(%arg0: !test.op_asm_type_interface):
"test.terminator"() : ()->()
}
return
}
32 changes: 32 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOpDefs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,38 @@ void CustomResultsNameOp::getAsmResultNames(
setNameFn(getResult(i), str.getValue());
}

//===----------------------------------------------------------------------===//
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was expecting this to be backed into the AsmPrinter when it determines aliases, what is your current plan for implementation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was expecting this to be backed into the AsmPrinter when it determines aliases, what is your current plan for implementation?

Can you split this into a different PR? It'd be nice to add the TypeInterface by itself before using it for value names.

Previous comment asked to split the integration with AsmPrinter into another PR. For testing purpose of this PR, I had to use a test op with custom OpAsmOpInterface to see if such interface is effective or not.

My plan after this PR is to

  • A PR on integrating OpAsmTypeInterface with AsmPrinter, for value names
  • A PR on the alias part for OpAsmTypeInterface
  • A PR on integrating OpAsmTypeInterface with AsmPrinter, for alias
  • A PR on definition of OpAsmAttrInterface, for alias
  • A PR on integrating OpAsmAttrInterface with AsmPrinter
  • A PR on documentation

// ResultNameFromTypeOp
//===----------------------------------------------------------------------===//

void ResultNameFromTypeOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
auto result = getResult();
auto setResultNameFn = [&](::llvm::StringRef name) {
setNameFn(result, name);
};
auto opAsmTypeInterface =
::mlir::cast<::mlir::OpAsmTypeInterface>(result.getType());
opAsmTypeInterface.getAsmName(setResultNameFn);
}

//===----------------------------------------------------------------------===//
// BlockArgumentNameFromTypeOp
//===----------------------------------------------------------------------===//

void BlockArgumentNameFromTypeOp::getAsmBlockArgumentNames(
::mlir::Region &region, ::mlir::OpAsmSetValueNameFn setNameFn) {
for (auto &block : region) {
for (auto arg : block.getArguments()) {
if (auto opAsmTypeInterface =
::mlir::dyn_cast<::mlir::OpAsmTypeInterface>(arg.getType())) {
auto setArgNameFn = [&](StringRef name) { setNameFn(arg, name); };
opAsmTypeInterface.getAsmName(setArgNameFn);
}
}
}
}

//===----------------------------------------------------------------------===//
// ResultTypeWithTraitOp
//===----------------------------------------------------------------------===//
Expand Down
15 changes: 15 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -924,6 +924,21 @@ def CustomResultsNameOp
let results = (outs Variadic<AnyInteger>:$r);
}

// This is used to test OpAsmTypeInterface::getAsmName for op result name,
def ResultNameFromTypeOp
: TEST_Op<"result_name_from_type",
[DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let results = (outs AnyType:$r);
}

// This is used to test OpAsmTypeInterface::getAsmName for block argument,
def BlockArgumentNameFromTypeOp
: TEST_Op<"block_argument_name_from_type",
[DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
let regions = (region AnyRegion:$body);
let assemblyFormat = "regions attr-dict-with-keyword";
}

// This is used to test the OpAsmOpInterface::getDefaultDialect() feature:
// operations nested in a region under this op will drop the "test." dialect
// prefix.
Expand Down
5 changes: 5 additions & 0 deletions mlir/test/lib/Dialect/Test/TestTypeDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -398,4 +398,9 @@ def TestTypeVerification : Test_Type<"TestTypeVerification"> {
let assemblyFormat = "`<` $param `>`";
}

def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface",
[DeclareTypeInterfaceMethods<OpAsmTypeInterface, ["getAsmName"]>]> {
let mnemonic = "op_asm_type_interface";
}

#endif // TEST_TYPEDEFS
5 changes: 5 additions & 0 deletions mlir/test/lib/Dialect/Test/TestTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,3 +532,8 @@ void TestRecursiveAliasType::print(AsmPrinter &printer) const {
}
printer << ">";
}

void TestTypeOpAsmTypeInterfaceType::getAsmName(
OpAsmSetNameFn setNameFn) const {
setNameFn("op_asm_type_interface");
}
Loading