diff --git a/src/tools/async_mutex.rs b/src/tools/async_mutex.rs index bd6cb06b..eeaa2af3 100644 --- a/src/tools/async_mutex.rs +++ b/src/tools/async_mutex.rs @@ -1,50 +1,46 @@ -use std::marker::PhantomData; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; -use failure::{bail, Error}; -use futures::{Async, Poll}; -use futures::future::Future; -use tokio::sync::lock::Lock as TokioLock; -pub use tokio::sync::lock::LockGuard as AsyncLockGuard; +use failure::Error; +use futures::future::FutureExt; +use tokio::sync::Lock as TokioLock; -pub struct AsyncMutex(TokioLock); +pub use tokio::sync::LockGuard as AsyncLockGuard; -unsafe impl Sync for AsyncMutex {} +pub struct AsyncMutex(TokioLock); -impl AsyncMutex { +unsafe impl Sync for AsyncMutex {} + +impl AsyncMutex { pub fn new(value: T) -> Self { Self(TokioLock::new(value)) } - // to allow any error type (we never error, so we have no error type of our own) - pub fn lock(&self) -> LockFuture { + pub fn lock(&self) -> LockFuture { + let mut lock = self.0.clone(); LockFuture { - lock: self.0.clone(), - _error: PhantomData, + lock: async move { lock.lock().await }.boxed(), } } + // FIXME: remove Result<> from this. pub fn new_locked(value: T) -> Result<(Self, AsyncLockGuard), Error> { let mut this = Self::new(value); - let guard = match this.0.poll_lock() { - Async::Ready(guard) => guard, - _ => bail!("failed to create locked mutex"), - }; + let guard = futures::executor::block_on(this.0.lock()); Ok((this, guard)) } } /// Represents a lock to be held in the future: -pub struct LockFuture { - lock: TokioLock, - // We can't error and we don't want to enforce a specific error type either - _error: PhantomData, +pub struct LockFuture { + lock: Pin> + Send + 'static>>, } -impl Future for LockFuture { - type Item = AsyncLockGuard; - type Error = E; +impl Future for LockFuture { + type Output = AsyncLockGuard; - fn poll(&mut self) -> Poll, E> { - Ok(self.lock.poll_lock()) + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.lock.poll_unpin(cx) } }