Skip to content

[libc][bug] Fix out of bound write in memcpy w/ software prefetching #90591

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions libc/src/string/memory_utils/x86_64/inline_memcpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,13 @@ inline_memcpy_x86_sse2_ge64_sw_prefetching(Ptr __restrict dst,
offset += K_THREE_CACHELINES;
}
}
return builtin::Memcpy<32>::loop_and_tail_offset(dst, src, count, offset);
// We don't use 'loop_and_tail_offset' because it assumes at least one
// iteration of the loop.
while (offset + 32 <= count) {
builtin::Memcpy<32>::block_offset(dst, src, offset);
offset += 32;
}
return builtin::Memcpy<32>::tail(dst, src, count);
}

[[maybe_unused]] LIBC_INLINE void
Expand Down Expand Up @@ -139,7 +145,13 @@ inline_memcpy_x86_avx_ge64_sw_prefetching(Ptr __restrict dst,
builtin::Memcpy<K_THREE_CACHELINES>::block_offset(dst, src, offset);
offset += K_THREE_CACHELINES;
}
return builtin::Memcpy<64>::loop_and_tail_offset(dst, src, count, offset);
// We don't use 'loop_and_tail_offset' because it assumes at least one
// iteration of the loop.
while (offset + 64 <= count) {
builtin::Memcpy<64>::block_offset(dst, src, offset);
offset += 64;
}
return builtin::Memcpy<64>::tail(dst, src, count);
}

[[maybe_unused]] LIBC_INLINE void
Expand Down
41 changes: 41 additions & 0 deletions libc/test/src/string/memcpy_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@
//===----------------------------------------------------------------------===//

#include "memory_utils/memory_check_utils.h"
#include "src/__support/macros/properties/os.h" // LIBC_TARGET_OS_IS_LINUX
#include "src/string/memcpy.h"
#include "test/UnitTest/Test.h"

#if !defined(LIBC_FULL_BUILD) && defined(LIBC_TARGET_OS_IS_LINUX)
#include "memory_utils/protected_pages.h"
#endif // !defined(LIBC_FULL_BUILD) && defined(LIBC_TARGET_OS_IS_LINUX)

namespace LIBC_NAMESPACE {

// Adapt CheckMemcpy signature to memcpy.
Expand All @@ -30,4 +35,40 @@ TEST(LlvmLibcMemcpyTest, SizeSweep) {
}
}

#if !defined(LIBC_FULL_BUILD) && defined(LIBC_TARGET_OS_IS_LINUX)

TEST(LlvmLibcMemcpyTest, CheckAccess) {
static constexpr size_t MAX_SIZE = 1024;
LIBC_ASSERT(MAX_SIZE < GetPageSize());
ProtectedPages pages;
const Page write_buffer = pages.GetPageA().WithAccess(PROT_WRITE);
const Page read_buffer = [&]() {
// We fetch page B in write mode.
auto page = pages.GetPageB().WithAccess(PROT_WRITE);
// And fill it with random numbers.
for (size_t i = 0; i < page.page_size; ++i)
page.page_ptr[i] = rand();
// Then return it in read mode.
return page.WithAccess(PROT_READ);
}();
for (size_t size = 0; size < MAX_SIZE; ++size) {
// We cross-check the function with two sources and two destinations.
// - The first of them (bottom) is always page aligned and faults when
// accessing bytes before it.
// - The second one (top) is not necessarily aligned and faults when
// accessing bytes after it.
const uint8_t *sources[2] = {read_buffer.bottom(size),
read_buffer.top(size)};
uint8_t *destinations[2] = {write_buffer.bottom(size),
write_buffer.top(size)};
for (const uint8_t *src : sources) {
for (uint8_t *dst : destinations) {
LIBC_NAMESPACE::memcpy(dst, src, size);
}
}
}
}

#endif // !defined(LIBC_FULL_BUILD) && defined(LIBC_TARGET_OS_IS_LINUX)

} // namespace LIBC_NAMESPACE
99 changes: 99 additions & 0 deletions libc/test/src/string/memory_utils/protected_pages.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
//===-- protected_pages.h -------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// This file provides protected pages that fault when accessing prior or past
// it. This is useful to check memory functions that must not access outside of
// the provided size limited buffer.
//===----------------------------------------------------------------------===//

#ifndef LIBC_TEST_SRC_STRING_MEMORY_UTILS_PROTECTED_PAGES_H
#define LIBC_TEST_SRC_STRING_MEMORY_UTILS_PROTECTED_PAGES_H

#include "src/__support/macros/properties/os.h" // LIBC_TARGET_OS_IS_LINUX
#if defined(LIBC_FULL_BUILD) || !defined(LIBC_TARGET_OS_IS_LINUX)
#error "Protected pages requires mmap and cannot be used in full build mode."
#endif // defined(LIBC_FULL_BUILD) || !defined(LIBC_TARGET_OS_IS_LINUX)

#include "src/__support/macros/attributes.h" // LIBC_INLINE
#include <stddef.h> // size_t
#include <stdint.h> // uint8_t
#include <sys/mman.h> // mmap, munmap
#include <unistd.h> // sysconf, _SC_PAGESIZE

// Returns mmap page size.
LIBC_INLINE size_t GetPageSize() {
static const size_t PAGE_SIZE = sysconf(_SC_PAGESIZE);
return PAGE_SIZE;
}

// Represents a page of memory whose access can be configured throught the
// 'WithAccess' function. Accessing data above or below this page will trap as
// it is sandwiched between two pages with no read / write access.
struct Page {
// Returns an aligned pointer that can be accessed up to page_size. Accessing
// data at ptr[-1] will fault.
LIBC_INLINE uint8_t *bottom(size_t size) const {
if (size >= page_size)
__builtin_trap();
return page_ptr;
}
// Returns a pointer to a buffer that can be accessed up to size. Accessing
// data at ptr[size] will trap.
LIBC_INLINE uint8_t *top(size_t size) const {
return page_ptr + page_size - size;
}

// protection is one of PROT_READ / PROT_WRITE.
LIBC_INLINE Page &WithAccess(int protection) {
if (mprotect(page_ptr, page_size, protection) != 0)
__builtin_trap();
return *this;
}

const size_t page_size;
uint8_t *const page_ptr;
};

// Allocates 5 consecutive pages that will trap if accessed.
// | page layout | access | page name |
// |-------------|--------|:---------:|
// | 0 | trap | |
// | 1 | custom | A |
// | 2 | trap | |
// | 3 | custom | B |
// | 4 | trap | |
//
// The pages A and B can be retrieved as with 'GetPageA' / 'GetPageB' and their
// accesses can be customized through the 'WithAccess' function.
struct ProtectedPages {
static constexpr size_t PAGES = 5;

ProtectedPages()
: page_size(GetPageSize()),
ptr(mmap(/*address*/ nullptr, /*length*/ PAGES * page_size,
/*protection*/ PROT_NONE,
/*flags*/ MAP_PRIVATE | MAP_ANONYMOUS, /*fd*/ -1,
/*offset*/ 0)) {
if (reinterpret_cast<intptr_t>(ptr) == -1)
__builtin_trap();
}
~ProtectedPages() { munmap(ptr, PAGES * page_size); }

LIBC_INLINE Page GetPageA() const { return Page{page_size, page<1>()}; }
LIBC_INLINE Page GetPageB() const { return Page{page_size, page<3>()}; }

private:
template <size_t index> LIBC_INLINE uint8_t *page() const {
static_assert(index < PAGES);
return static_cast<uint8_t *>(ptr) + (index * page_size);
}

const size_t page_size;
void *const ptr = nullptr;
};

#endif // LIBC_TEST_SRC_STRING_MEMORY_UTILS_PROTECTED_PAGES_H
31 changes: 31 additions & 0 deletions libc/test/src/string/memset_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@
//===----------------------------------------------------------------------===//

#include "memory_utils/memory_check_utils.h"
#include "src/__support/macros/properties/os.h" // LIBC_TARGET_OS_IS_LINUX
#include "src/string/memset.h"
#include "test/UnitTest/Test.h"

#if !defined(LIBC_FULL_BUILD) && defined(LIBC_TARGET_OS_IS_LINUX)
#include "memory_utils/protected_pages.h"
#endif // !defined(LIBC_FULL_BUILD) && defined(LIBC_TARGET_OS_IS_LINUX)

namespace LIBC_NAMESPACE {

// Adapt CheckMemset signature to memset.
Expand All @@ -27,4 +32,30 @@ TEST(LlvmLibcMemsetTest, SizeSweep) {
}
}

#if !defined(LIBC_FULL_BUILD) && defined(LIBC_TARGET_OS_IS_LINUX)

TEST(LlvmLibcMemsetTest, CheckAccess) {
static constexpr size_t MAX_SIZE = 1024;
LIBC_ASSERT(MAX_SIZE < GetPageSize());
ProtectedPages pages;
const Page write_buffer = pages.GetPageA().WithAccess(PROT_WRITE);
const cpp::array<int, 2> fill_chars = {0, 0x7F};
for (int fill_char : fill_chars) {
for (size_t size = 0; size < MAX_SIZE; ++size) {
// We cross-check the function with two destinations.
// - The first of them (bottom) is always page aligned and faults when
// accessing bytes before it.
// - The second one (top) is not necessarily aligned and faults when
// accessing bytes after it.
uint8_t *destinations[2] = {write_buffer.bottom(size),
write_buffer.top(size)};
for (uint8_t *dst : destinations) {
LIBC_NAMESPACE::memset(dst, fill_char, size);
}
}
}
}

#endif // !defined(LIBC_FULL_BUILD) && defined(LIBC_TARGET_OS_IS_LINUX)

} // namespace LIBC_NAMESPACE
Loading