use std::borrow::Borrow; use std::collections::hash_map; use std::collections::{HashMap, HashSet}; use std::io::{Read, Write}; use std::mem; use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; use endian_trait::Endian; use failure::*; use crate::common; use crate::protocol::*; use crate::tools::swapped_data_to_buf; use crate::{ChunkEntry, FixedChunk, IndexType}; #[derive(Clone, Copy, Eq, PartialEq)] #[repr(transparent)] pub struct BackupStream(pub(crate) u8); #[derive(Clone, Copy, Eq, PartialEq)] #[repr(transparent)] pub struct StreamId(pub(crate) u8); impl From for StreamId { fn from(v: BackupStream) -> Self { Self(v.0) } } struct BackupStreamData { id: u8, index_type: IndexType, pos: u64, path: Option, } pub enum AckState { Waiting, // no ack received yet. Received, // already received an ack, but the user hasn't seen it yet. Ignore, // client doesn't care. AwaitingData, // waiting for something other than an 'Ok' packet } pub struct Client where S: Read + Write, { chunks: RwLock>, common: common::Connection, handshake_done: bool, cur_id: u8, free_ids: Vec, waiting_ids: HashMap, hash_download: Option, upload_chunk: Option, upload_id: u8, upload_pos: usize, upload_state: u8, streams: HashMap, } type Result = std::result::Result; impl Client where S: Read + Write, { pub fn new(socket: S) -> Self { Self { chunks: RwLock::new(HashSet::new()), common: common::Connection::new(socket), handshake_done: false, cur_id: 1, free_ids: Vec::new(), waiting_ids: HashMap::new(), hash_download: None, upload_state: 0, upload_pos: 0, upload_id: 0, upload_chunk: None, streams: HashMap::new(), } } pub fn eof(&self) -> bool { self.common.eof } pub fn error(&self) -> bool { self.common.error } /// It is safe to clear the error after an `io::ErrorKind::Interrupted`. pub fn clear_err(&mut self) { self.common.clear_err() } pub fn wait_for_handshake(&mut self) -> Result { if !self.handshake_done { self.poll_read(true)?; } Ok(self.handshake_done) } pub fn query_hashes(&mut self, file_name: &str) -> Result<()> { if self.hash_download.is_some() { bail!("hash query already in progress"); } let id = self.next_id()?; let mut packet = Packet::builder(id, PacketType::GetHashList); packet .write_data(client::GetHashList { name_length: file_name.len() as u16, }) .write_buf(file_name.as_bytes()); self.common.queue_data(packet.finish())?; self.hash_download = Some(id); Ok(()) } pub fn wait_for_hashes(&mut self) -> Result { while self.hash_download.is_some() { if !self.poll_read(true)? { break; } } Ok(self.hash_download.is_none()) } fn chunk_read_lock(&self) -> Result>> { self.chunks .read() .map_err(|_| format_err!("lock poisoned, disconnecting client...")) } pub fn is_chunk_available>(&self, chunk: &T) -> bool { self.chunk_read_lock().unwrap().contains(chunk.borrow()) } /// Attempts to upload a chunk. Returns an error state only on fatal errors. If the underlying /// writer returns a `WouldBlock` error this returns `None` and `continue_upload_chunk` has to /// be called until the chunk is uploaded completely. During this time no other operations /// should be performed on the this object! /// See `continue_upload_chunk()` for a description of the returned value. pub fn upload_chunk(&mut self, info: &T, data: &[u8]) -> Result> where T: Borrow, { if self.upload_chunk.is_some() { bail!("cannot concurrently upload multiple chunks"); } self.upload_id = self.next_id()?; self.upload_chunk = Some(info.borrow().clone()); self.next_upload_state(0); self.continue_upload_chunk(data) } fn next_upload_state(&mut self, state: u8) { self.upload_state = state; self.upload_pos = 0; } // This is split into a static method not borrowing self so the buffer can point into self... fn do_upload_write( writer: &mut common::Connection, buf: &[u8], ) -> Result> { match writer.write_some(buf) { Ok(put) => Ok(Some((put, put == buf.len()))), Err(e) => { if e.kind() == std::io::ErrorKind::WouldBlock { Ok(None) } else { Err(e.into()) } } } } fn after_upload_write(&mut self, put: Option<(usize, bool)>, next_state: u8) -> bool { match put { None => return false, Some((put, done)) => { if done { self.next_upload_state(next_state); } else { self.upload_pos += put; } done } } } fn upload_write(&mut self, buf: &[u8], next_state: u8) -> Result { let wrote = Self::do_upload_write(&mut self.common, buf)?; Ok(self.after_upload_write(wrote, next_state)) } /// If an `upload_chunk()` call returned `Ok(false)` this needs to be used to complete the /// upload process as the chunk may have already been partially written to the socket. /// This function will return `Ok(false)` on `WouldBlock` errors just like `upload_chunk()` /// will, after which the caller should wait for the writer to become write-ready and then /// call this method again. /// Once the complete chunk packet has been written to the underlying socket, this returns a /// packet ID which can be waited upon for completion pub fn continue_upload_chunk(&mut self, data: &[u8]) -> Result> { loop { match self.upload_state { // Writing the packet header: 0 => { let len = mem::size_of::() + mem::size_of::() + data.len(); let packet = Packet { id: self.upload_id, pkttype: PacketType::UploadChunk as _, length: len as _, } .to_le(); let buf = unsafe { swapped_data_to_buf(&packet) }; if !self.upload_write(&buf[self.upload_pos..], 1)? { return Ok(None); } } // Writing the hash: 1 => { let chunk = self.upload_chunk.as_ref().unwrap(); let buf = &chunk.0[self.upload_pos..]; let wrote = Self::do_upload_write(&mut self.common, buf)?; if !self.after_upload_write(wrote, 2) { return Ok(None); } } // Writing the data: 2 => { if !self.upload_write(&data[self.upload_pos..], 3)? { return Ok(None); } } // Done: 3 => { self.upload_chunk = None; self.expect_ok_for_id(self.upload_id); return Ok(Some(StreamId(self.upload_id))); } n => bail!("bad chunk upload state: {}", n), } } } // generic data polling method, returns true if at least one packet was received pub fn poll_read(&mut self, one: bool) -> Result { if self.common.eof { // polls after EOF are errors: bail!("server disconnected"); } if !self.common.poll_read()? { // On the client side we do not expect a server-side disconnect, so error out! if self.common.eof { bail!("server disconnected"); } return Ok(false); } loop { match self.common.current_packet_type { PacketType::Ok => self.recv_ok()?, PacketType::Hello => self.recv_hello()?, PacketType::HashListPart => self.recv_hash_list()?, PacketType::BackupCreated => self.backup_created()?, _ => bail!( "server sent an unexpected packet of type {}", self.common.current_packet_type as u32, ), } self.common.next()?; if one || !self.common.poll_read()? { break; } } Ok(true) } // None => nothing was queued // Some(true) => queue finished // Some(false) => queue not finished pub fn poll_send(&mut self) -> Result> { self.common.poll_send() } // private helpermethods fn next_id(&mut self) -> Result { if let Some(id) = self.free_ids.pop() { return Ok(id); } if self.cur_id < 0xff { self.cur_id += 1; return Ok(self.cur_id - 1); } bail!("too many concurrent transactions"); } fn free_id(&mut self, id: u8) { self.free_ids.push(id); } fn expect_ok_for_id(&mut self, id: u8) { self.waiting_ids.insert(id, AckState::Waiting); } fn expect_response_for_id(&mut self, id: u8) { self.waiting_ids.insert(id, AckState::AwaitingData); } // Methods handling received packets: fn recv_hello(&mut self) -> Result<()> { let hello = self.common.read_unaligned_data::(0)?; if hello.magic != server::HELLO_MAGIC { bail!("received an invalid hello packet"); } let ver = hello.version; if ver != server::PROTOCOL_VERSION { bail!( "hello packet contained incompatible protocol version: {} (!= {})", ver, server::PROTOCOL_VERSION, ); } self.handshake_done = true; self.free_id(0); Ok(()) } fn chunk_write_lock(&self) -> Result>> { self.chunks .write() .map_err(|_| format_err!("lock poisoned, disconnecting client...")) } fn recv_hash_list(&mut self) -> Result<()> { let data = self.common.packet_data(); if data.len() == 0 { // No more hashes, we're done self.hash_download = None; return Ok(()); } if (data.len() % mem::size_of::()) != 0 { bail!("hash list contains invalid size"); } let chunk_list: &[FixedChunk] = unsafe { std::slice::from_raw_parts( data.as_ptr() as *const FixedChunk, data.len() / mem::size_of::(), ) }; let mut my_chunks = self.chunk_write_lock()?; for c in chunk_list { eprintln!("Got chunk '{}'", c.digest_to_hex()); my_chunks.insert(c.clone()); } Ok(()) } fn ack_id(&mut self, id: u8, data_packet: bool) -> Result<()> { use hash_map::Entry::*; match self.waiting_ids.entry(id) { Vacant(_) => bail!("received unexpected packet for transaction id {}", id), Occupied(mut entry) => match entry.get() { AckState::Ignore => { entry.remove(); } AckState::Received => bail!("duplicate Ack received for transaction id {}", id), AckState::Waiting => { if data_packet { bail!("received data packet while expecting simple Ok for {}", id); } *entry.get_mut() = AckState::Received; } AckState::AwaitingData => { if !data_packet { bail!( "received empty Ok while waiting for data on stream id {}", id ); } *entry.get_mut() = AckState::Received; } }, } Ok(()) } fn recv_ok(&mut self) -> Result<()> { self.ack_id(self.common.current_packet.id, false) } pub fn wait_for_id(&mut self, id: StreamId) -> Result { if !self.waiting_ids.contains_key(&id.0) { bail!("wait_for_id() called on unexpected id {}", id.0); } loop { if !self.poll_read(true)? { return Ok(false); } use hash_map::Entry::*; match self.waiting_ids.entry(id.0) { Vacant(_) => return Ok(true), Occupied(entry) => match entry.get() { AckState::Received => { entry.remove(); return Ok(true); } _ => continue, }, } } } pub fn discard_id(&mut self, id: StreamId) -> Result<()> { use hash_map::Entry::*; match self.waiting_ids.entry(id.0) { Vacant(_) => bail!("discard_id called with unknown id {}", id.0), Occupied(mut entry) => match entry.get() { AckState::Ignore => (), AckState::Received => { entry.remove(); } AckState::Waiting | AckState::AwaitingData => { *entry.get_mut() = AckState::Ignore; } }, } Ok(()) } pub fn create_backup( &mut self, index_type: IndexType, backup_type: &str, id: &str, timestamp: i64, file_name: &str, chunk_size: usize, file_size: Option, is_new: bool, ) -> Result { let backup_type = backup_type::name_to_id(backup_type)?; if id.len() > 0xff { bail!("id too long"); } if file_name.len() > 0xff { bail!("file name too long"); } let mut flags: backup_flags::Type = 0; if is_new { flags |= backup_flags::EXCL; } if index_type == IndexType::Dynamic { flags |= backup_flags::DYNAMIC_CHUNKS; if file_size.is_some() { bail!("file size must be None on dynamic backup streams"); } } else if file_size.is_none() { bail!("file size is mandatory for fixed backup streams"); } let packet_id = self.next_id()?; let mut packet = Packet::builder(packet_id, PacketType::CreateBackup); packet .write_data(client::CreateBackup { backup_type, id_length: id.len() as _, timestamp: timestamp as u64, flags, name_length: file_name.len() as _, chunk_size: chunk_size as _, file_size: file_size.unwrap_or(0) as u64, }) .write_buf(id.as_bytes()) .write_buf(file_name.as_bytes()); self.streams.insert( packet_id, BackupStreamData { id: packet_id, index_type, pos: 0, path: None, }, ); self.expect_response_for_id(packet_id); self.common.queue_data(packet.finish())?; Ok(BackupStream(packet_id)) } fn backup_created(&mut self) -> Result<()> { let info = self .common .read_unaligned_data::(0)?; let data = &self.common.packet_data()[mem::size_of_val(&info)..]; if data.len() != info.path_length as usize { bail!("backup-created packet has invalid length"); } let name = std::str::from_utf8(data)?; let pkt_id = self.common.current_packet.id; self.streams .get_mut(&pkt_id) .ok_or_else(|| format_err!("BackupCreated response for invalid stream: {}", pkt_id))? .path = Some(name.to_string()); self.ack_id(pkt_id, true)?; Ok(()) } pub fn dynamic_chunk(&mut self, stream: BackupStream, entry: &ChunkEntry) -> Result { self.dynamic_data(stream, &entry.hash, entry.size) } pub fn dynamic_data>( &mut self, stream: BackupStream, digest: &T, size: u64, ) -> Result { let data = self .streams .get_mut(&stream.0) .ok_or_else(|| format_err!("no such active backup stream"))?; if data.index_type != IndexType::Dynamic { bail!("dynamic_data called for stream of static chunks"); } let mut packet = Packet::builder(data.id, PacketType::BackupDataDynamic); packet .write_data(data.pos as u64) .write_buf(&digest.borrow().0); data.pos += size; self.common.queue_data(packet.finish()) } pub fn fixed_data>( &mut self, stream: BackupStream, index: usize, digest: &T, ) -> Result { let data = self .streams .get_mut(&stream.0) .ok_or_else(|| format_err!("no such active backup stream"))?; if data.index_type != IndexType::Fixed { bail!("fixed_data called for stream of dynamic chunks"); } let mut packet = Packet::builder(data.id, PacketType::BackupDataFixed); packet .write_data(index as u64) .write_buf(&digest.borrow().0); self.common.queue_data(packet.finish()) } pub fn finish_backup(&mut self, stream: BackupStream) -> Result<(StreamId, String, bool)> { let path = self .streams .remove(&stream.0) .ok_or_else(|| format_err!("no such active backup stream"))? .path .unwrap_or_else(|| "".to_string()); let ack = self .common .queue_data(Packet::simple(stream.0, PacketType::BackupFinished))?; self.expect_ok_for_id(stream.0); Ok((StreamId(stream.0), path, ack)) } }