Skip to content

[mlir] Allow trailing digit for alias in AsmPrinter #127993

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 48 additions & 14 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -552,8 +552,11 @@ class SymbolAlias {
/// Print this alias to the given stream.
void print(raw_ostream &os) const {
os << (isType ? "!" : "#") << name;
if (suffixIndex)
if (suffixIndex) {
if (isdigit(name.back()))
os << '_';
os << suffixIndex;
}
}

/// Returns true if this is a type alias.
Expand Down Expand Up @@ -659,6 +662,12 @@ class AliasInitializer {
template <typename T>
void generateAlias(T symbol, InProgressAliasInfo &alias, bool canBeDeferred);

/// Uniques the given alias name within the printer by generating name index
/// used as alias name suffix.
static unsigned
uniqueAliasNameIndex(StringRef alias, llvm::StringMap<unsigned> &nameCounts,
llvm::StringSet<llvm::BumpPtrAllocator &> &usedAliases);

/// Given a collection of aliases and symbols, initialize a mapping from a
/// symbol to a given alias.
static void initializeAliases(
Expand Down Expand Up @@ -1025,8 +1034,7 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
/// the string needs to be modified in any way, the provided buffer is used to
/// store the new copy,
static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
StringRef allowedPunctChars = "$._-",
bool allowTrailingDigit = true) {
StringRef allowedPunctChars = "$._-") {
assert(!name.empty() && "Shouldn't have an empty name here");

auto validChar = [&](char ch) {
Expand All @@ -1053,14 +1061,6 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
return buffer;
}

// If the name ends with a trailing digit, add a '_' to avoid potential
// conflicts with autogenerated ID's.
if (!allowTrailingDigit && isdigit(name.back())) {
copyNameToBuffer();
buffer.push_back('_');
return buffer;
}

// Check to see that the name consists of only valid identifier characters.
for (char ch : name) {
if (!validChar(ch)) {
Expand All @@ -1073,6 +1073,36 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
return name;
}

unsigned AliasInitializer::uniqueAliasNameIndex(
StringRef alias, llvm::StringMap<unsigned> &nameCounts,
llvm::StringSet<llvm::BumpPtrAllocator &> &usedAliases) {
if (!usedAliases.count(alias)) {
usedAliases.insert(alias);
// 0 is not printed in SymbolAlias.
return 0;
}
// Otherwise, we had a conflict - probe until we find a unique name.
SmallString<64> probeAlias(alias);
// alias with trailing digit will be printed as _N
if (isdigit(alias.back()))
probeAlias.push_back('_');
// nameCounts start from 1 because 0 is not printed in SymbolAlias.
if (nameCounts[probeAlias] == 0)
nameCounts[probeAlias] = 1;
// This is guaranteed to terminate (and usually in a single iteration)
// because it generates new names by incrementing nameCounts.
while (true) {
unsigned nameIndex = nameCounts[probeAlias]++;
probeAlias += llvm::utostr(nameIndex);
if (!usedAliases.count(probeAlias)) {
usedAliases.insert(probeAlias);
return nameIndex;
}
// Reset probeAlias to the original alias for the next iteration.
probeAlias.resize(alias.size() + isdigit(alias.back()) ? 1 : 0);
}
}

/// Given a collection of aliases and symbols, initialize a mapping from a
/// symbol to a given alias.
void AliasInitializer::initializeAliases(
Expand All @@ -1084,12 +1114,17 @@ void AliasInitializer::initializeAliases(
return lhs.second < rhs.second;
});

// This keeps track of all of the non-numeric names that are in flight,
// allowing us to check for duplicates.
llvm::BumpPtrAllocator usedAliasAllocator;
llvm::StringSet<llvm::BumpPtrAllocator &> usedAliases(usedAliasAllocator);

llvm::StringMap<unsigned> nameCounts;
for (auto &[symbol, aliasInfo] : unprocessedAliases) {
if (!aliasInfo.alias)
continue;
StringRef alias = *aliasInfo.alias;
unsigned nameIndex = nameCounts[alias]++;
unsigned nameIndex = uniqueAliasNameIndex(alias, nameCounts, usedAliases);
symbolToAlias.insert(
{symbol, SymbolAlias(alias, nameIndex, aliasInfo.isType,
aliasInfo.canBeDeferred)});
Expand Down Expand Up @@ -1196,8 +1231,7 @@ void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,

SmallString<16> tempBuffer;
StringRef name =
sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-",
/*allowTrailingDigit=*/false);
sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-");
name = name.copy(aliasAllocator);
alias = InProgressAliasInfo(name);
}
Expand Down
22 changes: 14 additions & 8 deletions mlir/test/IR/print-attr-type-aliases.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// CHECK-DAG: #test2Ealias = "alias_test:dot_in_name"
"test.op"() {alias_test = "alias_test:dot_in_name"} : () -> ()

// CHECK-DAG: #test_alias0_ = "alias_test:trailing_digit"
// CHECK-DAG: #test_alias0 = "alias_test:trailing_digit"
"test.op"() {alias_test = "alias_test:trailing_digit"} : () -> ()

// CHECK-DAG: #_0_test_alias = "alias_test:prefixed_digit"
Expand All @@ -14,9 +14,15 @@
// CHECK-DAG: #_25test = "alias_test:prefixed_symbol"
"test.op"() {alias_test = "alias_test:prefixed_symbol"} : () -> ()

// CHECK-DAG: #test_alias_conflict0_ = "alias_test:sanitize_conflict_a"
// CHECK-DAG: #test_alias_conflict0_1 = "alias_test:sanitize_conflict_b"
"test.op"() {alias_test = ["alias_test:sanitize_conflict_a", "alias_test:sanitize_conflict_b"]} : () -> ()
// CHECK-DAG: #test_alias_conflict0 = "alias_test:trailing_digit_conflict_b"
// CHECK-DAG: #test_alias_conflict0_1 = "alias_test:trailing_digit_conflict_c"
// CHECK-DAG: #test_alias_conflict0_ = "alias_test:trailing_digit_conflict_d"
// CHECK-DAG: #test_alias_conflict0_1_1 = "alias_test:trailing_digit_conflict_e"
// CHECK-DAG: #test_alias_conflict0_1_2 = "alias_test:trailing_digit_conflict_f"
// CHECK-DAG: #test_alias_conflict0_1_ = "alias_test:trailing_digit_conflict_g"
// CHECK-DAG: #test_alias_conflict0_1_1_1 = "alias_test:trailing_digit_conflict_h"
// CHECK-DAG: #test_alias_conflict0_1_1_1_1 = "alias_test:trailing_digit_conflict_a"
"test.op"() {alias_test = ["alias_test:trailing_digit_conflict_a", "alias_test:trailing_digit_conflict_b", "alias_test:trailing_digit_conflict_c", "alias_test:trailing_digit_conflict_d", "alias_test:trailing_digit_conflict_e", "alias_test:trailing_digit_conflict_f", "alias_test:trailing_digit_conflict_g", "alias_test:trailing_digit_conflict_h"]} : () -> ()

// CHECK-DAG: !tuple = tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>
"test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>)
Expand All @@ -28,8 +34,8 @@
// CHECK-DAG: tensor<32xf32, #test_encoding>
"test.op"() : () -> tensor<32xf32, "alias_test:tensor_encoding">

// CHECK-DAG: !test_ui8_ = !test.int<unsigned, 8>
// CHECK-DAG: tensor<32x!test_ui8_>
// CHECK-DAG: !test_ui8 = !test.int<unsigned, 8>
// CHECK-DAG: tensor<32x!test_ui8>
"test.op"() : () -> tensor<32x!test.int<unsigned, 8>>

// CHECK-DAG: #[[LOC_NESTED:.+]] = loc("nested")
Expand All @@ -47,8 +53,8 @@
// -----

// Ensure self type parameters get considered for aliases.
// CHECK: !test_ui8_ = !test.int<unsigned, 8>
// CHECK: #test.attr_with_self_type_param : !test_ui8_
// CHECK: !test_ui8 = !test.int<unsigned, 8>
// CHECK: #test.attr_with_self_type_param : !test_ui8
"test.op"() {alias_test = #test.attr_with_self_type_param : !test.int<unsigned, 8> } : () -> ()

// -----
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/IR/recursive-type.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
// CHECK: ![[$NAME:.*]] = !test.test_rec_alias<name, !test.test_rec_alias<name>>
// CHECK: ![[$NAME5:.*]] = !test.test_rec_alias<name5, !test.test_rec_alias<name3, !test.test_rec_alias<name4, !test.test_rec_alias<name5>>>>
// CHECK: ![[$NAME2:.*]] = !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, !name5_>
// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, !name4_>
// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, !name5>
// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, !name4>

// CHECK-LABEL: @roundtrip
func.func @roundtrip() {
Expand Down
20 changes: 16 additions & 4 deletions mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,24 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
StringSwitch<std::optional<StringRef>>(strAttr.getValue())
.Case("alias_test:dot_in_name", StringRef("test.alias"))
.Case("alias_test:trailing_digit", StringRef("test_alias0"))
.Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
.Case("alias_test:prefixed_symbol", StringRef("%test"))
.Case("alias_test:sanitize_conflict_a",
.Case("alias_test:trailing_digit_conflict_a",
StringRef("test_alias_conflict0_1_1_1"))
.Case("alias_test:trailing_digit_conflict_b",
StringRef("test_alias_conflict0"))
.Case("alias_test:trailing_digit_conflict_c",
StringRef("test_alias_conflict0"))
.Case("alias_test:sanitize_conflict_b",
.Case("alias_test:trailing_digit_conflict_d",
StringRef("test_alias_conflict0_"))
.Case("alias_test:trailing_digit_conflict_e",
StringRef("test_alias_conflict0_1"))
.Case("alias_test:trailing_digit_conflict_f",
StringRef("test_alias_conflict0_1"))
.Case("alias_test:trailing_digit_conflict_g",
StringRef("test_alias_conflict0_1_"))
.Case("alias_test:trailing_digit_conflict_h",
StringRef("test_alias_conflict0_1_1"))
.Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
.Case("alias_test:prefixed_symbol", StringRef("%test"))
.Case("alias_test:tensor_encoding", StringRef("test_encoding"))
.Default(std::nullopt);
if (!aliasName)
Expand Down