Skip to content

Commit b21dc21

Browse files
[mlir] Allow trailing digit for alias in AsmPrinter
1 parent 747588d commit b21dc21

File tree

4 files changed

+70
-28
lines changed

4 files changed

+70
-28
lines changed

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -547,8 +547,11 @@ class SymbolAlias {
547547
/// Print this alias to the given stream.
548548
void print(raw_ostream &os) const {
549549
os << (isType ? "!" : "#") << name;
550-
if (suffixIndex)
550+
if (suffixIndex) {
551+
if (isdigit(name.back()))
552+
os << '_';
551553
os << suffixIndex;
554+
}
552555
}
553556

554557
/// Returns true if this is a type alias.
@@ -1020,8 +1023,7 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
10201023
/// the string needs to be modified in any way, the provided buffer is used to
10211024
/// store the new copy,
10221025
static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
1023-
StringRef allowedPunctChars = "$._-",
1024-
bool allowTrailingDigit = true) {
1026+
StringRef allowedPunctChars = "$._-") {
10251027
assert(!name.empty() && "Shouldn't have an empty name here");
10261028

10271029
auto validChar = [&](char ch) {
@@ -1048,14 +1050,6 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
10481050
return buffer;
10491051
}
10501052

1051-
// If the name ends with a trailing digit, add a '_' to avoid potential
1052-
// conflicts with autogenerated ID's.
1053-
if (!allowTrailingDigit && isdigit(name.back())) {
1054-
copyNameToBuffer();
1055-
buffer.push_back('_');
1056-
return buffer;
1057-
}
1058-
10591053
// Check to see that the name consists of only valid identifier characters.
10601054
for (char ch : name) {
10611055
if (!validChar(ch)) {
@@ -1079,12 +1073,43 @@ void AliasInitializer::initializeAliases(
10791073
return lhs.second < rhs.second;
10801074
});
10811075

1076+
// This keeps track of all of the non-numeric names that are in flight,
1077+
// allowing us to check for duplicates.
1078+
llvm::BumpPtrAllocator usedAliasAllocator;
1079+
llvm::StringSet<llvm::BumpPtrAllocator &> usedAliases(usedAliasAllocator);
1080+
10821081
llvm::StringMap<unsigned> nameCounts;
10831082
for (auto &[symbol, aliasInfo] : unprocessedAliases) {
10841083
if (!aliasInfo.alias)
10851084
continue;
10861085
StringRef alias = *aliasInfo.alias;
1087-
unsigned nameIndex = nameCounts[alias]++;
1086+
// Get nameIndex that will not generate conflicting name.
1087+
unsigned nameIndex = 0;
1088+
if (!usedAliases.count(alias)) {
1089+
usedAliases.insert(alias);
1090+
} else {
1091+
// Otherwise, we had a conflict - probe until we find a unique name.
1092+
SmallString<64> probeAlias(alias);
1093+
// alias with trailing digit will be printed as _N
1094+
if (isdigit(alias.back())) {
1095+
probeAlias.push_back('_');
1096+
}
1097+
// nameCounts start from 1 because 0 is not printed in SymbolAlias.
1098+
if (nameCounts[probeAlias] == 0) {
1099+
nameCounts[probeAlias] = 1;
1100+
}
1101+
// This is guaranteed to terminate (and usually in a single iteration)
1102+
// because it generates new names by incrementing nameCounts.
1103+
while (true) {
1104+
nameIndex = nameCounts[probeAlias]++;
1105+
probeAlias += llvm::utostr(nameIndex);
1106+
if (!usedAliases.count(probeAlias)) {
1107+
usedAliases.insert(probeAlias);
1108+
break;
1109+
}
1110+
probeAlias.resize(alias.size() + isdigit(alias.back()) ? 1 : 0);
1111+
}
1112+
}
10881113
symbolToAlias.insert(
10891114
{symbol, SymbolAlias(alias, nameIndex, aliasInfo.isType,
10901115
aliasInfo.canBeDeferred)});
@@ -1191,8 +1216,7 @@ void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
11911216

11921217
SmallString<16> tempBuffer;
11931218
StringRef name =
1194-
sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-",
1195-
/*allowTrailingDigit=*/false);
1219+
sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-");
11961220
name = name.copy(aliasAllocator);
11971221
alias = InProgressAliasInfo(name);
11981222
}

mlir/test/IR/print-attr-type-aliases.mlir

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
// CHECK-DAG: #test2Ealias = "alias_test:dot_in_name"
66
"test.op"() {alias_test = "alias_test:dot_in_name"} : () -> ()
77

8-
// CHECK-DAG: #test_alias0_ = "alias_test:trailing_digit"
8+
// CHECK-DAG: #test_alias0 = "alias_test:trailing_digit"
99
"test.op"() {alias_test = "alias_test:trailing_digit"} : () -> ()
1010

1111
// CHECK-DAG: #_0_test_alias = "alias_test:prefixed_digit"
@@ -14,9 +14,15 @@
1414
// CHECK-DAG: #_25test = "alias_test:prefixed_symbol"
1515
"test.op"() {alias_test = "alias_test:prefixed_symbol"} : () -> ()
1616

17-
// CHECK-DAG: #test_alias_conflict0_ = "alias_test:sanitize_conflict_a"
18-
// CHECK-DAG: #test_alias_conflict0_1 = "alias_test:sanitize_conflict_b"
19-
"test.op"() {alias_test = ["alias_test:sanitize_conflict_a", "alias_test:sanitize_conflict_b"]} : () -> ()
17+
// CHECK-DAG: #test_alias_conflict0 = "alias_test:trailing_digit_conflict_b"
18+
// CHECK-DAG: #test_alias_conflict0_1 = "alias_test:trailing_digit_conflict_c"
19+
// CHECK-DAG: #test_alias_conflict0_ = "alias_test:trailing_digit_conflict_d"
20+
// CHECK-DAG: #test_alias_conflict0_1_1 = "alias_test:trailing_digit_conflict_e"
21+
// CHECK-DAG: #test_alias_conflict0_1_2 = "alias_test:trailing_digit_conflict_f"
22+
// CHECK-DAG: #test_alias_conflict0_1_ = "alias_test:trailing_digit_conflict_g"
23+
// CHECK-DAG: #test_alias_conflict0_1_1_1 = "alias_test:trailing_digit_conflict_h"
24+
// CHECK-DAG: #test_alias_conflict0_1_1_1_1 = "alias_test:trailing_digit_conflict_a"
25+
"test.op"() {alias_test = ["alias_test:trailing_digit_conflict_a", "alias_test:trailing_digit_conflict_b", "alias_test:trailing_digit_conflict_c", "alias_test:trailing_digit_conflict_d", "alias_test:trailing_digit_conflict_e", "alias_test:trailing_digit_conflict_f", "alias_test:trailing_digit_conflict_g", "alias_test:trailing_digit_conflict_h"]} : () -> ()
2026

2127
// CHECK-DAG: !tuple = tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>
2228
"test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>)
@@ -28,8 +34,8 @@
2834
// CHECK-DAG: tensor<32xf32, #test_encoding>
2935
"test.op"() : () -> tensor<32xf32, "alias_test:tensor_encoding">
3036

31-
// CHECK-DAG: !test_ui8_ = !test.int<unsigned, 8>
32-
// CHECK-DAG: tensor<32x!test_ui8_>
37+
// CHECK-DAG: !test_ui8 = !test.int<unsigned, 8>
38+
// CHECK-DAG: tensor<32x!test_ui8>
3339
"test.op"() : () -> tensor<32x!test.int<unsigned, 8>>
3440

3541
// CHECK-DAG: #[[LOC_NESTED:.+]] = loc("nested")
@@ -47,8 +53,8 @@
4753
// -----
4854

4955
// Ensure self type parameters get considered for aliases.
50-
// CHECK: !test_ui8_ = !test.int<unsigned, 8>
51-
// CHECK: #test.attr_with_self_type_param : !test_ui8_
56+
// CHECK: !test_ui8 = !test.int<unsigned, 8>
57+
// CHECK: #test.attr_with_self_type_param : !test_ui8
5258
"test.op"() {alias_test = #test.attr_with_self_type_param : !test.int<unsigned, 8> } : () -> ()
5359

5460
// -----

mlir/test/IR/recursive-type.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
// CHECK: ![[$NAME:.*]] = !test.test_rec_alias<name, !test.test_rec_alias<name>>
55
// CHECK: ![[$NAME5:.*]] = !test.test_rec_alias<name5, !test.test_rec_alias<name3, !test.test_rec_alias<name4, !test.test_rec_alias<name5>>>>
66
// CHECK: ![[$NAME2:.*]] = !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
7-
// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, !name5_>
8-
// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, !name4_>
7+
// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, !name5>
8+
// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, !name4>
99

1010
// CHECK-LABEL: @roundtrip
1111
func.func @roundtrip() {

mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,12 +187,24 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
187187
StringSwitch<std::optional<StringRef>>(strAttr.getValue())
188188
.Case("alias_test:dot_in_name", StringRef("test.alias"))
189189
.Case("alias_test:trailing_digit", StringRef("test_alias0"))
190-
.Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
191-
.Case("alias_test:prefixed_symbol", StringRef("%test"))
192-
.Case("alias_test:sanitize_conflict_a",
190+
.Case("alias_test:trailing_digit_conflict_a",
191+
StringRef("test_alias_conflict0_1_1_1"))
192+
.Case("alias_test:trailing_digit_conflict_b",
193+
StringRef("test_alias_conflict0"))
194+
.Case("alias_test:trailing_digit_conflict_c",
193195
StringRef("test_alias_conflict0"))
194-
.Case("alias_test:sanitize_conflict_b",
196+
.Case("alias_test:trailing_digit_conflict_d",
195197
StringRef("test_alias_conflict0_"))
198+
.Case("alias_test:trailing_digit_conflict_e",
199+
StringRef("test_alias_conflict0_1"))
200+
.Case("alias_test:trailing_digit_conflict_f",
201+
StringRef("test_alias_conflict0_1"))
202+
.Case("alias_test:trailing_digit_conflict_g",
203+
StringRef("test_alias_conflict0_1_"))
204+
.Case("alias_test:trailing_digit_conflict_h",
205+
StringRef("test_alias_conflict0_1_1"))
206+
.Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
207+
.Case("alias_test:prefixed_symbol", StringRef("%test"))
196208
.Case("alias_test:tensor_encoding", StringRef("test_encoding"))
197209
.Default(std::nullopt);
198210
if (!aliasName)

0 commit comments

Comments
 (0)