diff --git a/Cargo.lock b/Cargo.lock index aef6c60ec2..428e2ce1c0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -253,6 +253,15 @@ dependencies = [ "memory_addr", ] +[[package]] +name = "axasync-std" +version = "0.2.0" +dependencies = [ + "arceos_api", + "axerrno", + "axio", +] + [[package]] name = "axconfig" version = "0.2.0" @@ -1262,6 +1271,15 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hybrid-task-schedule" +version = "0.2.0" +dependencies = [ + "axasync-std", + "axstd", + "rand", +] + [[package]] name = "iana-time-zone" version = "0.1.63" diff --git a/Cargo.toml b/Cargo.toml index 18813c573a..8ae6165e01 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ members = [ "ulib/axstd", "ulib/axlibc", + "ulib/axasync-std", "examples/helloworld", "examples/helloworld-myplat", @@ -30,6 +31,7 @@ members = [ "examples/httpserver", "examples/httpserver", "examples/shell", + "examples/hybrid-task-schedule", ] [workspace.package] @@ -45,6 +47,7 @@ categories = ["os", "no-std"] [workspace.dependencies] axstd = { path = "ulib/axstd" } +axasync-std = { path = "ulib/axasync-std" } axlibc = { path = "ulib/axlibc" } arceos_api = { path = "api/arceos_api" } diff --git a/api/arceos_api/src/imp/task.rs b/api/arceos_api/src/imp/task.rs index 2ad194cb7b..624f23ce4c 100644 --- a/api/arceos_api/src/imp/task.rs +++ b/api/arceos_api/src/imp/task.rs @@ -5,6 +5,13 @@ pub fn ax_sleep_until(deadline: crate::time::AxTimeValue) { axhal::time::busy_wait_until(deadline); } +pub async fn ax_sleep_until_f(deadline: crate::time::AxTimeValue) { + #[cfg(feature = "multitask")] + axtask::sleep_until_f(deadline).await; + #[cfg(not(feature = "multitask"))] + axhal::time::busy_wait_until(deadline); +} + pub fn ax_yield_now() { #[cfg(feature = "multitask")] axtask::yield_now(); @@ -16,6 +23,17 @@ pub fn ax_yield_now() { } } +pub async fn ax_yield_now_f() { + #[cfg(feature = "multitask")] + axtask::yield_now_f().await; + #[cfg(not(feature = "multitask"))] + if cfg!(feature = "irq") { + axhal::arch::wait_for_irqs(); + } else { + core::hint::spin_loop(); + } +} + pub fn ax_exit(_exit_code: i32) -> ! { #[cfg(feature = "multitask")] axtask::exit(_exit_code); @@ -23,6 +41,13 @@ pub fn ax_exit(_exit_code: i32) -> ! { crate::sys::ax_terminate(); } +pub async fn ax_exit_f(_exit_code: i32) { + #[cfg(feature = "multitask")] + axtask::exit_f(_exit_code).await; + #[cfg(not(feature = "multitask"))] + axhal::misc::terminate(); +} + cfg_task! { use core::time::Duration; @@ -70,10 +95,25 @@ cfg_task! { } } + pub fn ax_spawn_f(f: F, name: alloc::string::String) -> AxTaskHandle + where + F: core::future::Future + Send + 'static, + { + let inner = axtask::spawn_raw_f(f, name); + AxTaskHandle { + id: inner.id().as_u64(), + inner, + } + } + pub fn ax_wait_for_exit(task: AxTaskHandle) -> Option { task.inner.join() } + pub async fn ax_wait_for_exit_f(task: AxTaskHandle) -> Option { + task.inner.join_f().await + } + pub fn ax_set_current_priority(prio: isize) -> crate::AxResult { if axtask::set_priority(prio) { Ok(()) @@ -109,6 +149,19 @@ cfg_task! { false } + pub async fn ax_wait_queue_wait_f(wq: &AxWaitQueueHandle, timeout: Option) -> bool { + #[cfg(feature = "irq")] + if let Some(dur) = timeout { + return wq.0.wait_timeout_f(dur).await; + } + + if timeout.is_some() { + axlog::warn!("ax_wait_queue_wait: the `timeout` argument is ignored without the `irq` feature"); + } + wq.0.wait_f().await; + false + } + pub fn ax_wait_queue_wait_until( wq: &AxWaitQueueHandle, until_condition: impl Fn() -> bool, @@ -126,6 +179,23 @@ cfg_task! { false } + pub async fn ax_wait_queue_wait_until_f( + wq: &AxWaitQueueHandle, + until_condition: impl Fn() -> bool, + timeout: Option, + ) -> bool { + #[cfg(feature = "irq")] + if let Some(dur) = timeout { + return wq.0.wait_timeout_until_f(dur, until_condition).await; + } + + if timeout.is_some() { + axlog::warn!("ax_wait_queue_wait_until: the `timeout` argument is ignored without the `irq` feature"); + } + wq.0.wait_until_f(until_condition).await; + false + } + pub fn ax_wait_queue_wake(wq: &AxWaitQueueHandle, count: u32) { if count == u32::MAX { wq.0.notify_all(true); diff --git a/api/arceos_api/src/lib.rs b/api/arceos_api/src/lib.rs index 3e6d291870..aa3841b908 100644 --- a/api/arceos_api/src/lib.rs +++ b/api/arceos_api/src/lib.rs @@ -140,9 +140,50 @@ pub mod task { pub fn ax_exit(exit_code: i32) -> !; } + define_api! { + /// Current coroutine task is going to sleep, it will be woken up at the given deadline. + /// + /// If the feature `multitask` is not enabled, it uses busy-wait instead + pub async fn ax_sleep_until_f(deadline: crate::time::AxTimeValue); + /// Current coroutine task gives up the CPU time voluntarily, and switches to another + /// ready task. + /// + /// If the feature `multitask` is not enabled, it does nothing. + pub async fn ax_yield_now_f(); + + /// Exits the current coroutine task with the given exit code. + pub async fn ax_exit_f(exit_code: i32); + } + + define_api! { + @cfg "multitask"; + + /// Waits for the given task to exit, and returns its exit code (the + /// argument of [`ax_exit`]). + pub async fn ax_wait_for_exit_f(task: AxTaskHandle) -> Option; + /// Blocks the current task and put it into the wait queue, until + /// other tasks notify the wait queue, or the the given duration has + /// elapsed (if specified). + pub async fn ax_wait_queue_wait_f(wq: &AxWaitQueueHandle, timeout: Option) -> bool; + /// Blocks the current task and put it into the wait queue, until the + /// given condition becomes true, or the the given duration has elapsed + /// (if specified). + pub async fn ax_wait_queue_wait_until_f( + wq: &AxWaitQueueHandle, + until_condition: impl Fn() -> bool, + timeout: Option, + ) -> bool; + } + define_api! { @cfg "multitask"; + /// Spawns a new task with the given entry point and other arguments. + pub fn ax_spawn_f( + f: impl core::future::Future + Send + 'static, + name: alloc::string::String, + ) -> AxTaskHandle; + /// Returns the current task's ID. pub fn ax_current_task_id() -> u64; /// Spawns a new task with the given entry point and other arguments. diff --git a/api/arceos_api/src/macros.rs b/api/arceos_api/src/macros.rs index ae96825967..9c261bcacf 100644 --- a/api/arceos_api/src/macros.rs +++ b/api/arceos_api/src/macros.rs @@ -20,6 +20,14 @@ macro_rules! define_api_type { } macro_rules! define_api { + ($( $(#[$attr:meta])* $vis:vis async fn $name:ident( $($arg:ident : $type:ty),* $(,)? ) $( -> $ret:ty )? ; )+) => { + $( + $(#[$attr])* + $vis async fn $name( $($arg : $type),* ) $( -> $ret )? { + $crate::imp::$name( $($arg),* ).await + } + )+ + }; ($( $(#[$attr:meta])* $vis:vis fn $name:ident( $($arg:ident : $type:ty),* $(,)? ) $( -> $ret:ty )? ; )+) => { $( $(#[$attr])* @@ -55,6 +63,25 @@ macro_rules! define_api { } )+ }; + ( + @cfg $feature:literal; + $( $(#[$attr:meta])* $vis:vis async fn $name:ident( $($arg:ident : $type:ty),* $(,)? ) $( -> $ret:ty )? ; )+ + ) => { + $( + #[cfg(feature = $feature)] + $(#[$attr])* + $vis async fn $name( $($arg : $type),* ) $( -> $ret )? { + $crate::imp::$name( $($arg),* ).await + } + + #[allow(unused_variables)] + #[cfg(all(feature = "dummy-if-not-enabled", not(feature = $feature)))] + $(#[$attr])* + $vis async fn $name( $($arg : $type),* ) $( -> $ret )? { + unimplemented!(stringify!($name)) + } + )+ + }; ( @cfg $feature:literal; $( $(#[$attr:meta])* $vis:vis unsafe fn $name:ident( $($arg:ident : $type:ty),* $(,)? ) $( -> $ret:ty )? ; )+ diff --git a/examples/hybrid-task-schedule/Cargo.toml b/examples/hybrid-task-schedule/Cargo.toml new file mode 100644 index 0000000000..2605ae0a12 --- /dev/null +++ b/examples/hybrid-task-schedule/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "hybrid-task-schedule" +version.workspace = true +edition.workspace = true +authors.workspace = true +license.workspace = true +homepage.workspace = true +documentation.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true + +[dependencies] +axstd = { workspace = true, features = [ + "alloc", + "multitask", + "net", + "irq", +], optional = true } +axasync-std = { workspace = true } +rand = { version = "0.9.1", default-features = false, features = ["small_rng"] } diff --git a/examples/hybrid-task-schedule/src/main.rs b/examples/hybrid-task-schedule/src/main.rs new file mode 100644 index 0000000000..a91e8be25a --- /dev/null +++ b/examples/hybrid-task-schedule/src/main.rs @@ -0,0 +1,137 @@ +//! Hybrid task scheduling examples + +#![cfg_attr(feature = "axstd", no_std)] +#![cfg_attr(feature = "axstd", no_main)] + +#[macro_use] +#[cfg(feature = "axstd")] +extern crate axstd as std; + +extern crate axasync_std as async_std; + +use rand::{RngCore, SeedableRng, rngs::SmallRng}; +use std::thread; +use std::{sync::Arc, vec::Vec}; + +#[cfg(feature = "axstd")] +use std::os::arceos::api::task::{self as api, AxWaitQueueHandle}; + +const NUM_DATA: usize = 2_000_000; +const NUM_TASKS: usize = 100; + +#[cfg(feature = "axstd")] +fn barrier() { + use std::sync::atomic::{AtomicUsize, Ordering}; + static BARRIER_WQ: AxWaitQueueHandle = AxWaitQueueHandle::new(); + static BARRIER_COUNT: AtomicUsize = AtomicUsize::new(0); + + BARRIER_COUNT.fetch_add(1, Ordering::Relaxed); + api::ax_wait_queue_wait_until( + &BARRIER_WQ, + || BARRIER_COUNT.load(Ordering::Relaxed) == NUM_TASKS, + None, + ); + api::ax_wait_queue_wake(&BARRIER_WQ, u32::MAX); // wakeup all +} + +#[cfg(feature = "axstd")] +async fn barrier_f() { + use std::sync::atomic::{AtomicUsize, Ordering}; + static BARRIER_WQ: AxWaitQueueHandle = AxWaitQueueHandle::new(); + static BARRIER_COUNT: AtomicUsize = AtomicUsize::new(0); + + BARRIER_COUNT.fetch_add(1, Ordering::Relaxed); + api::ax_wait_queue_wait_until_f( + &BARRIER_WQ, + || BARRIER_COUNT.load(Ordering::Relaxed) == NUM_TASKS, + None, + ) + .await; + api::ax_wait_queue_wake(&BARRIER_WQ, u32::MAX); // wakeup all +} + +#[cfg(not(feature = "axstd"))] +fn barrier() { + use std::sync::{Barrier, OnceLock}; + static BARRIER: OnceLock = OnceLock::new(); + BARRIER.get_or_init(|| Barrier::new(NUM_TASKS)).wait(); +} + +fn sqrt(n: &u64) -> u64 { + let mut x = *n; + loop { + if x * x <= *n && (x + 1) * (x + 1) > *n { + return x; + } + x = (x + *n / x) / 2; + } +} + +#[cfg_attr(feature = "axstd", unsafe(no_mangle))] +fn main() { + let mut rng = SmallRng::seed_from_u64(0xdead_beef); + let vec = Arc::new( + (0..NUM_DATA) + .map(|_| rng.next_u32() as u64) + .collect::>(), + ); + let expect: u64 = vec.iter().map(sqrt).sum(); + + let mut tasks = Vec::with_capacity(NUM_TASKS); + for i in 0..NUM_TASKS { + let vec = vec.clone(); + tasks.push(async_std::task::spawn(move || async move { + let left = i * (NUM_DATA / NUM_TASKS); + let right = (left + (NUM_DATA / NUM_TASKS)).min(NUM_DATA); + println!( + "part {}: {:?} [{}, {})", + i, + thread::current().id(), + left, + right + ); + + async_std::task::spawn(|| async { + println!("spawn a thread"); + }) + .join() + .unwrap(); + + let partial_sum: u64 = vec[left..right].iter().map(sqrt).sum(); + barrier(); + async_std::task::yield_now().await; + #[cfg(feature = "axstd")] + barrier_f().await; + async_std::task::sleep(core::time::Duration::from_millis(1)).await; + + println!("part {}: {:?} finished", i, thread::current().id()); + partial_sum + })); + } + + let actual = tasks.into_iter().map(|t| t.join().unwrap()).sum(); + println!("sum = {}", actual); + assert_eq!(expect, actual); + + println!("Parallel summation tests run OK!"); + async_std::block_on! {hello_world()}; + async_std::callasync! {test()}; +} + +async fn hello_world() { + println!("hello world!"); +} + +async fn test() -> i32 { + let mut flag = false; + core::future::poll_fn(|_cx| { + if !flag { + flag = true; + core::task::Poll::Pending + } else { + core::task::Poll::Ready(()) + } + }) + .await; + 43 +} diff --git a/modules/axtask/src/api.rs b/modules/axtask/src/api.rs index 1e3fdca726..a0b3fd5e51 100644 --- a/modules/axtask/src/api.rs +++ b/modules/axtask/src/api.rs @@ -116,6 +116,16 @@ where spawn_task(TaskInner::new(f, name, stack_size)) } +/// Spawns a new coroutine task with the given future and name. +/// +/// Returns the task reference. +pub fn spawn_raw_f(f: F, name: String) -> AxTaskRef +where + F: Future + Send + 'static, +{ + spawn_task(TaskInner::new_f(f, name)) +} + /// Spawns a new task with the default parameters. /// /// The default task name is an empty string. The default task stack size is @@ -129,6 +139,18 @@ where spawn_raw(f, "".into(), axconfig::TASK_STACK_SIZE) } +/// Spawns a new coroutine task with the default parameters. +/// +/// The default task name is an empty string. +/// +/// Returns the task reference. +pub fn spawn_f(f: F) -> AxTaskRef +where + F: Future + Send + 'static, +{ + spawn_raw_f(f, "".into()) +} + /// Set the priority for current task. /// /// The range of the priority is dependent on the underlying scheduler. For @@ -185,6 +207,20 @@ pub fn yield_now() { current_run_queue::().yield_current() } +/// Current coroutine task gives up the CPU time voluntarily, and switches to another +/// ready task. +#[inline] +pub async fn yield_now_f() { + crate::run_queue::YieldFuture::::new().await; +} + +/// Current coroutine task is going to sleep for the given duration. +/// +/// If the feature `irq` is not enabled, it uses busy-wait instead. +pub async fn sleep_f(dur: core::time::Duration) { + sleep_until_f(axhal::time::wall_time() + dur).await; +} + /// Current task is going to sleep for the given duration. /// /// If the feature `irq` is not enabled, it uses busy-wait instead. @@ -202,11 +238,26 @@ pub fn sleep_until(deadline: axhal::time::TimeValue) { axhal::time::busy_wait_until(deadline); } +/// Current coroutine task is going to sleep, it will be woken up at the given deadline. +/// +/// If the feature `irq` is not enabled, it uses busy-wait instead. +pub async fn sleep_until_f(deadline: axhal::time::TimeValue) { + #[cfg(feature = "irq")] + crate::run_queue::SleepUntilFuture::::new(deadline).await; + #[cfg(not(feature = "irq"))] + axhal::time::busy_wait_until(deadline); +} + /// Exits the current task. pub fn exit(exit_code: i32) -> ! { current_run_queue::().exit_current(exit_code) } +/// Exits the current coroutine task. +pub async fn exit_f(exit_code: i32) { + crate::run_queue::ExitFuture::::new(exit_code).await; +} + /// The idle task routine. /// /// It runs an infinite loop that keeps calling [`yield_now()`]. diff --git a/modules/axtask/src/run_queue.rs b/modules/axtask/src/run_queue.rs index 9d70918d95..11319f4625 100644 --- a/modules/axtask/src/run_queue.rs +++ b/modules/axtask/src/run_queue.rs @@ -1,10 +1,13 @@ use alloc::collections::VecDeque; use alloc::sync::Arc; -use core::mem::MaybeUninit; - #[cfg(feature = "smp")] use alloc::sync::Weak; +use core::future::Future; +use core::mem::MaybeUninit; +use core::pin::Pin; +use core::task::{Context, Poll}; + use axsched::BaseScheduler; use kernel_guard::BaseGuard; use kspin::SpinRaw; @@ -548,6 +551,8 @@ impl AxRunQueue { unsafe { let prev_ctx_ptr = prev_task.ctx_mut_ptr(); let next_ctx_ptr = next_task.ctx_mut_ptr(); + // If the next task is a coroutine, this will set the kstack and ctx. + next_task.set_kstack(); // Store the weak pointer of **prev_task** in percpu variable `PREV_TASK`. #[cfg(feature = "smp")] @@ -570,6 +575,67 @@ impl AxRunQueue { clear_prev_task_on_cpu(); } } + + /// Core reschedule subroutine. + /// Pick the next task to run and switch to it. + /// This function is only used in `YieldFuture`, `ExitFuture`, + /// `SleepUntilFuture` and `BlockedReschedFuture`. + fn resched_f(&mut self) -> Poll<()> { + let next_task = self + .scheduler + .lock() + .pick_next_task() + .unwrap_or_else(|| unsafe { + // Safety: IRQs must be disabled at this time. + IDLE_TASK.current_ref_raw().get_unchecked().clone() + }); + assert!( + next_task.is_ready(), + "next {} is not ready: {:?}", + next_task.id_name(), + next_task.state() + ); + let prev_task = crate::current(); + // Make sure that IRQs are disabled by kernel guard or other means. + #[cfg(all(not(test), feature = "irq"))] // Note: irq is faked under unit tests. + assert!( + !axhal::asm::irqs_enabled(), + "IRQs must be disabled during scheduling" + ); + trace!( + "context switch: {} -> {}", + prev_task.id_name(), + next_task.id_name() + ); + #[cfg(feature = "preempt")] + next_task.set_preempt_pending(false); + next_task.set_state(TaskState::Running); + if prev_task.ptr_eq(&next_task) { + return Poll::Ready(()); + } + + // Claim the task as running, we do this before switching to it + // such that any running task will have this set. + #[cfg(feature = "smp")] + next_task.set_on_cpu(true); + + unsafe { + // Store the weak pointer of **prev_task** in percpu variable `PREV_TASK`. + #[cfg(feature = "smp")] + { + *PREV_TASK.current_ref_mut_raw() = Arc::downgrade(prev_task.as_task_ref()); + } + + // The strong reference count of `prev_task` will be decremented by 1, + // but won't be dropped until `gc_entry()` is called. + assert!(Arc::strong_count(prev_task.as_task_ref()) > 1); + assert!(Arc::strong_count(&next_task) >= 1); + + // Directly change the `CurrentTask` and return `Pending`. + CurrentTask::set_current(prev_task, next_task); + Poll::Pending + } + } } fn gc_entry() { @@ -664,3 +730,252 @@ pub(crate) fn init_secondary() { RUN_QUEUES[cpu_id].write(RUN_QUEUE.current_ref_mut_raw()); } } + +/// The `YieldFuture` used when yielding the current task and reschedule. +/// When polling this future, the current task will be put into the run queue +/// with `Ready` state and reschedule to the next task on the run queue. +/// +/// The polling operation is as the same as the +/// `current_run_queue::().yield_current()` function. +/// +/// SAFETY: +/// Due to this future is constructed with `current_run_queue::()`, +/// the operation about manipulating the RunQueue and the switching to next task is +/// safe(The `IRQ` and `Preempt` are disabled). +pub(crate) struct YieldFuture<'a, G: BaseGuard> { + current_run_queue: CurrentRunQueueRef<'a, G>, + flag: bool, +} + +impl<'a, G: BaseGuard> YieldFuture<'a, G> { + pub(crate) fn new() -> Self { + Self { + current_run_queue: current_run_queue::(), + flag: false, + } + } +} + +impl<'a, G: BaseGuard> Unpin for YieldFuture<'a, G> {} + +impl<'a, G: BaseGuard> Future for YieldFuture<'a, G> { + type Output = (); + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + let Self { + current_run_queue, + flag, + } = self.get_mut(); + if !(*flag) { + *flag = !*flag; + let curr = ¤t_run_queue.current_task; + trace!("task yield: {}", curr.id_name()); + assert!(curr.is_running()); + current_run_queue + .inner + .put_task_with_state(curr.clone(), TaskState::Running, false); + current_run_queue.inner.resched_f() + } else { + Poll::Ready(()) + } + } +} + +/// Due not manually release the `current_run_queue.state`, +/// otherwise it will cause double release. +impl<'a, G: BaseGuard> Drop for YieldFuture<'a, G> { + fn drop(&mut self) {} +} + +/// The `ExitFuture` used when exiting the current task +/// with the specified exit code, which is always return `Poll::Pending`. +/// +/// The polling operation is as the same as the +/// `current_run_queue::().exit_current()` function. +/// +/// SAFETY: as the same as the `YieldFuture`. However, It wrap the `CurrentRunQueueRef` +/// with `ManuallyDrop`, otherwise the `IRQ` and `Preempt` state of other +/// tasks(maybe `main` or `gc` task) which recycle the exited task(which used this future) +/// will be error due to automatically drop the `CurrentRunQueueRef. +/// The `CurrentRunQueueRef` should never be drop. +pub(crate) struct ExitFuture<'a, G: BaseGuard> { + current_run_queue: core::mem::ManuallyDrop>, + exit_code: i32, +} + +impl<'a, G: BaseGuard> ExitFuture<'a, G> { + pub(crate) fn new(exit_code: i32) -> Self { + Self { + current_run_queue: core::mem::ManuallyDrop::new(current_run_queue::()), + exit_code, + } + } +} + +impl<'a, G: BaseGuard> Unpin for ExitFuture<'a, G> {} + +impl<'a, G: BaseGuard> Future for ExitFuture<'a, G> { + type Output = (); + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + let Self { + current_run_queue, + exit_code, + } = self.get_mut(); + let exit_code = *exit_code; + let curr = ¤t_run_queue.current_task; + debug!("task exit: {}, exit_code={}", curr.id_name(), exit_code); + assert!(curr.is_running(), "task is not running: {:?}", curr.state()); + assert!(!curr.is_idle()); + curr.set_state(TaskState::Exited); + + // Notify the joiner task. + curr.notify_exit(exit_code); + + // Safety: it is called from `current_run_queue::().exit_current(exit_code)`, + // which disabled IRQs and preemption. + unsafe { + // Push current task to the `EXITED_TASKS` list, which will be consumed by the GC task. + EXITED_TASKS.current_ref_mut_raw().push_back(curr.clone()); + // Wake up the GC task to drop the exited tasks. + WAIT_FOR_EXIT.current_ref_mut_raw().notify_one(false); + } + + assert!(current_run_queue.inner.resched_f().is_pending()); + Poll::Pending + } +} + +#[cfg(feature = "irq")] +pub(crate) struct SleepUntilFuture<'a, G: BaseGuard> { + current_run_queue: CurrentRunQueueRef<'a, G>, + deadline: axhal::time::TimeValue, + flag: bool, +} + +#[cfg(feature = "irq")] +impl<'a, G: BaseGuard> SleepUntilFuture<'a, G> { + pub fn new(deadline: axhal::time::TimeValue) -> Self { + Self { + current_run_queue: current_run_queue::(), + deadline, + flag: false, + } + } +} + +#[cfg(feature = "irq")] +impl<'a, G: BaseGuard> Unpin for SleepUntilFuture<'a, G> {} + +#[cfg(feature = "irq")] +impl<'a, G: BaseGuard> Future for SleepUntilFuture<'a, G> { + type Output = (); + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + let Self { + current_run_queue, + deadline, + flag, + } = self.get_mut(); + if !(*flag) { + *flag = !*flag; + let deadline = *deadline; + let curr = ¤t_run_queue.current_task; + debug!("task sleep: {}, deadline={:?}", curr.id_name(), deadline); + assert!(curr.is_running()); + assert!(!curr.is_idle()); + + let now = axhal::time::wall_time(); + if now < deadline { + crate::timers::set_alarm_wakeup(deadline, curr.clone()); + curr.set_state(TaskState::Blocked); + assert!(current_run_queue.inner.resched_f().is_pending()); + Poll::Pending + } else { + Poll::Ready(()) + } + } else { + Poll::Ready(()) + } + } +} + +#[cfg(feature = "irq")] +impl<'a, G: BaseGuard> Drop for SleepUntilFuture<'a, G> { + fn drop(&mut self) {} +} + +/// The `BlockedReschedFuture` used when blocking the current task. +/// +/// When polling this future, current task will be put into the wait queue and reschedule, +/// the state of current task will be marked as `Blocked`, set the `in_wait_queue` flag as true. +/// Note: +/// 1. When polling this future, the wait queue is locked. +/// 2. When polling this future, the current task is in the running state. +/// 3. When polling this future, the current task is not the idle task. +/// 4. The lock of the wait queue will be released explicitly after current task is pushed into it. +/// +/// SAFETY: +/// as the same as the `YieldFuture`. Due to the `WaitQueueGuard` is not implemented +/// the `Send` trait, this future must hold the reference about the `WaitQueue` instead +/// of the `WaitQueueGuard`. +pub(crate) struct BlockedReschedFuture<'a, G: BaseGuard> { + current_run_queue: CurrentRunQueueRef<'a, G>, + wq: &'a WaitQueue, + flag: bool, +} + +impl<'a, G: BaseGuard> BlockedReschedFuture<'a, G> { + pub fn new(current_run_queue: CurrentRunQueueRef<'a, G>, wq: &'a WaitQueue) -> Self { + Self { + current_run_queue, + wq, + flag: false, + } + } +} + +impl<'a, G: BaseGuard> Unpin for BlockedReschedFuture<'a, G> {} + +impl<'a, G: BaseGuard> Future for BlockedReschedFuture<'a, G> { + type Output = (); + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + let Self { + current_run_queue, + wq, + flag, + } = self.get_mut(); + if !(*flag) { + *flag = !*flag; + let mut wq_guard = wq.queue.lock(); + let curr = ¤t_run_queue.current_task; + assert!(curr.is_running()); + assert!(!curr.is_idle()); + // we must not block current task with preemption disabled. + // Current expected preempt count is 2. + // 1 for `NoPreemptIrqSave`, 1 for wait queue's `SpinNoIrq`. + #[cfg(feature = "preempt")] + assert!(curr.can_preempt(2)); + + // Mark the task as blocked, this has to be done before adding it to the wait queue + // while holding the lock of the wait queue. + curr.set_state(TaskState::Blocked); + curr.set_in_wait_queue(true); + + wq_guard.push_back(curr.clone()); + // Drop the lock of wait queue explictly. + drop(wq_guard); + + // Current task's state has been changed to `Blocked` and added to the wait queue. + // Note that the state may have been set as `Ready` in `unblock_task()`, + // see `unblock_task()` for details. + + debug!("task block: {}", curr.id_name()); + assert!(current_run_queue.inner.resched_f().is_pending()); + Poll::Pending + } else { + Poll::Ready(()) + } + } +} + +impl<'a, G: BaseGuard> Drop for BlockedReschedFuture<'a, G> { + fn drop(&mut self) {} +} diff --git a/modules/axtask/src/task.rs b/modules/axtask/src/task.rs index 1a2926218d..5496208b05 100644 --- a/modules/axtask/src/task.rs +++ b/modules/axtask/src/task.rs @@ -69,12 +69,16 @@ pub struct TaskInner { exit_code: AtomicI32, wait_for_exit: WaitQueue, - kstack: Option, + kstack: UnsafeCell>, ctx: UnsafeCell, task_ext: AxTaskExt, #[cfg(feature = "tls")] tls: TlsArea, + + /// The future of coroutine task. + pub(crate) future: + UnsafeCell + Send + 'static>>>>, } impl TaskId { @@ -122,13 +126,27 @@ impl TaskInner { t.entry = Some(Box::into_raw(Box::new(entry))); t.ctx_mut().init(task_entry as usize, kstack.top(), tls); - t.kstack = Some(kstack); + t.kstack = UnsafeCell::new(Some(kstack)); if t.name == "idle" { t.is_idle = true; } t } + /// Create a new task with the given future. + pub fn new_f(future: F, name: String) -> Self + where + F: Future + Send + 'static, + { + let mut t = Self::new_common(TaskId::new(), name); + debug!("new task: {}", t.id_name()); + t.future = UnsafeCell::new(Some(Box::pin(async { + future.await; + crate::exit_f(0).await + }))); + t + } + /// Gets the ID of the task. pub const fn id(&self) -> TaskId { self.id @@ -153,6 +171,16 @@ impl TaskInner { Some(self.exit_code.load(Ordering::Acquire)) } + /// Wait for the task to exit, and return the exit code. + /// + /// It will return immediately if the task has already exited (but not dropped). + pub async fn join_f(&self) -> Option { + self.wait_for_exit + .wait_until_f(|| self.state() == TaskState::Exited) + .await; + Some(self.exit_code.load(Ordering::Acquire)) + } + /// Returns the pointer to the user-defined task extended data. /// /// # Safety @@ -187,12 +215,38 @@ impl TaskInner { /// Returns the top address of the kernel stack. #[inline] pub const fn kernel_stack_top(&self) -> Option { - match &self.kstack { + match unsafe { &*self.kstack.get() } { Some(s) => Some(s.top()), None => None, } } + /// Get the mut ref about the `kstack` field. + #[inline] + const unsafe fn kernel_stack(&self) -> *mut Option { + self.kstack.get() + } + + /// Once the `kstack` field is None, the task is a coroutine. + /// The `kstack` and the `ctx` will be set up, + /// so the next coroutine will start at `coroutine_schedule` function. + /// + /// This function is only used before switching task. + pub(crate) fn set_kstack(&self) { + let kstack = unsafe { &mut *self.kernel_stack() }; + if kstack.is_none() && !self.is_init && !self.is_idle { + let stack = alloc_stack_for_coroutine(); + let kstack_top = stack.top(); + *kstack = Some(stack); + let ctx = unsafe { &mut *self.ctx_mut_ptr() }; + #[cfg(feature = "tls")] + let tls = VirtAddr::from(self.tls.tls_ptr() as usize); + #[cfg(not(feature = "tls"))] + let tls = VirtAddr::from(0); + ctx.init(coroutine_schedule as usize, kstack_top, tls); + } + } + /// Gets the cpu affinity mask of the task. /// /// Returns the cpu affinity mask of the task in type [`AxCpuMask`]. @@ -234,11 +288,12 @@ impl TaskInner { preempt_disable_count: AtomicUsize::new(0), exit_code: AtomicI32::new(0), wait_for_exit: WaitQueue::new(), - kstack: None, + kstack: UnsafeCell::new(None), ctx: UnsafeCell::new(TaskContext::new()), task_ext: AxTaskExt::empty(), #[cfg(feature = "tls")] tls: TlsArea::alloc(), + future: UnsafeCell::new(None), } } @@ -535,3 +590,77 @@ extern "C" fn task_entry() -> ! { } crate::exit(0); } + +#[percpu::def_percpu] +static COROUTINE_STACK_POOL: alloc::vec::Vec = alloc::vec::Vec::new(); + +/// Alloc a stack for running a coroutine. +/// If the `COROUTINE_STACK_POOL` is empty, +/// it will alloc a new stack on the allocator. +fn alloc_stack_for_coroutine() -> TaskStack { + unsafe { + COROUTINE_STACK_POOL + .current_ref_mut_raw() + .pop() + .unwrap_or_else(|| TaskStack::alloc(axconfig::TASK_STACK_SIZE)) + } +} + +/// Recycle the stack after the coroutine running to a certain stage. +fn recycle_stack_of_coroutine(kstack: TaskStack) { + unsafe { + COROUTINE_STACK_POOL.current_ref_mut_raw().push(kstack); + } +} + +/// The function about coroutine scheduling. +pub(crate) extern "C" fn coroutine_schedule() { + use core::task::{Context, Waker}; + loop { + #[cfg(feature = "smp")] + unsafe { + // Clear the prev task on CPU before running the task entry function. + crate::run_queue::clear_prev_task_on_cpu(); + } + // Enable irq (if feature "irq" is enabled) before running the task entry function. + #[cfg(feature = "irq")] + axhal::asm::enable_irqs(); + let waker = Waker::noop(); + let mut cx = Context::from_waker(waker); + let curr = crate::current(); + let future = unsafe { &mut *curr.future.get() } + .as_mut() + .expect("The task should be a coroutine."); + let _res = future.as_mut().poll(&mut cx); + assert!(!curr.is_running()); + // Make sure that IRQs are disabled by kernel guard or other means. + #[cfg(all(not(test), feature = "irq"))] // Note: irq is faked under unit tests. + assert!( + !axhal::asm::irqs_enabled(), + "IRQs must be disabled during scheduling" + ); + let prev_task = curr; + // pick the kstack of prev_task + let kstack = unsafe { &mut *prev_task.kernel_stack() } + .take() + .expect("The kernel stack should be taken out after running."); + let next_task = crate::current(); + let next_kstack = unsafe { &mut *next_task.kernel_stack() }; + if next_kstack.is_none() && !next_task.is_init() && !next_task.is_idle() { + // Pass the `kstack` to the next coroutine task. + *next_kstack = Some(kstack); + } else { + unsafe { + let prev_ctx_ptr = prev_task.ctx_mut_ptr(); + let next_ctx_ptr = next_task.ctx_mut_ptr(); + // Recycle the `kstack` before switching to the next thread task. + recycle_stack_of_coroutine(kstack); + // After switching to the thread task, it will restore to the `switch_to` function. + // The prev task will be cleaned in the `switch_to` function. + // The irq_state will be restore by dropping the `current_run_queue`. + (*prev_ctx_ptr).switch_to(&*next_ctx_ptr); + panic!("Shoule never reach here."); + } + } + } +} diff --git a/modules/axtask/src/wait_queue.rs b/modules/axtask/src/wait_queue.rs index b213579476..ef15bffa4e 100644 --- a/modules/axtask/src/wait_queue.rs +++ b/modules/axtask/src/wait_queue.rs @@ -30,7 +30,7 @@ use crate::{AxTaskRef, CurrentTask, current_run_queue, select_run_queue}; /// assert_eq!(VALUE.load(Ordering::Acquire), 1); /// ``` pub struct WaitQueue { - queue: SpinNoIrq>, + pub(crate) queue: SpinNoIrq>, } pub(crate) type WaitQueueGuard<'a> = SpinNoIrqGuard<'a, VecDeque>; @@ -81,6 +81,14 @@ impl WaitQueue { self.cancel_events(crate::current(), false); } + /// Blocks the current coroutine task and put it into the wait queue, until other task + /// notifies it. + pub async fn wait_f(&self) { + let rq = current_run_queue::(); + crate::run_queue::BlockedReschedFuture::new(rq, self).await; + self.cancel_events(crate::current(), false); + } + /// Blocks the current task and put it into the wait queue, until the given /// `condition` becomes true. /// @@ -103,6 +111,27 @@ impl WaitQueue { self.cancel_events(curr, false); } + /// Blocks the current coroutine task and put it into the wait queue, until the given + /// `condition` becomes true. + /// + /// Note that even other tasks notify this task, it will not wake up until + /// the condition becomes true. + pub async fn wait_until_f(&self, condition: F) + where + F: Fn() -> bool, + { + let curr = crate::current(); + loop { + let rq = current_run_queue::(); + if condition() { + break; + } + crate::run_queue::BlockedReschedFuture::new(rq, self).await; + // Preemption may occur here. + } + self.cancel_events(curr, false); + } + /// Blocks the current task and put it into the wait queue, until other tasks /// notify it, or the given duration has elapsed. #[cfg(feature = "irq")] @@ -126,6 +155,29 @@ impl WaitQueue { timeout } + /// Blocks the current coroutine task and put it into the wait queue, until other tasks + /// notify it, or the given duration has elapsed. + #[cfg(feature = "irq")] + pub async fn wait_timeout_f(&self, dur: core::time::Duration) -> bool { + let rq = current_run_queue::(); + let curr = crate::current(); + let deadline = axhal::time::wall_time() + dur; + debug!( + "task wait_timeout: {} deadline={:?}", + curr.id_name(), + deadline + ); + crate::timers::set_alarm_wakeup(deadline, curr.clone()); + + crate::run_queue::BlockedReschedFuture::new(rq, self).await; + + let timeout = curr.in_wait_queue(); // still in the wait queue, must have timed out + + // Always try to remove the task from the timer list. + self.cancel_events(curr, true); + timeout + } + /// Blocks the current task and put it into the wait queue, until the given /// `condition` becomes true, or the given duration has elapsed. /// @@ -165,6 +217,44 @@ impl WaitQueue { timeout } + /// Blocks the current coroutine task and put it into the wait queue, until the given + /// `condition` becomes true, or the given duration has elapsed. + /// + /// Note that even other tasks notify this task, it will not wake up until + /// the above conditions are met. + #[cfg(feature = "irq")] + pub async fn wait_timeout_until_f(&self, dur: core::time::Duration, condition: F) -> bool + where + F: Fn() -> bool, + { + let curr = crate::current(); + let deadline = axhal::time::wall_time() + dur; + debug!( + "task wait_timeout: {}, deadline={:?}", + curr.id_name(), + deadline + ); + crate::timers::set_alarm_wakeup(deadline, curr.clone()); + + let mut timeout = true; + loop { + let rq = current_run_queue::(); + if axhal::time::wall_time() >= deadline { + break; + } + if condition() { + timeout = false; + break; + } + + crate::run_queue::BlockedReschedFuture::new(rq, self).await; + // Preemption may occur here. + } + // Always try to remove the task from the timer list. + self.cancel_events(curr, true); + timeout + } + /// Wakes up one task in the wait queue, usually the first one. /// /// If `resched` is true, the current task will be preempted when the diff --git a/ulib/axasync-std/Cargo.toml b/ulib/axasync-std/Cargo.toml new file mode 100644 index 0000000000..43f237755a --- /dev/null +++ b/ulib/axasync-std/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "axasync-std" +version.workspace = true +edition.workspace = true +license.workspace = true +homepage.workspace = true +documentation.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +authors = ["Fangliang Zhao <1491657576@qq.com>"] +description = "ArceOS user library with an async interface" + + +[dependencies] +arceos_api = { workspace = true, features = ["multitask"] } +axerrno = "0.1" +axio = "0.1" diff --git a/ulib/axasync-std/src/lib.rs b/ulib/axasync-std/src/lib.rs new file mode 100644 index 0000000000..1ab42e83b6 --- /dev/null +++ b/ulib/axasync-std/src/lib.rs @@ -0,0 +1,33 @@ +//! # The ArceOS Async Standard Library +//! +//! [ArceOS]: https://github.com/arceos-org/arceos + +#![cfg_attr(all(not(test), not(doc)), no_std)] +#![feature(doc_cfg)] +#![feature(doc_auto_cfg)] + +#[macro_use] +mod macros; +pub mod task; + +/// Traits, helpers, and type definitions for core I/O functionality. +pub mod io { + + /// A specialized [`Result`] type for I/O operations. + /// + /// This type is broadly used across [`axasync_std::io`] for any operation which may + /// produce an error. + /// + /// This typedef is generally used to avoid writing out [`io::Error`] directly and + /// is otherwise a direct mapping to [`Result`]. + /// + /// While usual Rust style is to import types directly, aliases of [`Result`] + /// often are not, to make it easier to distinguish between them. [`Result`] is + /// generally assumed to be [`std::result::Result`][`Result`], and so users of this alias + /// will generally use `io::Result` instead of shadowing the [prelude]'s import + /// of [`std::result::Result`][`Result`]. + /// + /// [`axstd::io`]: crate::io + /// [`io::Error`]: Error + pub type Result = axio::Result; +} diff --git a/ulib/axasync-std/src/macros.rs b/ulib/axasync-std/src/macros.rs new file mode 100644 index 0000000000..75756aefbc --- /dev/null +++ b/ulib/axasync-std/src/macros.rs @@ -0,0 +1,100 @@ +//! Async Standard library macros + +/// The `block_on!` macro which is used to poll a async function in a busy spinning manner. +/// +/// This macro is used in the normal function. In the async function, the future can be direcyly invoked. +/// +/// Usage scenarios: when the future can be polled without waiting or the waiting for a short time, it should use this macro. +/// +/// Examples: +/// +/// ```rust +/// fn main() { +/// block_on!{hello_world()}; +/// } +/// +/// async fn hello_world() { +/// println!("hello world!"); +/// } +/// ``` +/// +#[macro_export] +macro_rules! block_on { + ($l:expr) => { + // The future can be pinned on the stack directly + // because the stack cannot be used by other task. + let mut future = $l; + let mut pinned_fut = unsafe { core::pin::Pin::new_unchecked(&mut future) }; + // The waker can use the `Waker::noop()` because + // there is no task switching while polling the future. + // The task which call this macro and poll this future can + // be preempt by the timer IRQ. + let waker = core::task::Waker::noop(); + let mut cx = core::task::Context::from_waker(&waker); + loop { + if let core::task::Poll::Ready(res) = pinned_fut.as_mut().poll(&mut cx) { + break res; + } + } + }; +} + +/// The `callasync!` macro is the same as the `block_on!`, +/// but it is combined with thread switching. +/// +/// This macro is used in the normal function. +/// In the async function, the future can be direcyly invoked. +/// +/// Usage scenarios: +/// when the future need wait for a long time to be `Poll::Ready`, +/// and the thread must wait for the result of the future, +/// it should use this macro. +/// It can yield the thread to run other task. +/// +/// The yield operation can be defined through a `trait` +/// which is as the same as the implementation in +/// [`axlog`](https://github.com/arceos-org/arceos/tree/main/modules/axlog) crate. +/// +/// Examples: +/// ```rust +/// fn main() { +/// callasync!{test()}; +/// } +/// +/// async fn test() -> i32 { +/// let mut flag = false; +/// core::future::poll_fn(|_cx| { +/// if !flag { +/// flag = true; +/// core::task::Poll::Pending +/// } else { +/// core::task::Poll::Ready(()) +/// } +/// }).await; +/// 43 +/// } +/// ``` +#[macro_export] +macro_rules! callasync { + ($l:expr) => { + // The future can be pinned on the stack directly + // because the stack cannot be used by other task. + let mut future = $l; + let mut pinned_fut = unsafe { core::pin::Pin::new_unchecked(&mut future) }; + // The waker can use the `Waker::noop()` because + // the task is switched as a thread. + // The task which call this macro and poll this future can + // be preempt by the timer IRQ. + let waker = core::task::Waker::noop(); + let mut cx = core::task::Context::from_waker(&waker); + loop { + match pinned_fut.as_mut().poll(&mut cx) { + core::task::Poll::Ready(r) => break r, + core::task::Poll::Pending => { + // Yield the task which call this marco when the future return `Pending`. + $crate::task::_api::ax_yield_now(); + } + } + } + }; +} diff --git a/ulib/axasync-std/src/task/mod.rs b/ulib/axasync-std/src/task/mod.rs new file mode 100644 index 0000000000..b47a6fbf72 --- /dev/null +++ b/ulib/axasync-std/src/task/mod.rs @@ -0,0 +1,44 @@ +//! The async interfaces about coroutines. +use arceos_api::task as api; +pub use core::task::*; + +mod multi; +pub use multi::*; + +/// Current coroutine gives up the CPU time voluntarily, and switches to another +/// ready task. +/// +/// For single-threaded configuration (`multitask` feature is disabled), we just +/// relax the CPU and wait for incoming interrupts. +pub async fn yield_now() { + api::ax_yield_now_f().await; +} + +/// Exits the current coroutine. +/// +/// For single-threaded configuration (`multitask` feature is disabled), +/// it directly terminates the main thread and shutdown. +pub async fn exit(exit_code: i32) { + api::ax_exit_f(exit_code).await; +} + +/// Current coroutine is going to sleep for the given duration. +/// +/// If one of `multitask` or `irq` features is not enabled, it uses busy-wait +/// instead. +pub async fn sleep(dur: core::time::Duration) { + sleep_until(arceos_api::time::ax_wall_time() + dur).await; +} + +/// Current thread is going to sleep, it will be woken up at the given deadline. +/// +/// If one of `multitask` or `irq` features is not enabled, it uses busy-wait +/// instead. +pub async fn sleep_until(deadline: arceos_api::time::AxTimeValue) { + api::ax_sleep_until_f(deadline).await; +} + +#[doc(hidden)] +pub mod _api { + pub use arceos_api::task::ax_yield_now; +} diff --git a/ulib/axasync-std/src/task/multi.rs b/ulib/axasync-std/src/task/multi.rs new file mode 100644 index 0000000000..985dcc45d0 --- /dev/null +++ b/ulib/axasync-std/src/task/multi.rs @@ -0,0 +1,196 @@ +//! Coroutine APIs for multi-task configuration. + +extern crate alloc; + +use crate::io; +use alloc::{string::String, sync::Arc}; +use arceos_api::task::{self as api, AxTaskHandle}; +use axerrno::ax_err_type; +use core::{cell::UnsafeCell, future::Future, num::NonZeroU64}; + +/// A unique identifier for a running coroutine task. +#[derive(Eq, PartialEq, Clone, Copy, Debug)] +pub struct TaskId(NonZeroU64); + +/// A handle to a coroutine. +pub struct Task { + id: TaskId, +} + +impl TaskId { + /// This returns a numeric identifier for the coroutine task identified by this + /// `TaskId`. + pub fn as_u64(&self) -> NonZeroU64 { + self.0 + } +} + +impl Task { + fn from_id(id: u64) -> Self { + Self { + id: TaskId(NonZeroU64::new(id).unwrap()), + } + } + + /// Gets the coroutine task's unique identifier. + pub fn id(&self) -> TaskId { + self.id + } +} + +/// Task factory, which can be used in order to configure the properties of +/// a new coroutine task. +/// +/// Methods can be chained on it in order to configure it. +#[derive(Debug)] +pub struct Builder { + // A name for the coroutine task-to-be, for identification in panic messages + name: Option, +} + +impl Builder { + /// Generates the base configuration for spawning a coroutine task, from which + /// configuration methods can be chained. + pub const fn new() -> Builder { + Builder { name: None } + } + + /// Names the coroutine task-to-be. + pub fn name(mut self, name: String) -> Builder { + self.name = Some(name); + self + } + + /// Spawns a new coroutine task by taking ownership of the `Builder`, and returns an + /// [`io::Result`] to its [`JoinHandle`]. + /// + /// The spawned coroutine task may outlive the caller (unless the caller coroutine task + /// is the main coroutine task; the whole process is terminated when the main + /// coroutine task finishes). The join handle can be used to block on + /// termination of the spawned coroutine task. + pub fn spawn(self, f: F1) -> io::Result> + where + F1: FnOnce() -> F2, + F1: Send + 'static, + F2: Future + Send + 'static, + T: Send + 'static, + { + unsafe { self.spawn_unchecked(f) } + } + + unsafe fn spawn_unchecked(self, f: F1) -> io::Result> + where + F1: FnOnce() -> F2, + F1: Send + 'static, + F2: Future + Send + 'static, + T: Send + 'static, + { + let name = self.name.unwrap_or_default(); + + let my_packet = Arc::new(Packet { + result: UnsafeCell::new(None), + }); + let their_packet = my_packet.clone(); + + let main = async move { + let ret = f().await; + // SAFETY: `their_packet` as been built just above and moved by the + // closure (it is an Arc<...>) and `my_packet` will be stored in the + // same `JoinHandle` as this closure meaning the mutation will be + // safe (not modify it and affect a value far away). + unsafe { *their_packet.result.get() = Some(ret) }; + drop(their_packet); + }; + + let task = api::ax_spawn_f(main, name); + Ok(JoinHandle { + task: Task::from_id(task.id()), + native: task, + packet: my_packet, + }) + } +} + +/// Gets a handle to the coroutine task that invokes it. +pub fn current() -> Task { + let id = api::ax_current_task_id(); + Task::from_id(id) +} + +/// Spawns a new coroutine task, returning a [`JoinHandle`] for it. +/// +/// The join handle provides a [`join`] method that can be used to join the +/// spawned coroutine task. +/// +/// The default task name is an empty string. The default coroutine task stack size is +/// [`arceos_api::config::TASK_STACK_SIZE`]. +/// +/// [`join`]: JoinHandle::join +pub fn spawn(f: F1) -> JoinHandle +where + F1: FnOnce() -> F2, + F1: Send + 'static, + F2: Future + Send + 'static, + T: Send + 'static, +{ + Builder::new() + .spawn(f) + .expect("failed to spawn coroutine task") +} + +struct Packet { + result: UnsafeCell>, +} + +unsafe impl Sync for Packet {} + +/// An owned permission to join on a coroutine task (block on its termination). +/// +/// A `JoinHandle` *detaches* the associated coroutine task when it is dropped, which +/// means that there is no longer any handle to the coroutine task and no way to `join` +/// on it. +pub struct JoinHandle { + native: AxTaskHandle, + task: Task, + packet: Arc>, +} + +unsafe impl Send for JoinHandle {} +unsafe impl Sync for JoinHandle {} + +impl JoinHandle { + /// Extracts a handle to the underlying coroutine task. + pub fn task(&self) -> &Task { + &self.task + } + + /// Waits for the associated coroutine task to finish. + /// + /// This function will return immediately if the associated coroutine task has + /// already finished. + pub async fn join_f(mut self) -> io::Result { + api::ax_wait_for_exit_f(self.native) + .await + .ok_or_else(|| ax_err_type!(BadState))?; + Arc::get_mut(&mut self.packet) + .unwrap() + .result + .get_mut() + .take() + .ok_or_else(|| ax_err_type!(BadState)) + } + + /// Waits for the associated coroutine task to finish. + /// + /// This function will return immediately if the associated coroutine task has + /// already finished. + pub fn join(mut self) -> io::Result { + api::ax_wait_for_exit(self.native).ok_or_else(|| ax_err_type!(BadState))?; + Arc::get_mut(&mut self.packet) + .unwrap() + .result + .get_mut() + .take() + .ok_or_else(|| ax_err_type!(BadState)) + } +}