drop RawWaker usage

this was also leaking a refcount before, this is fixed now

See-also: proxmox/proxmox-async:
  * d0a3e38006fe ("drop RawWaker usage")
  * ff132e93c6fd ("rustfmt")

Signed-off-by: Wolfgang Bumiller <w.bumiller@proxmox.com>
This commit is contained in:
Wolfgang Bumiller 2022-01-20 15:41:00 +01:00
parent d2351f1a81
commit 9b00099ead
1 changed files with 24 additions and 44 deletions

View File

@ -2,8 +2,8 @@
use std::cell::RefCell; use std::cell::RefCell;
use std::future::Future; use std::future::Future;
use std::sync::{Arc, Weak, Mutex}; use std::sync::{Arc, Mutex, Weak};
use std::task::{Context, Poll, RawWaker, Waker}; use std::task::{Context, Poll, Waker};
use std::thread::{self, Thread}; use std::thread::{self, Thread};
use lazy_static::lazy_static; use lazy_static::lazy_static;
@ -15,8 +15,7 @@ thread_local! {
} }
fn is_in_tokio() -> bool { fn is_in_tokio() -> bool {
tokio::runtime::Handle::try_current() tokio::runtime::Handle::try_current().is_ok()
.is_ok()
} }
fn is_blocking() -> bool { fn is_blocking() -> bool {
@ -49,7 +48,8 @@ lazy_static! {
static ref RUNTIME: Mutex<Weak<Runtime>> = Mutex::new(Weak::new()); static ref RUNTIME: Mutex<Weak<Runtime>> = Mutex::new(Weak::new());
} }
extern { #[link(name = "crypto")]
extern "C" {
fn OPENSSL_thread_stop(); fn OPENSSL_thread_stop();
} }
@ -58,16 +58,19 @@ extern {
/// This makes sure that tokio's worker threads are marked for us so that we know whether we /// This makes sure that tokio's worker threads are marked for us so that we know whether we
/// can/need to use `block_in_place` in our `block_on` helper. /// can/need to use `block_in_place` in our `block_on` helper.
pub fn get_runtime_with_builder<F: Fn() -> runtime::Builder>(get_builder: F) -> Arc<Runtime> { pub fn get_runtime_with_builder<F: Fn() -> runtime::Builder>(get_builder: F) -> Arc<Runtime> {
let mut guard = RUNTIME.lock().unwrap(); let mut guard = RUNTIME.lock().unwrap();
if let Some(rt) = guard.upgrade() { return rt; } if let Some(rt) = guard.upgrade() {
return rt;
}
let mut builder = get_builder(); let mut builder = get_builder();
builder.on_thread_stop(|| { builder.on_thread_stop(|| {
// avoid openssl bug: https://github.com/openssl/openssl/issues/6214 // avoid openssl bug: https://github.com/openssl/openssl/issues/6214
// call OPENSSL_thread_stop to avoid race with openssl cleanup handlers // call OPENSSL_thread_stop to avoid race with openssl cleanup handlers
unsafe { OPENSSL_thread_stop(); } unsafe {
OPENSSL_thread_stop();
}
}); });
let runtime = builder.build().expect("failed to spawn tokio runtime"); let runtime = builder.build().expect("failed to spawn tokio runtime");
@ -82,7 +85,6 @@ pub fn get_runtime_with_builder<F: Fn() -> runtime::Builder>(get_builder: F) ->
/// ///
/// This calls get_runtime_with_builder() using the tokio default threaded scheduler /// This calls get_runtime_with_builder() using the tokio default threaded scheduler
pub fn get_runtime() -> Arc<Runtime> { pub fn get_runtime() -> Arc<Runtime> {
get_runtime_with_builder(|| { get_runtime_with_builder(|| {
let mut builder = runtime::Builder::new_multi_thread(); let mut builder = runtime::Builder::new_multi_thread();
builder.enable_all(); builder.enable_all();
@ -90,7 +92,6 @@ pub fn get_runtime() -> Arc<Runtime> {
}) })
} }
/// Block on a synchronous piece of code. /// Block on a synchronous piece of code.
pub fn block_in_place<R>(fut: impl FnOnce() -> R) -> R { pub fn block_in_place<R>(fut: impl FnOnce() -> R) -> R {
// don't double-exit the context (tokio doesn't like that) // don't double-exit the context (tokio doesn't like that)
@ -155,12 +156,22 @@ pub fn main<F: Future>(fut: F) -> F::Output {
block_on(fut) block_on(fut)
} }
struct ThreadWaker(Thread);
impl std::task::Wake for ThreadWaker {
fn wake(self: Arc<Self>) {
self.0.unpark();
}
fn wake_by_ref(self: &Arc<Self>) {
self.0.unpark();
}
}
fn block_on_local_future<F: Future>(fut: F) -> F::Output { fn block_on_local_future<F: Future>(fut: F) -> F::Output {
pin_mut!(fut); pin_mut!(fut);
let waker = Arc::new(thread::current()); let waker = Waker::from(Arc::new(ThreadWaker(thread::current())));
let waker = thread_waker_clone(Arc::into_raw(waker) as *const ());
let waker = unsafe { Waker::from_raw(waker) };
let mut context = Context::from_waker(&waker); let mut context = Context::from_waker(&waker);
loop { loop {
match fut.as_mut().poll(&mut context) { match fut.as_mut().poll(&mut context) {
@ -169,34 +180,3 @@ fn block_on_local_future<F: Future>(fut: F) -> F::Output {
} }
} }
} }
const THREAD_WAKER_VTABLE: std::task::RawWakerVTable = std::task::RawWakerVTable::new(
thread_waker_clone,
thread_waker_wake,
thread_waker_wake_by_ref,
thread_waker_drop,
);
fn thread_waker_clone(this: *const ()) -> RawWaker {
let this = unsafe { Arc::from_raw(this as *const Thread) };
let cloned = Arc::clone(&this);
let _ = Arc::into_raw(this);
RawWaker::new(Arc::into_raw(cloned) as *const (), &THREAD_WAKER_VTABLE)
}
fn thread_waker_wake(this: *const ()) {
let this = unsafe { Arc::from_raw(this as *const Thread) };
this.unpark();
}
fn thread_waker_wake_by_ref(this: *const ()) {
let this = unsafe { Arc::from_raw(this as *const Thread) };
this.unpark();
let _ = Arc::into_raw(this);
}
fn thread_waker_drop(this: *const ()) {
let this = unsafe { Arc::from_raw(this as *const Thread) };
drop(this);
}