diff --git a/src/api2/node.rs b/src/api2/node.rs index b572a9cf..b80a6618 100644 --- a/src/api2/node.rs +++ b/src/api2/node.rs @@ -267,7 +267,7 @@ fn upgrade_to_websocket( let (ws, response) = WebSocket::new(parts.headers)?; - tokio::spawn(async move { + crate::server::spawn_internal_task(async move { let conn: Upgraded = match req_body.on_upgrade().map_err(Error::from).await { Ok(upgraded) => upgraded, _ => bail!("error"), diff --git a/src/server/state.rs b/src/server/state.rs index 41e910dd..d0732434 100644 --- a/src/server/state.rs +++ b/src/server/state.rs @@ -19,6 +19,7 @@ pub struct ServerState { pub shutdown_listeners: BroadcastData<()>, pub last_worker_listeners: BroadcastData<()>, pub worker_count: usize, + pub task_count: usize, pub reload_request: bool, } @@ -28,6 +29,7 @@ lazy_static! { shutdown_listeners: BroadcastData::new(), last_worker_listeners: BroadcastData::new(), worker_count: 0, + task_count: 0, reload_request: false, }); } @@ -101,20 +103,40 @@ pub fn last_worker_future() -> impl Future> { } pub fn set_worker_count(count: usize) { - let mut data = SERVER_STATE.lock().unwrap(); - data.worker_count = count; + SERVER_STATE.lock().unwrap().worker_count = count; - if !(data.mode == ServerMode::Shutdown && data.worker_count == 0) { return; } - - data.last_worker_listeners.notify_listeners(Ok(())); + check_last_worker(); } - pub fn check_last_worker() { - let mut data = SERVER_STATE.lock().unwrap(); - if !(data.mode == ServerMode::Shutdown && data.worker_count == 0) { return; } + if !(data.mode == ServerMode::Shutdown && data.worker_count == 0 && data.task_count == 0) { return; } data.last_worker_listeners.notify_listeners(Ok(())); } + +/// Spawns a tokio task that will be tracked for reload +/// and if it is finished, notify the last_worker_listener if we +/// are in shutdown mode +pub fn spawn_internal_task(task: T) +where + T: Future + Send + 'static, + T::Output: Send + 'static, +{ + let mut data = SERVER_STATE.lock().unwrap(); + data.task_count += 1; + + tokio::spawn(async move { + let _ = tokio::spawn(task).await; // ignore errors + + { // drop mutex + let mut data = SERVER_STATE.lock().unwrap(); + if data.task_count > 0 { + data.task_count -= 1; + } + } + + check_last_worker(); + }); +}