Skip to content

Commit 26e800d

Browse files
committed
Update Tid implementation.
Gets rid of some unnecessary spinning and changes the `Tid` API to use `ThreadId` instead of raw `u64`s. Also changes `current_thread_id` to use `rtabort!` instead of `panic!`.
1 parent 897b4a8 commit 26e800d

File tree

1 file changed

+34
-25
lines changed

1 file changed

+34
-25
lines changed

library/std/src/sync/reentrant_lock.rs

+34-25
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use crate::fmt;
88
use crate::ops::Deref;
99
use crate::panic::{RefUnwindSafe, UnwindSafe};
1010
use crate::sys::sync as sys;
11+
use crate::thread::ThreadId;
1112

1213
/// A re-entrant mutual exclusion lock
1314
///
@@ -92,18 +93,19 @@ cfg_if!(
9293
struct Tid(AtomicU64);
9394

9495
impl Tid {
95-
const fn new(tid: u64) -> Self {
96-
Self(AtomicU64::new(tid))
96+
const fn new() -> Self {
97+
Self(AtomicU64::new(0))
9798
}
9899

99100
#[inline]
100-
fn get(&self) -> u64 {
101-
self.0.load(Relaxed)
101+
fn contains(&self, owner: ThreadId) -> bool {
102+
owner.as_u64().get() == self.0.load(Relaxed)
102103
}
103104

104105
#[inline]
105-
fn set(&self, tid: u64) {
106-
self.0.store(tid, Relaxed)
106+
fn set(&self, tid: Option<ThreadId>) {
107+
let value = tid.map_or(0, |tid| tid.as_u64().get());
108+
self.0.store(value, Relaxed);
107109
}
108110
}
109111
} else if #[cfg(target_has_atomic = "32")] {
@@ -116,16 +118,18 @@ cfg_if!(
116118
}
117119

118120
impl Tid {
119-
const fn new(tid: u64) -> Self {
121+
const fn new() -> Self {
120122
Self {
121123
seq: AtomicU32::new(0),
122-
low: AtomicU32::new(tid as u32),
123-
high: AtomicU32::new((tid >> 32) as u32),
124+
low: AtomicU32::new(0),
125+
high: AtomicU32::new(0),
124126
}
125127
}
126128

127129
#[inline]
128-
fn get(&self) -> u64 {
130+
// NOTE: This assumes that `owner` is the ID of the current
131+
// thread, and may spuriously return `false` if that's not the case.
132+
fn contains(&self, owner: ThreadId) -> bool {
129133
// Synchronizes with the release-increment in `set()` to ensure
130134
// we only read the data after it's been fully written.
131135
let mut seq = self.seq.load(Acquire);
@@ -137,23 +141,28 @@ cfg_if!(
137141
// store to ensure that `get()` doesn't see data from a subsequent
138142
// `set()` call.
139143
match self.seq.compare_exchange_weak(seq, seq, Release, Acquire) {
140-
Ok(_) => return u64::from(low) | (u64::from(high) << 32),
144+
Ok(_) => {
145+
let tid = u64::from(low) | (u64::from(high) << 32);
146+
return owner.as_u64().get() == tid;
147+
},
141148
Err(new) => seq = new,
142149
}
143150
} else {
144-
crate::hint::spin_loop();
145-
seq = self.seq.load(Acquire);
151+
// Another thread is currently writing to the seqlock. That thread
152+
// must also be holding the mutex, so we can't currently be the lock owner.
153+
return false;
146154
}
147155
}
148156
}
149157

150158
#[inline]
151159
// This may only be called from one thread at a time, otherwise
152160
// concurrent `get()` calls may return teared data.
153-
fn set(&self, tid: u64) {
161+
fn set(&self, tid: Option<ThreadId>) {
162+
let value = tid.map_or(0, |tid| tid.as_u64().get());
154163
self.seq.fetch_add(1, Acquire);
155-
self.low.store(tid as u32, Relaxed);
156-
self.high.store((tid >> 32) as u32, Relaxed);
164+
self.low.store(value as u32, Relaxed);
165+
self.high.store((value >> 32) as u32, Relaxed);
157166
self.seq.fetch_add(1, Release);
158167
}
159168
}
@@ -213,7 +222,7 @@ impl<T> ReentrantLock<T> {
213222
pub const fn new(t: T) -> ReentrantLock<T> {
214223
ReentrantLock {
215224
mutex: sys::Mutex::new(),
216-
owner: Tid::new(0),
225+
owner: Tid::new(),
217226
lock_count: UnsafeCell::new(0),
218227
data: t,
219228
}
@@ -266,11 +275,11 @@ impl<T: ?Sized> ReentrantLock<T> {
266275
let this_thread = current_thread_id();
267276
// Safety: We only touch lock_count when we own the lock.
268277
unsafe {
269-
if self.owner.get() == this_thread {
278+
if self.owner.contains(this_thread) {
270279
self.increment_lock_count().expect("lock count overflow in reentrant mutex");
271280
} else {
272281
self.mutex.lock();
273-
self.owner.set(this_thread);
282+
self.owner.set(Some(this_thread));
274283
debug_assert_eq!(*self.lock_count.get(), 0);
275284
*self.lock_count.get() = 1;
276285
}
@@ -308,11 +317,11 @@ impl<T: ?Sized> ReentrantLock<T> {
308317
let this_thread = current_thread_id();
309318
// Safety: We only touch lock_count when we own the lock.
310319
unsafe {
311-
if self.owner.get() == this_thread {
320+
if self.owner.contains(this_thread) {
312321
self.increment_lock_count()?;
313322
Some(ReentrantLockGuard { lock: self })
314323
} else if self.mutex.try_lock() {
315-
self.owner.set(this_thread);
324+
self.owner.set(Some(this_thread));
316325
debug_assert_eq!(*self.lock_count.get(), 0);
317326
*self.lock_count.get() = 1;
318327
Some(ReentrantLockGuard { lock: self })
@@ -385,7 +394,7 @@ impl<T: ?Sized> Drop for ReentrantLockGuard<'_, T> {
385394
unsafe {
386395
*self.lock.lock_count.get() -= 1;
387396
if *self.lock.lock_count.get() == 0 {
388-
self.lock.owner.set(0);
397+
self.lock.owner.set(None);
389398
self.lock.mutex.unlock();
390399
}
391400
}
@@ -397,11 +406,11 @@ impl<T: ?Sized> Drop for ReentrantLockGuard<'_, T> {
397406
///
398407
/// Panics if called during a TLS destructor on a thread that hasn't
399408
/// been assigned an ID.
400-
pub(crate) fn current_thread_id() -> u64 {
409+
pub(crate) fn current_thread_id() -> ThreadId {
401410
#[cold]
402411
fn no_tid() -> ! {
403-
panic!("Thread hasn't been assigned an ID!")
412+
rtabort!("Thread hasn't been assigned an ID!")
404413
}
405414

406-
crate::thread::try_current_id().map_or_else(|| no_tid(), |tid| tid.as_u64().get())
415+
crate::thread::try_current_id().unwrap_or_else(|| no_tid())
407416
}

0 commit comments

Comments
 (0)