Skip to content

[llvm][aarch64] Fix Arm64EC name mangling algorithm #115567

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/include/llvm/Demangle/Demangle.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define LLVM_DEMANGLE_DEMANGLE_H

#include <cstddef>
#include <optional>
#include <string>
#include <string_view>

Expand Down Expand Up @@ -54,6 +55,9 @@ enum MSDemangleFlags {
char *microsoftDemangle(std::string_view mangled_name, size_t *n_read,
int *status, MSDemangleFlags Flags = MSDF_None);

std::optional<size_t>
getArm64ECInsertionPointInMangledName(std::string_view MangledName);

// Demangles a Rust v0 mangled symbol.
char *rustDemangle(std::string_view MangledName);

Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/Demangle/MicrosoftDemangle.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifndef LLVM_DEMANGLE_MICROSOFTDEMANGLE_H
#define LLVM_DEMANGLE_MICROSOFTDEMANGLE_H

#include "llvm/Demangle/Demangle.h"
#include "llvm/Demangle/MicrosoftDemangleNodes.h"

#include <cassert>
Expand Down Expand Up @@ -141,6 +142,9 @@ enum class FunctionIdentifierCodeGroup { Basic, Under, DoubleUnder };
// It has a set of functions to parse mangled symbols into Type instances.
// It also has a set of functions to convert Type instances to strings.
class Demangler {
friend std::optional<size_t>
llvm::getArm64ECInsertionPointInMangledName(std::string_view MangledName);

public:
Demangler() = default;
virtual ~Demangler() = default;
Expand Down
2 changes: 1 addition & 1 deletion llvm/include/llvm/IR/Mangler.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ std::optional<std::string> getArm64ECDemangledFunctionName(StringRef Name);
/// Check if an ARM64EC function name is mangled.
bool inline isArm64ECMangledFunctionName(StringRef Name) {
return Name[0] == '#' ||
(Name[0] == '?' && Name.find("$$h") != StringRef::npos);
(Name[0] == '?' && Name.find("@$$h") != StringRef::npos);
}

} // End llvm namespace
Expand Down
19 changes: 19 additions & 0 deletions llvm/lib/Demangle/MicrosoftDemangle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <array>
#include <cctype>
#include <cstdio>
#include <optional>
#include <string_view>
#include <tuple>

Expand Down Expand Up @@ -2428,6 +2429,24 @@ void Demangler::dumpBackReferences() {
std::printf("\n");
}

std::optional<size_t>
llvm::getArm64ECInsertionPointInMangledName(std::string_view MangledName) {
std::string_view ProcessedName{MangledName};

// We only support this for MSVC-style C++ symbols.
if (!consumeFront(ProcessedName, '?'))
return std::nullopt;

// The insertion point is just after the name of the symbol, so parse that to
// remove it from the processed name.
Demangler D;
D.demangleFullyQualifiedSymbolName(ProcessedName);
if (D.Error)
return std::nullopt;

return MangledName.length() - ProcessedName.length();
}

char *llvm::microsoftDemangle(std::string_view MangledName, size_t *NMangled,
int *Status, MSDemangleFlags Flags) {
Demangler D;
Expand Down
19 changes: 8 additions & 11 deletions llvm/lib/IR/Mangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Demangle/Demangle.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
Expand Down Expand Up @@ -299,21 +300,17 @@ std::optional<std::string> llvm::getArm64ECMangledFunctionName(StringRef Name) {
return std::optional<std::string>(("#" + Name).str());
}

// Insert the ARM64EC "$$h" tag after the mangled function name.
// If the name contains $$h, then it is already mangled.
if (Name.contains("$$h"))
return std::nullopt;
size_t InsertIdx = Name.find("@@");
size_t ThreeAtSignsIdx = Name.find("@@@");
if (InsertIdx != std::string::npos && InsertIdx != ThreeAtSignsIdx) {
InsertIdx += 2;
} else {
InsertIdx = Name.find("@");
if (InsertIdx != std::string::npos)
InsertIdx++;
}

// Ask the demangler where we should insert "$$h".
auto InsertIdx = getArm64ECInsertionPointInMangledName(Name);
if (!InsertIdx)
return std::nullopt;

return std::optional<std::string>(
(Name.substr(0, InsertIdx) + "$$h" + Name.substr(InsertIdx)).str());
(Name.substr(0, *InsertIdx) + "$$h" + Name.substr(*InsertIdx)).str());
}

std::optional<std::string>
Expand Down
77 changes: 77 additions & 0 deletions llvm/unittests/IR/ManglerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,81 @@ TEST(ManglerTest, GOFF) {
"L#foo");
}

TEST(ManglerTest, Arm64EC) {
constexpr std::string_view Arm64ECNames[] = {
// Basic C name.
"#Foo",

// Basic C++ name.
"?foo@@$$hYAHXZ",

// Regression test: https://github.com/llvm/llvm-project/issues/115231
"?GetValue@?$Wrapper@UA@@@@$$hQEBAHXZ",

// Symbols from:
// ```
// namespace A::B::C::D {
// struct Base {
// virtual int f() { return 0; }
// };
// }
// struct Derived : public A::B::C::D::Base {
// virtual int f() override { return 1; }
// };
// A::B::C::D::Base* MakeObj() { return new Derived(); }
// ```
// void * __cdecl operator new(unsigned __int64)
"??2@$$hYAPEAX_K@Z",
// public: virtual int __cdecl A::B::C::D::Base::f(void)
"?f@Base@D@C@B@A@@$$hUEAAHXZ",
// public: __cdecl A::B::C::D::Base::Base(void)
"??0Base@D@C@B@A@@$$hQEAA@XZ",
// public: virtual int __cdecl Derived::f(void)
"?f@Derived@@$$hUEAAHXZ",
// public: __cdecl Derived::Derived(void)
"??0Derived@@$$hQEAA@XZ",
// struct A::B::C::D::Base * __cdecl MakeObj(void)
"?MakeObj@@$$hYAPEAUBase@D@C@B@A@@XZ",

// Symbols from:
// ```
// template <typename T> struct WW { struct Z{}; };
// template <typename X> struct Wrapper {
// int GetValue(typename WW<X>::Z) const;
// };
// struct A { };
// template <typename X> int Wrapper<X>::GetValue(typename WW<X>::Z) const
// { return 3; }
// template class Wrapper<A>;
// ```
// public: int __cdecl Wrapper<struct A>::GetValue(struct WW<struct
// A>::Z)const
"?GetValue@?$Wrapper@UA@@@@$$hQEBAHUZ@?$WW@UA@@@@@Z",
};

for (const auto &Arm64ECName : Arm64ECNames) {
// Check that this is a mangled name.
EXPECT_TRUE(isArm64ECMangledFunctionName(Arm64ECName))
<< "Test case: " << Arm64ECName;
// Refuse to mangle it again.
EXPECT_FALSE(getArm64ECMangledFunctionName(Arm64ECName).has_value())
<< "Test case: " << Arm64ECName;

// Demangle.
auto Arm64Name = getArm64ECDemangledFunctionName(Arm64ECName);
EXPECT_TRUE(Arm64Name.has_value()) << "Test case: " << Arm64ECName;
// Check that it is not mangled.
EXPECT_FALSE(isArm64ECMangledFunctionName(Arm64Name.value()))
<< "Test case: " << Arm64ECName;
// Refuse to demangle it again.
EXPECT_FALSE(getArm64ECDemangledFunctionName(Arm64Name.value()).has_value())
<< "Test case: " << Arm64ECName;

// Round-trip.
auto RoundTripArm64ECName =
getArm64ECMangledFunctionName(Arm64Name.value());
EXPECT_EQ(RoundTripArm64ECName, Arm64ECName);
}
}

} // end anonymous namespace
Loading