|
18 | 18 | //! depending on the value of cfg!(parallel_compiler).
|
19 | 19 |
|
20 | 20 | use crate::owning_ref::{Erased, OwningRef};
|
| 21 | +use std::cell::{RefCell, RefMut}; |
21 | 22 | use std::collections::HashMap;
|
22 | 23 | use std::hash::{BuildHasher, Hash};
|
| 24 | +use std::mem; |
| 25 | +use std::num::NonZeroUsize; |
23 | 26 | use std::ops::{Deref, DerefMut};
|
24 | 27 | use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
|
25 | 28 |
|
@@ -460,6 +463,91 @@ impl<T: Clone> Clone for Lock<T> {
|
460 | 463 | }
|
461 | 464 | }
|
462 | 465 |
|
| 466 | +fn next() -> NonZeroUsize { |
| 467 | + static COUNTER: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(1); |
| 468 | + NonZeroUsize::new(COUNTER.fetch_add(1, Ordering::SeqCst)).expect("more than usize::MAX threads") |
| 469 | +} |
| 470 | + |
| 471 | +pub(crate) fn get_thread_id() -> NonZeroUsize { |
| 472 | + thread_local!(static THREAD_ID: NonZeroUsize = next()); |
| 473 | + THREAD_ID.with(|&x| x) |
| 474 | +} |
| 475 | + |
| 476 | +// `RefLock` is a thread-safe data structure because it can |
| 477 | +// only be used within the thread in which it was created. |
| 478 | +pub struct RefLock<T> { |
| 479 | + val: RefCell<T>, |
| 480 | + thread_id: NonZeroUsize, |
| 481 | +} |
| 482 | + |
| 483 | +impl<T> RefLock<T> { |
| 484 | + #[inline(always)] |
| 485 | + pub fn new(value: T) -> Self { |
| 486 | + Self { val: RefCell::new(value), thread_id: get_thread_id() } |
| 487 | + } |
| 488 | + |
| 489 | + #[inline(always)] |
| 490 | + fn assert_thread(&self) { |
| 491 | + assert_eq!(get_thread_id(), self.thread_id); |
| 492 | + } |
| 493 | + |
| 494 | + #[inline(always)] |
| 495 | + pub fn get_mut(&mut self) -> &mut T { |
| 496 | + self.assert_thread(); |
| 497 | + self.val.get_mut() |
| 498 | + } |
| 499 | + |
| 500 | + #[inline(always)] |
| 501 | + pub fn try_lock(&self) -> Option<RefMut<'_, T>> { |
| 502 | + self.assert_thread(); |
| 503 | + self.val.try_borrow_mut().ok() |
| 504 | + } |
| 505 | + |
| 506 | + #[inline(always)] |
| 507 | + #[track_caller] |
| 508 | + pub fn lock(&self) -> RefMut<'_, T> { |
| 509 | + self.assert_thread(); |
| 510 | + self.val.borrow_mut() |
| 511 | + } |
| 512 | + |
| 513 | + #[inline(always)] |
| 514 | + #[track_caller] |
| 515 | + pub fn with_lock<F: FnOnce(&mut T) -> R, R>(&self, f: F) -> R { |
| 516 | + f(&mut *self.lock()) |
| 517 | + } |
| 518 | + |
| 519 | + #[inline(always)] |
| 520 | + #[track_caller] |
| 521 | + pub fn borrow(&self) -> RefMut<'_, T> { |
| 522 | + self.lock() |
| 523 | + } |
| 524 | + |
| 525 | + #[inline(always)] |
| 526 | + #[track_caller] |
| 527 | + pub fn borrow_mut(&self) -> RefMut<'_, T> { |
| 528 | + self.lock() |
| 529 | + } |
| 530 | +} |
| 531 | + |
| 532 | +unsafe impl<T> std::marker::Sync for RefLock<T> {} |
| 533 | + |
| 534 | +impl<T> Drop for RefLock<T> { |
| 535 | + fn drop(&mut self) { |
| 536 | + if mem::needs_drop::<T>() { |
| 537 | + if get_thread_id() != self.thread_id { |
| 538 | + panic!("destructor of fragile object ran on wrong thread"); |
| 539 | + } |
| 540 | + } |
| 541 | + } |
| 542 | +} |
| 543 | + |
| 544 | +impl<T: Default> Default for RefLock<T> { |
| 545 | + #[inline] |
| 546 | + fn default() -> Self { |
| 547 | + RefLock::new(T::default()) |
| 548 | + } |
| 549 | +} |
| 550 | + |
463 | 551 | #[derive(Debug, Default)]
|
464 | 552 | pub struct RwLock<T>(InnerRwLock<T>);
|
465 | 553 |
|
|
0 commit comments