@@ -8,6 +8,7 @@ use crate::fmt;
8
8
use crate :: ops:: Deref ;
9
9
use crate :: panic:: { RefUnwindSafe , UnwindSafe } ;
10
10
use crate :: sys:: sync as sys;
11
+ use crate :: thread:: ThreadId ;
11
12
12
13
/// A re-entrant mutual exclusion lock
13
14
///
@@ -92,18 +93,19 @@ cfg_if!(
92
93
struct Tid ( AtomicU64 ) ;
93
94
94
95
impl Tid {
95
- const fn new( tid : u64 ) -> Self {
96
- Self ( AtomicU64 :: new( tid ) )
96
+ const fn new( ) -> Self {
97
+ Self ( AtomicU64 :: new( 0 ) )
97
98
}
98
99
99
100
#[ inline]
100
- fn get ( & self ) -> u64 {
101
- self . 0 . load( Relaxed )
101
+ fn contains ( & self , owner : ThreadId ) -> bool {
102
+ owner . as_u64 ( ) . get ( ) == self . 0 . load( Relaxed )
102
103
}
103
104
104
105
#[ inline]
105
- fn set( & self , tid: u64 ) {
106
- self . 0 . store( tid, Relaxed )
106
+ fn set( & self , tid: Option <ThreadId >) {
107
+ let value = tid. map_or( 0 , |tid| tid. as_u64( ) . get( ) ) ;
108
+ self . 0 . store( value, Relaxed ) ;
107
109
}
108
110
}
109
111
} else if #[ cfg( target_has_atomic = "32" ) ] {
@@ -116,16 +118,18 @@ cfg_if!(
116
118
}
117
119
118
120
impl Tid {
119
- const fn new( tid : u64 ) -> Self {
121
+ const fn new( ) -> Self {
120
122
Self {
121
123
seq: AtomicU32 :: new( 0 ) ,
122
- low: AtomicU32 :: new( tid as u32 ) ,
123
- high: AtomicU32 :: new( ( tid >> 32 ) as u32 ) ,
124
+ low: AtomicU32 :: new( 0 ) ,
125
+ high: AtomicU32 :: new( 0 ) ,
124
126
}
125
127
}
126
128
127
129
#[ inline]
128
- fn get( & self ) -> u64 {
130
+ // NOTE: This assumes that `owner` is the ID of the current
131
+ // thread, and may spuriously return `false` if that's not the case.
132
+ fn contains( & self , owner: ThreadId ) -> bool {
129
133
// Synchronizes with the release-increment in `set()` to ensure
130
134
// we only read the data after it's been fully written.
131
135
let mut seq = self . seq. load( Acquire ) ;
@@ -137,23 +141,28 @@ cfg_if!(
137
141
// store to ensure that `get()` doesn't see data from a subsequent
138
142
// `set()` call.
139
143
match self . seq. compare_exchange_weak( seq, seq, Release , Acquire ) {
140
- Ok ( _) => return u64 :: from( low) | ( u64 :: from( high) << 32 ) ,
144
+ Ok ( _) => {
145
+ let tid = u64 :: from( low) | ( u64 :: from( high) << 32 ) ;
146
+ return owner. as_u64( ) . get( ) == tid;
147
+ } ,
141
148
Err ( new) => seq = new,
142
149
}
143
150
} else {
144
- crate :: hint:: spin_loop( ) ;
145
- seq = self . seq. load( Acquire ) ;
151
+ // Another thread is currently writing to the seqlock. That thread
152
+ // must also be holding the mutex, so we can't currently be the lock owner.
153
+ return false ;
146
154
}
147
155
}
148
156
}
149
157
150
158
#[ inline]
151
159
// This may only be called from one thread at a time, otherwise
152
160
// concurrent `get()` calls may return teared data.
153
- fn set( & self , tid: u64 ) {
161
+ fn set( & self , tid: Option <ThreadId >) {
162
+ let value = tid. map_or( 0 , |tid| tid. as_u64( ) . get( ) ) ;
154
163
self . seq. fetch_add( 1 , Acquire ) ;
155
- self . low. store( tid as u32 , Relaxed ) ;
156
- self . high. store( ( tid >> 32 ) as u32 , Relaxed ) ;
164
+ self . low. store( value as u32 , Relaxed ) ;
165
+ self . high. store( ( value >> 32 ) as u32 , Relaxed ) ;
157
166
self . seq. fetch_add( 1 , Release ) ;
158
167
}
159
168
}
@@ -213,7 +222,7 @@ impl<T> ReentrantLock<T> {
213
222
pub const fn new ( t : T ) -> ReentrantLock < T > {
214
223
ReentrantLock {
215
224
mutex : sys:: Mutex :: new ( ) ,
216
- owner : Tid :: new ( 0 ) ,
225
+ owner : Tid :: new ( ) ,
217
226
lock_count : UnsafeCell :: new ( 0 ) ,
218
227
data : t,
219
228
}
@@ -266,11 +275,11 @@ impl<T: ?Sized> ReentrantLock<T> {
266
275
let this_thread = current_thread_id ( ) ;
267
276
// Safety: We only touch lock_count when we own the lock.
268
277
unsafe {
269
- if self . owner . get ( ) == this_thread {
278
+ if self . owner . contains ( this_thread) {
270
279
self . increment_lock_count ( ) . expect ( "lock count overflow in reentrant mutex" ) ;
271
280
} else {
272
281
self . mutex . lock ( ) ;
273
- self . owner . set ( this_thread) ;
282
+ self . owner . set ( Some ( this_thread) ) ;
274
283
debug_assert_eq ! ( * self . lock_count. get( ) , 0 ) ;
275
284
* self . lock_count . get ( ) = 1 ;
276
285
}
@@ -308,11 +317,11 @@ impl<T: ?Sized> ReentrantLock<T> {
308
317
let this_thread = current_thread_id ( ) ;
309
318
// Safety: We only touch lock_count when we own the lock.
310
319
unsafe {
311
- if self . owner . get ( ) == this_thread {
320
+ if self . owner . contains ( this_thread) {
312
321
self . increment_lock_count ( ) ?;
313
322
Some ( ReentrantLockGuard { lock : self } )
314
323
} else if self . mutex . try_lock ( ) {
315
- self . owner . set ( this_thread) ;
324
+ self . owner . set ( Some ( this_thread) ) ;
316
325
debug_assert_eq ! ( * self . lock_count. get( ) , 0 ) ;
317
326
* self . lock_count . get ( ) = 1 ;
318
327
Some ( ReentrantLockGuard { lock : self } )
@@ -385,7 +394,7 @@ impl<T: ?Sized> Drop for ReentrantLockGuard<'_, T> {
385
394
unsafe {
386
395
* self . lock . lock_count . get ( ) -= 1 ;
387
396
if * self . lock . lock_count . get ( ) == 0 {
388
- self . lock . owner . set ( 0 ) ;
397
+ self . lock . owner . set ( None ) ;
389
398
self . lock . mutex . unlock ( ) ;
390
399
}
391
400
}
@@ -397,11 +406,11 @@ impl<T: ?Sized> Drop for ReentrantLockGuard<'_, T> {
397
406
///
398
407
/// Panics if called during a TLS destructor on a thread that hasn't
399
408
/// been assigned an ID.
400
- pub ( crate ) fn current_thread_id ( ) -> u64 {
409
+ pub ( crate ) fn current_thread_id ( ) -> ThreadId {
401
410
#[ cold]
402
411
fn no_tid ( ) -> ! {
403
- panic ! ( "Thread hasn't been assigned an ID!" )
412
+ rtabort ! ( "Thread hasn't been assigned an ID!" )
404
413
}
405
414
406
- crate :: thread:: try_current_id ( ) . map_or_else ( || no_tid ( ) , |tid| tid . as_u64 ( ) . get ( ) )
415
+ crate :: thread:: try_current_id ( ) . unwrap_or_else ( || no_tid ( ) )
407
416
}
0 commit comments