Skip to content

Commit 5ae19fa

Browse files
[mlir] Allow trailing digit for alias in AsmPrinter (#127993)
When generating aliases from `OpAsm{Dialect,Type,Attr}Interface`, the result would be sanitized and if the alias provided by the interface has a trailing digit, AsmPrinter would attach an underscore to it to presumably prevent confliction. #### Motivation There are two reasons to motivate the change from the old behavior to the proposed behavior 1. If the type/attribute can generate unique alias from its content, then the extra trailing underscore added by AsmPrinter will be strange ```mlir func.func @add(%ct: !ct_L0_) -> !ct_L0_ %ct_0 = bgv.add %ct, %ct : (!ct_L0_, !ct_L0_) -> !ct_L0_ %ct_1 = bgv.add %ct_0, %ct_0 : (!ct_L0_, !ct_L0_) -> !ct_L0_ %ct_2 = bgv.add %ct_1, %ct_1 : (!ct_L0_, !ct_L0_) -> !ct_L0_ return %ct_2 : !ct_L0_ } ``` Which aesthetically would be better if we have `(!ct_L0, !ct_L0) -> !ct_L0` 2. The Value name behavior is that, for the first instance, use no suffix `_N`, which can be similarly applied to alias name. See the IR above where the first one is called `%ct` and others are called `%ct_N`. See `uniqueValueName` for detail. #### Conflict detection ```mlir !test.type<a = 3> // suggest !name0 !test.type<a = 4> // suggest !name0 !test.another<b = 3> // suggest !name0_ !test.another<b = 4> // suggest !name0_ ``` The conflict detection is based on `nameCounts` in `initializeAliases`, where In the original way, the first two will get sanitized to `!name0_` and `initializeAlias` can assign unique id `0, 1, 2, 3` to them. In the current way, the `initializeAlias` uses `usedAliases` to track which name has been used, and use such information to generate a suffix id that will make the printed alias name unique. The result for the above example is `!name0, !name0_1, !name0_, !name0_2` now.
1 parent 35842f3 commit 5ae19fa

File tree

4 files changed

+80
-28
lines changed

4 files changed

+80
-28
lines changed

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -552,8 +552,11 @@ class SymbolAlias {
552552
/// Print this alias to the given stream.
553553
void print(raw_ostream &os) const {
554554
os << (isType ? "!" : "#") << name;
555-
if (suffixIndex)
555+
if (suffixIndex) {
556+
if (isdigit(name.back()))
557+
os << '_';
556558
os << suffixIndex;
559+
}
557560
}
558561

559562
/// Returns true if this is a type alias.
@@ -659,6 +662,12 @@ class AliasInitializer {
659662
template <typename T>
660663
void generateAlias(T symbol, InProgressAliasInfo &alias, bool canBeDeferred);
661664

665+
/// Uniques the given alias name within the printer by generating name index
666+
/// used as alias name suffix.
667+
static unsigned
668+
uniqueAliasNameIndex(StringRef alias, llvm::StringMap<unsigned> &nameCounts,
669+
llvm::StringSet<llvm::BumpPtrAllocator &> &usedAliases);
670+
662671
/// Given a collection of aliases and symbols, initialize a mapping from a
663672
/// symbol to a given alias.
664673
static void initializeAliases(
@@ -1025,8 +1034,7 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
10251034
/// the string needs to be modified in any way, the provided buffer is used to
10261035
/// store the new copy,
10271036
static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
1028-
StringRef allowedPunctChars = "$._-",
1029-
bool allowTrailingDigit = true) {
1037+
StringRef allowedPunctChars = "$._-") {
10301038
assert(!name.empty() && "Shouldn't have an empty name here");
10311039

10321040
auto validChar = [&](char ch) {
@@ -1053,14 +1061,6 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
10531061
return buffer;
10541062
}
10551063

1056-
// If the name ends with a trailing digit, add a '_' to avoid potential
1057-
// conflicts with autogenerated ID's.
1058-
if (!allowTrailingDigit && isdigit(name.back())) {
1059-
copyNameToBuffer();
1060-
buffer.push_back('_');
1061-
return buffer;
1062-
}
1063-
10641064
// Check to see that the name consists of only valid identifier characters.
10651065
for (char ch : name) {
10661066
if (!validChar(ch)) {
@@ -1073,6 +1073,36 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
10731073
return name;
10741074
}
10751075

1076+
unsigned AliasInitializer::uniqueAliasNameIndex(
1077+
StringRef alias, llvm::StringMap<unsigned> &nameCounts,
1078+
llvm::StringSet<llvm::BumpPtrAllocator &> &usedAliases) {
1079+
if (!usedAliases.count(alias)) {
1080+
usedAliases.insert(alias);
1081+
// 0 is not printed in SymbolAlias.
1082+
return 0;
1083+
}
1084+
// Otherwise, we had a conflict - probe until we find a unique name.
1085+
SmallString<64> probeAlias(alias);
1086+
// alias with trailing digit will be printed as _N
1087+
if (isdigit(alias.back()))
1088+
probeAlias.push_back('_');
1089+
// nameCounts start from 1 because 0 is not printed in SymbolAlias.
1090+
if (nameCounts[probeAlias] == 0)
1091+
nameCounts[probeAlias] = 1;
1092+
// This is guaranteed to terminate (and usually in a single iteration)
1093+
// because it generates new names by incrementing nameCounts.
1094+
while (true) {
1095+
unsigned nameIndex = nameCounts[probeAlias]++;
1096+
probeAlias += llvm::utostr(nameIndex);
1097+
if (!usedAliases.count(probeAlias)) {
1098+
usedAliases.insert(probeAlias);
1099+
return nameIndex;
1100+
}
1101+
// Reset probeAlias to the original alias for the next iteration.
1102+
probeAlias.resize(alias.size() + isdigit(alias.back()) ? 1 : 0);
1103+
}
1104+
}
1105+
10761106
/// Given a collection of aliases and symbols, initialize a mapping from a
10771107
/// symbol to a given alias.
10781108
void AliasInitializer::initializeAliases(
@@ -1084,12 +1114,17 @@ void AliasInitializer::initializeAliases(
10841114
return lhs.second < rhs.second;
10851115
});
10861116

1117+
// This keeps track of all of the non-numeric names that are in flight,
1118+
// allowing us to check for duplicates.
1119+
llvm::BumpPtrAllocator usedAliasAllocator;
1120+
llvm::StringSet<llvm::BumpPtrAllocator &> usedAliases(usedAliasAllocator);
1121+
10871122
llvm::StringMap<unsigned> nameCounts;
10881123
for (auto &[symbol, aliasInfo] : unprocessedAliases) {
10891124
if (!aliasInfo.alias)
10901125
continue;
10911126
StringRef alias = *aliasInfo.alias;
1092-
unsigned nameIndex = nameCounts[alias]++;
1127+
unsigned nameIndex = uniqueAliasNameIndex(alias, nameCounts, usedAliases);
10931128
symbolToAlias.insert(
10941129
{symbol, SymbolAlias(alias, nameIndex, aliasInfo.isType,
10951130
aliasInfo.canBeDeferred)});
@@ -1196,8 +1231,7 @@ void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
11961231

11971232
SmallString<16> tempBuffer;
11981233
StringRef name =
1199-
sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-",
1200-
/*allowTrailingDigit=*/false);
1234+
sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-");
12011235
name = name.copy(aliasAllocator);
12021236
alias = InProgressAliasInfo(name);
12031237
}

mlir/test/IR/print-attr-type-aliases.mlir

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
// CHECK-DAG: #test2Ealias = "alias_test:dot_in_name"
66
"test.op"() {alias_test = "alias_test:dot_in_name"} : () -> ()
77

8-
// CHECK-DAG: #test_alias0_ = "alias_test:trailing_digit"
8+
// CHECK-DAG: #test_alias0 = "alias_test:trailing_digit"
99
"test.op"() {alias_test = "alias_test:trailing_digit"} : () -> ()
1010

1111
// CHECK-DAG: #_0_test_alias = "alias_test:prefixed_digit"
@@ -14,9 +14,15 @@
1414
// CHECK-DAG: #_25test = "alias_test:prefixed_symbol"
1515
"test.op"() {alias_test = "alias_test:prefixed_symbol"} : () -> ()
1616

17-
// CHECK-DAG: #test_alias_conflict0_ = "alias_test:sanitize_conflict_a"
18-
// CHECK-DAG: #test_alias_conflict0_1 = "alias_test:sanitize_conflict_b"
19-
"test.op"() {alias_test = ["alias_test:sanitize_conflict_a", "alias_test:sanitize_conflict_b"]} : () -> ()
17+
// CHECK-DAG: #test_alias_conflict0 = "alias_test:trailing_digit_conflict_b"
18+
// CHECK-DAG: #test_alias_conflict0_1 = "alias_test:trailing_digit_conflict_c"
19+
// CHECK-DAG: #test_alias_conflict0_ = "alias_test:trailing_digit_conflict_d"
20+
// CHECK-DAG: #test_alias_conflict0_1_1 = "alias_test:trailing_digit_conflict_e"
21+
// CHECK-DAG: #test_alias_conflict0_1_2 = "alias_test:trailing_digit_conflict_f"
22+
// CHECK-DAG: #test_alias_conflict0_1_ = "alias_test:trailing_digit_conflict_g"
23+
// CHECK-DAG: #test_alias_conflict0_1_1_1 = "alias_test:trailing_digit_conflict_h"
24+
// CHECK-DAG: #test_alias_conflict0_1_1_1_1 = "alias_test:trailing_digit_conflict_a"
25+
"test.op"() {alias_test = ["alias_test:trailing_digit_conflict_a", "alias_test:trailing_digit_conflict_b", "alias_test:trailing_digit_conflict_c", "alias_test:trailing_digit_conflict_d", "alias_test:trailing_digit_conflict_e", "alias_test:trailing_digit_conflict_f", "alias_test:trailing_digit_conflict_g", "alias_test:trailing_digit_conflict_h"]} : () -> ()
2026

2127
// CHECK-DAG: !tuple = tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>
2228
"test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>)
@@ -28,8 +34,8 @@
2834
// CHECK-DAG: tensor<32xf32, #test_encoding>
2935
"test.op"() : () -> tensor<32xf32, "alias_test:tensor_encoding">
3036

31-
// CHECK-DAG: !test_ui8_ = !test.int<unsigned, 8>
32-
// CHECK-DAG: tensor<32x!test_ui8_>
37+
// CHECK-DAG: !test_ui8 = !test.int<unsigned, 8>
38+
// CHECK-DAG: tensor<32x!test_ui8>
3339
"test.op"() : () -> tensor<32x!test.int<unsigned, 8>>
3440

3541
// CHECK-DAG: #[[LOC_NESTED:.+]] = loc("nested")
@@ -47,8 +53,8 @@
4753
// -----
4854

4955
// Ensure self type parameters get considered for aliases.
50-
// CHECK: !test_ui8_ = !test.int<unsigned, 8>
51-
// CHECK: #test.attr_with_self_type_param : !test_ui8_
56+
// CHECK: !test_ui8 = !test.int<unsigned, 8>
57+
// CHECK: #test.attr_with_self_type_param : !test_ui8
5258
"test.op"() {alias_test = #test.attr_with_self_type_param : !test.int<unsigned, 8> } : () -> ()
5359

5460
// -----

mlir/test/IR/recursive-type.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
// CHECK: ![[$NAME:.*]] = !test.test_rec_alias<name, !test.test_rec_alias<name>>
55
// CHECK: ![[$NAME5:.*]] = !test.test_rec_alias<name5, !test.test_rec_alias<name3, !test.test_rec_alias<name4, !test.test_rec_alias<name5>>>>
66
// CHECK: ![[$NAME2:.*]] = !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
7-
// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, !name5_>
8-
// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, !name4_>
7+
// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, !name5>
8+
// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, !name4>
99

1010
// CHECK-LABEL: @roundtrip
1111
func.func @roundtrip() {

mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,12 +187,24 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
187187
StringSwitch<std::optional<StringRef>>(strAttr.getValue())
188188
.Case("alias_test:dot_in_name", StringRef("test.alias"))
189189
.Case("alias_test:trailing_digit", StringRef("test_alias0"))
190-
.Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
191-
.Case("alias_test:prefixed_symbol", StringRef("%test"))
192-
.Case("alias_test:sanitize_conflict_a",
190+
.Case("alias_test:trailing_digit_conflict_a",
191+
StringRef("test_alias_conflict0_1_1_1"))
192+
.Case("alias_test:trailing_digit_conflict_b",
193+
StringRef("test_alias_conflict0"))
194+
.Case("alias_test:trailing_digit_conflict_c",
193195
StringRef("test_alias_conflict0"))
194-
.Case("alias_test:sanitize_conflict_b",
196+
.Case("alias_test:trailing_digit_conflict_d",
195197
StringRef("test_alias_conflict0_"))
198+
.Case("alias_test:trailing_digit_conflict_e",
199+
StringRef("test_alias_conflict0_1"))
200+
.Case("alias_test:trailing_digit_conflict_f",
201+
StringRef("test_alias_conflict0_1"))
202+
.Case("alias_test:trailing_digit_conflict_g",
203+
StringRef("test_alias_conflict0_1_"))
204+
.Case("alias_test:trailing_digit_conflict_h",
205+
StringRef("test_alias_conflict0_1_1"))
206+
.Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
207+
.Case("alias_test:prefixed_symbol", StringRef("%test"))
196208
.Case("alias_test:tensor_encoding", StringRef("test_encoding"))
197209
.Default(std::nullopt);
198210
if (!aliasName)

0 commit comments

Comments
 (0)