diff --git a/src/backup/catalog.rs b/src/backup/catalog.rs index 65605e6e..9cbf275b 100644 --- a/src/backup/catalog.rs +++ b/src/backup/catalog.rs @@ -8,10 +8,12 @@ use std::convert::TryFrom; use chrono::offset::{TimeZone, Local}; use proxmox::tools::io::ReadExt; +use proxmox::sys::error::io_err_other; use crate::pxar::catalog::BackupCatalogWriter; use crate::pxar::{MatchPattern, MatchPatternSlice, MatchType}; use crate::backup::file_formats::PROXMOX_CATALOG_FILE_MAGIC_1_0; +use crate::tools::runtime::block_on; #[repr(u8)] #[derive(Copy,Clone,PartialEq)] @@ -384,12 +386,12 @@ impl SenderWriter { impl Write for SenderWriter { fn write(&mut self, buf: &[u8]) -> Result { - tokio::task::block_in_place(|| { - futures::executor::block_on(async move { - self.0.send(Ok(buf.to_vec())).await - .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err.to_string()))?; - Ok(buf.len()) - }) + block_on(async move { + self.0 + .send(Ok(buf.to_vec())) + .await + .map_err(io_err_other) + .and(Ok(buf.len())) }) } diff --git a/src/bin/download-speed.rs b/src/bin/download-speed.rs index 3e4010a2..c4b36217 100644 --- a/src/bin/download-speed.rs +++ b/src/bin/download-speed.rs @@ -57,7 +57,7 @@ async fn run() -> Result<(), Error> { #[tokio::main] async fn main() { - if let Err(err) = run().await { + if let Err(err) = proxmox_backup::tools::runtime::main(run()) { eprintln!("ERROR: {}", err); } println!("DONE"); diff --git a/src/bin/h2client.rs b/src/bin/h2client.rs index 542ecff0..b82ec55d 100644 --- a/src/bin/h2client.rs +++ b/src/bin/h2client.rs @@ -69,8 +69,11 @@ fn send_request( }) } -#[tokio::main] -async fn main() -> Result<(), Error> { +fn main() -> Result<(), Error> { + proxmox_backup::tools::runtime::main(run()) +} + +async fn run() -> Result<(), Error> { let start = std::time::SystemTime::now(); diff --git a/src/bin/h2s-client.rs b/src/bin/h2s-client.rs index df9a95b8..6c848cd2 100644 --- a/src/bin/h2s-client.rs +++ b/src/bin/h2s-client.rs @@ -67,8 +67,11 @@ fn send_request( }) } -#[tokio::main] -async fn main() -> Result<(), Error> { +fn main() -> Result<(), Error> { + proxmox_backup::tools::runtime::main(run()) +} + +async fn run() -> Result<(), Error> { let start = std::time::SystemTime::now(); let conn = diff --git a/src/bin/h2s-server.rs b/src/bin/h2s-server.rs index 39483af2..5218cccc 100644 --- a/src/bin/h2s-server.rs +++ b/src/bin/h2s-server.rs @@ -10,8 +10,11 @@ use proxmox_backup::configdir; // Simple H2 server to test H2 speed with h2s-client.rs -#[tokio::main] -async fn main() -> Result<(), Error> { +fn main() -> Result<(), Error> { + proxmox_backup::tools::runtime::main(run()) +} + +async fn run() -> Result<(), Error> { let key_path = configdir!("/proxy.key"); let cert_path = configdir!("/proxy.pem"); diff --git a/src/bin/h2server.rs b/src/bin/h2server.rs index 3d602134..5e31288d 100644 --- a/src/bin/h2server.rs +++ b/src/bin/h2server.rs @@ -8,8 +8,11 @@ use tokio::io::{AsyncRead, AsyncWrite}; use proxmox_backup::client::pipe_to_stream::PipeToSendStream; -#[tokio::main] -async fn main() -> Result<(), Error> { +fn main() -> Result<(), Error> { + proxmox_backup::tools::runtime::main(run()) +} + +async fn run() -> Result<(), Error> { let mut listener = TcpListener::bind(std::net::SocketAddr::from(([127,0,0,1], 8008))).await?; println!("listening on {:?}", listener.local_addr()); diff --git a/src/bin/proxmox-backup-api.rs b/src/bin/proxmox-backup-api.rs index baee4a41..a1ac399a 100644 --- a/src/bin/proxmox-backup-api.rs +++ b/src/bin/proxmox-backup-api.rs @@ -13,9 +13,8 @@ use proxmox_backup::auth_helpers::*; use proxmox_backup::config; use proxmox_backup::buildcfg; -#[tokio::main] -async fn main() { - if let Err(err) = run().await { +fn main() { + if let Err(err) = proxmox_backup::tools::runtime::main(run()) { eprintln!("Error: {}", err); std::process::exit(-1); } diff --git a/src/bin/proxmox-backup-proxy.rs b/src/bin/proxmox-backup-proxy.rs index ccf7c88a..fdd54cad 100644 --- a/src/bin/proxmox-backup-proxy.rs +++ b/src/bin/proxmox-backup-proxy.rs @@ -15,9 +15,8 @@ use proxmox_backup::tools::daemon; use proxmox_backup::server::{ApiConfig, rest::*}; use proxmox_backup::auth_helpers::*; -#[tokio::main] -async fn main() { - if let Err(err) = run().await { +fn main() { + if let Err(err) = proxmox_backup::tools::runtime::main(run()) { eprintln!("Error: {}", err); std::process::exit(-1); } diff --git a/src/bin/test_chunk_speed2.rs b/src/bin/test_chunk_speed2.rs index 27f7a157..7d061a5d 100644 --- a/src/bin/test_chunk_speed2.rs +++ b/src/bin/test_chunk_speed2.rs @@ -12,9 +12,8 @@ use proxmox_backup::backup::*; // // Note: I can currently get about 830MB/s -#[tokio::main] -async fn main() { - if let Err(err) = run().await { +fn main() { + if let Err(err) = proxmox_backup::tools::runtime::main(run()) { panic!("ERROR: {}", err); } } diff --git a/src/bin/upload-speed.rs b/src/bin/upload-speed.rs index 2b9a50b4..8fcf4f4d 100644 --- a/src/bin/upload-speed.rs +++ b/src/bin/upload-speed.rs @@ -21,9 +21,8 @@ async fn upload_speed() -> Result { Ok(res) } -#[tokio::main] -async fn main() { - match upload_speed().await { +fn main() { + match proxmox_backup::tools::runtime::main(upload_speed()) { Ok(mbs) => { println!("average upload speed: {} MB/s", mbs); } diff --git a/src/client/remote_chunk_reader.rs b/src/client/remote_chunk_reader.rs index aeb82183..d997d350 100644 --- a/src/client/remote_chunk_reader.rs +++ b/src/client/remote_chunk_reader.rs @@ -5,6 +5,7 @@ use failure::*; use super::BackupReader; use crate::backup::{ReadChunk, DataBlob, CryptConfig}; +use crate::tools::runtime::block_on; /// Read chunks from remote host using ``BackupReader`` pub struct RemoteChunkReader { @@ -35,7 +36,14 @@ impl ReadChunk for RemoteChunkReader { let mut chunk_data = Vec::with_capacity(4*1024*1024); - tokio::task::block_in_place(|| futures::executor::block_on(self.client.download_chunk(&digest, &mut chunk_data)))?; + //tokio::task::block_in_place(|| futures::executor::block_on(self.client.download_chunk(&digest, &mut chunk_data)))?; + block_on(async { + // download_chunk returns the writer back to us, but we need to return a 'static value + self.client + .download_chunk(&digest, &mut chunk_data) + .await + .map(drop) + })?; let chunk = DataBlob::from_raw(chunk_data)?; chunk.verify_crc()?; diff --git a/src/tools/runtime.rs b/src/tools/runtime.rs index f5e2ca92..92155f05 100644 --- a/src/tools/runtime.rs +++ b/src/tools/runtime.rs @@ -1,20 +1,131 @@ //! Helpers for quirks of the current tokio runtime. +use std::cell::RefCell; use std::future::Future; -pub fn main(fut: F) -> T -where - F: Future + Send + 'static, - T: std::fmt::Debug + Send + 'static, -{ - let mut rt = tokio::runtime::Runtime::new().unwrap(); - rt.block_on(async { - let (tx, rx) = tokio::sync::oneshot::channel(); +use lazy_static::lazy_static; +use tokio::runtime::{self, Runtime}; - tokio::spawn(async move { - tx.send(fut.await).unwrap() - }); - - rx.await.unwrap() - }) +thread_local! { + static HAS_RUNTIME: RefCell = RefCell::new(false); + static IN_TOKIO: RefCell = RefCell::new(false); +} + +fn is_in_tokio() -> bool { + IN_TOKIO.with(|v| *v.borrow()) +} + +fn has_runtime() -> bool { + HAS_RUNTIME.with(|v| *v.borrow()) +} + +struct RuntimeGuard(bool); + +impl RuntimeGuard { + fn enter() -> Self { + Self(HAS_RUNTIME.with(|v| { + let old = *v.borrow(); + *v.borrow_mut() = true; + old + })) + } +} + +impl Drop for RuntimeGuard { + fn drop(&mut self) { + HAS_RUNTIME.with(|v| { + *v.borrow_mut() = self.0; + }); + } +} + +lazy_static! { + static ref RUNTIME: Runtime = { + runtime::Builder::new() + .threaded_scheduler() + .enable_all() + .on_thread_start(|| IN_TOKIO.with(|v| *v.borrow_mut() = true)) + .build() + .expect("failed to spawn tokio runtime") + }; +} + +/// Get or create the current main tokio runtime. +/// +/// This makes sure that tokio's worker threads are marked for us so that we know whether we +/// can/need to use `block_in_place` in our `block_on` helper. +pub fn get_runtime() -> &'static Runtime { + &RUNTIME +} + +/// Associate the current newly spawned thread with the main tokio runtime. +pub fn enter_runtime(f: impl FnOnce() -> R) -> R { + let _guard = RuntimeGuard::enter(); + get_runtime().enter(f) +} + +/// Block on a synchronous piece of code. +pub fn block_in_place(fut: impl FnOnce() -> R) -> R { + if is_in_tokio() { + // we are in an actual tokio worker thread, block it: + tokio::task::block_in_place(fut) + } else { + // we're not inside a tokio worker, so just run the code: + fut() + } +} + +/// Block on a future in this thread. +pub fn block_on(fut: F) -> R +where + R: Send + 'static, + F: Future + Send, +{ + + if is_in_tokio() { + // inside a tokio worker we need to tell tokio that we're about to really block: + tokio::task::block_in_place(move || futures::executor::block_on(fut)) + } else if has_runtime() { + // we're already associated with a runtime, but we're not a worker-thread, we can just + // block this thread directly + // This is not strictly necessary, but it's a bit quicker tha the else branch below. + futures::executor::block_on(fut) + } else { + // not a worker thread, not associated with a runtime, make sure we have a runtime (spawn + // it on demand if necessary), then enter it: + enter_runtime(move || futures::executor::block_on(fut)) + } +} + +/* +fn block_on_impl(mut fut: F) -> F::Output +where + F: Future + Send, + F::Output: Send + 'static, +{ + let (tx, rx) = tokio::sync::oneshot::channel(); + let fut_ptr = &mut fut as *mut F as usize; // hack to not require F to be 'static + tokio::spawn(async move { + let fut: F = unsafe { std::ptr::read(fut_ptr as *mut F) }; + tx + .send(fut.await) + .map_err(drop) + .expect("failed to send block_on result to channel") + }); + + futures::executor::block_on(async move { + rx.await.expect("failed to receive block_on result from channel") + }) + std::mem::forget(fut); +} +*/ + +/// This used to be our tokio main entry point. Now this just calls out to `block_on` for +/// compatibility, which will perform all the necessary tasks on-demand anyway. +pub fn main(fut: F) -> F::Output +where + F: Future + Send, + F::Output: Send + 'static, +{ + block_on(fut) } diff --git a/src/tools/wrapped_reader_stream.rs b/src/tools/wrapped_reader_stream.rs index 422b17cd..927132dc 100644 --- a/src/tools/wrapped_reader_stream.rs +++ b/src/tools/wrapped_reader_stream.rs @@ -2,9 +2,10 @@ use std::io::{self, Read}; use std::pin::Pin; use std::task::{Context, Poll}; -use tokio::task::block_in_place; use futures::stream::Stream; +use crate::tools::runtime::block_in_place; + pub struct WrappedReaderStream { reader: R, buffer: Vec,