Skip to content

Commit 0f052a9

Browse files
authored
[mlir] Make StringRefParameter roundtrippable (#65813)
The current printer of `StringRefParameter` simply prints out the content of the string as is without escaping it any way. This leads to it generating invalid syntax, causing parser errors when read in again. This PR fixes that by adding `printString` to `AsmPrinter`, allowing one to print a string that can be parsed with `parseString`, using the same escaping syntax as `StringAttr`.
1 parent 743659b commit 0f052a9

File tree

4 files changed

+17
-2
lines changed

4 files changed

+17
-2
lines changed

mlir/include/mlir/IR/AttrTypeBase.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ class DefaultValuedParameter<string type, string value, string desc = ""> :
363363
class StringRefParameter<string desc = "", string value = ""> :
364364
AttrOrTypeParameter<"::llvm::StringRef", desc> {
365365
let allocator = [{$_dst = $_allocator.copyInto($_self);}];
366-
let printer = [{$_printer << '"' << $_self << '"';}];
366+
let printer = [{$_printer.printString($_self);}];
367367
let cppStorageType = "std::string";
368368
let defaultValue = value;
369369
}

mlir/include/mlir/IR/OpImplementation.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,10 @@ class AsmPrinter {
184184
/// has any special or non-printable characters in it.
185185
virtual void printKeywordOrString(StringRef keyword);
186186

187+
/// Print the given string as a quoted string, escaping any special or
188+
/// non-printable characters in it.
189+
virtual void printString(StringRef string);
190+
187191
/// Print the given string as a symbol reference, i.e. a form representable by
188192
/// a SymbolRefAttr. A symbol reference is represented as a string prefixed
189193
/// with '@'. The reference is surrounded with ""'s and escaped if it has any

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,7 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
779779
os << "%";
780780
}
781781
void printKeywordOrString(StringRef) override {}
782+
void printString(StringRef) override {}
782783
void printResourceHandle(const AsmDialectResourceHandle &) override {}
783784
void printSymbolName(StringRef) override {}
784785
void printSuccessor(Block *) override {}
@@ -919,6 +920,7 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
919920
/// determining potential aliases.
920921
void printFloat(const APFloat &) override {}
921922
void printKeywordOrString(StringRef) override {}
923+
void printString(StringRef) override {}
922924
void printSymbolName(StringRef) override {}
923925
void printResourceHandle(const AsmDialectResourceHandle &) override {}
924926

@@ -2767,6 +2769,13 @@ void AsmPrinter::printKeywordOrString(StringRef keyword) {
27672769
::printKeywordOrString(keyword, impl->getStream());
27682770
}
27692771

2772+
void AsmPrinter::printString(StringRef keyword) {
2773+
assert(impl && "expected AsmPrinter::printString to be overriden");
2774+
*this << '"';
2775+
printEscapedString(keyword, getStream());
2776+
*this << '"';
2777+
}
2778+
27702779
void AsmPrinter::printSymbolName(StringRef symbolRef) {
27712780
assert(impl && "expected AsmPrinter::printSymbolName to be overriden");
27722781
::printSymbolReference(symbolRef, impl->getStream());

mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ attributes {
7070
// CHECK: !test.optional_type_string
7171
// CHECK: !test.optional_type_string
7272
// CHECK: !test.optional_type_string<"non default">
73+
// CHECK: !test.optional_type_string<"containing\0A \22escape\22 characters\0F">
7374

7475
func.func private @test_roundtrip_default_parsers_struct(
7576
!test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4>
@@ -111,5 +112,6 @@ func.func private @test_roundtrip_default_parsers_struct(
111112
!test.custom_type_string<"bar" bar>,
112113
!test.optional_type_string,
113114
!test.optional_type_string<"default">,
114-
!test.optional_type_string<"non default">
115+
!test.optional_type_string<"non default">,
116+
!test.optional_type_string<"containing\n \"escape\" characters\0f">
115117
)

0 commit comments

Comments
 (0)