diff --git a/Cargo.lock b/Cargo.lock index 6e139eb9a..5de4bacdd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1232,6 +1232,7 @@ checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" name = "orpc-macros" version = "0.1.0" dependencies = [ + "heck", "proc-macro2", "quote", "syn 1.0.109", diff --git a/kernel/src/benchmarks/oqueue.rs b/kernel/src/benchmarks/oqueue.rs index d3bcf8555..5e0ac32e6 100644 --- a/kernel/src/benchmarks/oqueue.rs +++ b/kernel/src/benchmarks/oqueue.rs @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MPL-2.0 -// + #![allow(unsafe_code)] use alloc::{ alloc::{alloc, handle_alloc_error}, @@ -26,7 +26,7 @@ use ostd::{ }, sync::Blocker, }, - sync::Waker, + sync::{Waker, WakerKey}, }; use super::{Benchmark, BenchmarkHarness, time, *}; @@ -301,7 +301,11 @@ impl Blocker for RigtorpProducer { true } - fn prepare_to_wait(&self, _waker: &Arc) { + fn enqueue(&self, _waker: &Arc) -> WakerKey { + panic!("!"); + } + + fn remove(&self, _key: ostd::sync::WakerKey) { panic!("!"); } } @@ -328,7 +332,11 @@ impl Blocker for RigtorpConsumer { true } - fn prepare_to_wait(&self, _waker: &Arc) { + fn enqueue(&self, _waker: &Arc) -> WakerKey { + panic!("!"); + } + + fn remove(&self, _key: ostd::sync::WakerKey) { panic!("!"); } } diff --git a/ostd/libs/orpc-macros/Cargo.toml b/ostd/libs/orpc-macros/Cargo.toml index f361a4ae7..7d69aafa4 100644 --- a/ostd/libs/orpc-macros/Cargo.toml +++ b/ostd/libs/orpc-macros/Cargo.toml @@ -10,3 +10,4 @@ proc-macro = true proc-macro2 = "1.0" quote = "1.0" syn = { version = "1.0", features = ["full"] } +heck = "0.5.0" diff --git a/ostd/libs/orpc-macros/src/lib.rs b/ostd/libs/orpc-macros/src/lib.rs index 09275c59c..9049f3ebd 100644 --- a/ostd/libs/orpc-macros/src/lib.rs +++ b/ostd/libs/orpc-macros/src/lib.rs @@ -2,6 +2,7 @@ /// A set of macros for use with ORPC. The most important are the ORPC attribute macros `orpc_trait`, `orpc_server`, and /// `orpc_impl`. The `select` macro (for waiting on multiple OQueues) is also defined here. mod orpc_impl; +mod orpc_monitor; mod orpc_server; mod orpc_trait; mod parsing_utils; @@ -9,7 +10,8 @@ mod select; use proc_macro::TokenStream; use syn::{ - ItemImpl, ItemStruct, ItemTrait, Path, Token, parse_macro_input, punctuated::Punctuated, + ItemImpl, ItemStruct, ItemTrait, Path, Token, Visibility, parse_macro_input, + punctuated::Punctuated, }; /// Declare a trait as an ORPC trait that can be implemented by ORPC server. @@ -130,6 +132,59 @@ pub fn orpc_impl(attr: TokenStream, input: TokenStream) -> TokenStream { output.into() } +/// Declare an [ORPC monitor type](`ostd::orpc::framework::monitor`). This is applied to the `impl` +/// for monitor methods. +/// +/// This will generate the `*Monitor` type and methods on it to call methods and attach methods to +/// OQueues and a `start` method. The `start` method will initialize the monitor, giving it it's +/// initial state and associating it with a server. The default implementation spawns a thread which +/// handles all calls in an event loop. +/// +/// ```ignore +/// pub struct XYZ { +/// x: i32, +/// } +/// +/// #[orpc_monitor(pub)] +/// impl XYZ { +/// #[strong_observer] +/// pub fn update(&mut self, x: i32) -> Result<(), RPCError> { +/// // ... +/// Ok(()) +/// } +/// +/// #[consumer] +/// pub fn next(&mut self, _: ()) -> Result<(), RPCError> { +/// // ... +/// Ok(()) +/// } +/// +/// pub fn get(&mut self) -> Result { +/// Ok(self.x) +/// } +/// } +/// ``` +/// +/// This will generate methods: +/// +/// ```ignore +/// impl TestStateMonitor { +/// pub fn attach_update(&self, attachment: StrongObserver) -> Result<(), AttachmentError>; +/// pub fn update(&self, arg: i32) -> Result<(), RPCError>; +/// pub fn attach_next(&self, attachment: Consumer<()>) -> Result<(), AttachmentError>; +/// pub fn next(&self, arg: ()) -> Result<(), RPCError>; +/// pub fn get(&self) -> Result; +/// fn start(&self, server: Arc, state: TestState); +/// pub fn new() -> Self; +/// } +/// ``` +#[proc_macro_attribute] +pub fn orpc_monitor(arg: TokenStream, input: TokenStream) -> TokenStream { + let input_impl = parse_macro_input!(input as ItemImpl); + let vis = parse_macro_input!(arg as Visibility); + orpc_monitor::orpc_monitor_impl(vis, input_impl).into() +} + // TODO: The select syntax (and name) should be revisited. This provides a good starting point, but it also has some // issues, such as rust-analyzer refactors not working and not having a clean way to match over different message forms. diff --git a/ostd/libs/orpc-macros/src/orpc_monitor.rs b/ostd/libs/orpc-macros/src/orpc_monitor.rs new file mode 100644 index 000000000..3bc24ccfb --- /dev/null +++ b/ostd/libs/orpc-macros/src/orpc_monitor.rs @@ -0,0 +1,473 @@ +// SPDX-License-Identifier: MPL-2.0 + +use heck::ToUpperCamelCase; +use proc_macro2::TokenStream; +use quote::{format_ident, quote, quote_spanned}; +use syn::{Error, Ident, ItemImpl, Type, Visibility, parse_quote, spanned::Spanned as _}; + +/// The kind of attachment which a method should support. +#[derive(PartialEq)] +enum MethodAttachmentType { + StrongObserver, + Consumer, + None, +} + +impl MethodAttachmentType { + fn to_attachment_type_symbol(&self) -> Option { + match self { + MethodAttachmentType::StrongObserver => { + Some(quote! { ::ostd::orpc::oqueue::StrongObserver }) + } + MethodAttachmentType::Consumer => Some(quote! { ::ostd::orpc::oqueue::Consumer }), + MethodAttachmentType::None => None, + } + } +} + +/// The information about a method. +struct MethodDefinition { + definition: syn::ImplItemMethod, + attachment_type: MethodAttachmentType, +} + +impl MethodDefinition { + fn argument_type(&self) -> syn::Result> { + let self_argument = self.definition.sig.inputs.first(); + match self_argument { + Some(syn::FnArg::Receiver(_)) => (), + _ => { + return Err(syn::Error::new( + self_argument.span(), + "Monitor methods must have a self argument.", + )); + } + } + let first_input = self.definition.sig.inputs.iter().nth(1); + match first_input { + Some(syn::FnArg::Typed(pat_type)) => Ok(Some(*pat_type.ty.clone())), + None => Ok(None), + _ => { + // This is an error, but will be caught by the normal compiler. + Ok(Some(parse_quote! { () })) + } + } + } +} + +pub fn orpc_monitor_impl(monitor_vis: Visibility, input_impl: ItemImpl) -> TokenStream { + // A set of errors encountered. + let mut errors = Vec::new(); + + // Information about each method definition + let mut method_definitions = Vec::new(); + + for item in &input_impl.items { + if let syn::ImplItem::Method(method) = item { + // Collect information from the attributes. + let mut is_strong_observer = false; + let mut is_consumer = false; + let mut filtered_attrs = Vec::new(); + + for attr in &method.attrs { + if attr.path.is_ident("strong_observer") { + is_strong_observer = true; + } else if attr.path.is_ident("consumer") { + is_consumer = true; + } else { + filtered_attrs.push(attr.clone()); + } + } + + let kind = match (is_strong_observer, is_consumer) { + (true, false) => MethodAttachmentType::StrongObserver, + (false, true) => MethodAttachmentType::Consumer, + (false, false) => MethodAttachmentType::None, + (true, true) => { + errors.push(Error::new( + method.span(), + "Only a single attachment type can be used at a time. (This restriction may be relaxed in the future.)", + )); + MethodAttachmentType::None + } + }; + + let definition = syn::ImplItemMethod { + attrs: filtered_attrs, + ..method.clone() + }; + + method_definitions.push(MethodDefinition { + definition, + attachment_type: kind, + }); + } else { + errors.push(Error::new(item.span(), "Only methods are allowed in #[orpc_monitor] impl. (You can use another impl on the same type.)")); + } + } + + // The fields of the *Attachment struct + let mut attachment_fields = Vec::new(); + + for method_def in &method_definitions { + let method_name = &method_def.definition.sig.ident; + + match method_def.argument_type() { + Ok(param_type) => { + let param_type = param_type.clone().unwrap_or(parse_quote! {()}); + let field_name = format_ident!("{}_attachment", method_name); + if let Some(attachment_type) = + method_def.attachment_type.to_attachment_type_symbol() + { + attachment_fields.push(quote! { + #field_name: ::core::option::Option<#attachment_type<#param_type>> + }); + } + } + Err(e) => errors.push(e), + } + } + + // The name of the state type + let state_name = match input_impl.self_ty.as_ref() { + Type::Path(syn::TypePath { qself: None, path }) if path.get_ident().is_some() => { + path.get_ident().unwrap().clone() + } + _ => { + errors.push(Error::new( + input_impl.self_ty.span(), + "impl'd type must be referenced by a simple identifier", + )); + format_ident!("__ERROR__") + } + }; + + // The name of monitor type + let monitor_name = format_ident!("{}Monitor", state_name); + // The name of the enum holding the commands + let command_enum_name = format_ident!("{}MonitorCommand", state_name); + // The name of the struct holding the attachments + let attachment_struct_name = format_ident!("{}MonitorAttachments", state_name); + + // The commands in the command enum. + let mut command_variants = Vec::new(); + // The match arms for `Debug` impl for the commands type. This cannot be derived, because it + // needs to ignore the reply channels. + let mut command_debug_match_arms = Vec::new(); + + for method_def in &method_definitions { + let return_type = match &method_def.definition.sig.output { + syn::ReturnType::Default => quote! { () }, + syn::ReturnType::Type(_, ty) => quote! { #ty }, + }; + let method_name = &method_def.definition.sig.ident; + let variant_name = format_ident!("{}", &method_name.to_string().to_upper_camel_case(),); + + if let Ok(param_type) = method_def.argument_type() { + let param_type = param_type.clone().unwrap_or(parse_quote! {()}); + command_variants.push(quote! { + #variant_name(#param_type, ::ostd::orpc::oqueue::ValueProducer<#return_type>), + }); + + if let Some(attachment_type) = method_def.attachment_type.to_attachment_type_symbol() { + let attachment_message_name = format_ident!("Attach{variant_name}"); + command_variants.push(quote! { + #attachment_message_name ( + #attachment_type<#param_type>, + ::ostd::orpc::oqueue::ValueProducer<::core::result::Result<(), ::ostd::orpc::oqueue::AttachmentError>>, + ), + }); + command_debug_match_arms.push(quote! { + #command_enum_name::#attachment_message_name(_, _) => + write!(f, "{}", stringify!(#attachment_message_name)), + }); + } + command_debug_match_arms.push(quote! { + #command_enum_name::#variant_name(arg, _) => + write!(f, "{}({:?})", stringify!(#variant_name), arg), + }); + } + } + + // The methods to include in the monitor `impl` + let mut monitor_methods = Vec::new(); + + for method_def in &method_definitions { + let method_name = &method_def.definition.sig.ident; + let variant_name = format_ident!("{}", method_name.to_string().to_upper_camel_case()); + + let method_attrs = &method_def.definition.attrs; + let method_vis = &method_def.definition.vis; + + if let Ok(param_type) = method_def.argument_type() { + if let Some(attachment_type_base) = + method_def.attachment_type.to_attachment_type_symbol() + { + let attachment_variant_name = format_ident!("Attach{}", variant_name); + let attachment_method_name = + format_ident!("attach_{}", method_name, span = method_name.span()); + let param_type = param_type.clone().unwrap_or(parse_quote! {()}); + let attachment_type = quote! { #attachment_type_base<#param_type> }; + let span = method_def.definition.span(); + let attachment_docs = format!( + " +Configure the monitor to run `{method_name}` to handle values from `attachment`. + +The documentation for [`Self::{method_name}`] is:\n\n", + ); + monitor_methods.push( + quote_spanned! { span => + #[doc = #attachment_docs] + #(#method_attrs)* + #[allow(clippy::allow_attributes)] + #[allow(unused)] + #method_vis fn #attachment_method_name(&self, attachment: #attachment_type) + -> ::core::result::Result<(), ::ostd::orpc::oqueue::AttachmentError> + { + ::ostd::orpc::framework::monitor::synchronous_request( + &self.command_producer, + |reply_producer| #command_enum_name::#attachment_variant_name(attachment, reply_producer) + ) + } + } + ); + } + + let return_type = match &method_def.definition.sig.output { + syn::ReturnType::Default => quote! { () }, + syn::ReturnType::Type(_, ty) => quote! { #ty }, + }; + let (arg_decl, arg_use) = if let Some(param_type) = param_type { + (quote! { arg: #param_type }, quote! { arg }) + } else { + (quote! {}, quote! { () }) + }; + monitor_methods.push(quote_spanned! { method_def.definition.span() => + #(#method_attrs)* + #[allow(clippy::allow_attributes)] + #[allow(unused)] + #method_vis fn #method_name(&self, #arg_decl) -> #return_type { + ::ostd::orpc::framework::monitor::synchronous_request( + &self.command_producer, + |reply_producer| #command_enum_name::#variant_name(#arg_use, reply_producer) + ) + } + }); + } + } + + monitor_methods.push(generate_start_fn( + &method_definitions, + state_name, + &command_enum_name, + &attachment_struct_name, + )); + + let compiler_errors = { + if let Some(mut collected_error) = errors.pop() { + for e in errors { + collected_error.combine(e); + } + collected_error.into_compile_error() + } else { + quote! {} + } + }; + + let impl_without_our_attrs = { + syn::ItemImpl { + items: method_definitions + .iter() + .map(|d| d.definition.clone().into()) + .collect(), + ..input_impl + } + }; + + let expanded = quote! { + #impl_without_our_attrs + + #[doc(hidden)] + #[derive(Default)] + struct #attachment_struct_name { + #(#attachment_fields,)* + } + + #[doc(hidden)] + enum #command_enum_name { + #(#command_variants)* + } + + impl ::core::fmt::Debug for #command_enum_name { + fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { + match self { + #(#command_debug_match_arms)* + } + } + } + + #monitor_vis struct #monitor_name { + command_oqueue: ::ostd::orpc::oqueue::ConsumableOQueueRef<#command_enum_name>, + command_producer: ::ostd::orpc::oqueue::ValueProducer<#command_enum_name>, + } + + impl #monitor_name { + #(#monitor_methods)* + + #monitor_vis fn new() -> Self { + use ::ostd::orpc::oqueue::ConsumableOQueue; + let command_oqueue = ::ostd::orpc::oqueue::ConsumableOQueueRef::new(2); + Self { + command_producer: command_oqueue + .attach_value_producer() + .expect("single purpose OQueue failed."), + command_oqueue, + } + } + } + + impl ::core::default::Default for #monitor_name { + fn default() -> Self { + Self::new() + } + } + + #compiler_errors + }; + + expanded +} + +/// Generate a `start` method for the monitor which spawns a thread to process commands and messages +/// from the attached observers/consumers. +fn generate_start_fn( + method_definitions: &Vec, + state_name: Ident, + command_enum_name: &Ident, + attachment_struct_name: &Ident, +) -> TokenStream { + // The attachment blockers used for blocking the event loop. + let mut attachment_blockers = Vec::new(); + // Arms for a match used to process every command + let mut command_handling_arms = Vec::new(); + // Code blocks to poll and handle messages on attached OQueues. + let mut observe_blocks = Vec::new(); + + for method_def in method_definitions { + let method_name = &method_def.definition.sig.ident; + let field_name = format_ident!("{}_attachment", method_name); + let variant_name = format_ident!("{}", method_name.to_string().to_upper_camel_case()); + let attachment_variant_name = format_ident!("Attach{}", variant_name); + + // If the message allows OQueue attachments, + if method_def.attachment_type != MethodAttachmentType::None { + // The blocker for this method's OQueue + attachment_blockers.push(quote! { + &attachments.#field_name + }); + + match method_def.attachment_type { + MethodAttachmentType::StrongObserver => { + observe_blocks.push({ + quote! { + match attachments + .#field_name + .as_ref() + .map(|a| a.try_strong_observe()) + { + Some(Ok(Some(x))) => { + state.#method_name(x)?; + } + Some(Ok(None)) => {} + Some(e @ Err(_)) => { + ::ostd::ignore_err!( + e, + log::Level::Error, + "Detaching from OQueue due to handler error" + ); + attachments.#field_name = None; + } + None => {} + } + } + }); + } + MethodAttachmentType::Consumer => { + observe_blocks.push({ + quote! { + match attachments + .#field_name + .as_ref() + .map(|a| a.try_consume()) + { + Some(Some(x)) => { + state.#method_name(x)?; + } + _ => {} + } + } + }); + } + MethodAttachmentType::None => {} + } + } + + // Handle method calls. + let (arg_pat, call) = if method_def.definition.sig.inputs.len() > 1 { + (quote! { arg }, quote! { state.#method_name(arg) }) + } else { + (quote! { () }, quote! { state.#method_name() }) + }; + command_handling_arms.push(quote! { + #command_enum_name::#variant_name(#arg_pat, value_producer) => { + value_producer.produce(#call); + } + }); + + // For method which can be attached, handle the attach method. + if method_def.attachment_type != MethodAttachmentType::None { + command_handling_arms.push(quote! { + #command_enum_name::#attachment_variant_name(consumer, value_producer) => { + attachments.#field_name = Some(consumer.into()); + value_producer.produce(Ok(())) + } + }); + } + } + + // The expression holding all blockers for the event loop + let all_blockers = if attachment_blockers.is_empty() { + quote! { &[&command_consumer] } + } else { + quote! { &[&command_consumer, #(#attachment_blockers),*] } + }; + + quote! { + fn start(&self, server: ::alloc::sync::Arc, state: #state_name) { + ::ostd::orpc::framework::spawn_thread(server, { + use ::ostd::orpc::oqueue::ConsumableOQueue; + let command_consumer = self + .command_oqueue + .attach_consumer() + .expect("single purpose OQueue failed."); + move || { + let mut state = state; + let mut attachments = #attachment_struct_name::default(); + loop { + if let Some(c) = ::ostd::task::Task::current() { + c.block_on(#all_blockers); + } else { + ::ostd::task::Task::yield_now(); + } + if let Some(cmd) = command_consumer.try_consume() { + match cmd { + #(#command_handling_arms)* + } + } + #(#observe_blocks)* + } + } + }); + } + } +} diff --git a/ostd/src/orpc/framework/mod.rs b/ostd/src/orpc/framework/mod.rs index 20dfae3da..97e1cf406 100644 --- a/ostd/src/orpc/framework/mod.rs +++ b/ostd/src/orpc/framework/mod.rs @@ -24,12 +24,14 @@ pub mod errors; mod integration_test; +pub mod monitor; pub mod notifier; pub mod shutdown; pub mod threads; use alloc::{sync::Weak, vec::Vec}; use core::{ + any::Any, fmt::Display, num::NonZeroUsize, ops::DerefMut, @@ -47,7 +49,7 @@ use crate::{ }; /// The primary trait for all server. This provides access to information and capabilities common to all servers. -pub trait Server: Sync + Send + 'static { +pub trait Server: Any + Sync + Send + 'static { /// **INTERNAL** User code should never call this directly, however it cannot be private because generated code must /// use it. /// @@ -290,7 +292,7 @@ mod test { use super::*; use crate::{ orpc::{legacy_oqueue::generic_test, sync::Blocker}, - sync::Waker, + sync::{Waker, WakerKey}, }; struct InfiniteBlocker; @@ -300,7 +302,11 @@ mod test { false } - fn prepare_to_wait(&self, _waker: &Arc) {} + fn enqueue(&self, _task: &Arc) -> WakerKey { + WakerKey::default() + } + + fn remove(&self, _key: WakerKey) {} } struct TestServer { diff --git a/ostd/src/orpc/framework/monitor.rs b/ostd/src/orpc/framework/monitor.rs new file mode 100644 index 000000000..fe7dcda92 --- /dev/null +++ b/ostd/src/orpc/framework/monitor.rs @@ -0,0 +1,134 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! ORPC monitors are objects which hold state and have a set of methods which operate on that +//! state. They are **not** servers, and are orthogonal to them. As such, monitors must be inside a +//! server. The standard usage is to have a single monitor holding all the state of the server and +//! forwarding all server methods into that monitor. +//! +//! Monitors are created using the [`orpc_monitor`] macro. + +pub use orpc_macros::orpc_monitor; + +use crate::orpc::oqueue::{ValueProducer, new_reply_pair}; + +/// **INTERNAL FOR MACRO USE ONLY** +/// +/// Make a synchronous request over an OQueue. `make_request` is called with the reply producer +/// attachment to create the request message. The request is sent over `producer`. The consumer of +/// that message should guarantee exactly one reply will be published. +#[doc(hidden)] +pub fn synchronous_request( + producer: &ValueProducer, + make_request: impl FnOnce(ValueProducer) -> C, +) -> R { + let (reply_producer, reply_consumer) = new_reply_pair::(); + producer.produce(make_request(reply_producer)); + reply_consumer.consume() +} + +#[cfg(ktest)] +mod tests { + use orpc_macros::{orpc_monitor, orpc_server}; + + use crate::{ + assert_eq_eventually, + orpc::{ + framework::errors::RPCError, + oqueue::{ + ConsumableOQueue, ConsumableOQueueRef, OQueue, OQueueBase, OQueueRef, + ObservationQuery, + }, + }, + prelude::{Arc, ktest}, + }; + + #[orpc_server()] + struct TestServer { + monitor: TestStateMonitor, + } + + pub struct TestState { + x: i32, + } + + #[orpc_monitor(pub)] + impl TestState { + #[strong_observer] + pub fn update(&mut self, x: i32) -> Result<(), RPCError> { + self.x = (self.x + x * 3) / 4; + Ok(()) + } + + #[consumer] + pub fn next(&mut self, _: ()) -> Result<(), RPCError> { + self.x += 1; + Ok(()) + } + + pub fn get(&mut self) -> Result { + Ok(self.x) + } + } + + fn spawn_server() -> Arc { + let server = TestServer::new_with(|orpc_internal, _| TestServer { + orpc_internal, + monitor: Default::default(), + }); + server.monitor.start(server.clone(), TestState { x: 0 }); + server + } + + #[ktest] + fn monitor_updates_from_strong_observer() { + let values = OQueueRef::new(2); + let server = spawn_server(); + server + .monitor + .attach_update( + values + .attach_strong_observer(ObservationQuery::identity()) + .unwrap(), + ) + .unwrap(); + + let producer = values.attach_ref_producer().unwrap(); + + producer.produce_ref(&0); + producer.produce_ref(&100); + assert_eq_eventually!(server.monitor.get().unwrap(), 75); + + producer.produce_ref(&100); + assert_eq_eventually!(server.monitor.get().unwrap(), 93); + } + + #[ktest] + fn monitor_call() { + let server = spawn_server(); + + server.monitor.update(0).unwrap(); + server.monitor.update(100).unwrap(); + assert_eq_eventually!(server.monitor.get().unwrap(), 75); + + server.monitor.update(100).unwrap(); + assert_eq_eventually!(server.monitor.get().unwrap(), 93); + } + + #[ktest] + fn monitor_updates_from_consumer() { + let values = ConsumableOQueueRef::new(2); + let server = spawn_server(); + server + .monitor + .attach_next(values.attach_consumer().unwrap()) + .unwrap(); + let producer = values.attach_value_producer().unwrap(); + + producer.produce(()); + producer.produce(()); + assert_eq_eventually!(server.monitor.get().unwrap(), 2); + + producer.produce(()); + assert_eq_eventually!(server.monitor.get().unwrap(), 3); + } +} diff --git a/ostd/src/orpc/framework/threads.rs b/ostd/src/orpc/framework/threads.rs index 7f35799e3..4462023bb 100644 --- a/ostd/src/orpc/framework/threads.rs +++ b/ostd/src/orpc/framework/threads.rs @@ -24,8 +24,8 @@ pub(crate) type SpawnThreadFn = fn(Arc, Thre pub(crate) static SPAWN_THREAD_FN: Once = Once::new(); /// Start a new server thread. This should only be called while spawning a server. -pub fn spawn_thread( - server: Arc, +pub fn spawn_thread( + server: Arc, body: impl (FnOnce() -> Result<(), Box>) + Send + 'static, ) { if let Some(spawn_fn) = SPAWN_THREAD_FN.get() { diff --git a/ostd/src/orpc/legacy_oqueue/locking.rs b/ostd/src/orpc/legacy_oqueue/locking.rs index 2cb171b8e..b83b7a3fb 100644 --- a/ostd/src/orpc/legacy_oqueue/locking.rs +++ b/ostd/src/orpc/legacy_oqueue/locking.rs @@ -14,7 +14,7 @@ use super::{ }; use crate::{ prelude::{Arc, Box, Vec}, - sync::{SpinLock, WaitQueue, Waker}, + sync::{SpinLock, WaitQueue, Waker, WakerKey}, task::Task, }; @@ -334,8 +334,12 @@ impl Blocker for LockingProducer { self.oqueue().inner.lock().can_produce().is_some() } - fn prepare_to_wait(&self, waker: &Arc) { - self.oqueue().put_wait_queue.enqueue(waker.clone()); + fn enqueue(&self, waker: &Arc) -> WakerKey { + self.oqueue().put_wait_queue.enqueue(waker.clone()) + } + + fn remove(&self, key: WakerKey) { + self.oqueue().put_wait_queue.remove(key); } } @@ -375,9 +379,13 @@ impl Blocker for LockingConsumer { self.oqueue.inner.lock().can_consume() } - fn prepare_to_wait(&self, waker: &Arc) { + fn enqueue(&self, waker: &Arc) -> WakerKey { self.oqueue.read_wait_queue.enqueue(waker.clone()) } + + fn remove(&self, key: WakerKey) { + self.oqueue.read_wait_queue.remove(key); + } } impl Drop for LockingConsumer { @@ -442,9 +450,13 @@ impl Blocker for CloningLockingConsumer { self.oqueue().inner.lock().can_consume() } - fn prepare_to_wait(&self, waker: &Arc) { + fn enqueue(&self, waker: &Arc) -> WakerKey { self.oqueue().read_wait_queue.enqueue(waker.clone()) } + + fn remove(&self, key: WakerKey) { + self.oqueue().read_wait_queue.remove(key); + } } /// A strong observer for a locking OQueue. This will clone values and works only with [`CloningLockingConsumer`]. @@ -467,8 +479,12 @@ impl Blocker for LockingStrongObserver { self.oqueue().inner.lock().can_strong_observe(self.index) } - fn prepare_to_wait(&self, waker: &Arc) { - self.oqueue().read_wait_queue.enqueue(waker.clone()); + fn enqueue(&self, waker: &Arc) -> WakerKey { + self.oqueue().read_wait_queue.enqueue(waker.clone()) + } + + fn remove(&self, key: WakerKey) { + self.oqueue().read_wait_queue.remove(key); } } @@ -520,9 +536,13 @@ impl Blocker for LockingWeakObserver { self.oqueue().inner.lock().tail_index > self.max_observed_tail.get() } - fn prepare_to_wait(&self, waker: &Arc) { + fn enqueue(&self, waker: &Arc) -> WakerKey { self.oqueue().read_wait_queue.enqueue(waker.clone()) } + + fn remove(&self, key: WakerKey) { + self.oqueue().read_wait_queue.remove(key); + } } impl WeakObserver for LockingWeakObserver { diff --git a/ostd/src/orpc/legacy_oqueue/ringbuffer/mpmc.rs b/ostd/src/orpc/legacy_oqueue/ringbuffer/mpmc.rs index ad92a9c56..c9274b8c9 100644 --- a/ostd/src/orpc/legacy_oqueue/ringbuffer/mpmc.rs +++ b/ostd/src/orpc/legacy_oqueue/ringbuffer/mpmc.rs @@ -28,7 +28,7 @@ use crate::{ Blocker, Consumer, Cursor, OQueue, OQueueAttachError, Producer, StrongObserver, WeakObserver, }, - sync::{Mutex, WaitQueue, Waker}, + sync::{Mutex, WaitQueue, Waker, WakerKey}, task::Task, }; @@ -471,9 +471,13 @@ impl Blocker self.oqueue.size() < self.oqueue.capacity.into() } - fn prepare_to_wait(&self, waker: &Arc) { + fn enqueue(&self, waker: &Arc) -> WakerKey { self.oqueue.put_wait_queue.enqueue(waker.clone()) } + + fn remove(&self, key: WakerKey) { + self.oqueue.put_wait_queue.remove(key); + } } impl Producer @@ -529,9 +533,13 @@ impl Blocker !self.oqueue.empty() } - fn prepare_to_wait(&self, waker: &Arc) { + fn enqueue(&self, waker: &Arc) -> WakerKey { self.oqueue.read_wait_queue.enqueue(waker.clone()) } + + fn remove(&self, key: WakerKey) { + self.oqueue.read_wait_queue.remove(key); + } } impl Consumer @@ -573,9 +581,13 @@ impl Blocker for MPMCStrongObserver< !self.oqueue.empty() } - fn prepare_to_wait(&self, waker: &Arc) { + fn enqueue(&self, waker: &Arc) -> WakerKey { self.oqueue.read_wait_queue.enqueue(waker.clone()) } + + fn remove(&self, key: WakerKey) { + self.oqueue.read_wait_queue.remove(key); + } } impl StrongObserver @@ -608,9 +620,13 @@ impl Blocker for MPMCWeakObserver 0 } - fn prepare_to_wait(&self, waker: &Arc) { + fn enqueue(&self, waker: &Arc) -> WakerKey { self.oqueue.read_wait_queue.enqueue(waker.clone()) } + + fn remove(&self, key: WakerKey) { + self.oqueue.read_wait_queue.remove(key); + } } impl WeakObserver diff --git a/ostd/src/orpc/oqueue/implementation.rs b/ostd/src/orpc/oqueue/implementation.rs index 2bd2e6850..a6467b4f9 100644 --- a/ostd/src/orpc/oqueue/implementation.rs +++ b/ostd/src/orpc/oqueue/implementation.rs @@ -18,6 +18,7 @@ use core::{ marker::PhantomData, }; +use log::warn; use slotmap::{SlotMap, new_key_type}; use snafu::ensure; use static_assertions::assert_obj_safe; @@ -30,7 +31,7 @@ use crate::{ ResourceUnavailableSnafu, single_thread_ring_buffer::RingBuffer, }, }, - sync::{SpinLock, WaitQueue}, + sync::{SpinLock, WaitQueue, WakerKey}, }; new_key_type! { @@ -55,8 +56,8 @@ pub(crate) struct OQueueImplementation { /// The size to use for the consumer and strong-observer ring-buffers. len: usize, supports_consume: bool, - put_wait_queue: WaitQueue, - read_wait_queue: WaitQueue, + pub(super) put_wait_queue: WaitQueue, + pub(super) read_wait_queue: WaitQueue, } impl OQueueImplementation { @@ -64,7 +65,13 @@ impl OQueueImplementation { /// /// * `len` is the ring buffer length used for consumers and strong-observers. /// * `supports_consume` specifies the attachment it allows later. - pub(crate) fn new(len: usize, supports_consume: bool) -> Self { + pub(crate) fn new(mut len: usize, supports_consume: bool) -> Self { + if len < 2 { + warn!( + "Creating an OQueue with length {len} is automatically increased to 2. Ring buffers smaller than 2 are not supported." + ); + len = 2; + } Self { inner: SpinLock::new(OQueueInner { consumer_ring_buffer: Default::default(), @@ -234,6 +241,8 @@ impl OQueueImplementation { fn wrap_closure_ref( f: impl Fn(&T) + Send + 'static, ) -> Box { + // TODO(arthurp): This embeds a detail of ORPC in the middle of the OQueue implementation. It + // also forces this overhead on every closure regardless of it's origin. if let Some(s) = CurrentServer::current_cloned() { let f: Box = Box::new(move |v| { let _ = s.orpc_server_base().call_in_context::<_, RPCError>(|| { @@ -298,7 +307,6 @@ impl OQueueImplementation { } Ok(super::RefProducer { oqueue: self.clone(), - _phantom: PhantomData, }) } } @@ -355,6 +363,15 @@ impl OQueueImplementation { self.read_wait_queue.wait_until(|| self.try_consume()) } + pub(super) fn can_consume(&self) -> bool { + let inner = self.inner.lock(); + inner + .consumer_ring_buffer + .as_ref() + .expect("consume not supported") + .can_get_for_head(0) + } + /// Attempt to consume a value from the consumer ring buffer, taking ownership of the value. pub(super) fn try_consume(&self) -> Option { let mut inner = self.inner.lock(); @@ -382,7 +399,6 @@ impl OQueueImplementation { } Ok(super::ValueProducer { oqueue: self.clone(), - _phantom: PhantomData, }) } @@ -571,6 +587,12 @@ pub(super) trait UntypedOQueueImplementation: Sync + Send + Any { /// Release any resources held by inline observer with the given key. fn detach_inline_strong_observer(&self, inline_observer_id: InlineObserverKey); + fn can_strong_observe(&self, observer_id: ObserverKey) -> bool; + + fn enqueue_read_waker(&self, waker: &Arc) -> WakerKey; + + fn remove_read_waker(&self, key: WakerKey); + /// Copy the next value available to the specified observer into `dest` if it is available. This /// returns `Ok(true)` if the value was copied, `Ok(false)` if there was not value available /// yet, and an error if some other failure happened. @@ -739,4 +761,22 @@ impl UntypedOQueueImplementation for OQueueImplementation bool { + let mut inner = self.inner.lock(); + let ObservationRingBuffer { ring_buffer, .. } = inner + .observer_ring_buffers + .get_mut(observer_id) + .expect("should only be called with an id returned from new_observation_ring_buffer"); + let head_id = 0; + ring_buffer.can_get_for_head(head_id) + } + + fn enqueue_read_waker(&self, waker: &Arc) -> WakerKey { + self.read_wait_queue.enqueue(waker.clone()) + } + + fn remove_read_waker(&self, key: WakerKey) { + self.read_wait_queue.remove(key); + } } diff --git a/ostd/src/orpc/oqueue/mod.rs b/ostd/src/orpc/oqueue/mod.rs index 1cc7d81ae..c99704db3 100644 --- a/ostd/src/orpc/oqueue/mod.rs +++ b/ostd/src/orpc/oqueue/mod.rs @@ -60,13 +60,18 @@ use core::{ mod implementation; pub mod query; mod single_thread_ring_buffer; +mod utils; use ostd_macros::ostd_error; pub use query::ObservationQuery; use snafu::Snafu; +pub use utils::new_reply_pair; use self::implementation::{InlineObserverKey, ObserverKey}; -use crate::sync::SpinLock; +use crate::{ + orpc::sync::Blocker, + sync::{SpinLock, WakerKey}, +}; #[cfg(ktest)] pub(crate) mod generic_test; @@ -292,7 +297,6 @@ impl_oqueue_forward!(OQueueRef, inner, [+ ?Sized]); /// the consumer without copying or cloning. pub struct ValueProducer { oqueue: Arc>, - _phantom: PhantomData>, } impl ValueProducer { @@ -312,7 +316,6 @@ impl ValueProducer { /// There can be no consumers since the message is not moved into the OQueue. pub struct RefProducer { oqueue: Arc>, - _phantom: PhantomData>, } impl RefProducer { @@ -342,6 +345,20 @@ impl Drop for Consumer { } } +impl Blocker for Consumer { + fn should_try(&self) -> bool { + self.oqueue.can_consume() + } + + fn enqueue(&self, waker: &Arc) -> WakerKey { + self.oqueue.read_wait_queue.enqueue(waker.clone()) + } + + fn remove(&self, key: WakerKey) { + self.oqueue.read_wait_queue.remove(key); + } +} + impl Consumer { /// Consume a value from the queue, taking ownership of that value. pub fn consume(&self) -> T { @@ -401,19 +418,33 @@ type ConvertToInlineFn = fn( ) -> Result; /// An attachment to an OQueue which allows observing events from the OQueue. -pub struct StrongObserver { +pub struct StrongObserver { oqueue: Arc, observer_id: ObserverKey, convert_to_inline: ConvertToInlineFn, _phantom: PhantomData>, } -impl Drop for StrongObserver { +impl Drop for StrongObserver { fn drop(&mut self) { self.oqueue.detach_strong_observer(self.observer_id); } } +impl Blocker for StrongObserver { + fn should_try(&self) -> bool { + self.oqueue.can_strong_observe(self.observer_id) + } + + fn enqueue(&self, waker: &Arc) -> WakerKey { + self.oqueue.enqueue_read_waker(waker) + } + + fn remove(&self, key: WakerKey) { + self.oqueue.remove_read_waker(key) + } +} + impl StrongObserver { /// Observe a value from the queue. This value will have been extracted from the message by the /// query provided on attachment. diff --git a/ostd/src/orpc/oqueue/query.rs b/ostd/src/orpc/oqueue/query.rs index 86769b286..3d6e3e463 100644 --- a/ostd/src/orpc/oqueue/query.rs +++ b/ostd/src/orpc/oqueue/query.rs @@ -49,6 +49,17 @@ impl ObservationQuery { } } +impl ObservationQuery { + /// A query which observes the entire message. + /// + /// This is equivalent to `ObservationQuery::new(|x| *x)`, but may be optimized. + pub fn identity() -> Self { + Self { + extractor: Box::new(|x| Some(*x)), + } + } +} + #[cfg(ktest)] mod test { use super::*; diff --git a/ostd/src/orpc/oqueue/single_thread_ring_buffer.rs b/ostd/src/orpc/oqueue/single_thread_ring_buffer.rs index c10fd7d5f..4318b27b1 100644 --- a/ostd/src/orpc/oqueue/single_thread_ring_buffer.rs +++ b/ostd/src/orpc/oqueue/single_thread_ring_buffer.rs @@ -189,8 +189,7 @@ impl RingBuffer { None } - /// Return `Some(slot to read next)` if that head has value in the buffer to read. - pub(super) fn can_get_for_head(&mut self, head_id: usize) -> bool { + pub(super) fn can_get_for_head(&self, head_id: usize) -> bool { self.strong_reader_heads[head_id] != self.tail_index } diff --git a/ostd/src/orpc/oqueue/utils.rs b/ostd/src/orpc/oqueue/utils.rs new file mode 100644 index 000000000..16816b673 --- /dev/null +++ b/ostd/src/orpc/oqueue/utils.rs @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MPL-2.0 + +use crate::orpc::oqueue::{ConsumableOQueue as _, ConsumableOQueueRef, Consumer, ValueProducer}; + +/// Create a producer-consumer pair for a single message. +/// +/// This is for use as a reply channel for commands or calls. The returned producer may panic if the +/// user tried to produce more than once. +pub fn new_reply_pair() -> (ValueProducer, Consumer) { + let reply = ConsumableOQueueRef::new(2); + let reply_consumer = reply + .attach_consumer() + .expect("new reply OQueue always allows consumer"); + let reply_producer = reply + .attach_value_producer() + .expect("new reply OQueue always allows value producer"); + (reply_producer, reply_consumer) +} diff --git a/ostd/src/orpc/sync/mod.rs b/ostd/src/orpc/sync/mod.rs index 320bd8d89..131eb8630 100644 --- a/ostd/src/orpc/sync/mod.rs +++ b/ostd/src/orpc/sync/mod.rs @@ -1,16 +1,14 @@ // SPDX-License-Identifier: MPL-2.0 + //! A trait [`Blocker`] which allows a thread to wait for a wake-up from another thread. The API is designed to allow a //! waiter to wait on multiple blockers at the same time to support [`select!`]. -//! -//! TODO(#73): This needs to be reworked significantly because is forces some rather odd syntax in select and is not as -//! flexible as it should be. pub use orpc_macros::select; use crate::{ orpc::framework::CurrentServer, prelude::Arc, - sync::{Waiter, Waker}, + sync::{Waiter, Waker, WakerKey}, task::{CurrentTask, Task}, }; @@ -23,56 +21,9 @@ use crate::{ // users of the OQueue. HOWEVER, the contention may occur exactly when a wake up is actually required meaning that the // cost may be very low. -// TODO(#73): The current setup for blocking is to include a call to `try_*` and infer the blocker from the consumer. This is -// kind of "magical". It would be better to have a type which encapsulates the async call information (the try function -// and the blocker). This becomes extremely similar to the `Future` types in Rust async. However, it performs no -// computation when the check is made, only doing a few instructions to check if the operation is possible. This is -// critical as the check must occur with the thread in a special scheduling state. - /// Tasks can block waiting for a [`Blocker`] to notify them to retry an action. /// -/// To use this, the scheduler will: -/// -/// 1. Lock the task (spinning) and disable preemption. -/// 2. Use `add_task` to register the task to be awoken when the blocker unblocks. -/// 3. Call `should_try` to check if it should actually block. -/// 4. If the task should try: -/// 1. Unregister the task with `remove_task`. -/// 2. Unlock the task into the running state. -/// -/// If the task should not try: -/// 1. Unlock the task into the blocked state. -/// -/// To block on multiple blockers: -/// -/// 1. Lock the task (spinning) and disable preemption. -/// 2. For each blocker: -/// 1. Use `add_task` to register the task to be awoken when the blocker unblocks. -/// 2. Call `should_try` to check if it should actually block. -/// 4. If the task should try: -/// 1. Unregister the task from all blockers registered so far with `remove_task`. -/// 2. Unlock the task into the running state. -/// -/// If the task should not try: -/// 1. Unlock the task into the blocked state. -/// -/// To wake tasks the blocker will iterate the tasks and for each: (The waker must atomically "take" the list, -/// guanteeing that exactly one waker gets the non-empty list.) -/// -/// 1. Lock the task. (Spinning) -/// 2. Unlock the task into the runnable state and place it into the run queue. -/// -/// As written, this does not allow for `wake_one`, only `wake_all`. This is because we have no way to know if a task -/// will actually "try" the action after it is woken. This could fail because the task could have been woken already and -/// already passed the point where it would perform the check. This could happen in an ABA situation as well, where the -/// thread has blocked again, but waiting for a different blocker. -/// -/// These wait semantics also force every blocker to be checked everytime a task is awoken. This is because multiple -/// wakes could have occurred from different blockers. These is no way to distinguish multiple wakes from a single. -/// -/// NOTE: Many requirements here can be relaxed in cases where there is guaranteed to be only one waker thread or -/// similar limitations. This *may* improve performance, but may not. An obvious case would be single sender queues -/// not requiring an atomic take operation on the wait queue. +/// User code should generally only call [`Blocker::block_until`] or [`CurrentTask::block_on`]. pub trait Blocker { /// Return true if performing the action may succeed and should be attempted. This must be *very* fast and cannot /// block for any condition itself. This is because it will be called inside the scheduler with locks held. @@ -83,17 +34,23 @@ pub trait Blocker { /// This must have Acquire ordering. fn should_try(&self) -> bool; - /// Add a task to the wait queue of `self`. After this call, the task must be awoken if [`Blocker::should_try`] may + /// Add a task to the wait queue of `self`. After this call, the task must be awoken if [`Self::should_try`] may /// return `true` again. /// /// This must have Release ordering. /// - /// This returns an ID which can be passed to [`Blocker::remove_task`] (on the same instance) to improve the + /// This returns an ID which can be passed to [`Self::remove`] (on the same instance) to improve the /// performance of removal. - fn prepare_to_wait(&self, waker: &Arc); + fn enqueue(&self, waker: &Arc) -> WakerKey; + + /// Remove a task from the wait queue of self. After this call, the task will not be awoken when + /// when `should_try` will return true. This undoes the effect of [`Self::enqueue`]. + fn remove(&self, key: WakerKey); /// Block on self repeately until `cond` returns Some. This assumes that this blocker will be woken if `cond()` /// would change. + /// + /// To block on multiple `Blockers`, use [`CurrentTask::block_on`]. fn block_until(&self, cond: impl Fn() -> Option) -> T where Self: Sized, @@ -114,26 +71,57 @@ pub trait Blocker { } } +// Optional Blockers are allowed. If the blocker is `None`, then it will never wake up and performs +// no other handling. +impl Blocker for Option { + fn should_try(&self) -> bool { + self.as_ref().is_some_and(T::should_try) + } + + fn enqueue(&self, waker: &Arc) -> WakerKey { + if let Some(blocker) = self { + blocker.enqueue(waker) + } else { + WakerKey::default() + } + } + + fn remove(&self, key: WakerKey) { + if let Some(blocker) = self { + blocker.remove(key) + } + } +} + impl CurrentTask { - /// Wait for multiple blockers, waking if any wake. + /// Wait for multiple blockers, waking if any wake. This is equivalent to + /// [`Blocker::block_until`] if there is only one blocker. pub fn block_on(&self, blockers: &[&dyn Blocker; N]) { CurrentServer::abort_point(); let (waiter, waker) = Waiter::new_pair(); - if blockers.iter().any(|b| b.should_try()) { - return; + + // 1. Register for all blockers. + let keys = blockers.map(|blocker| blocker.enqueue(&waker)); + + // 2. Check if any of the blockers are actually ready. + if !blockers.iter().any(|b| b.should_try()) { + // 3. Block if no blockers are ready. This will immediately wake if any blocker woke + // between `prepare_to_wait` and here. This prevents this thread from dropping a wake + // between `should_try` and `wait_cancellable`. + waiter.wait_cancellable(|| { + if CurrentServer::is_aborted() { + Err(()) + } else { + Ok(()) + } + }); } - for blocker in blockers.iter() { - blocker.prepare_to_wait(&waker); + // 4. Unregister with all the blockers. This avoids the blocker queues growing without + // bound. + for (blocker, key) in blockers.iter().zip(keys) { + blocker.remove(key); } - // We will never - waiter.wait_cancellable(|| { - if CurrentServer::is_aborted() { - Err(()) - } else { - Ok(()) - } - }); CurrentServer::abort_point(); } diff --git a/ostd/src/stacktrace.rs b/ostd/src/stacktrace.rs index 8e6c1753f..5e3bd8b0f 100644 --- a/ostd/src/stacktrace.rs +++ b/ostd/src/stacktrace.rs @@ -88,11 +88,9 @@ impl CapturedStackTrace { if pc == 0 { return UnwindReasonCode::NORMAL_STOP; } - if data.count >= data.skip { - if data.res.frames.try_push(pc).is_some() { - // Stop if there is no more space available in the frames vec. - return UnwindReasonCode::NORMAL_STOP; - } + if data.count >= data.skip && data.res.frames.try_push(pc).is_some() { + // Stop if there is no more space available in the frames vec. + return UnwindReasonCode::NORMAL_STOP; } data.count += 1; UnwindReasonCode::NO_REASON diff --git a/ostd/src/sync/mod.rs b/ostd/src/sync/mod.rs index 8622423c8..d99f979c4 100644 --- a/ostd/src/sync/mod.rs +++ b/ostd/src/sync/mod.rs @@ -26,7 +26,7 @@ pub use self::{ RwMutexReadGuard, RwMutexUpgradeableGuard, RwMutexWriteGuard, }, spin::{ArcSpinLockGuard, SpinLock, SpinLockGuard}, - wait::{WaitQueue, Waiter, Waker}, + wait::{WaitQueue, Waiter, Waker, WakerKey}, }; pub(crate) fn init() { diff --git a/ostd/src/sync/wait.rs b/ostd/src/sync/wait.rs index 5ba8de8cd..745e9a5db 100644 --- a/ostd/src/sync/wait.rs +++ b/ostd/src/sync/wait.rs @@ -34,6 +34,11 @@ use crate::task::{Task, scheduler}; // Note that dropping a waiter must be treated as a `wait()` with zero timeout, because we need to // make sure that the wake event isn't lost in this case. +// TODO(arthurp): PERFORMANCE: waiters and wait queues are used quite a bit and require allocation +// and some inefficient data structures. Optimizations: use a "wakers" list data structure that +// allow faster removal; create a way to store a Waker without reference counting and generalized +// allocation (maybe generational references into thread local storage). + /// A wait queue. /// /// One may wait on a wait queue to put its executing thread to sleep. @@ -142,10 +147,26 @@ impl WaitQueue { /// Enqueues the input [`Waker`] to the wait queue. #[doc(hidden)] - pub fn enqueue(&self, waker: Arc) { + pub fn enqueue(&self, waker: Arc) -> WakerKey { let mut wakers = self.wakers.lock(); + let key = WakerKey(Some(waker.clone())); wakers.push_back(waker); self.num_wakers.fetch_add(1, Ordering::Acquire); + key + } + + /// Remove a waker from the queue. + #[doc(hidden)] + pub fn remove(&self, key: WakerKey) { + let Some(key_waker) = key.0 else { + return; + }; + + let mut wakers = self.wakers.lock(); + // TODO(arthurp): PERFORMANCE: This is O(n). We may need to optimize this or provide a + // variant of WaitQueue with faster remove. Probably by eliminating wake ordering. + wakers.retain(|w| w.task != key_waker.task); + self.num_wakers.swap(wakers.len() as u32, Ordering::Acquire); } } @@ -167,6 +188,11 @@ pub struct Waiter { impl !Send for Waiter {} impl !Sync for Waiter {} +/// A reference to a Waker in a WaitQueue. This can reference no waker at all. This is used for +/// removing wakers from a wait queue. +#[derive(Default)] +pub struct WakerKey(Option>); + /// A waker that can wake up the associated [`Waiter`]. /// /// A waker can be created by calling [`Waiter::new_pair`]. This method creates an `Arc` that can diff --git a/ostd/src/task/mod.rs b/ostd/src/task/mod.rs index 88379e94e..18a2dcc2e 100644 --- a/ostd/src/task/mod.rs +++ b/ostd/src/task/mod.rs @@ -73,6 +73,17 @@ pub struct Task { server: ForceSync>>>, } +impl PartialEq for Task { + fn eq(&self, other: &Self) -> bool { + let ret = core::ptr::eq(self as *const Self, other as *const Self); + // Check that id equality matches pointer equality. + debug_assert_eq!(self.id == other.id, ret); + ret + } +} + +impl Eq for Task {} + impl core::fmt::Debug for Task { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("Task")