diff --git a/monarch_hyperactor/Cargo.toml b/monarch_hyperactor/Cargo.toml index 4dd77e107..928cb8726 100644 --- a/monarch_hyperactor/Cargo.toml +++ b/monarch_hyperactor/Cargo.toml @@ -48,6 +48,7 @@ serde_multipart = { version = "0.0.0", path = "../serde_multipart" } tempfile = "3.22" thiserror = "2.0.12" tokio = { version = "1.47.1", features = ["full", "test-util", "tracing"] } +tokio-util = { version = "0.7.15", features = ["full"] } tracing = { version = "0.1.41", features = ["attributes", "valuable"] } [dev-dependencies] diff --git a/monarch_hyperactor/src/v1/actor_mesh.rs b/monarch_hyperactor/src/v1/actor_mesh.rs index d67b03787..d4d6518a2 100644 --- a/monarch_hyperactor/src/v1/actor_mesh.rs +++ b/monarch_hyperactor/src/v1/actor_mesh.rs @@ -46,6 +46,7 @@ use pyo3::prelude::*; use pyo3::types::PyBytes; use tokio::sync::watch; use tokio::task::JoinHandle; +use tokio_util::sync::CancellationToken; use crate::actor::PythonActor; use crate::actor::PythonMessage; @@ -90,7 +91,7 @@ impl RootHealthState { #[derive(Debug)] struct SupervisionMonitor { - task: JoinHandle<()>, + cancel: CancellationToken, receiver: watch::Receiver>, } @@ -98,7 +99,7 @@ impl Drop for SupervisionMonitor { fn drop(&mut self) { // The task is continuously polling for events on this mesh, but when // the mesh is no longer available we can stop querying it. - self.task.abort(); + self.cancel.cancel(); } } @@ -306,6 +307,8 @@ impl PythonActorMeshImpl { // not share the health state. This is fine because requerying a slice // of a mesh will still return any failed state. let (sender, receiver) = watch::channel(None); + let cancel = CancellationToken::new(); + let canceled = cancel.clone(); let task = get_tokio_runtime().spawn(async move { // 3 seconds is chosen to not penalize short-lived successful calls, // while still able to catch issues before they look like a hang or timeout. @@ -322,7 +325,8 @@ impl PythonActorMeshImpl { unhandled, health_state, time_between_checks, - sender.clone(), + sender, + canceled, ) .await; } @@ -342,13 +346,14 @@ impl PythonActorMeshImpl { unhandled, health_state, time_between_checks, - sender.clone(), + sender, + canceled, ) .await; } }; }); - SupervisionMonitor { task, receiver } + SupervisionMonitor { cancel, receiver } } } @@ -466,6 +471,7 @@ async fn actor_states_monitor( health_state: Arc, time_between_checks: tokio::time::Duration, sender: watch::Sender>, + canceled: CancellationToken, ) where A: Actor + RemotableActor + Referable, A::Params: RemoteMessage, @@ -479,7 +485,10 @@ async fn actor_states_monitor( let mut existing_states: HashMap> = HashMap::new(); loop { // Wait in between checking to avoid using too much network. - RealClock.sleep(time_between_checks).await; + tokio::select! { + _ = RealClock.sleep(time_between_checks) => (), + _ = canceled.cancelled() => break, + } // First check if the proc mesh is dead before trying to query their agents. let proc_states = mesh.proc_mesh().proc_states(cx).await; if let Err(e) = proc_states {