diff --git a/src/server/rest.rs b/src/server/rest.rs index 2b835c4a..365e3570 100644 --- a/src/server/rest.rs +++ b/src/server/rest.rs @@ -531,51 +531,89 @@ async fn handle_static_file_download(filename: PathBuf) -> Result (Option, Option, Option) { - - let mut ticket = None; - let mut language = None; +fn extract_lang_header(headers: &http::HeaderMap) -> Option { if let Some(raw_cookie) = headers.get("COOKIE") { if let Ok(cookie) = raw_cookie.to_str() { - ticket = tools::extract_cookie(cookie, "PBSAuthCookie"); - language = tools::extract_cookie(cookie, "PBSLangCookie"); + return tools::extract_cookie(cookie, "PBSLangCookie"); } } - let csrf_token = match headers.get("CSRFPreventionToken").map(|v| v.to_str()) { - Some(Ok(v)) => Some(v.to_owned()), - _ => None, - }; + None +} - (ticket, csrf_token, language) +struct UserAuthData{ + ticket: String, + csrf_token: Option, +} + +enum AuthData { + User(UserAuthData), + ApiToken(String), +} + +fn extract_auth_data(headers: &http::HeaderMap) -> Option { + if let Some(raw_cookie) = headers.get("COOKIE") { + if let Ok(cookie) = raw_cookie.to_str() { + if let Some(ticket) = tools::extract_cookie(cookie, "PBSAuthCookie") { + let csrf_token = match headers.get("CSRFPreventionToken").map(|v| v.to_str()) { + Some(Ok(v)) => Some(v.to_owned()), + _ => None, + }; + return Some(AuthData::User(UserAuthData { + ticket, + csrf_token, + })); + } + } + } + + match headers.get("AUTHORIZATION").map(|v| v.to_str()) { + Some(Ok(v)) => Some(AuthData::ApiToken(v.to_owned())), + _ => None, + } } fn check_auth( method: &hyper::Method, - ticket: &Option, - csrf_token: &Option, + auth_data: &AuthData, user_info: &CachedUserInfo, ) -> Result { - let ticket_lifetime = tools::ticket::TICKET_LIFETIME; + match auth_data { + AuthData::User(user_auth_data) => { + let ticket = user_auth_data.ticket.clone(); + let ticket_lifetime = tools::ticket::TICKET_LIFETIME; - let ticket = ticket.as_ref().map(String::as_str); - let userid: Userid = Ticket::parse(&ticket.ok_or_else(|| format_err!("missing ticket"))?)? - .verify_with_time_frame(public_auth_key(), "PBS", None, -300..ticket_lifetime)?; + let userid: Userid = Ticket::parse(&ticket)? + .verify_with_time_frame(public_auth_key(), "PBS", None, -300..ticket_lifetime)?; - let auth_id = Authid::from(userid.clone()); - if !user_info.is_active_auth_id(&auth_id) { - bail!("user account disabled or expired."); - } + let auth_id = Authid::from(userid.clone()); + if !user_info.is_active_auth_id(&auth_id) { + bail!("user account disabled or expired."); + } - if method != hyper::Method::GET { - if let Some(csrf_token) = csrf_token { - verify_csrf_prevention_token(csrf_secret(), &userid, &csrf_token, -300, ticket_lifetime)?; - } else { - bail!("missing CSRF prevention token"); + if method != hyper::Method::GET { + if let Some(csrf_token) = &user_auth_data.csrf_token { + verify_csrf_prevention_token(csrf_secret(), &userid, &csrf_token, -300, ticket_lifetime)?; + } else { + bail!("missing CSRF prevention token"); + } + } + + Ok(auth_id) + }, + AuthData::ApiToken(api_token) => { + let mut parts = api_token.splitn(2, ':'); + let tokenid = parts.next() + .ok_or_else(|| format_err!("failed to split API token header"))?; + let tokenid: Authid = tokenid.parse()?; + + let tokensecret = parts.next() + .ok_or_else(|| format_err!("failed to split API token header"))?; + crate::config::token_shadow::verify_secret(&tokenid, &tokensecret)?; + + Ok(tokenid) } } - - Ok(Authid::from(userid)) } async fn handle_request( @@ -631,8 +669,11 @@ async fn handle_request( } if auth_required { - let (ticket, csrf_token, _) = extract_auth_data(&parts.headers); - match check_auth(&method, &ticket, &csrf_token, &user_info) { + let auth_result = match extract_auth_data(&parts.headers) { + Some(auth_data) => check_auth(&method, &auth_data, &user_info), + None => Err(format_err!("no authentication credentials provided.")), + }; + match auth_result { Ok(authid) => rpcenv.set_auth_id(Some(authid.to_string())), Err(err) => { // always delay unauthorized calls by 3 seconds (from start of request) @@ -685,14 +726,14 @@ async fn handle_request( } if comp_len == 0 { - let (ticket, csrf_token, language) = extract_auth_data(&parts.headers); - if ticket != None { - match check_auth(&method, &ticket, &csrf_token, &user_info) { - Ok(auth_id) => { + let language = extract_lang_header(&parts.headers); + if let Some(auth_data) = extract_auth_data(&parts.headers) { + match check_auth(&method, &auth_data, &user_info) { + Ok(auth_id) if !auth_id.is_token() => { let userid = auth_id.user(); let new_csrf_token = assemble_csrf_prevention_token(csrf_secret(), userid); return Ok(get_index(Some(userid.clone()), Some(new_csrf_token), language, &api, parts)); - } + }, _ => { tokio::time::delay_until(Instant::from_std(delay_unauth_time)).await; return Ok(get_index(None, None, language, &api, parts));