Skip to content

[mlir] Make StringRefParameter roundtrippable #65813

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/IR/AttrTypeBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ class DefaultValuedParameter<string type, string value, string desc = ""> :
class StringRefParameter<string desc = "", string value = ""> :
AttrOrTypeParameter<"::llvm::StringRef", desc> {
let allocator = [{$_dst = $_allocator.copyInto($_self);}];
let printer = [{$_printer << '"' << $_self << '"';}];
let printer = [{$_printer.printString($_self);}];
let cppStorageType = "std::string";
let defaultValue = value;
}
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/IR/OpImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ class AsmPrinter {
/// has any special or non-printable characters in it.
virtual void printKeywordOrString(StringRef keyword);

/// Print the given string as a quoted string, escaping any special or
/// non-printable characters in it.
virtual void printString(StringRef string);

/// Print the given string as a symbol reference, i.e. a form representable by
/// a SymbolRefAttr. A symbol reference is represented as a string prefixed
/// with '@'. The reference is surrounded with ""'s and escaped if it has any
Expand Down
9 changes: 9 additions & 0 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,7 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
os << "%";
}
void printKeywordOrString(StringRef) override {}
void printString(StringRef) override {}
void printResourceHandle(const AsmDialectResourceHandle &) override {}
void printSymbolName(StringRef) override {}
void printSuccessor(Block *) override {}
Expand Down Expand Up @@ -919,6 +920,7 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
/// determining potential aliases.
void printFloat(const APFloat &) override {}
void printKeywordOrString(StringRef) override {}
void printString(StringRef) override {}
void printSymbolName(StringRef) override {}
void printResourceHandle(const AsmDialectResourceHandle &) override {}

Expand Down Expand Up @@ -2767,6 +2769,13 @@ void AsmPrinter::printKeywordOrString(StringRef keyword) {
::printKeywordOrString(keyword, impl->getStream());
}

void AsmPrinter::printString(StringRef keyword) {
assert(impl && "expected AsmPrinter::printString to be overriden");
*this << '"';
printEscapedString(keyword, getStream());
*this << '"';
}

void AsmPrinter::printSymbolName(StringRef symbolRef) {
assert(impl && "expected AsmPrinter::printSymbolName to be overriden");
::printSymbolReference(symbolRef, impl->getStream());
Expand Down
4 changes: 3 additions & 1 deletion mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ attributes {
// CHECK: !test.optional_type_string
// CHECK: !test.optional_type_string
// CHECK: !test.optional_type_string<"non default">
// CHECK: !test.optional_type_string<"containing\0A \22escape\22 characters\0F">

func.func private @test_roundtrip_default_parsers_struct(
!test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4>
Expand Down Expand Up @@ -111,5 +112,6 @@ func.func private @test_roundtrip_default_parsers_struct(
!test.custom_type_string<"bar" bar>,
!test.optional_type_string,
!test.optional_type_string<"default">,
!test.optional_type_string<"non default">
!test.optional_type_string<"non default">,
!test.optional_type_string<"containing\n \"escape\" characters\0f">
)