Skip to content

Commit 2e426fe

Browse files
authored
Add unit tests for size returning new funcs in the MemProf use pass. (#105473)
We use a unit test to verify correctness since: a) we don't have a text format profile b) size returning new isn't supported natively c) a raw profile will need to be manipulated artificially The changes this test covers were made in #102258.
1 parent 625e929 commit 2e426fe

File tree

5 files changed

+199
-29
lines changed

5 files changed

+199
-29
lines changed

llvm/include/llvm/ProfileData/InstrProfReader.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -670,10 +670,11 @@ class IndexedMemProfReader {
670670

671671
public:
672672
IndexedMemProfReader() = default;
673+
virtual ~IndexedMemProfReader() = default;
673674

674675
Error deserialize(const unsigned char *Start, uint64_t MemProfOffset);
675676

676-
Expected<memprof::MemProfRecord>
677+
virtual Expected<memprof::MemProfRecord>
677678
getMemProfRecord(const uint64_t FuncNameHash) const;
678679
};
679680

@@ -768,11 +769,14 @@ class IndexedInstrProfReader : public InstrProfReader {
768769
uint64_t *MismatchedFuncSum = nullptr);
769770

770771
/// Return the memprof record for the function identified by
771-
/// llvm::md5(Name).
772+
/// llvm::md5(Name). Marked virtual so that unit tests can mock this function.
772773
Expected<memprof::MemProfRecord> getMemProfRecord(uint64_t FuncNameHash) {
773774
return MemProfReader.getMemProfRecord(FuncNameHash);
774775
}
775776

777+
/// Return the underlying memprof reader.
778+
IndexedMemProfReader &getIndexedMemProfReader() { return MemProfReader; }
779+
776780
/// Fill Counts with the profile data for the given function name.
777781
Error getFunctionCounts(StringRef FuncName, uint64_t FuncHash,
778782
std::vector<uint64_t> &Counts);

llvm/include/llvm/Transforms/Instrumentation/MemProfiler.h

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313
#define LLVM_TRANSFORMS_INSTRUMENTATION_MEMPROFILER_H
1414

1515
#include "llvm/ADT/IntrusiveRefCntPtr.h"
16+
#include "llvm/IR/ModuleSummaryIndex.h"
1617
#include "llvm/IR/PassManager.h"
18+
#include "llvm/ProfileData/InstrProfReader.h"
19+
#include "llvm/Support/VirtualFileSystem.h"
1720

1821
namespace llvm {
1922
class Function;
2023
class Module;
21-
22-
namespace vfs {
23-
class FileSystem;
24-
} // namespace vfs
24+
class TargetLibraryInfo;
2525

2626
/// Public interface to the memory profiler pass for instrumenting code to
2727
/// profile memory accesses.
@@ -52,6 +52,17 @@ class MemProfUsePass : public PassInfoMixin<MemProfUsePass> {
5252
IntrusiveRefCntPtr<vfs::FileSystem> FS = nullptr);
5353
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
5454

55+
struct AllocMatchInfo {
56+
uint64_t TotalSize = 0;
57+
AllocationType AllocType = AllocationType::None;
58+
bool Matched = false;
59+
};
60+
61+
void
62+
readMemprof(Function &F, const IndexedMemProfReader &MemProfReader,
63+
const TargetLibraryInfo &TLI,
64+
std::map<uint64_t, AllocMatchInfo> &FullStackIdToAllocMatchInfo);
65+
5566
private:
5667
std::string MemoryProfileFileName;
5768
IntrusiveRefCntPtr<vfs::FileSystem> FS;

llvm/lib/Transforms/Instrumentation/MemProfiler.cpp

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
#include "llvm/Support/CommandLine.h"
4040
#include "llvm/Support/Debug.h"
4141
#include "llvm/Support/HashBuilder.h"
42-
#include "llvm/Support/VirtualFileSystem.h"
4342
#include "llvm/TargetParser/Triple.h"
4443
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
4544
#include "llvm/Transforms/Utils/ModuleUtils.h"
@@ -55,6 +54,7 @@ namespace llvm {
5554
extern cl::opt<bool> PGOWarnMissing;
5655
extern cl::opt<bool> NoPGOWarnMismatch;
5756
extern cl::opt<bool> NoPGOWarnMismatchComdatWeak;
57+
using AllocMatchInfo = ::llvm::MemProfUsePass::AllocMatchInfo;
5858
} // namespace llvm
5959

6060
constexpr int LLVM_MEM_PROFILER_VERSION = 1;
@@ -148,10 +148,11 @@ static cl::opt<int> ClDebugMax("memprof-debug-max", cl::desc("Debug max inst"),
148148

149149
// By default disable matching of allocation profiles onto operator new that
150150
// already explicitly pass a hot/cold hint, since we don't currently
151-
// override these hints anyway.
152-
static cl::opt<bool> ClMemProfMatchHotColdNew(
151+
// override these hints anyway. Not static so that it can be set in the unit
152+
// test too.
153+
cl::opt<bool> ClMemProfMatchHotColdNew(
153154
"memprof-match-hot-cold-new",
154-
cl::desc(
155+
cl::desc(
155156
"Match allocation profiles onto existing hot/cold operator new calls"),
156157
cl::Hidden, cl::init(false));
157158

@@ -789,17 +790,11 @@ static bool isAllocationWithHotColdVariant(Function *Callee,
789790
}
790791
}
791792

792-
struct AllocMatchInfo {
793-
uint64_t TotalSize = 0;
794-
AllocationType AllocType = AllocationType::None;
795-
bool Matched = false;
796-
};
797-
798-
static void
799-
readMemprof(Module &M, Function &F, IndexedInstrProfReader *MemProfReader,
800-
const TargetLibraryInfo &TLI,
801-
std::map<uint64_t, AllocMatchInfo> &FullStackIdToAllocMatchInfo) {
802-
auto &Ctx = M.getContext();
793+
void MemProfUsePass::readMemprof(
794+
Function &F, const IndexedMemProfReader &MemProfReader,
795+
const TargetLibraryInfo &TLI,
796+
std::map<uint64_t, AllocMatchInfo> &FullStackIdToAllocMatchInfo) {
797+
auto &Ctx = F.getContext();
803798
// Previously we used getIRPGOFuncName() here. If F is local linkage,
804799
// getIRPGOFuncName() returns FuncName with prefix 'FileName;'. But
805800
// llvm-profdata uses FuncName in dwarf to create GUID which doesn't
@@ -810,7 +805,7 @@ readMemprof(Module &M, Function &F, IndexedInstrProfReader *MemProfReader,
810805
auto FuncName = F.getName();
811806
auto FuncGUID = Function::getGUID(FuncName);
812807
std::optional<memprof::MemProfRecord> MemProfRec;
813-
auto Err = MemProfReader->getMemProfRecord(FuncGUID).moveInto(MemProfRec);
808+
auto Err = MemProfReader.getMemProfRecord(FuncGUID).moveInto(MemProfRec);
814809
if (Err) {
815810
handleAllErrors(std::move(Err), [&](const InstrProfError &IPE) {
816811
auto Err = IPE.get();
@@ -838,8 +833,8 @@ readMemprof(Module &M, Function &F, IndexedInstrProfReader *MemProfReader,
838833
Twine(" Hash = ") + std::to_string(FuncGUID))
839834
.str();
840835

841-
Ctx.diagnose(
842-
DiagnosticInfoPGOProfile(M.getName().data(), Msg, DS_Warning));
836+
Ctx.diagnose(DiagnosticInfoPGOProfile(F.getParent()->getName().data(),
837+
Msg, DS_Warning));
843838
});
844839
return;
845840
}
@@ -1036,15 +1031,15 @@ PreservedAnalyses MemProfUsePass::run(Module &M, ModuleAnalysisManager &AM) {
10361031
return PreservedAnalyses::all();
10371032
}
10381033

1039-
std::unique_ptr<IndexedInstrProfReader> MemProfReader =
1034+
std::unique_ptr<IndexedInstrProfReader> IndexedReader =
10401035
std::move(ReaderOrErr.get());
1041-
if (!MemProfReader) {
1036+
if (!IndexedReader) {
10421037
Ctx.diagnose(DiagnosticInfoPGOProfile(
1043-
MemoryProfileFileName.data(), StringRef("Cannot get MemProfReader")));
1038+
MemoryProfileFileName.data(), StringRef("Cannot get IndexedReader")));
10441039
return PreservedAnalyses::all();
10451040
}
10461041

1047-
if (!MemProfReader->hasMemoryProfile()) {
1042+
if (!IndexedReader->hasMemoryProfile()) {
10481043
Ctx.diagnose(DiagnosticInfoPGOProfile(MemoryProfileFileName.data(),
10491044
"Not a memory profile"));
10501045
return PreservedAnalyses::all();
@@ -1057,12 +1052,13 @@ PreservedAnalyses MemProfUsePass::run(Module &M, ModuleAnalysisManager &AM) {
10571052
// it to an allocation in the IR.
10581053
std::map<uint64_t, AllocMatchInfo> FullStackIdToAllocMatchInfo;
10591054

1055+
const auto &MemProfReader = IndexedReader->getIndexedMemProfReader();
10601056
for (auto &F : M) {
10611057
if (F.isDeclaration())
10621058
continue;
10631059

10641060
const TargetLibraryInfo &TLI = FAM.getResult<TargetLibraryAnalysis>(F);
1065-
readMemprof(M, F, MemProfReader.get(), TLI, FullStackIdToAllocMatchInfo);
1061+
readMemprof(F, MemProfReader, TLI, FullStackIdToAllocMatchInfo);
10661062
}
10671063

10681064
if (ClPrintMemProfMatchInfo) {

llvm/unittests/Transforms/Instrumentation/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ set(LLVM_LINK_COMPONENTS
99

1010
add_llvm_unittest(InstrumentationTests
1111
PGOInstrumentationTest.cpp
12+
MemProfilerTest.cpp
1213
)
1314

1415
target_link_libraries(InstrumentationTests PRIVATE LLVMTestingSupport)
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
//===- MemProfilerTest.cpp - MemProfiler unit tests ------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "llvm/Transforms/Instrumentation/MemProfiler.h"
10+
#include "llvm/ADT/StringRef.h"
11+
#include "llvm/Analysis/TargetLibraryInfo.h"
12+
#include "llvm/AsmParser/Parser.h"
13+
#include "llvm/IR/Attributes.h"
14+
#include "llvm/IR/Metadata.h"
15+
#include "llvm/IR/Module.h"
16+
#include "llvm/IR/PassManager.h"
17+
#include "llvm/Passes/PassBuilder.h"
18+
#include "llvm/ProfileData/InstrProfReader.h"
19+
#include "llvm/ProfileData/MemProf.h"
20+
#include "llvm/ProfileData/MemProfData.inc"
21+
#include "llvm/Support/Error.h"
22+
#include "llvm/Support/SourceMgr.h"
23+
24+
#include "gmock/gmock.h"
25+
#include "gtest/gtest.h"
26+
27+
extern llvm::cl::opt<bool> ClMemProfMatchHotColdNew;
28+
29+
namespace llvm {
30+
namespace memprof {
31+
namespace {
32+
33+
using ::testing::Return;
34+
using ::testing::SizeIs;
35+
36+
struct MemProfilerTest : public ::testing::Test {
37+
LLVMContext Context;
38+
std::unique_ptr<Module> M;
39+
40+
MemProfilerTest() { ClMemProfMatchHotColdNew = true; }
41+
42+
void parseAssembly(const StringRef IR) {
43+
SMDiagnostic Error;
44+
M = parseAssemblyString(IR, Error, Context);
45+
std::string ErrMsg;
46+
raw_string_ostream OS(ErrMsg);
47+
Error.print("", OS);
48+
49+
// A failure here means that the test itself is buggy.
50+
if (!M)
51+
report_fatal_error(OS.str().c_str());
52+
}
53+
};
54+
55+
// A mock memprof reader we can inject into the function we are testing.
56+
class MockMemProfReader : public IndexedMemProfReader {
57+
public:
58+
MOCK_METHOD(Expected<MemProfRecord>, getMemProfRecord,
59+
(const uint64_t FuncNameHash), (const, override));
60+
61+
// A helper function to create mock records from frames.
62+
static MemProfRecord makeRecord(ArrayRef<ArrayRef<Frame>> AllocFrames) {
63+
MemProfRecord Record;
64+
MemInfoBlock Info;
65+
// Mimic values which will be below the cold threshold.
66+
Info.AllocCount = 1, Info.TotalSize = 550;
67+
Info.TotalLifetime = 1000 * 1000, Info.TotalLifetimeAccessDensity = 1;
68+
for (const auto &Callstack : AllocFrames) {
69+
AllocationInfo AI;
70+
AI.Info = PortableMemInfoBlock(Info, getHotColdSchema());
71+
AI.CallStack = std::vector(Callstack.begin(), Callstack.end());
72+
Record.AllocSites.push_back(AI);
73+
}
74+
return Record;
75+
}
76+
};
77+
78+
TEST_F(MemProfilerTest, AnnotatesCall) {
79+
parseAssembly(R"IR(
80+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
81+
target triple = "x86_64-unknown-linux-gnu"
82+
83+
define void @_Z3foov() !dbg !10 {
84+
entry:
85+
%c1 = call {ptr, i64} @__size_returning_new(i64 32), !dbg !13
86+
%c2 = call {ptr, i64} @__size_returning_new_aligned(i64 32, i64 8), !dbg !14
87+
%c3 = call {ptr, i64} @__size_returning_new_hot_cold(i64 32, i8 254), !dbg !15
88+
%c4 = call {ptr, i64} @__size_returning_new_aligned_hot_cold(i64 32, i64 8, i8 254), !dbg !16
89+
ret void
90+
}
91+
92+
declare {ptr, i64} @__size_returning_new(i64)
93+
declare {ptr, i64} @__size_returning_new_aligned(i64, i64)
94+
declare {ptr, i64} @__size_returning_new_hot_cold(i64, i8)
95+
declare {ptr, i64} @__size_returning_new_aligned_hot_cold(i64, i64, i8)
96+
97+
!llvm.dbg.cu = !{!0}
98+
!llvm.module.flags = !{!2, !3}
99+
100+
!0 = distinct !DICompileUnit(language: DW_LANG_C_plus_plus_14, file: !1)
101+
!1 = !DIFile(filename: "mock_file.cc", directory: "mock_dir")
102+
!2 = !{i32 7, !"Dwarf Version", i32 5}
103+
!3 = !{i32 2, !"Debug Info Version", i32 3}
104+
!10 = distinct !DISubprogram(name: "foo", linkageName: "_Z3foov", scope: !1, file: !1, line: 4, type: !11, scopeLine: 4, unit: !0, retainedNodes: !12)
105+
!11 = !DISubroutineType(types: !12)
106+
!12 = !{}
107+
!13 = !DILocation(line: 5, column: 10, scope: !10)
108+
!14 = !DILocation(line: 6, column: 10, scope: !10)
109+
!15 = !DILocation(line: 7, column: 10, scope: !10)
110+
!16 = !DILocation(line: 8, column: 10, scope: !10)
111+
)IR");
112+
113+
auto *F = M->getFunction("_Z3foov");
114+
ASSERT_NE(F, nullptr);
115+
116+
TargetLibraryInfoWrapperPass WrapperPass;
117+
auto &TLI = WrapperPass.getTLI(*F);
118+
119+
auto Guid = Function::getGUID("_Z3foov");
120+
// All the allocation sites are in foo().
121+
MemProfRecord MockRecord =
122+
MockMemProfReader::makeRecord({{Frame(Guid, 1, 10, false)},
123+
{Frame(Guid, 2, 10, false)},
124+
{Frame(Guid, 3, 10, false)},
125+
{Frame(Guid, 4, 10, false)}});
126+
// Set up mocks for the reader.
127+
MockMemProfReader Reader;
128+
EXPECT_CALL(Reader, getMemProfRecord(Guid)).WillOnce(Return(MockRecord));
129+
130+
MemProfUsePass Pass("/unused/profile/path");
131+
std::map<uint64_t, MemProfUsePass::AllocMatchInfo> Unused;
132+
Pass.readMemprof(*F, Reader, TLI, Unused);
133+
134+
// Since we only have a single type of behaviour for each allocation site, we
135+
// only get function attributes.
136+
std::vector<llvm::Attribute> CallsiteAttrs;
137+
for (const auto &BB : *F) {
138+
for (const auto &I : BB) {
139+
if (auto *CI = dyn_cast<CallInst>(&I)) {
140+
if (!CI->getCalledFunction()->getName().starts_with(
141+
"__size_returning_new"))
142+
continue;
143+
Attribute Attr = CI->getFnAttr("memprof");
144+
// The attribute will be invalid if it didn't find one named memprof.
145+
ASSERT_TRUE(Attr.isValid());
146+
CallsiteAttrs.push_back(Attr);
147+
}
148+
}
149+
}
150+
151+
// We match all the variants including ones with the hint since we set
152+
// ClMemProfMatchHotColdNew to true.
153+
EXPECT_THAT(CallsiteAttrs, SizeIs(4));
154+
}
155+
156+
} // namespace
157+
} // namespace memprof
158+
} // namespace llvm

0 commit comments

Comments
 (0)