diff --git a/src/sync.rs b/src/sync.rs index 3bcc760..4bb9184 100644 --- a/src/sync.rs +++ b/src/sync.rs @@ -278,3 +278,133 @@ impl Parker { } } } + +#[cfg(feature = "alloc")] +pub mod channel { + use alloc::sync::Arc; + use core::{ + cell::{Cell, UnsafeCell}, + marker::PhantomData, + mem::MaybeUninit, + sync::atomic::{AtomicU32, Ordering}, + }; + + #[repr(C)] + #[derive(Debug)] + struct Channel { + state: AtomicU32, + val: UnsafeCell>, + } + + unsafe impl Send for Channel {} + unsafe impl Sync for Channel {} + + impl Channel { + const OCCUPIED_BIT: u32 = 0b01; + const WAITING_BIT: u32 = 0b10; + + fn new() -> Self { + Self { + state: AtomicU32::new(0), + val: UnsafeCell::new(MaybeUninit::uninit()), + } + } + } + + pub fn channel() -> (Sender, Receiver) { + let channel = Arc::new(Channel::::new()); + let receiver = Receiver(channel.clone(), PhantomData); + let sender = Sender(channel); + (sender, receiver) + } + + #[derive(Debug)] + #[repr(transparent)] + // `PhantomData>` is used to ensure that `Receiver` is `!Sync` but `Send`. + pub struct Receiver(Arc>, PhantomData>); + + #[derive(Debug)] + #[repr(transparent)] + pub struct Sender(Arc>); + + impl Receiver { + pub fn is_empty(&self) -> bool { + self.0.state.load(Ordering::Acquire) & Channel::::OCCUPIED_BIT == 0 + } + + pub fn as_sender(self) -> Sender { + Sender(self.0.clone()) + } + + fn wait(&mut self) { + loop { + let state = self + .0 + .state + .fetch_or(Channel::::WAITING_BIT, Ordering::Acquire); + if state & Channel::::OCCUPIED_BIT == 0 { + // The channel is empty, so we need to wait for a value to be sent. + // We will block until the sender wakes us up. + atomic_wait::wait(&self.0.state, Channel::::WAITING_BIT); + } else { + // The channel is occupied, so we can return. + self.0 + .state + .fetch_and(!Channel::::WAITING_BIT, Ordering::Release); + break; + } + } + } + + /// Takes the value from the channel, if it is present. + fn take(&mut self) -> Option { + // unset the OCCUPIED_BIT to indicate that we are taking the value, if any is present. + if self + .0 + .state + .fetch_and(!Channel::::OCCUPIED_BIT, Ordering::Acquire) + & Channel::::OCCUPIED_BIT + == 0 + { + // The channel was empty, so we return None. + None + } else { + unsafe { Some(self.0.val.get().read().assume_init_read()) } + } + } + + pub fn recv(mut self) -> T { + loop { + if let Some(t) = self.take() { + return t; + } + + self.wait(); + } + } + } + + impl Sender { + pub fn send(self, value: T) { + unsafe { + self.0.val.get().write(MaybeUninit::new(value)); + } + + // Set the OCCUPIED_BIT to indicate that a value is present. + let state = self + .0 + .state + .fetch_or(Channel::::OCCUPIED_BIT, Ordering::Release); + assert!( + state & Channel::::OCCUPIED_BIT == 0, + "Channel is already occupied" + ); + + // If there are any receivers waiting, we need to wake them up. + if state & Channel::::WAITING_BIT != 0 { + // There are receivers waiting, so we need to wake them up. + atomic_wait::wake_all(&self.0.state); + } + } + } +}