use new proxmox-async crate

Signed-off-by: Dietmar Maurer <dietmar@proxmox.com>
This commit is contained in:
Dietmar Maurer
2021-11-19 17:36:06 +01:00
parent ea67cd70c9
commit 9a1b24b6b1
56 changed files with 66 additions and 1686 deletions

View File

@ -33,13 +33,13 @@ walkdir = "2"
zstd = { version = "0.6", features = [ "bindgen" ] }
proxmox = { version = "0.15.3", default-features = false, features = [ "tokio" ] }
proxmox-async = "0.1"
proxmox-borrow = "1"
proxmox-io = { version = "1", features = [ "tokio" ] }
proxmox-lang = { version = "1" }
proxmox-time = { version = "1" }
pbs-buildcfg = { path = "../pbs-buildcfg" }
pbs-runtime = { path = "../pbs-runtime" }
[dev-dependencies]
tokio = { version = "1.6", features = [ "macros" ] }

View File

@ -7,7 +7,7 @@ use std::collections::HashMap;
use std::future::Future;
use std::sync::{Arc, Mutex};
use crate::broadcast_future::BroadcastFuture;
use proxmox_async::broadcast_future::BroadcastFuture;
use crate::lru_cache::LruCache;
/// Interface for asynchronously getting values on cache misses.

View File

@ -1,99 +0,0 @@
//! Async wrappers for blocking I/O (adding `block_in_place` around channels/readers)
use std::io::{self, Read};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::sync::mpsc::Receiver;
use futures::stream::Stream;
use pbs_runtime::block_in_place;
/// Wrapper struct to convert a Reader into a Stream
pub struct WrappedReaderStream<R: Read + Unpin> {
reader: R,
buffer: Vec<u8>,
}
impl <R: Read + Unpin> WrappedReaderStream<R> {
pub fn new(reader: R) -> Self {
let mut buffer = Vec::with_capacity(64*1024);
unsafe { buffer.set_len(buffer.capacity()); }
Self { reader, buffer }
}
}
impl<R: Read + Unpin> Stream for WrappedReaderStream<R> {
type Item = Result<Vec<u8>, io::Error>;
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
match block_in_place(|| this.reader.read(&mut this.buffer)) {
Ok(n) => {
if n == 0 {
// EOF
Poll::Ready(None)
} else {
Poll::Ready(Some(Ok(this.buffer[..n].to_vec())))
}
}
Err(err) => Poll::Ready(Some(Err(err))),
}
}
}
/// Wrapper struct to convert a channel Receiver into a Stream
pub struct StdChannelStream<T>(pub Receiver<T>);
impl<T> Stream for StdChannelStream<T> {
type Item = T;
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Option<Self::Item>> {
match block_in_place(|| self.0.recv()) {
Ok(data) => Poll::Ready(Some(data)),
Err(_) => Poll::Ready(None),// channel closed
}
}
}
#[cfg(test)]
mod test {
use std::io;
use anyhow::Error;
use futures::stream::TryStreamExt;
#[test]
fn test_wrapped_stream_reader() -> Result<(), Error> {
pbs_runtime::main(async {
run_wrapped_stream_reader_test().await
})
}
struct DummyReader(usize);
impl io::Read for DummyReader {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0 += 1;
if self.0 >= 10 {
return Ok(0);
}
unsafe {
std::ptr::write_bytes(buf.as_mut_ptr(), 0, buf.len());
}
Ok(buf.len())
}
}
async fn run_wrapped_stream_reader_test() -> Result<(), Error> {
let mut reader = super::WrappedReaderStream::new(DummyReader(0));
while let Some(_data) = reader.try_next().await? {
// just waiting
}
Ok(())
}
}

View File

@ -1,180 +0,0 @@
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use anyhow::{format_err, Error};
use futures::future::{FutureExt, TryFutureExt};
use tokio::sync::oneshot;
/// Broadcast results to registered listeners using asnyc oneshot channels
#[derive(Default)]
pub struct BroadcastData<T> {
result: Option<Result<T, String>>,
listeners: Vec<oneshot::Sender<Result<T, Error>>>,
}
impl <T: Clone> BroadcastData<T> {
pub fn new() -> Self {
Self {
result: None,
listeners: vec![],
}
}
pub fn notify_listeners(&mut self, result: Result<T, String>) {
self.result = Some(result.clone());
loop {
match self.listeners.pop() {
None => { break; },
Some(ch) => {
match &result {
Ok(result) => { let _ = ch.send(Ok(result.clone())); },
Err(err) => { let _ = ch.send(Err(format_err!("{}", err))); },
}
},
}
}
}
pub fn listen(&mut self) -> impl Future<Output = Result<T, Error>> {
use futures::future::{ok, Either};
match &self.result {
None => {},
Some(Ok(result)) => return Either::Left(ok(result.clone())),
Some(Err(err)) => return Either::Left(futures::future::err(format_err!("{}", err))),
}
let (tx, rx) = oneshot::channel::<Result<T, Error>>();
self.listeners.push(tx);
Either::Right(rx
.map(|res| match res {
Ok(Ok(t)) => Ok(t),
Ok(Err(e)) => Err(e),
Err(e) => Err(Error::from(e)),
})
)
}
}
type SourceFuture<T> = Pin<Box<dyn Future<Output = Result<T, Error>> + Send>>;
struct BroadCastFutureBinding<T> {
broadcast: BroadcastData<T>,
future: Option<SourceFuture<T>>,
}
/// Broadcast future results to registered listeners
pub struct BroadcastFuture<T> {
inner: Arc<Mutex<BroadCastFutureBinding<T>>>,
}
impl<T: Clone + Send + 'static> BroadcastFuture<T> {
/// Create instance for specified source future.
///
/// The result of the future is sent to all registered listeners.
pub fn new(source: Box<dyn Future<Output = Result<T, Error>> + Send>) -> Self {
let inner = BroadCastFutureBinding {
broadcast: BroadcastData::new(),
future: Some(Pin::from(source)),
};
Self { inner: Arc::new(Mutex::new(inner)) }
}
/// Creates a new instance with a oneshot channel as trigger
pub fn new_oneshot() -> (Self, oneshot::Sender<Result<T, Error>>) {
let (tx, rx) = oneshot::channel::<Result<T, Error>>();
let rx = rx
.map_err(Error::from)
.and_then(futures::future::ready);
(Self::new(Box::new(rx)), tx)
}
fn notify_listeners(
inner: Arc<Mutex<BroadCastFutureBinding<T>>>,
result: Result<T, String>,
) {
let mut data = inner.lock().unwrap();
data.broadcast.notify_listeners(result);
}
fn spawn(inner: Arc<Mutex<BroadCastFutureBinding<T>>>) -> impl Future<Output = Result<T, Error>> {
let mut data = inner.lock().unwrap();
if let Some(source) = data.future.take() {
let inner1 = inner.clone();
let task = source.map(move |value| {
match value {
Ok(value) => Self::notify_listeners(inner1, Ok(value)),
Err(err) => Self::notify_listeners(inner1, Err(err.to_string())),
}
});
tokio::spawn(task);
}
data.broadcast.listen()
}
/// Register a listener
pub fn listen(&self) -> impl Future<Output = Result<T, Error>> {
let inner2 = self.inner.clone();
async move { Self::spawn(inner2).await }
}
}
#[test]
fn test_broadcast_future() {
use std::sync::atomic::{AtomicUsize, Ordering};
static CHECKSUM: AtomicUsize = AtomicUsize::new(0);
let (sender, trigger) = BroadcastFuture::new_oneshot();
let receiver1 = sender.listen()
.map_ok(|res| {
CHECKSUM.fetch_add(res, Ordering::SeqCst);
})
.map_err(|err| { panic!("got error {}", err); })
.map(|_| ());
let receiver2 = sender.listen()
.map_ok(|res| {
CHECKSUM.fetch_add(res*2, Ordering::SeqCst);
})
.map_err(|err| { panic!("got error {}", err); })
.map(|_| ());
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async move {
let r1 = tokio::spawn(receiver1);
let r2 = tokio::spawn(receiver2);
trigger.send(Ok(1)).unwrap();
let _ = r1.await;
let _ = r2.await;
});
let result = CHECKSUM.load(Ordering::SeqCst);
assert_eq!(result, 3);
// the result stays available until the BroadcastFuture is dropped
rt.block_on(sender.listen()
.map_ok(|res| {
CHECKSUM.fetch_add(res*4, Ordering::SeqCst);
})
.map_err(|err| { panic!("got error {}", err); })
.map(|_| ()));
let result = CHECKSUM.load(Ordering::SeqCst);
assert_eq!(result, 7);
}

View File

@ -1,194 +0,0 @@
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use anyhow::Error;
use bytes::Bytes;
use flate2::{Compress, Compression, FlushCompress};
use futures::ready;
use futures::stream::Stream;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use proxmox::io_format_err;
use proxmox_io::ByteBuffer;
const BUFFER_SIZE: usize = 8192;
pub enum Level {
Fastest,
Best,
Default,
Precise(u32),
}
#[derive(Eq, PartialEq)]
enum EncoderState {
Reading,
Writing,
Flushing,
Finished,
}
pub struct DeflateEncoder<T> {
inner: T,
compressor: Compress,
buffer: ByteBuffer,
input_buffer: Bytes,
state: EncoderState,
}
impl<T> DeflateEncoder<T> {
pub fn new(inner: T) -> Self {
Self::with_quality(inner, Level::Default)
}
pub fn with_quality(inner: T, level: Level) -> Self {
let level = match level {
Level::Fastest => Compression::fast(),
Level::Best => Compression::best(),
Level::Default => Compression::new(3),
Level::Precise(val) => Compression::new(val),
};
Self {
inner,
compressor: Compress::new(level, false),
buffer: ByteBuffer::with_capacity(BUFFER_SIZE),
input_buffer: Bytes::new(),
state: EncoderState::Reading,
}
}
pub fn total_in(&self) -> u64 {
self.compressor.total_in()
}
pub fn total_out(&self) -> u64 {
self.compressor.total_out()
}
pub fn into_inner(self) -> T {
self.inner
}
fn encode(
&mut self,
inbuf: &[u8],
flush: FlushCompress,
) -> Result<(usize, flate2::Status), io::Error> {
let old_in = self.compressor.total_in();
let old_out = self.compressor.total_out();
let res = self
.compressor
.compress(&inbuf[..], self.buffer.get_free_mut_slice(), flush)?;
let new_in = (self.compressor.total_in() - old_in) as usize;
let new_out = (self.compressor.total_out() - old_out) as usize;
self.buffer.add_size(new_out);
Ok((new_in, res))
}
}
impl DeflateEncoder<Vec<u8>> {
// assume small files
pub async fn compress_vec<R>(&mut self, reader: &mut R, size_hint: usize) -> Result<(), Error>
where
R: AsyncRead + Unpin,
{
let mut buffer = Vec::with_capacity(size_hint);
reader.read_to_end(&mut buffer).await?;
self.inner.reserve(size_hint); // should be enough since we want smalller files
self.compressor.compress_vec(&buffer[..], &mut self.inner, FlushCompress::Finish)?;
Ok(())
}
}
impl<T: AsyncWrite + Unpin> DeflateEncoder<T> {
pub async fn compress<R>(&mut self, reader: &mut R) -> Result<(), Error>
where
R: AsyncRead + Unpin,
{
let mut buffer = ByteBuffer::with_capacity(BUFFER_SIZE);
let mut eof = false;
loop {
if !eof && !buffer.is_full() {
let read = buffer.read_from_async(reader).await?;
if read == 0 {
eof = true;
}
}
let (read, _res) = self.encode(&buffer[..], FlushCompress::None)?;
buffer.consume(read);
self.inner.write_all(&self.buffer[..]).await?;
self.buffer.clear();
if buffer.is_empty() && eof {
break;
}
}
loop {
let (_read, res) = self.encode(&[][..], FlushCompress::Finish)?;
self.inner.write_all(&self.buffer[..]).await?;
self.buffer.clear();
if res == flate2::Status::StreamEnd {
break;
}
}
Ok(())
}
}
impl<T, O> Stream for DeflateEncoder<T>
where
T: Stream<Item = Result<O, io::Error>> + Unpin,
O: Into<Bytes>
{
type Item = Result<Bytes, io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
loop {
match this.state {
EncoderState::Reading => {
if let Some(res) = ready!(Pin::new(&mut this.inner).poll_next(cx)) {
let buf = res?;
this.input_buffer = buf.into();
this.state = EncoderState::Writing;
} else {
this.state = EncoderState::Flushing;
}
}
EncoderState::Writing => {
if this.input_buffer.is_empty() {
return Poll::Ready(Some(Err(io_format_err!("empty input during write"))));
}
let mut buf = this.input_buffer.split_off(0);
let (read, res) = this.encode(&buf[..], FlushCompress::None)?;
this.input_buffer = buf.split_off(read);
if this.input_buffer.is_empty() {
this.state = EncoderState::Reading;
}
if this.buffer.is_full() || res == flate2::Status::BufError {
let bytes = this.buffer.remove_data(this.buffer.len()).to_vec();
return Poll::Ready(Some(Ok(bytes.into())));
}
}
EncoderState::Flushing => {
let (_read, res) = this.encode(&[][..], FlushCompress::Finish)?;
if !this.buffer.is_empty() {
let bytes = this.buffer.remove_data(this.buffer.len()).to_vec();
return Poll::Ready(Some(Ok(bytes.into())));
}
if res == flate2::Status::StreamEnd {
this.state = EncoderState::Finished;
}
}
EncoderState::Finished => return Poll::Ready(None),
}
}
}
}

View File

@ -1,9 +1,6 @@
pub mod acl;
pub mod blocking;
pub mod broadcast_future;
pub mod cert;
pub mod cli;
pub mod compression;
pub mod crypt_config;
pub mod format;
pub mod fs;
@ -14,13 +11,10 @@ pub mod nom;
pub mod percent_encoding;
pub mod sha;
pub mod str;
pub mod stream;
pub mod sync;
pub mod sys;
pub mod ticket;
pub mod tokio;
pub mod xattr;
pub mod zip;
pub mod async_lru_cache;

View File

@ -1,229 +0,0 @@
//! Wrappers between async readers and streams.
use std::io::{self, Read};
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use anyhow::{Error, Result};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::sync::mpsc::Sender;
use futures::ready;
use futures::future::FutureExt;
use futures::stream::Stream;
use proxmox::io_format_err;
use proxmox::sys::error::io_err_other;
use proxmox_io::ByteBuffer;
use pbs_runtime::block_in_place;
/// Wrapper struct to convert a Reader into a Stream
pub struct WrappedReaderStream<R: Read + Unpin> {
reader: R,
buffer: Vec<u8>,
}
impl <R: Read + Unpin> WrappedReaderStream<R> {
pub fn new(reader: R) -> Self {
let mut buffer = Vec::with_capacity(64*1024);
unsafe { buffer.set_len(buffer.capacity()); }
Self { reader, buffer }
}
}
impl<R: Read + Unpin> Stream for WrappedReaderStream<R> {
type Item = Result<Vec<u8>, io::Error>;
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
match block_in_place(|| this.reader.read(&mut this.buffer)) {
Ok(n) => {
if n == 0 {
// EOF
Poll::Ready(None)
} else {
Poll::Ready(Some(Ok(this.buffer[..n].to_vec())))
}
}
Err(err) => Poll::Ready(Some(Err(err))),
}
}
}
/// Wrapper struct to convert an AsyncReader into a Stream
pub struct AsyncReaderStream<R: AsyncRead + Unpin> {
reader: R,
buffer: Vec<u8>,
}
impl <R: AsyncRead + Unpin> AsyncReaderStream<R> {
pub fn new(reader: R) -> Self {
let mut buffer = Vec::with_capacity(64*1024);
unsafe { buffer.set_len(buffer.capacity()); }
Self { reader, buffer }
}
pub fn with_buffer_size(reader: R, buffer_size: usize) -> Self {
let mut buffer = Vec::with_capacity(buffer_size);
unsafe { buffer.set_len(buffer.capacity()); }
Self { reader, buffer }
}
}
impl<R: AsyncRead + Unpin> Stream for AsyncReaderStream<R> {
type Item = Result<Vec<u8>, io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
let mut read_buf = ReadBuf::new(&mut this.buffer);
match ready!(Pin::new(&mut this.reader).poll_read(cx, &mut read_buf)) {
Ok(()) => {
let n = read_buf.filled().len();
if n == 0 {
// EOF
Poll::Ready(None)
} else {
Poll::Ready(Some(Ok(this.buffer[..n].to_vec())))
}
}
Err(err) => Poll::Ready(Some(Err(err))),
}
}
}
#[cfg(test)]
mod test {
use std::io;
use anyhow::Error;
use futures::stream::TryStreamExt;
#[test]
fn test_wrapped_stream_reader() -> Result<(), Error> {
pbs_runtime::main(async {
run_wrapped_stream_reader_test().await
})
}
struct DummyReader(usize);
impl io::Read for DummyReader {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0 += 1;
if self.0 >= 10 {
return Ok(0);
}
unsafe {
std::ptr::write_bytes(buf.as_mut_ptr(), 0, buf.len());
}
Ok(buf.len())
}
}
async fn run_wrapped_stream_reader_test() -> Result<(), Error> {
let mut reader = super::WrappedReaderStream::new(DummyReader(0));
while let Some(_data) = reader.try_next().await? {
// just waiting
}
Ok(())
}
}
/// Wrapper around tokio::sync::mpsc::Sender, which implements Write
pub struct AsyncChannelWriter {
sender: Option<Sender<Result<Vec<u8>, Error>>>,
buf: ByteBuffer,
state: WriterState,
}
type SendResult = io::Result<Sender<Result<Vec<u8>>>>;
enum WriterState {
Ready,
Sending(Pin<Box<dyn Future<Output = SendResult> + Send + 'static>>),
}
impl AsyncChannelWriter {
pub fn new(sender: Sender<Result<Vec<u8>, Error>>, buf_size: usize) -> Self {
Self {
sender: Some(sender),
buf: ByteBuffer::with_capacity(buf_size),
state: WriterState::Ready,
}
}
fn poll_write_impl(
&mut self,
cx: &mut Context,
buf: &[u8],
flush: bool,
) -> Poll<io::Result<usize>> {
loop {
match &mut self.state {
WriterState::Ready => {
if flush {
if self.buf.is_empty() {
return Poll::Ready(Ok(0));
}
} else {
let free_size = self.buf.free_size();
if free_size > buf.len() || self.buf.is_empty() {
let count = free_size.min(buf.len());
self.buf.get_free_mut_slice()[..count].copy_from_slice(&buf[..count]);
self.buf.add_size(count);
return Poll::Ready(Ok(count));
}
}
let sender = match self.sender.take() {
Some(sender) => sender,
None => return Poll::Ready(Err(io_err_other("no sender"))),
};
let data = self.buf.remove_data(self.buf.len()).to_vec();
let future = async move {
sender
.send(Ok(data))
.await
.map(move |_| sender)
.map_err(|err| io_format_err!("could not send: {}", err))
};
self.state = WriterState::Sending(future.boxed());
}
WriterState::Sending(ref mut future) => match ready!(future.as_mut().poll(cx)) {
Ok(sender) => {
self.sender = Some(sender);
self.state = WriterState::Ready;
}
Err(err) => return Poll::Ready(Err(err)),
},
}
}
}
}
impl AsyncWrite for AsyncChannelWriter {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
let this = self.get_mut();
this.poll_write_impl(cx, buf, false)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
let this = self.get_mut();
match ready!(this.poll_write_impl(cx, &[], true)) {
Ok(_) => Poll::Ready(Ok(())),
Err(err) => Poll::Ready(Err(err)),
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.poll_flush(cx)
}
}

View File

@ -1,2 +0,0 @@
pub mod tokio_writer_adapter;
pub use tokio_writer_adapter::TokioWriterAdapter;

View File

@ -1,26 +0,0 @@
use std::io::Write;
use tokio::task::block_in_place;
/// Wrapper around a writer which implements Write
///
/// wraps each write with a 'block_in_place' so that
/// any (blocking) writer can be safely used in async context in a
/// tokio runtime
pub struct TokioWriterAdapter<W: Write>(W);
impl<W: Write> TokioWriterAdapter<W> {
pub fn new(writer: W) -> Self {
Self(writer)
}
}
impl<W: Write> Write for TokioWriterAdapter<W> {
fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> {
block_in_place(|| self.0.write(buf))
}
fn flush(&mut self) -> Result<(), std::io::Error> {
block_in_place(|| self.0.flush())
}
}

View File

@ -1,672 +0,0 @@
//! ZIP Helper
//!
//! Provides an interface to create a ZIP File from ZipEntries
//! for a more detailed description of the ZIP format, see:
//! https://pkware.cachefly.net/webdocs/casestudies/APPNOTE.TXT
use std::convert::TryInto;
use std::ffi::OsString;
use std::io;
use std::mem::size_of;
use std::os::unix::ffi::OsStrExt;
use std::path::{Component, Path, PathBuf};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::SystemTime;
use anyhow::{format_err, Error, Result};
use endian_trait::Endian;
use futures::ready;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
use crc32fast::Hasher;
use proxmox_time::gmtime;
use crate::compression::{DeflateEncoder, Level};
const LOCAL_FH_SIG: u32 = 0x04034B50;
const LOCAL_FF_SIG: u32 = 0x08074B50;
const CENTRAL_DIRECTORY_FH_SIG: u32 = 0x02014B50;
const END_OF_CENTRAL_DIR: u32 = 0x06054B50;
const VERSION_NEEDED: u16 = 0x002d;
const VERSION_MADE_BY: u16 = 0x032d;
const ZIP64_EOCD_RECORD: u32 = 0x06064B50;
const ZIP64_EOCD_LOCATOR: u32 = 0x07064B50;
// bits for time:
// 0-4: day of the month (1-31)
// 5-8: month: (1 = jan, etc.)
// 9-15: year offset from 1980
//
// bits for date:
// 0-4: second / 2
// 5-10: minute (0-59)
// 11-15: hour (0-23)
//
// see https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-filetimetodosdatetime
fn epoch_to_dos(epoch: i64) -> (u16, u16) {
let gmtime = match gmtime(epoch) {
Ok(gmtime) => gmtime,
Err(_) => return (0, 0),
};
let seconds = (gmtime.tm_sec / 2) & 0b11111;
let minutes = gmtime.tm_min & 0xb111111;
let hours = gmtime.tm_hour & 0b11111;
let time: u16 = ((hours << 11) | (minutes << 5) | (seconds)) as u16;
let date: u16 = if gmtime.tm_year > (2108 - 1900) || gmtime.tm_year < (1980 - 1900) {
0
} else {
let day = gmtime.tm_mday & 0b11111;
let month = (gmtime.tm_mon + 1) & 0b1111;
let year = (gmtime.tm_year + 1900 - 1980) & 0b1111111;
((year << 9) | (month << 5) | (day)) as u16
};
(date, time)
}
#[derive(Endian)]
#[repr(C, packed)]
struct Zip64Field {
field_type: u16,
field_size: u16,
uncompressed_size: u64,
compressed_size: u64,
}
#[derive(Endian)]
#[repr(C, packed)]
struct Zip64FieldWithOffset {
field_type: u16,
field_size: u16,
uncompressed_size: u64,
compressed_size: u64,
offset: u64,
start_disk: u32,
}
#[derive(Endian)]
#[repr(C, packed)]
struct LocalFileHeader {
signature: u32,
version_needed: u16,
flags: u16,
compression: u16,
time: u16,
date: u16,
crc32: u32,
compressed_size: u32,
uncompressed_size: u32,
filename_len: u16,
extra_field_len: u16,
}
#[derive(Endian)]
#[repr(C, packed)]
struct LocalFileFooter {
signature: u32,
crc32: u32,
compressed_size: u64,
uncompressed_size: u64,
}
#[derive(Endian)]
#[repr(C, packed)]
struct CentralDirectoryFileHeader {
signature: u32,
version_made_by: u16,
version_needed: u16,
flags: u16,
compression: u16,
time: u16,
date: u16,
crc32: u32,
compressed_size: u32,
uncompressed_size: u32,
filename_len: u16,
extra_field_len: u16,
comment_len: u16,
start_disk: u16,
internal_flags: u16,
external_flags: u32,
offset: u32,
}
#[derive(Endian)]
#[repr(C, packed)]
struct EndOfCentralDir {
signature: u32,
disk_number: u16,
start_disk: u16,
disk_record_count: u16,
total_record_count: u16,
directory_size: u32,
directory_offset: u32,
comment_len: u16,
}
#[derive(Endian)]
#[repr(C, packed)]
struct Zip64EOCDRecord {
signature: u32,
field_size: u64,
version_made_by: u16,
version_needed: u16,
disk_number: u32,
disk_number_central_dir: u32,
disk_record_count: u64,
total_record_count: u64,
directory_size: u64,
directory_offset: u64,
}
#[derive(Endian)]
#[repr(C, packed)]
struct Zip64EOCDLocator {
signature: u32,
disk_number: u32,
offset: u64,
disk_count: u32,
}
async fn write_struct<E, T>(output: &mut T, data: E) -> io::Result<()>
where
T: AsyncWrite + ?Sized + Unpin,
E: Endian,
{
let data = data.to_le();
let data = unsafe {
std::slice::from_raw_parts(
&data as *const E as *const u8,
core::mem::size_of_val(&data),
)
};
output.write_all(data).await
}
/// Represents an Entry in a ZIP File
///
/// used to add to a ZipEncoder
pub struct ZipEntry {
filename: OsString,
mtime: i64,
mode: u16,
crc32: u32,
uncompressed_size: u64,
compressed_size: u64,
offset: u64,
is_file: bool,
}
impl ZipEntry {
/// Creates a new ZipEntry
///
/// if is_file is false the path will contain an trailing separator,
/// so that the zip file understands that it is a directory
pub fn new<P: AsRef<Path>>(path: P, mtime: i64, mode: u16, is_file: bool) -> Self {
let mut relpath = PathBuf::new();
for comp in path.as_ref().components() {
if let Component::Normal(_) = comp {
relpath.push(comp);
}
}
if !is_file {
relpath.push(""); // adds trailing slash
}
Self {
filename: relpath.into(),
crc32: 0,
mtime,
mode,
uncompressed_size: 0,
compressed_size: 0,
offset: 0,
is_file,
}
}
async fn write_local_header<W>(&self, mut buf: &mut W) -> io::Result<usize>
where
W: AsyncWrite + Unpin + ?Sized,
{
let filename = self.filename.as_bytes();
let filename_len = filename.len();
let header_size = size_of::<LocalFileHeader>();
let zip_field_size = size_of::<Zip64Field>();
let size: usize = header_size + filename_len + zip_field_size;
let (date, time) = epoch_to_dos(self.mtime);
write_struct(
&mut buf,
LocalFileHeader {
signature: LOCAL_FH_SIG,
version_needed: 0x2d,
flags: 1 << 3,
compression: 0x8,
time,
date,
crc32: 0,
compressed_size: 0xFFFFFFFF,
uncompressed_size: 0xFFFFFFFF,
filename_len: filename_len as u16,
extra_field_len: zip_field_size as u16,
},
)
.await?;
buf.write_all(filename).await?;
write_struct(
&mut buf,
Zip64Field {
field_type: 0x0001,
field_size: 2 * 8,
uncompressed_size: 0,
compressed_size: 0,
},
)
.await?;
Ok(size)
}
async fn write_data_descriptor<W: AsyncWrite + Unpin + ?Sized>(
&self,
mut buf: &mut W,
) -> io::Result<usize> {
let size = size_of::<LocalFileFooter>();
write_struct(
&mut buf,
LocalFileFooter {
signature: LOCAL_FF_SIG,
crc32: self.crc32,
compressed_size: self.compressed_size,
uncompressed_size: self.uncompressed_size,
},
)
.await?;
Ok(size)
}
async fn write_central_directory_header<W: AsyncWrite + Unpin + ?Sized>(
&self,
mut buf: &mut W,
) -> io::Result<usize> {
let filename = self.filename.as_bytes();
let filename_len = filename.len();
let header_size = size_of::<CentralDirectoryFileHeader>();
let zip_field_size = size_of::<Zip64FieldWithOffset>();
let mut size: usize = header_size + filename_len;
let (date, time) = epoch_to_dos(self.mtime);
let (compressed_size, uncompressed_size, offset, need_zip64) = if self.compressed_size
>= (u32::MAX as u64)
|| self.uncompressed_size >= (u32::MAX as u64)
|| self.offset >= (u32::MAX as u64)
{
size += zip_field_size;
(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, true)
} else {
(
self.compressed_size as u32,
self.uncompressed_size as u32,
self.offset as u32,
false,
)
};
write_struct(
&mut buf,
CentralDirectoryFileHeader {
signature: CENTRAL_DIRECTORY_FH_SIG,
version_made_by: VERSION_MADE_BY,
version_needed: VERSION_NEEDED,
flags: 1 << 3,
compression: 0x8,
time,
date,
crc32: self.crc32,
compressed_size,
uncompressed_size,
filename_len: filename_len as u16,
extra_field_len: if need_zip64 { zip_field_size as u16 } else { 0 },
comment_len: 0,
start_disk: 0,
internal_flags: 0,
external_flags: (self.mode as u32) << 16 | (!self.is_file as u32) << 4,
offset,
},
)
.await?;
buf.write_all(filename).await?;
if need_zip64 {
write_struct(
&mut buf,
Zip64FieldWithOffset {
field_type: 1,
field_size: 3 * 8 + 4,
uncompressed_size: self.uncompressed_size,
compressed_size: self.compressed_size,
offset: self.offset,
start_disk: 0,
},
)
.await?;
}
Ok(size)
}
}
// wraps an asyncreader and calculates the hash
struct HashWrapper<R> {
inner: R,
hasher: Hasher,
}
impl<R> HashWrapper<R> {
fn new(inner: R) -> Self {
Self {
inner,
hasher: Hasher::new(),
}
}
// consumes self and returns the hash and the reader
fn finish(self) -> (u32, R) {
let crc32 = self.hasher.finalize();
(crc32, self.inner)
}
}
impl<R> AsyncRead for HashWrapper<R>
where
R: AsyncRead + Unpin,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<(), io::Error>> {
let this = self.get_mut();
let old_len = buf.filled().len();
ready!(Pin::new(&mut this.inner).poll_read(cx, buf))?;
let new_len = buf.filled().len();
if new_len > old_len {
this.hasher.update(&buf.filled()[old_len..new_len]);
}
Poll::Ready(Ok(()))
}
}
/// Wraps a writer that implements AsyncWrite for creating a ZIP archive
///
/// This will create a ZIP archive on the fly with files added with
/// 'add_entry'. To Finish the file, call 'finish'
/// Example:
/// ```no_run
/// use anyhow::{Error, Result};
/// use tokio::fs::File;
///
/// use pbs_tools::zip::{ZipEncoder, ZipEntry};
///
/// #[tokio::main]
/// async fn main() -> Result<(), Error> {
/// let target = File::open("foo.zip").await?;
/// let mut source = File::open("foo.txt").await?;
///
/// let mut zip = ZipEncoder::new(target);
/// zip.add_entry(ZipEntry::new(
/// "foo.txt",
/// 0,
/// 0o100755,
/// true,
/// ), Some(source)).await?;
///
/// zip.finish().await?;
///
/// Ok(())
/// }
/// ```
pub struct ZipEncoder<W>
where
W: AsyncWrite + Unpin,
{
byte_count: usize,
files: Vec<ZipEntry>,
target: Option<W>,
}
impl<W: AsyncWrite + Unpin> ZipEncoder<W> {
pub fn new(target: W) -> Self {
Self {
byte_count: 0,
files: Vec::new(),
target: Some(target),
}
}
pub async fn add_entry<R: AsyncRead + Unpin>(
&mut self,
mut entry: ZipEntry,
content: Option<R>,
) -> Result<(), Error> {
let mut target = self
.target
.take()
.ok_or_else(|| format_err!("had no target during add entry"))?;
entry.offset = self.byte_count.try_into()?;
self.byte_count += entry.write_local_header(&mut target).await?;
if let Some(content) = content {
let mut reader = HashWrapper::new(content);
let mut enc = DeflateEncoder::with_quality(target, Level::Fastest);
enc.compress(&mut reader).await?;
let total_in = enc.total_in();
let total_out = enc.total_out();
target = enc.into_inner();
let (crc32, _reader) = reader.finish();
self.byte_count += total_out as usize;
entry.compressed_size = total_out;
entry.uncompressed_size = total_in;
entry.crc32 = crc32;
}
self.byte_count += entry.write_data_descriptor(&mut target).await?;
self.target = Some(target);
self.files.push(entry);
Ok(())
}
async fn write_eocd(
&mut self,
central_dir_size: usize,
central_dir_offset: usize,
) -> Result<(), Error> {
let entrycount = self.files.len();
let mut target = self
.target
.take()
.ok_or_else(|| format_err!("had no target during write_eocd"))?;
let mut count = entrycount as u16;
let mut directory_size = central_dir_size as u32;
let mut directory_offset = central_dir_offset as u32;
if central_dir_size > u32::MAX as usize
|| central_dir_offset > u32::MAX as usize
|| entrycount > u16::MAX as usize
{
count = 0xFFFF;
directory_size = 0xFFFFFFFF;
directory_offset = 0xFFFFFFFF;
write_struct(
&mut target,
Zip64EOCDRecord {
signature: ZIP64_EOCD_RECORD,
field_size: 44,
version_made_by: VERSION_MADE_BY,
version_needed: VERSION_NEEDED,
disk_number: 0,
disk_number_central_dir: 0,
disk_record_count: entrycount.try_into()?,
total_record_count: entrycount.try_into()?,
directory_size: central_dir_size.try_into()?,
directory_offset: central_dir_offset.try_into()?,
},
)
.await?;
let locator_offset = central_dir_offset + central_dir_size;
write_struct(
&mut target,
Zip64EOCDLocator {
signature: ZIP64_EOCD_LOCATOR,
disk_number: 0,
offset: locator_offset.try_into()?,
disk_count: 1,
},
)
.await?;
}
write_struct(
&mut target,
EndOfCentralDir {
signature: END_OF_CENTRAL_DIR,
disk_number: 0,
start_disk: 0,
disk_record_count: count,
total_record_count: count,
directory_size,
directory_offset,
comment_len: 0,
},
)
.await?;
self.target = Some(target);
Ok(())
}
pub async fn finish(&mut self) -> Result<(), Error> {
let mut target = self
.target
.take()
.ok_or_else(|| format_err!("had no target during finish"))?;
let central_dir_offset = self.byte_count;
let mut central_dir_size = 0;
for file in &self.files {
central_dir_size += file.write_central_directory_header(&mut target).await?;
}
self.target = Some(target);
self.write_eocd(central_dir_size, central_dir_offset)
.await?;
self.target
.take()
.ok_or_else(|| format_err!("had no target for flush"))?
.flush()
.await?;
Ok(())
}
}
/// Zip a local directory and write encoded data to target. "source" has to point to a valid
/// directory, it's name will be the root of the zip file - e.g.:
/// source:
/// /foo/bar
/// zip file:
/// /bar/file1
/// /bar/dir1
/// /bar/dir1/file2
/// ...
/// ...except if "source" is the root directory
pub async fn zip_directory<W>(target: W, source: &Path) -> Result<(), Error>
where
W: AsyncWrite + Unpin + Send,
{
use walkdir::WalkDir;
use std::os::unix::fs::MetadataExt;
let base_path = source.parent().unwrap_or_else(|| Path::new("/"));
let mut encoder = ZipEncoder::new(target);
for entry in WalkDir::new(&source).into_iter() {
match entry {
Ok(entry) => {
let entry_path = entry.path().to_owned();
let encoder = &mut encoder;
if let Err(err) = async move {
let entry_path_no_base = entry.path().strip_prefix(base_path)?;
let metadata = entry.metadata()?;
let mtime = match metadata.modified().unwrap_or_else(|_| SystemTime::now()).duration_since(SystemTime::UNIX_EPOCH) {
Ok(dur) => dur.as_secs() as i64,
Err(time_error) => -(time_error.duration().as_secs() as i64)
};
let mode = metadata.mode() as u16;
if entry.file_type().is_file() {
let file = tokio::fs::File::open(entry.path()).await?;
let ze = ZipEntry::new(
&entry_path_no_base,
mtime,
mode,
true,
);
encoder.add_entry(ze, Some(file)).await?;
} else if entry.file_type().is_dir() {
let ze = ZipEntry::new(
&entry_path_no_base,
mtime,
mode,
false,
);
let content: Option<tokio::fs::File> = None;
encoder.add_entry(ze, content).await?;
}
// ignore other file types
let ok: Result<(), Error> = Ok(());
ok
}
.await
{
eprintln!(
"zip: error encoding file or directory '{}': {}",
entry_path.display(),
err
);
}
}
Err(err) => {
eprintln!("zip: error reading directory entry: {}", err);
}
}
}
encoder.finish().await
}