src/server/rest.rs: switch to async

Signed-off-by: Wolfgang Bumiller <w.bumiller@proxmox.com>
This commit is contained in:
Wolfgang Bumiller 2019-08-26 13:33:38 +02:00
parent b9b7f7ec1f
commit 91e4587343

View File

@ -1,31 +1,27 @@
use crate::tools;
use crate::api_schema::*;
use crate::api_schema::router::*;
use crate::api_schema::config::*;
use crate::auth_helpers::*;
use super::environment::RestEnvironment;
use super::formatter::*;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use failure::*;
use futures::future::{self, Either, FutureExt, TryFutureExt};
use futures::stream::TryStreamExt;
use hyper::header;
use hyper::http::request::Parts;
use hyper::rt::Future;
use hyper::{Body, Request, Response, StatusCode};
use serde_json::{json, Value};
use tokio::fs::File;
use url::form_urlencoded;
use futures::future::{self, Either};
//use tokio::prelude::*;
//use tokio::timer::Delay;
use tokio::fs::File;
//use bytes::{BytesMut, BufMut};
//use hyper::body::Payload;
use hyper::http::request::Parts;
use hyper::{Body, Request, Response, StatusCode};
use hyper::service::{Service, MakeService};
use hyper::rt::{Future, Stream};
use hyper::header;
use super::environment::RestEnvironment;
use super::formatter::*;
use crate::api_schema::config::*;
use crate::api_schema::router::*;
use crate::api_schema::*;
use crate::auth_helpers::*;
use crate::tools;
extern "C" { fn tzset(); }
@ -40,47 +36,48 @@ impl RestServer {
}
}
impl MakeService<&tokio_openssl::SslStream<tokio::net::TcpStream>> for RestServer
{
type ReqBody = Body;
type ResBody = Body;
impl tower_service::Service<&tokio_openssl::SslStream<tokio::net::TcpStream>> for RestServer {
type Response = ApiService;
type Error = Error;
type MakeError = Error;
type Service = ApiService;
type Future = Box<dyn Future<Item = Self::Service, Error = Self::MakeError> + Send>;
fn make_service(&mut self, ctx: &tokio_openssl::SslStream<tokio::net::TcpStream>) -> Self::Future {
match ctx.get_ref().get_ref().peer_addr() {
type Future = Pin<Box<dyn Future<Output = Result<ApiService, Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, ctx: &tokio_openssl::SslStream<tokio::net::TcpStream>) -> Self::Future {
match ctx.get_ref().peer_addr() {
Err(err) => {
Box::new(future::err(format_err!("unable to get peer address - {}", err)))
future::err(format_err!("unable to get peer address - {}", err)).boxed()
}
Ok(peer) => {
Box::new(future::ok(ApiService { peer, api_config: self.api_config.clone() }))
future::ok(ApiService { peer, api_config: self.api_config.clone() }).boxed()
}
}
}
}
impl MakeService<&tokio::net::TcpStream> for RestServer
{
type ReqBody = Body;
type ResBody = Body;
impl tower_service::Service<&tokio::net::TcpStream> for RestServer {
type Response = ApiService;
type Error = Error;
type MakeError = Error;
type Service = ApiService;
type Future = Box<dyn Future<Item = Self::Service, Error = Self::MakeError> + Send>;
fn make_service(&mut self, ctx: &tokio::net::TcpStream) -> Self::Future {
type Future = Pin<Box<dyn Future<Output = Result<ApiService, Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, ctx: &tokio::net::TcpStream) -> Self::Future {
match ctx.peer_addr() {
Err(err) => {
Box::new(future::err(format_err!("unable to get peer address - {}", err)))
future::err(format_err!("unable to get peer address - {}", err)).boxed()
}
Ok(peer) => {
Box::new(future::ok(ApiService { peer, api_config: self.api_config.clone() }))
future::ok(ApiService { peer, api_config: self.api_config.clone() }).boxed()
}
}
}
}
pub struct ApiService {
pub peer: std::net::SocketAddr,
pub api_config: Arc<ApiConfig>,
@ -109,19 +106,22 @@ fn log_response(
}
}
impl Service for ApiService {
type ReqBody = Body;
type ResBody = Body;
impl tower_service::Service<Request<Body>> for ApiService {
type Response = Response<Body>;
type Error = Error;
type Future = Box<dyn Future<Item = Response<Body>, Error = Self::Error> + Send>;
type Future = Pin<Box<dyn Future<Output = Result<Response<Body>, Self::Error>> + Send>>;
fn call(&mut self, req: Request<Self::ReqBody>) -> Self::Future {
fn poll_ready(&mut self, _cx: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let path = req.uri().path().to_owned();
let method = req.method().clone();
let peer = self.peer.clone();
Box::new(handle_request(self.api_config.clone(), req).then(move |result| {
match result {
Pin::from(handle_request(self.api_config.clone(), req))
.map(move |result| match result {
Ok(res) => {
log_response(&peer, method, &path, &res);
Ok::<_, Self::Error>(res)
@ -139,8 +139,8 @@ impl Service for ApiService {
Ok(resp)
}
}
}
}))
})
.boxed()
}
}
@ -149,7 +149,7 @@ fn get_request_parameters_async(
parts: Parts,
req_body: Body,
uri_param: HashMap<String, String>,
) -> Box<dyn Future<Item = Value, Error = failure::Error> + Send>
) -> Box<dyn Future<Output = Result<Value, failure::Error>> + Send>
{
let mut is_json = false;
@ -169,15 +169,15 @@ fn get_request_parameters_async(
let resp = req_body
.map_err(|err| http_err!(BAD_REQUEST, format!("Promlems reading request body: {}", err)))
.fold(Vec::new(), |mut acc, chunk| {
.try_fold(Vec::new(), |mut acc, chunk| async move {
if acc.len() + chunk.len() < 64*1024 { //fimxe: max request body size?
acc.extend_from_slice(&*chunk);
Ok(acc)
} else {
Err(http_err!(BAD_REQUEST, format!("Request body too large")))
}
else { Err(http_err!(BAD_REQUEST, format!("Request body too large"))) }
})
.and_then(move |body| {
.and_then(move |body| async move {
let utf8 = std::str::from_utf8(&body)?;
let obj_schema = &info.parameters;
@ -216,7 +216,7 @@ fn get_request_parameters_async(
let params = parse_parameter_strings(&param_list, obj_schema, true)?;
Ok(params)
});
}.boxed());
Box::new(resp)
}
@ -227,8 +227,7 @@ fn proxy_protected_request(
info: &'static ApiMethod,
mut parts: Parts,
req_body: Body,
) -> BoxFut
{
) -> BoxFut {
let mut uri_parts = parts.uri.clone().into_parts();
@ -243,19 +242,22 @@ fn proxy_protected_request(
let resp = hyper::client::Client::new()
.request(request)
.map_err(Error::from)
.map(|mut resp| {
.map_ok(|mut resp| {
resp.extensions_mut().insert(NoLogExtension());
resp
});
let resp = if info.reload_timezone {
Either::A(resp.then(|resp| {unsafe { tzset() }; resp }))
} else {
Either::B(resp)
};
return Box::new(resp);
let reload_timezone = info.reload_timezone;
Box::new(async move {
let result = resp.await;
if reload_timezone {
unsafe {
tzset();
}
}
result
})
}
pub fn handle_sync_api_request<Env: RpcEnvironment>(
@ -271,14 +273,16 @@ pub fn handle_sync_api_request<Env: RpcEnvironment>(
let delay_unauth_time = std::time::Instant::now() + std::time::Duration::from_millis(3000);
let resp = params
let resp = Pin::from(params)
.and_then(move |params| {
let mut delay = false;
let resp = match (info.handler.as_ref().unwrap())(params, info, &mut rpcenv) {
Ok(data) => (formatter.format_data)(data, &rpcenv),
Err(err) => {
if let Some(httperr) = err.downcast_ref::<HttpError>() {
if httperr.code == StatusCode::UNAUTHORIZED { delay = true; }
if httperr.code == StatusCode::UNAUTHORIZED {
delay = true;
}
}
(formatter.format_error)(err)
}
@ -289,13 +293,13 @@ pub fn handle_sync_api_request<Env: RpcEnvironment>(
}
if delay {
Either::A(delayed_response(resp, delay_unauth_time))
Either::Left(delayed_response(resp, delay_unauth_time))
} else {
Either::B(future::ok(resp))
Either::Right(future::ok(resp))
}
})
.or_else(move |err| {
Ok((formatter.format_error)(err))
future::ok((formatter.format_error)(err))
});
Box::new(resp)
@ -426,35 +430,37 @@ fn extension_to_content_type(filename: &Path) -> (&'static str, bool) {
("application/octet-stream", false)
}
fn simple_static_file_download(filename: PathBuf) -> BoxFut {
async fn simple_static_file_download(filename: PathBuf) -> Result<Response<Body>, Error> {
let (content_type, _nocomp) = extension_to_content_type(&filename);
Box::new(File::open(filename)
.map_err(|err| http_err!(BAD_REQUEST, format!("File open failed: {}", err)))
.and_then(move |file| {
let buf: Vec<u8> = Vec::new();
tokio::io::read_to_end(file, buf)
.map_err(|err| http_err!(BAD_REQUEST, format!("File read failed: {}", err)))
.and_then(move |data| {
let mut response = Response::new(data.1.into());
use tokio::io::AsyncReadExt;
let mut file = File::open(filename)
.await
.map_err(|err| http_err!(BAD_REQUEST, format!("File open failed: {}", err)))?;
let mut data: Vec<u8> = Vec::new();
file.read_to_end(&mut data)
.await
.map_err(|err| http_err!(BAD_REQUEST, format!("File read failed: {}", err)))?;
let mut response = Response::new(data.into());
response.headers_mut().insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static(content_type));
Ok(response)
})
}))
}
fn chuncked_static_file_download(filename: PathBuf) -> BoxFut {
async fn chuncked_static_file_download(filename: PathBuf) -> Result<Response<Body>, Error> {
let (content_type, _nocomp) = extension_to_content_type(&filename);
Box::new(File::open(filename)
.map_err(|err| http_err!(BAD_REQUEST, format!("File open failed: {}", err)))
.and_then(move |file| {
let payload = tokio::codec::FramedRead::new(file, tokio::codec::BytesCodec::new()).
map(|bytes| hyper::Chunk::from(bytes.freeze()));
let file = File::open(filename)
.await
.map_err(|err| http_err!(BAD_REQUEST, format!("File open failed: {}", err)))?;
let payload = tokio::codec::FramedRead::new(file, tokio::codec::BytesCodec::new())
.map_ok(|bytes| hyper::Chunk::from(bytes.freeze()));
let body = Body::wrap_stream(payload);
// fixme: set other headers ?
@ -462,19 +468,19 @@ fn chuncked_static_file_download(filename: PathBuf) -> BoxFut {
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, content_type)
.body(body)
.unwrap())
}))
.unwrap()
)
}
fn handle_static_file_download(filename: PathBuf) -> BoxFut {
let response = tokio::fs::metadata(filename.clone())
.map_err(|err| http_err!(BAD_REQUEST, format!("File access problems: {}", err)))
.and_then(|metadata| {
.and_then(|metadata| async move {
if metadata.len() < 1024*32 {
Either::A(simple_static_file_download(filename))
simple_static_file_download(filename).await
} else {
Either::B(chuncked_static_file_download(filename))
chuncked_static_file_download(filename).await
}
});
@ -523,11 +529,13 @@ fn check_auth(method: &hyper::Method, ticket: &Option<String>, token: &Option<St
Ok(username)
}
fn delayed_response(resp: Response<Body>, delay_unauth_time: std::time::Instant) -> BoxFut {
Box::new(tokio::timer::Delay::new(delay_unauth_time)
.map_err(|err| http_err!(INTERNAL_SERVER_ERROR, format!("tokio timer delay error: {}", err)))
.and_then(|_| Ok(resp)))
async fn delayed_response(
resp: Response<Body>,
delay_unauth_time: std::time::Instant,
) -> Result<Response<Body>, Error> {
tokio::timer::Delay::new(delay_unauth_time)
.await;
Ok(resp)
}
pub fn handle_request(api: Arc<ApiConfig>, req: Request<Body>) -> BoxFut {
@ -579,7 +587,9 @@ pub fn handle_request(api: Arc<ApiConfig>, req: Request<Body>) -> BoxFut {
Err(err) => {
// always delay unauthorized calls by 3 seconds (from start of request)
let err = http_err!(UNAUTHORIZED, format!("permission check failed - {}", err));
return delayed_response((formatter.format_error)(err), delay_unauth_time);
return Box::new(
delayed_response((formatter.format_error)(err), delay_unauth_time)
);
}
}
}
@ -616,7 +626,9 @@ pub fn handle_request(api: Arc<ApiConfig>, req: Request<Body>) -> BoxFut {
let new_token = assemble_csrf_prevention_token(csrf_secret(), &username);
return Box::new(future::ok(get_index(Some(username), Some(new_token))));
}
_ => return delayed_response(get_index(None, None), delay_unauth_time),
_ => {
return Box::new(delayed_response(get_index(None, None), delay_unauth_time));
}
}
} else {
return Box::new(future::ok(get_index(None, None)));