diff --git a/src/tools/runtime.rs b/src/tools/runtime.rs index ed7bd7ed..4054a71c 100644 --- a/src/tools/runtime.rs +++ b/src/tools/runtime.rs @@ -2,6 +2,9 @@ use std::cell::RefCell; use std::future::Future; +use std::sync::Arc; +use std::task::{Context, Poll, RawWaker, Waker}; +use std::thread::{self, Thread}; use lazy_static::lazy_static; use tokio::runtime::{self, Runtime}; @@ -79,16 +82,16 @@ pub fn block_in_place(fut: impl FnOnce() -> R) -> R { pub fn block_on(fut: F) -> F::Output { if is_in_tokio() { // inside a tokio worker we need to tell tokio that we're about to really block: - tokio::task::block_in_place(move || futures::executor::block_on(fut)) + tokio::task::block_in_place(move || block_on_local_future(fut)) } else if has_runtime() { // we're already associated with a runtime, but we're not a worker-thread, we can just // block this thread directly // This is not strictly necessary, but it's a bit quicker tha the else branch below. - futures::executor::block_on(fut) + block_on_local_future(fut) } else { // not a worker thread, not associated with a runtime, make sure we have a runtime (spawn // it on demand if necessary), then enter it: - enter_runtime(move || futures::executor::block_on(fut)) + enter_runtime(move || block_on_local_future(fut)) } } @@ -120,3 +123,50 @@ where pub fn main(fut: F) -> F::Output { block_on(fut) } + +fn block_on_local_future(mut fut: F) -> F::Output { + use std::pin::Pin; + let mut fut = unsafe { Pin::new_unchecked(&mut fut) }; + + let waker = Arc::new(thread::current()); + let waker = thread_waker_clone(Arc::into_raw(waker) as *const ()); + let waker = unsafe { Waker::from_raw(waker) }; + let mut context = Context::from_waker(&waker); + loop { + match fut.as_mut().poll(&mut context) { + Poll::Ready(out) => return out, + Poll::Pending => continue, + } + } +} + +const THREAD_WAKER_VTABLE: std::task::RawWakerVTable = std::task::RawWakerVTable::new( + thread_waker_clone, + thread_waker_wake, + thread_waker_wake_by_ref, + thread_waker_drop, +); + +fn thread_waker_clone(this: *const ()) -> RawWaker { + let this = unsafe { Arc::from_raw(this as *const Thread) }; + let cloned = Arc::clone(&this); + let _ = Arc::into_raw(this); + + RawWaker::new(Arc::into_raw(cloned) as *const (), &THREAD_WAKER_VTABLE) +} + +fn thread_waker_wake(this: *const ()) { + let this = unsafe { Arc::from_raw(this as *const Thread) }; + this.unpark(); +} + +fn thread_waker_wake_by_ref(this: *const ()) { + let this = unsafe { Arc::from_raw(this as *const Thread) }; + this.unpark(); + let _ = Arc::into_raw(this); +} + +fn thread_waker_drop(this: *const ()) { + let this = unsafe { Arc::from_raw(this as *const Thread) }; + drop(this); +}