protocol: allow explicitly polling for a single packet

When using *blocking* I/O wait_for_* functions need to be
able to not read past their awaited packet.

Signed-off-by: Wolfgang Bumiller <w.bumiller@proxmox.com>
This commit is contained in:
Wolfgang Bumiller 2019-03-07 13:50:41 +01:00
parent 4932843923
commit e402769332
4 changed files with 31 additions and 23 deletions

View File

@ -273,7 +273,7 @@ pub extern "C" fn proxmox_backup_continue_upload(me: *mut CClient) -> c_int {
#[no_mangle] #[no_mangle]
pub extern "C" fn proxmox_backup_poll_read(me: *mut CClient) -> c_int { pub extern "C" fn proxmox_backup_poll_read(me: *mut CClient) -> c_int {
let me = unsafe { &mut *me }; let me = unsafe { &mut *me };
match me.client.poll_read() { match me.client.poll_read(false) {
Ok(_) => 0, Ok(_) => 0,
Err(e) => me.set_error(e), Err(e) => me.set_error(e),
} }

View File

@ -103,7 +103,7 @@ where
pub fn wait_for_handshake(&mut self) -> Result<bool> { pub fn wait_for_handshake(&mut self) -> Result<bool> {
if !self.handshake_done { if !self.handshake_done {
self.poll_read()?; self.poll_read(true)?;
} }
Ok(self.handshake_done) Ok(self.handshake_done)
} }
@ -126,8 +126,10 @@ where
} }
pub fn wait_for_hashes(&mut self) -> Result<bool> { pub fn wait_for_hashes(&mut self) -> Result<bool> {
if self.hash_download.is_some() { while self.hash_download.is_some() {
self.poll_read()?; if !self.poll_read(true)? {
break;
}
} }
Ok(self.hash_download.is_none()) Ok(self.hash_download.is_none())
} }
@ -254,8 +256,8 @@ where
} }
} }
// generic data polling method // generic data polling method, returns true if at least one packet was received
pub fn poll_read(&mut self) -> Result<()> { pub fn poll_read(&mut self, one: bool) -> Result<bool> {
if self.common.eof { if self.common.eof {
// polls after EOF are errors: // polls after EOF are errors:
bail!("server disconnected"); bail!("server disconnected");
@ -266,7 +268,7 @@ where
if self.common.eof { if self.common.eof {
bail!("server disconnected"); bail!("server disconnected");
} }
return Ok(()); return Ok(false);
} }
loop { loop {
@ -280,11 +282,12 @@ where
self.common.current_packet_type as u32, self.common.current_packet_type as u32,
), ),
} }
if !self.common.next()? { self.common.next()?;
if one || !self.common.poll_read()? {
break; break;
} }
} }
Ok(()) Ok(true)
} }
// None => nothing was queued // None => nothing was queued
@ -414,19 +417,23 @@ where
if !self.waiting_ids.contains_key(&id.0) { if !self.waiting_ids.contains_key(&id.0) {
bail!("wait_for_id() called on unexpected id {}", id.0); bail!("wait_for_id() called on unexpected id {}", id.0);
} }
self.poll_read()?;
use hash_map::Entry::*; loop {
if !self.poll_read(true)? {
return Ok(false);
}
match self.waiting_ids.entry(id.0) { use hash_map::Entry::*;
Vacant(_) => Ok(true), match self.waiting_ids.entry(id.0) {
Occupied(entry) => match entry.get() { Vacant(_) => return Ok(true),
AckState::Received => { Occupied(entry) => match entry.get() {
entry.remove(); AckState::Received => {
Ok(true) entry.remove();
} return Ok(true);
_ => Ok(false), }
}, _ => continue,
},
}
} }
} }

View File

@ -138,7 +138,7 @@ where
&self.buffer[beg..end] &self.buffer[beg..end]
} }
pub fn next(&mut self) -> Result<bool> { pub fn next(&mut self) -> Result<()> {
let pktlen = self.packet_length(); let pktlen = self.packet_length();
unsafe { unsafe {
if self.buffer.len() != pktlen { if self.buffer.len() != pktlen {
@ -150,7 +150,7 @@ where
} }
self.buffer.set_len(self.buffer.len() - pktlen); self.buffer.set_len(self.buffer.len() - pktlen);
} }
self.poll_data_do() Ok(())
} }
// NOTE: After calling this you must `self.buffer.set_len()` when done! // NOTE: After calling this you must `self.buffer.set_len()` when done!

View File

@ -184,7 +184,8 @@ where
self.common.current_packet_type as u32, self.common.current_packet_type as u32,
), ),
}; };
if !self.common.next()? { self.common.next()?;
if !self.common.poll_read()? {
break; break;
} }
} }