Skip to content

Commit 810c043

Browse files
committed
Implement iterator logic in RawIter
1 parent 7ee722e commit 810c043

File tree

1 file changed

+80
-102
lines changed

1 file changed

+80
-102
lines changed

src/lib.rs

Lines changed: 80 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ pub use cached::{CachedIntoIter, CachedIterMut, CachedThreadLocal};
7676
use std::cell::UnsafeCell;
7777
use std::fmt;
7878
use std::iter::FusedIterator;
79-
use std::marker::PhantomData;
8079
use std::mem;
8180
use std::mem::MaybeUninit;
8281
use std::panic::UnwindSafe;
@@ -274,20 +273,7 @@ impl<T: Send> ThreadLocal<T> {
274273
{
275274
Iter {
276275
thread_local: self,
277-
yielded: 0,
278-
bucket: 0,
279-
bucket_size: 1,
280-
index: 0,
281-
}
282-
}
283-
284-
fn raw_iter_mut(&mut self) -> RawIterMut<T> {
285-
RawIterMut {
286-
remaining: *self.values.get_mut(),
287-
buckets: unsafe { *(&self.buckets as *const _ as *const [*mut Entry<T>; BUCKETS]) },
288-
bucket: 0,
289-
bucket_size: 1,
290-
index: 0,
276+
raw: RawIter::new(),
291277
}
292278
}
293279

@@ -299,8 +285,8 @@ impl<T: Send> ThreadLocal<T> {
299285
/// threads are currently accessing their associated values.
300286
pub fn iter_mut(&mut self) -> IterMut<T> {
301287
IterMut {
302-
raw: self.raw_iter_mut(),
303-
marker: PhantomData,
288+
thread_local: self,
289+
raw: RawIter::new(),
304290
}
305291
}
306292

@@ -319,10 +305,10 @@ impl<T: Send> IntoIterator for ThreadLocal<T> {
319305
type Item = T;
320306
type IntoIter = IntoIter<T>;
321307

322-
fn into_iter(mut self) -> IntoIter<T> {
308+
fn into_iter(self) -> IntoIter<T> {
323309
IntoIter {
324-
raw: self.raw_iter_mut(),
325-
_thread_local: self,
310+
thread_local: self,
311+
raw: RawIter::new(),
326312
}
327313
}
328314
}
@@ -361,22 +347,26 @@ impl<T: Send + fmt::Debug> fmt::Debug for ThreadLocal<T> {
361347

362348
impl<T: Send + UnwindSafe> UnwindSafe for ThreadLocal<T> {}
363349

364-
/// Iterator over the contents of a `ThreadLocal`.
365350
#[derive(Debug)]
366-
pub struct Iter<'a, T: Send + Sync> {
367-
thread_local: &'a ThreadLocal<T>,
351+
struct RawIter {
368352
yielded: usize,
369353
bucket: usize,
370354
bucket_size: usize,
371355
index: usize,
372356
}
357+
impl RawIter {
358+
fn new() -> Self {
359+
Self {
360+
yielded: 0,
361+
bucket: 0,
362+
bucket_size: 1,
363+
index: 0,
364+
}
365+
}
373366

374-
impl<'a, T: Send + Sync> Iterator for Iter<'a, T> {
375-
type Item = &'a T;
376-
377-
fn next(&mut self) -> Option<Self::Item> {
367+
fn next<'a, T: Send + Sync>(&mut self, thread_local: &'a ThreadLocal<T>) -> Option<&'a T> {
378368
while self.bucket < BUCKETS {
379-
let bucket = unsafe { self.thread_local.buckets.get_unchecked(self.bucket) };
369+
let bucket = unsafe { thread_local.buckets.get_unchecked(self.bucket) };
380370
let bucket = bucket.load(Ordering::Relaxed);
381371

382372
if !bucket.is_null() {
@@ -390,140 +380,128 @@ impl<'a, T: Send + Sync> Iterator for Iter<'a, T> {
390380
}
391381
}
392382

393-
if self.bucket != 0 {
394-
self.bucket_size <<= 1;
395-
}
396-
self.bucket += 1;
397-
398-
self.index = 0;
383+
self.next_bucket();
399384
}
400385
None
401386
}
402-
403-
fn size_hint(&self) -> (usize, Option<usize>) {
404-
let total = self.thread_local.values.load(Ordering::Acquire);
405-
(total - self.yielded, None)
406-
}
407-
}
408-
impl<T: Send + Sync> FusedIterator for Iter<'_, T> {}
409-
410-
struct RawIterMut<T: Send> {
411-
remaining: usize,
412-
buckets: [*mut Entry<T>; BUCKETS],
413-
bucket: usize,
414-
bucket_size: usize,
415-
index: usize,
416-
}
417-
418-
impl<T: Send> Iterator for RawIterMut<T> {
419-
type Item = *mut MaybeUninit<T>;
420-
421-
fn next(&mut self) -> Option<Self::Item> {
422-
if self.remaining == 0 {
387+
fn next_mut<'a, T: Send>(
388+
&mut self,
389+
thread_local: &'a mut ThreadLocal<T>,
390+
) -> Option<&'a mut Entry<T>> {
391+
if *thread_local.values.get_mut() == self.yielded {
423392
return None;
424393
}
425394

426395
loop {
427-
let bucket = unsafe { *self.buckets.get_unchecked(self.bucket) };
396+
let bucket = unsafe { thread_local.buckets.get_unchecked_mut(self.bucket) };
397+
let bucket = *bucket.get_mut();
428398

429399
if !bucket.is_null() {
430400
while self.index < self.bucket_size {
431401
let entry = unsafe { &mut *bucket.add(self.index) };
432402
self.index += 1;
433403
if *entry.present.get_mut() {
434-
self.remaining -= 1;
435-
return Some(entry.value.get());
404+
self.yielded += 1;
405+
return Some(entry);
436406
}
437407
}
438408
}
439409

440-
if self.bucket != 0 {
441-
self.bucket_size <<= 1;
442-
}
443-
self.bucket += 1;
410+
self.next_bucket();
411+
}
412+
}
444413

445-
self.index = 0;
414+
fn next_bucket(&mut self) {
415+
if self.bucket != 0 {
416+
self.bucket_size <<= 1;
446417
}
418+
self.bucket += 1;
419+
self.index = 0;
447420
}
448421

449-
fn size_hint(&self) -> (usize, Option<usize>) {
450-
(self.remaining, Some(self.remaining))
422+
fn size_hint<T: Send>(&self, thread_local: &ThreadLocal<T>) -> (usize, Option<usize>) {
423+
let total = thread_local.values.load(Ordering::Acquire);
424+
(total - self.yielded, None)
425+
}
426+
fn size_hint_frozen<T: Send>(&self, thread_local: &ThreadLocal<T>) -> (usize, Option<usize>) {
427+
let total = unsafe { *(&thread_local.values as *const AtomicUsize as *const usize) };
428+
let remaining = total - self.yielded;
429+
(remaining, Some(remaining))
451430
}
452431
}
453432

454-
unsafe impl<T: Send> Send for RawIterMut<T> {}
455-
unsafe impl<T: Send + Sync> Sync for RawIterMut<T> {}
433+
/// Iterator over the contents of a `ThreadLocal`.
434+
#[derive(Debug)]
435+
pub struct Iter<'a, T: Send + Sync> {
436+
thread_local: &'a ThreadLocal<T>,
437+
raw: RawIter,
438+
}
439+
440+
impl<'a, T: Send + Sync> Iterator for Iter<'a, T> {
441+
type Item = &'a T;
442+
fn next(&mut self) -> Option<Self::Item> {
443+
self.raw.next(self.thread_local)
444+
}
445+
fn size_hint(&self) -> (usize, Option<usize>) {
446+
self.raw.size_hint(self.thread_local)
447+
}
448+
}
449+
impl<T: Send + Sync> FusedIterator for Iter<'_, T> {}
456450

457451
/// Mutable iterator over the contents of a `ThreadLocal`.
458452
pub struct IterMut<'a, T: Send> {
459-
raw: RawIterMut<T>,
460-
marker: PhantomData<&'a mut ThreadLocal<T>>,
453+
thread_local: &'a mut ThreadLocal<T>,
454+
raw: RawIter,
461455
}
462456

463457
impl<'a, T: Send> Iterator for IterMut<'a, T> {
464458
type Item = &'a mut T;
465-
466459
fn next(&mut self) -> Option<&'a mut T> {
467460
self.raw
468-
.next()
469-
.map(|x| unsafe { &mut *(&mut *x).as_mut_ptr() })
461+
.next_mut(self.thread_local)
462+
.map(|entry| unsafe { &mut *(&mut *entry.value.get()).as_mut_ptr() })
470463
}
471-
472464
fn size_hint(&self) -> (usize, Option<usize>) {
473-
self.raw.size_hint()
465+
self.raw.size_hint_frozen(self.thread_local)
474466
}
475467
}
476468

477469
impl<T: Send> ExactSizeIterator for IterMut<'_, T> {}
478470
impl<T: Send> FusedIterator for IterMut<'_, T> {}
479471

480-
// The Debug bound is technically unnecessary but makes the API more consistent and future-proof.
481-
impl<T: Send + fmt::Debug> fmt::Debug for IterMut<'_, T> {
472+
// Manual impl so we don't call Debug on the ThreadLocal, as doing so would create a reference to
473+
// this thread's value that potentially aliases with a mutable reference we have given out.
474+
impl<'a, T: Send + fmt::Debug> fmt::Debug for IterMut<'a, T> {
482475
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
483-
f.debug_struct("IterMut")
484-
.field("remaining", &self.raw.remaining)
485-
.field("bucket", &self.raw.bucket)
486-
.field("bucket_size", &self.raw.bucket_size)
487-
.field("index", &self.raw.index)
488-
.finish()
476+
f.debug_struct("IterMut").field("raw", &self.raw).finish()
489477
}
490478
}
491479

492480
/// An iterator that moves out of a `ThreadLocal`.
481+
#[derive(Debug)]
493482
pub struct IntoIter<T: Send> {
494-
raw: RawIterMut<T>,
495-
_thread_local: ThreadLocal<T>,
483+
thread_local: ThreadLocal<T>,
484+
raw: RawIter,
496485
}
497486

498487
impl<T: Send> Iterator for IntoIter<T> {
499488
type Item = T;
500-
501489
fn next(&mut self) -> Option<T> {
502-
self.raw
503-
.next()
504-
.map(|x| unsafe { std::mem::replace(&mut *x, MaybeUninit::uninit()).assume_init() })
490+
self.raw.next_mut(&mut self.thread_local).map(|entry| {
491+
*entry.present.get_mut() = false;
492+
unsafe {
493+
std::mem::replace(&mut *entry.value.get(), MaybeUninit::uninit()).assume_init()
494+
}
495+
})
505496
}
506-
507497
fn size_hint(&self) -> (usize, Option<usize>) {
508-
self.raw.size_hint()
498+
self.raw.size_hint_frozen(&self.thread_local)
509499
}
510500
}
511501

512502
impl<T: Send> ExactSizeIterator for IntoIter<T> {}
513503
impl<T: Send> FusedIterator for IntoIter<T> {}
514504

515-
// The Debug bound is technically unnecessary but makes the API more consistent and future-proof.
516-
impl<T: Send + fmt::Debug> fmt::Debug for IntoIter<T> {
517-
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
518-
f.debug_struct("IntoIter")
519-
.field("remaining", &self.raw.remaining)
520-
.field("bucket", &self.raw.bucket)
521-
.field("bucket_size", &self.raw.bucket_size)
522-
.field("index", &self.raw.index)
523-
.finish()
524-
}
525-
}
526-
527505
fn allocate_bucket<T>(size: usize) -> *mut Entry<T> {
528506
Box::into_raw(
529507
(0..size)

0 commit comments

Comments
 (0)