Skip to content

Commit dd626e7

Browse files
committed
Fix cloning Symbols not increasing their ref count
1 parent 3fe815b commit dd626e7

File tree

2 files changed

+63
-28
lines changed

2 files changed

+63
-28
lines changed

crates/intern/src/symbol.rs

Lines changed: 62 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use std::{
55
borrow::Borrow,
66
fmt,
77
hash::{BuildHasherDefault, Hash, Hasher},
8-
mem,
8+
mem::{self, ManuallyDrop},
99
ptr::NonNull,
1010
sync::OnceLock,
1111
};
@@ -25,6 +25,15 @@ const _: () = assert!(std::mem::align_of::<Box<str>>() == std::mem::align_of::<&
2525
const _: () = assert!(std::mem::size_of::<Arc<Box<str>>>() == std::mem::size_of::<&&str>());
2626
const _: () = assert!(std::mem::align_of::<Arc<Box<str>>>() == std::mem::align_of::<&&str>());
2727

28+
const _: () =
29+
assert!(std::mem::size_of::<*const *const str>() == std::mem::size_of::<TaggedArcPtr>());
30+
const _: () =
31+
assert!(std::mem::align_of::<*const *const str>() == std::mem::align_of::<TaggedArcPtr>());
32+
33+
const _: () = assert!(std::mem::size_of::<Arc<Box<str>>>() == std::mem::size_of::<TaggedArcPtr>());
34+
const _: () =
35+
assert!(std::mem::align_of::<Arc<Box<str>>>() == std::mem::align_of::<TaggedArcPtr>());
36+
2837
/// A pointer that points to a pointer to a `str`, it may be backed as a `&'static &'static str` or
2938
/// `Arc<Box<str>>` but its size is that of a thin pointer. The active variant is encoded as a tag
3039
/// in the LSB of the alignment niche.
@@ -40,19 +49,24 @@ impl TaggedArcPtr {
4049
const BOOL_BITS: usize = true as usize;
4150

4251
const fn non_arc(r: &'static &'static str) -> Self {
43-
Self {
44-
// SAFETY: The pointer is non-null as it is derived from a reference
45-
// Ideally we would call out to `pack_arc` but for a `false` tag, unfortunately the
46-
// packing stuff requires reading out the pointer to an integer which is not supported
47-
// in const contexts, so here we make use of the fact that for the non-arc version the
48-
// tag is false (0) and thus does not need touching the actual pointer value.ext)
49-
packed: unsafe {
50-
NonNull::new_unchecked((r as *const &str).cast::<*const str>().cast_mut())
51-
},
52-
}
52+
assert!(
53+
mem::align_of::<&'static &'static str>().trailing_zeros() as usize > Self::BOOL_BITS
54+
);
55+
// SAFETY: The pointer is non-null as it is derived from a reference
56+
// Ideally we would call out to `pack_arc` but for a `false` tag, unfortunately the
57+
// packing stuff requires reading out the pointer to an integer which is not supported
58+
// in const contexts, so here we make use of the fact that for the non-arc version the
59+
// tag is false (0) and thus does not need touching the actual pointer value.ext)
60+
61+
let packed =
62+
unsafe { NonNull::new_unchecked((r as *const &str).cast::<*const str>().cast_mut()) };
63+
Self { packed }
5364
}
5465

5566
fn arc(arc: Arc<Box<str>>) -> Self {
67+
assert!(
68+
mem::align_of::<&'static &'static str>().trailing_zeros() as usize > Self::BOOL_BITS
69+
);
5670
Self {
5771
packed: Self::pack_arc(
5872
// Safety: `Arc::into_raw` always returns a non null pointer
@@ -63,12 +77,14 @@ impl TaggedArcPtr {
6377

6478
/// Retrieves the tag.
6579
#[inline]
66-
pub(crate) fn try_as_arc_owned(self) -> Option<Arc<Box<str>>> {
80+
pub(crate) fn try_as_arc_owned(self) -> Option<ManuallyDrop<Arc<Box<str>>>> {
6781
// Unpack the tag from the alignment niche
6882
let tag = Strict::addr(self.packed.as_ptr()) & Self::BOOL_BITS;
6983
if tag != 0 {
7084
// Safety: We checked that the tag is non-zero -> true, so we are pointing to the data offset of an `Arc`
71-
Some(unsafe { Arc::from_raw(self.pointer().as_ptr().cast::<Box<str>>()) })
85+
Some(ManuallyDrop::new(unsafe {
86+
Arc::from_raw(self.pointer().as_ptr().cast::<Box<str>>())
87+
}))
7288
} else {
7389
None
7490
}
@@ -122,10 +138,11 @@ impl TaggedArcPtr {
122138
}
123139
}
124140

125-
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
141+
#[derive(PartialEq, Eq, Hash, Debug)]
126142
pub struct Symbol {
127143
repr: TaggedArcPtr,
128144
}
145+
129146
const _: () = assert!(std::mem::size_of::<Symbol>() == std::mem::size_of::<NonNull<()>>());
130147
const _: () = assert!(std::mem::align_of::<Symbol>() == std::mem::align_of::<NonNull<()>>());
131148

@@ -185,19 +202,27 @@ impl Symbol {
185202
fn drop_slow(arc: &Arc<Box<str>>) {
186203
let (mut shard, hash) = Self::select_shard(arc);
187204

188-
if Arc::count(arc) != 2 {
189-
// Another thread has interned another copy
190-
return;
205+
match Arc::count(arc) {
206+
0 => unreachable!(),
207+
1 => unreachable!(),
208+
2 => (),
209+
_ => {
210+
// Another thread has interned another copy
211+
return;
212+
}
191213
}
192214

193-
match shard.raw_entry_mut().from_key_hashed_nocheck::<str>(hash, arc.as_ref()) {
194-
RawEntryMut::Occupied(occ) => occ.remove_entry(),
195-
RawEntryMut::Vacant(_) => unreachable!(),
196-
}
197-
.0
198-
.0
199-
.try_as_arc_owned()
200-
.unwrap();
215+
ManuallyDrop::into_inner(
216+
match shard.raw_entry_mut().from_key_hashed_nocheck::<str>(hash, arc.as_ref()) {
217+
RawEntryMut::Occupied(occ) => occ.remove_entry(),
218+
RawEntryMut::Vacant(_) => unreachable!(),
219+
}
220+
.0
221+
.0
222+
.try_as_arc_owned()
223+
.unwrap(),
224+
);
225+
debug_assert_eq!(Arc::count(&arc), 1);
201226

202227
// Shrink the backing storage if the shard is less than 50% occupied.
203228
if shard.len() * 2 < shard.capacity() {
@@ -219,7 +244,13 @@ impl Drop for Symbol {
219244
Self::drop_slow(&arc);
220245
}
221246
// decrement the ref count
222-
drop(arc);
247+
ManuallyDrop::into_inner(arc);
248+
}
249+
}
250+
251+
impl Clone for Symbol {
252+
fn clone(&self) -> Self {
253+
Self { repr: increase_arc_refcount(self.repr) }
223254
}
224255
}
225256

@@ -228,8 +259,7 @@ fn increase_arc_refcount(repr: TaggedArcPtr) -> TaggedArcPtr {
228259
return repr;
229260
};
230261
// increase the ref count
231-
mem::forget(arc.clone());
232-
mem::forget(arc);
262+
mem::forget(Arc::clone(&arc));
233263
repr
234264
}
235265

@@ -265,6 +295,7 @@ mod tests {
265295
let base_len = MAP.get().unwrap().len();
266296
let hello = Symbol::intern("hello");
267297
let world = Symbol::intern("world");
298+
let more_worlds = world.clone();
268299
let bang = Symbol::intern("!");
269300
let q = Symbol::intern("?");
270301
assert_eq!(MAP.get().unwrap().len(), base_len + 4);
@@ -275,6 +306,7 @@ mod tests {
275306
drop(q);
276307
assert_eq!(MAP.get().unwrap().len(), base_len + 3);
277308
let default = Symbol::intern("default");
309+
let many_worlds = world.clone();
278310
assert_eq!(MAP.get().unwrap().len(), base_len + 3);
279311
assert_eq!(
280312
"hello default world!",
@@ -285,6 +317,8 @@ mod tests {
285317
"hello world!",
286318
format!("{} {}{}", hello.as_str(), world.as_str(), bang.as_str())
287319
);
320+
drop(many_worlds);
321+
drop(more_worlds);
288322
drop(hello);
289323
drop(world);
290324
drop(bang);

crates/intern/src/symbol/symbols.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use crate::{
1010
symbol::{SymbolProxy, TaggedArcPtr},
1111
Symbol,
1212
};
13+
1314
macro_rules! define_symbols {
1415
(@WITH_NAME: $($alias:ident = $value:literal),* $(,)? @PLAIN: $($name:ident),* $(,)?) => {
1516
$(

0 commit comments

Comments
 (0)