Skip to content

[libc][stdlib] Fix UB in freelist #95330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 13, 2024
Merged

Conversation

PiJoules
Copy link
Contributor

Some of the freelist code uses type punning which is UB in C++, namely because we read from a union member that is not the active union member. For cases we used this, we should either use memcpy for storing or reinterpret_castcpp::byte* for comparing.

@llvmbot
Copy link
Member

llvmbot commented Jun 12, 2024

@llvm/pr-subscribers-libc

Author: None (PiJoules)

Changes

Some of the freelist code uses type punning which is UB in C++, namely because we read from a union member that is not the active union member. For cases we used this, we should either use memcpy for storing or reinterpret_cast<cpp::byte*> for comparing.


Full diff: https://github.com/llvm/llvm-project/pull/95330.diff

2 Files Affected:

  • (modified) libc/src/stdlib/CMakeLists.txt (+1)
  • (modified) libc/src/stdlib/freelist.h (+21-30)
diff --git a/libc/src/stdlib/CMakeLists.txt b/libc/src/stdlib/CMakeLists.txt
index d4aa50a43d186..a67aeb32be920 100644
--- a/libc/src/stdlib/CMakeLists.txt
+++ b/libc/src/stdlib/CMakeLists.txt
@@ -401,6 +401,7 @@ else()
       libc.src.__support.CPP.cstddef
       libc.src.__support.CPP.array
       libc.src.__support.CPP.span
+      libc.src.string.memcpy
   )
   add_entrypoint_external(
     malloc
diff --git a/libc/src/stdlib/freelist.h b/libc/src/stdlib/freelist.h
index c01ed6eddb7d4..b7deaba475a21 100644
--- a/libc/src/stdlib/freelist.h
+++ b/libc/src/stdlib/freelist.h
@@ -13,6 +13,7 @@
 #include "src/__support/CPP/cstddef.h"
 #include "src/__support/CPP/span.h"
 #include "src/__support/fixedvector.h"
+#include "src/string/memcpy.h"
 
 namespace LIBC_NAMESPACE {
 
@@ -92,19 +93,14 @@ bool FreeList<NUM_BUCKETS>::add_chunk(span<cpp::byte> chunk) {
   if (chunk.size() < sizeof(FreeListNode))
     return false;
 
-  union {
-    FreeListNode *node;
-    cpp::byte *bytes;
-  } aliased;
-
-  aliased.bytes = chunk.data();
-
+  // Add it to the correct list.
   size_t chunk_ptr = find_chunk_ptr_for_size(chunk.size(), false);
 
-  // Add it to the correct list.
-  aliased.node->size = chunk.size();
-  aliased.node->next = chunks_[chunk_ptr];
-  chunks_[chunk_ptr] = aliased.node;
+  FreeListNode node;
+  node.next = chunks_[chunk_ptr];
+  node.size = chunk.size();
+  LIBC_NAMESPACE::memcpy(chunk.data(), &node, sizeof(node));
+  chunks_[chunk_ptr] = reinterpret_cast<FreeListNode *>(chunk.data());
 
   return true;
 }
@@ -123,17 +119,13 @@ span<cpp::byte> FreeList<NUM_BUCKETS>::find_chunk(size_t size) const {
 
   // Now iterate up the buckets, walking each list to find a good candidate
   for (size_t i = chunk_ptr; i < chunks_.size(); i++) {
-    union {
-      FreeListNode *node;
-      cpp::byte *data;
-    } aliased;
-    aliased.node = chunks_[static_cast<unsigned short>(i)];
+    FreeListNode *node = chunks_[static_cast<unsigned short>(i)];
 
-    while (aliased.node != nullptr) {
-      if (aliased.node->size >= size)
-        return span<cpp::byte>(aliased.data, aliased.node->size);
+    while (node != nullptr) {
+      if (node->size >= size)
+        return span<cpp::byte>(reinterpret_cast<cpp::byte *>(node), node->size);
 
-      aliased.node = aliased.node->next;
+      node = node->next;
     }
   }
 
@@ -150,30 +142,29 @@ bool FreeList<NUM_BUCKETS>::remove_chunk(span<cpp::byte> chunk) {
   union {
     FreeListNode *node;
     cpp::byte *data;
-  } aliased, aliased_next;
+  } aliased_next;
 
   // Check head first.
   if (chunks_[chunk_ptr] == nullptr)
     return false;
 
-  aliased.node = chunks_[chunk_ptr];
-  if (aliased.data == chunk.data()) {
-    chunks_[chunk_ptr] = aliased.node->next;
+  FreeListNode *node = chunks_[chunk_ptr];
+  if (reinterpret_cast<cpp::byte *>(node) == chunk.data()) {
+    chunks_[chunk_ptr] = node->next;
     return true;
   }
 
   // No? Walk the nodes.
-  aliased.node = chunks_[chunk_ptr];
+  node = chunks_[chunk_ptr];
 
-  while (aliased.node->next != nullptr) {
-    aliased_next.node = aliased.node->next;
-    if (aliased_next.data == chunk.data()) {
+  while (node->next != nullptr) {
+    if (reinterpret_cast<cpp::byte *>(node->next) == chunk.data()) {
       // Found it, remove this node out of the chain
-      aliased.node->next = aliased_next.node->next;
+      node->next = node->next->next;
       return true;
     }
 
-    aliased.node = aliased.node->next;
+    node = node->next;
   }
 
   return false;

FreeListNode node;
node.next = chunks_[chunk_ptr];
node.size = chunk.size();
LIBC_NAMESPACE::memcpy(chunk.data(), &node, sizeof(node));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@PiJoules PiJoules Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, how might I use it in this case? The idea here is to store the FreListNode onto the chunk buffer, but bit_cast returns an object rather than reference so I wouldn't be able to do something like:

FreeListNode &node = cpp::bit_cast<FreeListNode>(chunk.data());
node.next = chunks_[chunk_ptr];
node.size = chunk.size();

Syntax-wise, this works

FreeListNode *node = cpp::bit_cast<FreeListNode *>(chunk.data());
node->next = chunks_[chunk_ptr];
node->size = chunk.size();

but I think this would still be UB by breaking strict aliasing unless the original chunk.data() points to an actual FreeListNode.

Copy link
Contributor Author

@PiJoules PiJoules Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, since this is the only place a FreeListNode is created, I think a better way to do this would be placement new via

FreeListNode *node = new (chunk.data()) FreeListNode(chunks_[chunk_ptr], chunk.size());

This way we start the object lifetime of a FreeListNode at chunk.data(). This only requires chunk.data() being properly aligned but I think this should always be the case because a FreeListNode should always be stored after a Block which will have alignment large enough to hold a pointer. We don't need to explicitly destroy the object and even if we reuse the same buffer space, a placement new automatically destroys the lifetime of the old FreeListNode.

Some of the freelist code uses type punning which is UB in C++, namely
because we read from a union member that is not the active union
member.
@PiJoules PiJoules force-pushed the fix-ub-in-freelist branch from ba4509f to e304a81 Compare June 13, 2024 21:48
@PiJoules PiJoules merged commit 3106a23 into llvm:main Jun 13, 2024
4 of 5 checks passed
@PiJoules PiJoules deleted the fix-ub-in-freelist branch June 13, 2024 21:55
EthanLuisMcDonough pushed a commit to EthanLuisMcDonough/llvm-project that referenced this pull request Aug 13, 2024
Some of the freelist code uses type punning which is UB in C++, namely
because we read from a union member that is not the active union member.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants