diff --git a/Cargo.lock b/Cargo.lock index 93ab10d31b..f6aafd73cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9030,10 +9030,10 @@ name = "virtio_p9" version = "0.0.0" dependencies = [ "anyhow", - "async-trait", - "event-listener", + "futures", "guestmem", "inspect", + "pal_async", "plan9", "task_control", "tracing", diff --git a/vm/devices/virtio/virtio_p9/Cargo.toml b/vm/devices/virtio/virtio_p9/Cargo.toml index f3876a192e..9119fa57f5 100644 --- a/vm/devices/virtio/virtio_p9/Cargo.toml +++ b/vm/devices/virtio/virtio_p9/Cargo.toml @@ -18,8 +18,8 @@ vmcore.workspace = true task_control.workspace = true anyhow.workspace = true -async-trait.workspace = true -event-listener.workspace = true +futures.workspace = true +pal_async.workspace = true tracing.workspace = true [lints] diff --git a/vm/devices/virtio/virtio_p9/src/lib.rs b/vm/devices/virtio/virtio_p9/src/lib.rs index 3738534bec..b16ac9f5b2 100644 --- a/vm/devices/virtio/virtio_p9/src/lib.rs +++ b/vm/devices/virtio/virtio_p9/src/lib.rs @@ -8,22 +8,24 @@ pub mod resolver; use anyhow::Context as _; -use async_trait::async_trait; +use futures::StreamExt; use guestmem::GuestMemory; use inspect::InspectMut; +use pal_async::wait::PolledWait; use plan9::Plan9FileSystem; -use std::sync::Arc; use std::task::Context; use std::task::Poll; use std::task::ready; +use task_control::AsyncRun; +use task_control::Cancelled; +use task_control::InspectTaskMut; +use task_control::StopTask; use task_control::TaskControl; use virtio::DeviceTraits; use virtio::Resources; use virtio::VirtioDevice; +use virtio::VirtioQueue; use virtio::VirtioQueueCallbackWork; -use virtio::VirtioQueueState; -use virtio::VirtioQueueWorker; -use virtio::VirtioQueueWorkerContext; use virtio::spec::VirtioDeviceFeatures; use vmcore::vm_task::VmTaskDriver; use vmcore::vm_task::VmTaskDriverSource; @@ -34,16 +36,10 @@ const VIRTIO_9P_F_MOUNT_TAG: u32 = 1; #[derive(InspectMut)] pub struct VirtioPlan9Device { - #[inspect(skip)] - fs: Arc, - #[inspect(skip)] tag: Vec, - memory: GuestMemory, driver: VmTaskDriver, - #[inspect(skip)] - worker: Option>, - #[inspect(skip)] - exit_event: event_listener::Event, + #[inspect(mut)] + worker: TaskControl, } impl VirtioPlan9Device { @@ -69,12 +65,9 @@ impl VirtioPlan9Device { } VirtioPlan9Device { - fs: Arc::new(fs), tag: tag_buffer, - memory, driver: driver_source.simple(), - worker: None, - exit_event: event_listener::Event::new(), + worker: TaskControl::new(Plan9Worker { mem: memory, fs }), } } } @@ -121,68 +114,96 @@ impl VirtioDevice for VirtioPlan9Device { return Ok(()); } - let worker = VirtioPlan9Worker { - mem: self.memory.clone(), - fs: self.fs.clone(), - }; - let worker = VirtioQueueWorker::new(self.driver.clone(), Box::new(worker)); - self.worker = Some(worker.into_running_task( - "virtio-9p-queue".to_string(), - self.memory.clone(), - resources.features.clone(), - queue_resources, - self.exit_event.listen(), - )); + let queue_event = PolledWait::new(&self.driver, queue_resources.event) + .context("failed to create polled wait")?; + let queue = VirtioQueue::new( + resources.features, + queue_resources.params, + self.worker.task().mem.clone(), + queue_resources.notify, + queue_event, + ) + .context("failed to create virtio queue")?; + + self.worker + .insert(self.driver.clone(), "virtio-9p-queue", Plan9Queue { queue }); + self.worker.start(); Ok(()) } fn poll_disable(&mut self, cx: &mut Context<'_>) -> Poll<()> { - self.exit_event.notify(usize::MAX); - if let Some(worker) = &mut self.worker { - ready!(worker.poll_stop(cx)); + ready!(self.worker.poll_stop(cx)); + if self.worker.has_state() { + self.worker.remove(); } - self.worker = None; Poll::Ready(()) } } -struct VirtioPlan9Worker { +#[derive(InspectMut)] +struct Plan9Worker { mem: GuestMemory, - fs: Arc, + #[inspect(skip)] + fs: Plan9FileSystem, +} + +#[derive(InspectMut)] +struct Plan9Queue { + queue: VirtioQueue, +} + +impl InspectTaskMut for Plan9Worker { + fn inspect_mut(&mut self, req: inspect::Request<'_>, state: Option<&mut Plan9Queue>) { + req.respond().merge(self).merge(state); + } } -#[async_trait] -impl VirtioQueueWorkerContext for VirtioPlan9Worker { - async fn process_work(&mut self, work: anyhow::Result) -> bool { - if let Err(err) = work { - tracing::error!(err = err.as_ref() as &dyn std::error::Error, "queue error"); - return false; +impl AsyncRun for Plan9Worker { + async fn run( + &mut self, + stop: &mut StopTask<'_>, + state: &mut Plan9Queue, + ) -> Result<(), Cancelled> { + loop { + let work = stop.until_stopped(state.queue.next()).await?; + let Some(work) = work else { break }; + match work { + Ok(work) => { + process_9p_request(self, work); + } + Err(err) => { + tracing::error!(error = &err as &dyn std::error::Error, "queue error"); + break; + } + } } - let mut work = work.unwrap(); - // Make a copy of the incoming message. - let mut message = vec![0; work.get_payload_length(false) as usize]; - if let Err(e) = work.read(&self.mem, &mut message) { + Ok(()) + } +} + +fn process_9p_request(worker: &Plan9Worker, mut work: VirtioQueueCallbackWork) { + // Make a copy of the incoming message. + let mut message = vec![0; work.get_payload_length(false) as usize]; + if let Err(e) = work.read(&worker.mem, &mut message) { + tracing::error!( + error = &e as &dyn std::error::Error, + "[VIRTIO 9P] Failed to read guest memory" + ); + return; + } + + // Allocate a temporary buffer for the response. + let mut response = vec![9; work.get_payload_length(true) as usize]; + if let Ok(size) = worker.fs.process_message(&message, &mut response) { + // Write out the response. + if let Err(e) = work.write(&worker.mem, &response[0..size]) { tracing::error!( error = &e as &dyn std::error::Error, - "[VIRTIO 9P] Failed to read guest memory" + "[VIRTIO 9P] Failed to write guest memory" ); - return false; + return; } - // Allocate a temporary buffer for the response. - let mut response = vec![9; work.get_payload_length(true) as usize]; - if let Ok(size) = self.fs.process_message(&message, &mut response) { - // Write out the response. - if let Err(e) = work.write(&self.mem, &response[0..size]) { - tracing::error!( - error = &e as &dyn std::error::Error, - "[VIRTIO 9P] Failed to write guest memory" - ); - return false; - } - - work.complete(size as u32); - } - true + work.complete(size as u32); } }