protocol: allow explicitly polling for a single packet

When using *blocking* I/O wait_for_* functions need to be
able to not read past their awaited packet.

Signed-off-by: Wolfgang Bumiller <w.bumiller@proxmox.com>
This commit is contained in:
Wolfgang Bumiller 2019-03-07 13:50:41 +01:00
parent 4932843923
commit e402769332
4 changed files with 31 additions and 23 deletions

View File

@ -273,7 +273,7 @@ pub extern "C" fn proxmox_backup_continue_upload(me: *mut CClient) -> c_int {
#[no_mangle]
pub extern "C" fn proxmox_backup_poll_read(me: *mut CClient) -> c_int {
let me = unsafe { &mut *me };
match me.client.poll_read() {
match me.client.poll_read(false) {
Ok(_) => 0,
Err(e) => me.set_error(e),
}

View File

@ -103,7 +103,7 @@ where
pub fn wait_for_handshake(&mut self) -> Result<bool> {
if !self.handshake_done {
self.poll_read()?;
self.poll_read(true)?;
}
Ok(self.handshake_done)
}
@ -126,8 +126,10 @@ where
}
pub fn wait_for_hashes(&mut self) -> Result<bool> {
if self.hash_download.is_some() {
self.poll_read()?;
while self.hash_download.is_some() {
if !self.poll_read(true)? {
break;
}
}
Ok(self.hash_download.is_none())
}
@ -254,8 +256,8 @@ where
}
}
// generic data polling method
pub fn poll_read(&mut self) -> Result<()> {
// generic data polling method, returns true if at least one packet was received
pub fn poll_read(&mut self, one: bool) -> Result<bool> {
if self.common.eof {
// polls after EOF are errors:
bail!("server disconnected");
@ -266,7 +268,7 @@ where
if self.common.eof {
bail!("server disconnected");
}
return Ok(());
return Ok(false);
}
loop {
@ -280,11 +282,12 @@ where
self.common.current_packet_type as u32,
),
}
if !self.common.next()? {
self.common.next()?;
if one || !self.common.poll_read()? {
break;
}
}
Ok(())
Ok(true)
}
// None => nothing was queued
@ -414,21 +417,25 @@ where
if !self.waiting_ids.contains_key(&id.0) {
bail!("wait_for_id() called on unexpected id {}", id.0);
}
self.poll_read()?;
loop {
if !self.poll_read(true)? {
return Ok(false);
}
use hash_map::Entry::*;
match self.waiting_ids.entry(id.0) {
Vacant(_) => Ok(true),
Vacant(_) => return Ok(true),
Occupied(entry) => match entry.get() {
AckState::Received => {
entry.remove();
Ok(true)
return Ok(true);
}
_ => Ok(false),
_ => continue,
},
}
}
}
pub fn discard_id(&mut self, id: StreamId) -> Result<()> {
use hash_map::Entry::*;

View File

@ -138,7 +138,7 @@ where
&self.buffer[beg..end]
}
pub fn next(&mut self) -> Result<bool> {
pub fn next(&mut self) -> Result<()> {
let pktlen = self.packet_length();
unsafe {
if self.buffer.len() != pktlen {
@ -150,7 +150,7 @@ where
}
self.buffer.set_len(self.buffer.len() - pktlen);
}
self.poll_data_do()
Ok(())
}
// NOTE: After calling this you must `self.buffer.set_len()` when done!

View File

@ -184,7 +184,8 @@ where
self.common.current_packet_type as u32,
),
};
if !self.common.next()? {
self.common.next()?;
if !self.common.poll_read()? {
break;
}
}