use crate::sync::linked_list::{LinkedList, Node};
use core::{
cell::UnsafeCell,
fmt,
future::Future,
mem::MaybeUninit,
ops::{Deref, DerefMut},
pin::Pin,
sync::atomic::{AtomicU8, Ordering},
task::{Context, Poll, Waker},
};
pub struct Mutex<T: ?Sized> {
state: AtomicU8,
waiters: LinkedList<Waiter>,
data: UnsafeCell<T>,
}
const DATA_LOCKED: u8 = 1 << 0;
const WAITERS_LOCKED: u8 = 1 << 1;
#[must_use = "if unused the Mutex will immediately unlock"]
pub struct MutexGuard<'a, T: ?Sized> {
mutex: &'a Mutex<T>,
}
pub struct MutexLockFuture<'a, T: ?Sized> {
mutex: &'a Mutex<T>,
waiter: Option<*const Node<Waiter>>,
}
struct Waiter {
state: AtomicU8,
wakers: [UnsafeCell<MaybeUninit<Waker>>; 2],
}
const WAITER_INDEX: u8 = 1 << 0;
const WAITER_DISABLED: u8 = 1 << 1;
unsafe impl<T: ?Sized + Send> Send for Mutex<T> {}
unsafe impl<T: ?Sized + Send> Sync for Mutex<T> {}
unsafe impl<T: ?Sized + Send> Send for MutexGuard<'_, T> {}
unsafe impl<T: ?Sized + Sync> Sync for MutexGuard<'_, T> {}
unsafe impl<T: ?Sized + Send> Send for MutexLockFuture<'_, T> {}
impl<T> Mutex<T> {
#[inline]
pub const fn new(data: T) -> Self {
Self { state: AtomicU8::new(0), waiters: LinkedList::new(), data: UnsafeCell::new(data) }
}
#[inline]
pub fn into_inner(self) -> T {
self.data.into_inner()
}
}
impl<T: ?Sized> Mutex<T> {
#[inline]
pub fn try_lock(&self) -> Option<MutexGuard<'_, T>> {
if self.state.fetch_or(DATA_LOCKED, Ordering::Acquire) & DATA_LOCKED == 0 {
Some(MutexGuard { mutex: self })
} else {
None
}
}
#[inline]
pub fn lock(&self) -> MutexLockFuture<'_, T> {
MutexLockFuture { mutex: self, waiter: None }
}
#[inline]
pub fn get_mut(&mut self) -> &mut T {
unsafe { &mut *self.data.get() }
}
fn unlock(&self) {
let waiters_lock =
self.state.fetch_or(WAITERS_LOCKED, Ordering::Acquire) & WAITERS_LOCKED == 0;
if waiters_lock {
unsafe {
self.waiters
.drain_filter_raw(|waiter| (*waiter).is_disabled())
.for_each(|node| drop(Box::from_raw(node)));
}
}
self.state.fetch_and(!DATA_LOCKED, Ordering::Release);
for waiter in unsafe { self.waiters.iter_mut_unchecked() } {
if waiter.wake() {
break;
}
}
if waiters_lock {
self.state.fetch_and(!WAITERS_LOCKED, Ordering::Release);
}
}
}
impl<T: ?Sized> MutexLockFuture<'_, T> {
fn disable_waiter(&mut self) {
if let Some(waiter) = self.waiter.take() {
unsafe { (*waiter).disable() };
}
}
}
impl<'a, T: ?Sized> Future for MutexLockFuture<'a, T> {
type Output = MutexGuard<'a, T>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Some(lock) = self.mutex.try_lock() {
self.disable_waiter();
return Poll::Ready(lock);
}
if let Some(waiter) = self.waiter {
unsafe { (*waiter).register(cx.waker()) };
} else {
let waiter = Box::into_raw(Box::new(Node::from(Waiter::from(cx.waker().clone()))));
self.waiter = Some(waiter);
unsafe { self.mutex.waiters.push_raw(waiter) };
}
if let Some(lock) = self.mutex.try_lock() {
self.disable_waiter();
return Poll::Ready(lock);
}
Poll::Pending
}
}
impl<T: ?Sized> Drop for MutexLockFuture<'_, T> {
fn drop(&mut self) {
if let Some(waiter) = self.waiter {
if unsafe { (*waiter).disable() } & WAITER_DISABLED != 0 {
drop(self.mutex.try_lock());
}
}
}
}
impl Waiter {
fn register(&self, waker: &Waker) {
let state = self.state.load(Ordering::Acquire);
let mut index = (state & WAITER_INDEX) as usize;
if state & WAITER_DISABLED != 0
|| !waker
.will_wake(unsafe { (*self.wakers.get_unchecked(index).get()).assume_init_ref() })
{
index = (index + 1) % 2;
unsafe { (*self.wakers.get_unchecked(index).get()).write(waker.clone()) };
self.state.store(index as u8, Ordering::Release);
}
}
fn wake(&self) -> bool {
let state = self.disable();
if state & WAITER_DISABLED == 0 {
let index = (state & WAITER_INDEX) as usize;
unsafe { (*self.wakers.get_unchecked(index).get()).assume_init_read().wake() };
true
} else {
false
}
}
fn disable(&self) -> u8 {
self.state.fetch_or(WAITER_DISABLED, Ordering::Relaxed)
}
fn is_disabled(&self) -> bool {
self.state.load(Ordering::Relaxed) & WAITER_DISABLED != 0
}
}
impl From<Waker> for Waiter {
fn from(waker: Waker) -> Self {
Self {
state: AtomicU8::new(0),
wakers: [
UnsafeCell::new(MaybeUninit::new(waker)),
UnsafeCell::new(MaybeUninit::uninit()),
],
}
}
}
impl<T> From<T> for Mutex<T> {
#[inline]
fn from(data: T) -> Self {
Self::new(data)
}
}
impl<T: ?Sized + Default> Default for Mutex<T> {
#[inline]
fn default() -> Self {
Self::new(Default::default())
}
}
impl<T: ?Sized + fmt::Debug> fmt::Debug for Mutex<T> {
#[allow(clippy::option_if_let_else)]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(guard) = self.try_lock() {
f.debug_struct("Mutex").field("data", &&*guard).finish()
} else {
struct LockedPlaceholder;
impl fmt::Debug for LockedPlaceholder {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("<locked>")
}
}
f.debug_struct("Mutex").field("data", &LockedPlaceholder).finish()
}
}
}
impl<T: ?Sized> Deref for MutexGuard<'_, T> {
type Target = T;
#[inline]
fn deref(&self) -> &T {
unsafe { &*self.mutex.data.get() }
}
}
impl<T: ?Sized> DerefMut for MutexGuard<'_, T> {
#[inline]
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.mutex.data.get() }
}
}
impl<T: ?Sized> Drop for MutexGuard<'_, T> {
#[inline]
fn drop(&mut self) {
self.mutex.unlock();
}
}
impl<T: ?Sized + fmt::Debug> fmt::Debug for MutexGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MutexGuard").field("mutex", &self.mutex).finish()
}
}
impl<T: ?Sized + fmt::Display> fmt::Display for MutexGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(**self).fmt(f)
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::sync::Arc;
use core::{
future::Future,
sync::atomic::{AtomicUsize, Ordering},
task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
};
use futures::pin_mut;
#[derive(Eq, PartialEq, Debug)]
struct NonCopy(i32);
struct Counter(AtomicUsize);
impl Counter {
fn to_waker(&'static self) -> Waker {
unsafe fn clone(counter: *const ()) -> RawWaker {
RawWaker::new(counter, &VTABLE)
}
unsafe fn wake(counter: *const ()) {
unsafe { (*(counter as *const Counter)).0.fetch_add(1, Ordering::SeqCst) };
}
static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake, drop);
unsafe { Waker::from_raw(RawWaker::new(self as *const _ as *const (), &VTABLE)) }
}
}
#[test]
fn try_lock() {
let m = Mutex::new(());
*m.try_lock().unwrap() = ();
}
#[test]
fn lock() {
static COUNTER: Counter = Counter(AtomicUsize::new(0));
let waker = COUNTER.to_waker();
let mut cx = Context::from_waker(&waker);
let a = Arc::new(Mutex::new(1));
let b = Arc::clone(&a);
let c = Arc::clone(&b);
let d = Arc::new(Mutex::new(0));
let e = Arc::clone(&d);
let f = async move {
let mut b = b.lock().await;
let mut _e = e.lock().await;
*b *= 3;
};
let g = async move {
let mut c = c.lock().await;
*c *= 5;
};
pin_mut!(f);
pin_mut!(g);
let d = d.try_lock().unwrap();
assert_eq!(*d, 0);
assert_eq!(f.as_mut().poll(&mut cx), Poll::Pending);
assert_eq!(g.as_mut().poll(&mut cx), Poll::Pending);
assert_eq!(COUNTER.0.load(Ordering::SeqCst), 0);
drop(d);
assert_eq!(COUNTER.0.load(Ordering::SeqCst), 1);
assert_eq!(g.as_mut().poll(&mut cx), Poll::Pending);
assert_eq!(f.as_mut().poll(&mut cx), Poll::Ready(()));
assert_eq!(COUNTER.0.load(Ordering::SeqCst), 2);
assert!(!a.waiters.is_empty());
assert_eq!(g.as_mut().poll(&mut cx), Poll::Ready(()));
assert!(a.waiters.is_empty());
assert_eq!(*a.try_lock().unwrap(), 15);
}
#[test]
fn into_inner() {
let m = Mutex::new(NonCopy(10));
assert_eq!(m.into_inner(), NonCopy(10));
}
#[test]
fn into_inner_drop() {
struct Foo(Arc<AtomicUsize>);
impl Drop for Foo {
fn drop(&mut self) {
self.0.fetch_add(1, Ordering::SeqCst);
}
}
let num_drops = Arc::new(AtomicUsize::new(0));
let m = Mutex::new(Foo(num_drops.clone()));
assert_eq!(num_drops.load(Ordering::SeqCst), 0);
{
let _inner = m.into_inner();
assert_eq!(num_drops.load(Ordering::SeqCst), 0);
}
assert_eq!(num_drops.load(Ordering::SeqCst), 1);
}
#[test]
fn get_mut() {
let mut m = Mutex::new(NonCopy(10));
*m.get_mut() = NonCopy(20);
assert_eq!(m.into_inner(), NonCopy(20));
}
#[test]
fn mutex_unsized() {
let mutex: &Mutex<[i32]> = &Mutex::new([1, 2, 3]);
{
let b = &mut *mutex.try_lock().unwrap();
b[0] = 4;
b[2] = 5;
}
let comp: &[i32] = &[4, 2, 5];
assert_eq!(&*mutex.try_lock().unwrap(), comp);
}
}