Skip to content

Commit 9f252c4

Browse files
[mlir] Allow trailing digit for alias in AsmPrinter
1 parent aea7403 commit 9f252c4

File tree

4 files changed

+81
-28
lines changed

4 files changed

+81
-28
lines changed

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -552,8 +552,11 @@ class SymbolAlias {
552552
/// Print this alias to the given stream.
553553
void print(raw_ostream &os) const {
554554
os << (isType ? "!" : "#") << name;
555-
if (suffixIndex)
555+
if (suffixIndex) {
556+
if (isdigit(name.back()))
557+
os << '_';
556558
os << suffixIndex;
559+
}
557560
}
558561

559562
/// Returns true if this is a type alias.
@@ -659,6 +662,12 @@ class AliasInitializer {
659662
template <typename T>
660663
void generateAlias(T symbol, InProgressAliasInfo &alias, bool canBeDeferred);
661664

665+
/// Uniques the given alias name within the printer by generating name index
666+
/// used as alias name suffix.
667+
static unsigned
668+
uniqueAliasNameIndex(StringRef alias, llvm::StringMap<unsigned> &nameCounts,
669+
llvm::StringSet<llvm::BumpPtrAllocator &> &usedAliases);
670+
662671
/// Given a collection of aliases and symbols, initialize a mapping from a
663672
/// symbol to a given alias.
664673
static void initializeAliases(
@@ -1025,8 +1034,7 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
10251034
/// the string needs to be modified in any way, the provided buffer is used to
10261035
/// store the new copy,
10271036
static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
1028-
StringRef allowedPunctChars = "$._-",
1029-
bool allowTrailingDigit = true) {
1037+
StringRef allowedPunctChars = "$._-") {
10301038
assert(!name.empty() && "Shouldn't have an empty name here");
10311039

10321040
auto validChar = [&](char ch) {
@@ -1053,14 +1061,6 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
10531061
return buffer;
10541062
}
10551063

1056-
// If the name ends with a trailing digit, add a '_' to avoid potential
1057-
// conflicts with autogenerated ID's.
1058-
if (!allowTrailingDigit && isdigit(name.back())) {
1059-
copyNameToBuffer();
1060-
buffer.push_back('_');
1061-
return buffer;
1062-
}
1063-
10641064
// Check to see that the name consists of only valid identifier characters.
10651065
for (char ch : name) {
10661066
if (!validChar(ch)) {
@@ -1073,6 +1073,37 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
10731073
return name;
10741074
}
10751075

1076+
unsigned AliasInitializer::uniqueAliasNameIndex(
1077+
StringRef alias, llvm::StringMap<unsigned> &nameCounts,
1078+
llvm::StringSet<llvm::BumpPtrAllocator &> &usedAliases) {
1079+
// Get nameIndex that will not generate conflicting name.
1080+
unsigned nameIndex = 0;
1081+
if (!usedAliases.count(alias)) {
1082+
usedAliases.insert(alias);
1083+
} else {
1084+
// Otherwise, we had a conflict - probe until we find a unique name.
1085+
SmallString<64> probeAlias(alias);
1086+
// alias with trailing digit will be printed as _N
1087+
if (isdigit(alias.back()))
1088+
probeAlias.push_back('_');
1089+
// nameCounts start from 1 because 0 is not printed in SymbolAlias.
1090+
if (nameCounts[probeAlias] == 0)
1091+
nameCounts[probeAlias] = 1;
1092+
// This is guaranteed to terminate (and usually in a single iteration)
1093+
// because it generates new names by incrementing nameCounts.
1094+
while (true) {
1095+
nameIndex = nameCounts[probeAlias]++;
1096+
probeAlias += llvm::utostr(nameIndex);
1097+
if (!usedAliases.count(probeAlias)) {
1098+
usedAliases.insert(probeAlias);
1099+
break;
1100+
}
1101+
probeAlias.resize(alias.size() + isdigit(alias.back()) ? 1 : 0);
1102+
}
1103+
}
1104+
return nameIndex;
1105+
}
1106+
10761107
/// Given a collection of aliases and symbols, initialize a mapping from a
10771108
/// symbol to a given alias.
10781109
void AliasInitializer::initializeAliases(
@@ -1084,12 +1115,17 @@ void AliasInitializer::initializeAliases(
10841115
return lhs.second < rhs.second;
10851116
});
10861117

1118+
// This keeps track of all of the non-numeric names that are in flight,
1119+
// allowing us to check for duplicates.
1120+
llvm::BumpPtrAllocator usedAliasAllocator;
1121+
llvm::StringSet<llvm::BumpPtrAllocator &> usedAliases(usedAliasAllocator);
1122+
10871123
llvm::StringMap<unsigned> nameCounts;
10881124
for (auto &[symbol, aliasInfo] : unprocessedAliases) {
10891125
if (!aliasInfo.alias)
10901126
continue;
10911127
StringRef alias = *aliasInfo.alias;
1092-
unsigned nameIndex = nameCounts[alias]++;
1128+
unsigned nameIndex = uniqueAliasNameIndex(alias, nameCounts, usedAliases);
10931129
symbolToAlias.insert(
10941130
{symbol, SymbolAlias(alias, nameIndex, aliasInfo.isType,
10951131
aliasInfo.canBeDeferred)});
@@ -1196,8 +1232,7 @@ void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
11961232

11971233
SmallString<16> tempBuffer;
11981234
StringRef name =
1199-
sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-",
1200-
/*allowTrailingDigit=*/false);
1235+
sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-");
12011236
name = name.copy(aliasAllocator);
12021237
alias = InProgressAliasInfo(name);
12031238
}

mlir/test/IR/print-attr-type-aliases.mlir

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
// CHECK-DAG: #test2Ealias = "alias_test:dot_in_name"
66
"test.op"() {alias_test = "alias_test:dot_in_name"} : () -> ()
77

8-
// CHECK-DAG: #test_alias0_ = "alias_test:trailing_digit"
8+
// CHECK-DAG: #test_alias0 = "alias_test:trailing_digit"
99
"test.op"() {alias_test = "alias_test:trailing_digit"} : () -> ()
1010

1111
// CHECK-DAG: #_0_test_alias = "alias_test:prefixed_digit"
@@ -14,9 +14,15 @@
1414
// CHECK-DAG: #_25test = "alias_test:prefixed_symbol"
1515
"test.op"() {alias_test = "alias_test:prefixed_symbol"} : () -> ()
1616

17-
// CHECK-DAG: #test_alias_conflict0_ = "alias_test:sanitize_conflict_a"
18-
// CHECK-DAG: #test_alias_conflict0_1 = "alias_test:sanitize_conflict_b"
19-
"test.op"() {alias_test = ["alias_test:sanitize_conflict_a", "alias_test:sanitize_conflict_b"]} : () -> ()
17+
// CHECK-DAG: #test_alias_conflict0 = "alias_test:trailing_digit_conflict_b"
18+
// CHECK-DAG: #test_alias_conflict0_1 = "alias_test:trailing_digit_conflict_c"
19+
// CHECK-DAG: #test_alias_conflict0_ = "alias_test:trailing_digit_conflict_d"
20+
// CHECK-DAG: #test_alias_conflict0_1_1 = "alias_test:trailing_digit_conflict_e"
21+
// CHECK-DAG: #test_alias_conflict0_1_2 = "alias_test:trailing_digit_conflict_f"
22+
// CHECK-DAG: #test_alias_conflict0_1_ = "alias_test:trailing_digit_conflict_g"
23+
// CHECK-DAG: #test_alias_conflict0_1_1_1 = "alias_test:trailing_digit_conflict_h"
24+
// CHECK-DAG: #test_alias_conflict0_1_1_1_1 = "alias_test:trailing_digit_conflict_a"
25+
"test.op"() {alias_test = ["alias_test:trailing_digit_conflict_a", "alias_test:trailing_digit_conflict_b", "alias_test:trailing_digit_conflict_c", "alias_test:trailing_digit_conflict_d", "alias_test:trailing_digit_conflict_e", "alias_test:trailing_digit_conflict_f", "alias_test:trailing_digit_conflict_g", "alias_test:trailing_digit_conflict_h"]} : () -> ()
2026

2127
// CHECK-DAG: !tuple = tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>
2228
"test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>)
@@ -28,8 +34,8 @@
2834
// CHECK-DAG: tensor<32xf32, #test_encoding>
2935
"test.op"() : () -> tensor<32xf32, "alias_test:tensor_encoding">
3036

31-
// CHECK-DAG: !test_ui8_ = !test.int<unsigned, 8>
32-
// CHECK-DAG: tensor<32x!test_ui8_>
37+
// CHECK-DAG: !test_ui8 = !test.int<unsigned, 8>
38+
// CHECK-DAG: tensor<32x!test_ui8>
3339
"test.op"() : () -> tensor<32x!test.int<unsigned, 8>>
3440

3541
// CHECK-DAG: #[[LOC_NESTED:.+]] = loc("nested")
@@ -47,8 +53,8 @@
4753
// -----
4854

4955
// Ensure self type parameters get considered for aliases.
50-
// CHECK: !test_ui8_ = !test.int<unsigned, 8>
51-
// CHECK: #test.attr_with_self_type_param : !test_ui8_
56+
// CHECK: !test_ui8 = !test.int<unsigned, 8>
57+
// CHECK: #test.attr_with_self_type_param : !test_ui8
5258
"test.op"() {alias_test = #test.attr_with_self_type_param : !test.int<unsigned, 8> } : () -> ()
5359

5460
// -----

mlir/test/IR/recursive-type.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
// CHECK: ![[$NAME:.*]] = !test.test_rec_alias<name, !test.test_rec_alias<name>>
55
// CHECK: ![[$NAME5:.*]] = !test.test_rec_alias<name5, !test.test_rec_alias<name3, !test.test_rec_alias<name4, !test.test_rec_alias<name5>>>>
66
// CHECK: ![[$NAME2:.*]] = !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
7-
// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, !name5_>
8-
// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, !name4_>
7+
// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, !name5>
8+
// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, !name4>
99

1010
// CHECK-LABEL: @roundtrip
1111
func.func @roundtrip() {

mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,12 +187,24 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
187187
StringSwitch<std::optional<StringRef>>(strAttr.getValue())
188188
.Case("alias_test:dot_in_name", StringRef("test.alias"))
189189
.Case("alias_test:trailing_digit", StringRef("test_alias0"))
190-
.Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
191-
.Case("alias_test:prefixed_symbol", StringRef("%test"))
192-
.Case("alias_test:sanitize_conflict_a",
190+
.Case("alias_test:trailing_digit_conflict_a",
191+
StringRef("test_alias_conflict0_1_1_1"))
192+
.Case("alias_test:trailing_digit_conflict_b",
193+
StringRef("test_alias_conflict0"))
194+
.Case("alias_test:trailing_digit_conflict_c",
193195
StringRef("test_alias_conflict0"))
194-
.Case("alias_test:sanitize_conflict_b",
196+
.Case("alias_test:trailing_digit_conflict_d",
195197
StringRef("test_alias_conflict0_"))
198+
.Case("alias_test:trailing_digit_conflict_e",
199+
StringRef("test_alias_conflict0_1"))
200+
.Case("alias_test:trailing_digit_conflict_f",
201+
StringRef("test_alias_conflict0_1"))
202+
.Case("alias_test:trailing_digit_conflict_g",
203+
StringRef("test_alias_conflict0_1_"))
204+
.Case("alias_test:trailing_digit_conflict_h",
205+
StringRef("test_alias_conflict0_1_1"))
206+
.Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
207+
.Case("alias_test:prefixed_symbol", StringRef("%test"))
196208
.Case("alias_test:tensor_encoding", StringRef("test_encoding"))
197209
.Default(std::nullopt);
198210
if (!aliasName)

0 commit comments

Comments
 (0)