Skip to content

Commit d951418

Browse files
[mlir] Allow trailing digit for alias in AsmPrinter
1 parent 747588d commit d951418

File tree

4 files changed

+36
-31
lines changed

4 files changed

+36
-31
lines changed

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -547,8 +547,11 @@ class SymbolAlias {
547547
/// Print this alias to the given stream.
548548
void print(raw_ostream &os) const {
549549
os << (isType ? "!" : "#") << name;
550-
if (suffixIndex)
550+
if (suffixIndex) {
551+
if (isdigit(name.back()))
552+
os << '_';
551553
os << suffixIndex;
554+
}
552555
}
553556

554557
/// Returns true if this is a type alias.
@@ -1020,8 +1023,7 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
10201023
/// the string needs to be modified in any way, the provided buffer is used to
10211024
/// store the new copy,
10221025
static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
1023-
StringRef allowedPunctChars = "$._-",
1024-
bool allowTrailingDigit = true) {
1026+
StringRef allowedPunctChars = "$._-") {
10251027
assert(!name.empty() && "Shouldn't have an empty name here");
10261028

10271029
auto validChar = [&](char ch) {
@@ -1048,14 +1050,6 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
10481050
return buffer;
10491051
}
10501052

1051-
// If the name ends with a trailing digit, add a '_' to avoid potential
1052-
// conflicts with autogenerated ID's.
1053-
if (!allowTrailingDigit && isdigit(name.back())) {
1054-
copyNameToBuffer();
1055-
buffer.push_back('_');
1056-
return buffer;
1057-
}
1058-
10591053
// Check to see that the name consists of only valid identifier characters.
10601054
for (char ch : name) {
10611055
if (!validChar(ch)) {
@@ -1084,7 +1078,16 @@ void AliasInitializer::initializeAliases(
10841078
if (!aliasInfo.alias)
10851079
continue;
10861080
StringRef alias = *aliasInfo.alias;
1087-
unsigned nameIndex = nameCounts[alias]++;
1081+
unsigned nameIndex;
1082+
// If the alias ends with a digit, we need to pretend as if it has trailing
1083+
// underscore to get a unique nameIndex.
1084+
if (isdigit(alias.back())) {
1085+
SmallString<16> aliasBuffer(alias);
1086+
aliasBuffer.push_back('_');
1087+
nameIndex = nameCounts[aliasBuffer]++;
1088+
} else {
1089+
nameIndex = nameCounts[alias]++;
1090+
}
10881091
symbolToAlias.insert(
10891092
{symbol, SymbolAlias(alias, nameIndex, aliasInfo.isType,
10901093
aliasInfo.canBeDeferred)});
@@ -1191,8 +1194,7 @@ void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
11911194

11921195
SmallString<16> tempBuffer;
11931196
StringRef name =
1194-
sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-",
1195-
/*allowTrailingDigit=*/false);
1197+
sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-");
11961198
name = name.copy(aliasAllocator);
11971199
alias = InProgressAliasInfo(name);
11981200
}

mlir/test/IR/print-attr-type-aliases.mlir

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
// CHECK-DAG: #test2Ealias = "alias_test:dot_in_name"
66
"test.op"() {alias_test = "alias_test:dot_in_name"} : () -> ()
77

8-
// CHECK-DAG: #test_alias0_ = "alias_test:trailing_digit"
8+
// CHECK-DAG: #test_alias0 = "alias_test:trailing_digit"
99
"test.op"() {alias_test = "alias_test:trailing_digit"} : () -> ()
1010

1111
// CHECK-DAG: #_0_test_alias = "alias_test:prefixed_digit"
@@ -14,12 +14,11 @@
1414
// CHECK-DAG: #_25test = "alias_test:prefixed_symbol"
1515
"test.op"() {alias_test = "alias_test:prefixed_symbol"} : () -> ()
1616

17-
// CHECK-DAG: #test_alias_conflict0_ = "alias_test:sanitize_conflict_a"
18-
// CHECK-DAG: #test_alias_conflict0_1 = "alias_test:sanitize_conflict_b"
19-
"test.op"() {alias_test = ["alias_test:sanitize_conflict_a", "alias_test:sanitize_conflict_b"]} : () -> ()
20-
21-
// CHECK-DAG: !tuple = tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>
22-
"test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>)
17+
// CHECK-DAG: #test_alias_conflict0 = "alias_test:trailing_digit_conflict_a"
18+
// CHECK-DAG: #test_alias_conflict0_1 = "alias_test:trailing_digit_conflict_b"
19+
// CHECK-DAG: #test_alias_conflict0_2 = "alias_test:trailing_digit_conflict_c"
20+
// CHECK-DAG: #test_alias_conflict0_3 = "alias_test:trailing_digit_conflict_d"
21+
"test.op"() {alias_test = ["alias_test:trailing_digit_conflict_a", "alias_test:trailing_digit_conflict_b", "alias_test:trailing_digit_conflict_c", "alias_test:trailing_digit_conflict_d"]} : () -> ()
2322

2423
// CHECK-DAG: !test_tuple = tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>
2524
"test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>)
@@ -28,8 +27,8 @@
2827
// CHECK-DAG: tensor<32xf32, #test_encoding>
2928
"test.op"() : () -> tensor<32xf32, "alias_test:tensor_encoding">
3029

31-
// CHECK-DAG: !test_ui8_ = !test.int<unsigned, 8>
32-
// CHECK-DAG: tensor<32x!test_ui8_>
30+
// CHECK-DAG: !test_ui8 = !test.int<unsigned, 8>
31+
// CHECK-DAG: tensor<32x!test_ui8>
3332
"test.op"() : () -> tensor<32x!test.int<unsigned, 8>>
3433

3534
// CHECK-DAG: #[[LOC_NESTED:.+]] = loc("nested")
@@ -47,8 +46,8 @@
4746
// -----
4847

4948
// Ensure self type parameters get considered for aliases.
50-
// CHECK: !test_ui8_ = !test.int<unsigned, 8>
51-
// CHECK: #test.attr_with_self_type_param : !test_ui8_
49+
// CHECK: !test_ui8 = !test.int<unsigned, 8>
50+
// CHECK: #test.attr_with_self_type_param : !test_ui8
5251
"test.op"() {alias_test = #test.attr_with_self_type_param : !test.int<unsigned, 8> } : () -> ()
5352

5453
// -----

mlir/test/IR/recursive-type.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
// CHECK: ![[$NAME:.*]] = !test.test_rec_alias<name, !test.test_rec_alias<name>>
55
// CHECK: ![[$NAME5:.*]] = !test.test_rec_alias<name5, !test.test_rec_alias<name3, !test.test_rec_alias<name4, !test.test_rec_alias<name5>>>>
66
// CHECK: ![[$NAME2:.*]] = !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
7-
// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, !name5_>
8-
// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, !name4_>
7+
// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, !name5>
8+
// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, !name4>
99

1010
// CHECK-LABEL: @roundtrip
1111
func.func @roundtrip() {

mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,12 +187,16 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
187187
StringSwitch<std::optional<StringRef>>(strAttr.getValue())
188188
.Case("alias_test:dot_in_name", StringRef("test.alias"))
189189
.Case("alias_test:trailing_digit", StringRef("test_alias0"))
190-
.Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
191-
.Case("alias_test:prefixed_symbol", StringRef("%test"))
192-
.Case("alias_test:sanitize_conflict_a",
190+
.Case("alias_test:trailing_digit_conflict_a",
191+
StringRef("test_alias_conflict0"))
192+
.Case("alias_test:trailing_digit_conflict_b",
193193
StringRef("test_alias_conflict0"))
194-
.Case("alias_test:sanitize_conflict_b",
194+
.Case("alias_test:trailing_digit_conflict_c",
195195
StringRef("test_alias_conflict0_"))
196+
.Case("alias_test:trailing_digit_conflict_d",
197+
StringRef("test_alias_conflict0_"))
198+
.Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
199+
.Case("alias_test:prefixed_symbol", StringRef("%test"))
196200
.Case("alias_test:tensor_encoding", StringRef("test_encoding"))
197201
.Default(std::nullopt);
198202
if (!aliasName)

0 commit comments

Comments
 (0)