diff --git a/src/backup/data_blob.rs b/src/backup/data_blob.rs index f2599302..3706104d 100644 --- a/src/backup/data_blob.rs +++ b/src/backup/data_blob.rs @@ -294,52 +294,97 @@ impl DataBlob { use std::io::{Read, BufRead, Write, Seek, SeekFrom}; -/// Write compressed data blobs -pub struct CompressedDataBlobWriter { - compr: Option>, - hasher: crc32fast::Hasher, +enum BlobWriterState { + Uncompressed { writer: W, hasher: crc32fast::Hasher }, + Compressed { compr: zstd::stream::write::Encoder, hasher: crc32fast::Hasher }, } -impl CompressedDataBlobWriter { +/// Write compressed data blobs +pub struct DataBlobWriter { + state: BlobWriterState, +} - pub fn new(mut out: W) -> Result { - out.seek(SeekFrom::Start(0))?; +impl DataBlobWriter { + + pub fn new_uncompressed(mut writer: W) -> Result { + let hasher = crc32fast::Hasher::new(); + writer.seek(SeekFrom::Start(0))?; + let head = DataBlobHeader { magic: UNCOMPRESSED_BLOB_MAGIC_1_0, crc: [0; 4] }; + unsafe { + writer.write_le_value(head)?; + } + let state = BlobWriterState::Uncompressed { writer, hasher }; + Ok(Self { state }) + } + + pub fn new_compressed(mut writer: W) -> Result { + let hasher = crc32fast::Hasher::new(); + writer.seek(SeekFrom::Start(0))?; let head = DataBlobHeader { magic: COMPRESSED_BLOB_MAGIC_1_0, crc: [0; 4] }; unsafe { - out.write_le_value(head)?; + writer.write_le_value(head)?; } - let compr = zstd::stream::write::Encoder::new(out, 1)?; - Ok(Self { compr: Some(compr), hasher: crc32fast::Hasher::new() }) + let compr = zstd::stream::write::Encoder::new(writer, 1)?; + let state = BlobWriterState::Compressed { compr, hasher }; + Ok(Self { state }) } - pub fn finish(mut self) -> Result { - let compr = self.compr.take().expect("blob writer already finished"); - let mut out = compr.finish()?; + pub fn finish(self) -> Result { + match self.state { + BlobWriterState::Uncompressed { mut writer, hasher } => { + // write CRC + let crc = hasher.finalize(); + let head = DataBlobHeader { magic: COMPRESSED_BLOB_MAGIC_1_0, crc: crc.to_le_bytes() }; - // write CRC - let crc = self.hasher.finalize(); - let head = DataBlobHeader { magic: COMPRESSED_BLOB_MAGIC_1_0, crc: crc.to_le_bytes() }; + writer.seek(SeekFrom::Start(0))?; + unsafe { + writer.write_le_value(head)?; + } - out.seek(SeekFrom::Start(0))?; - unsafe { - out.write_le_value(head)?; + return Ok(writer) + } + BlobWriterState::Compressed { compr, hasher } => { + let mut writer = compr.finish()?; + + // write CRC + let crc = hasher.finalize(); + let head = DataBlobHeader { magic: COMPRESSED_BLOB_MAGIC_1_0, crc: crc.to_le_bytes() }; + + writer.seek(SeekFrom::Start(0))?; + unsafe { + writer.write_le_value(head)?; + } + + return Ok(writer) + } } - - Ok(out) } } -impl Write for CompressedDataBlobWriter { +impl Write for DataBlobWriter { fn write(&mut self, buf: &[u8]) -> Result { - let compr = self.compr.as_mut().expect("blob writer already finished"); - self.hasher.update(buf); - compr.write(buf) + match self.state { + BlobWriterState::Uncompressed { ref mut writer, ref mut hasher } => { + hasher.update(buf); + writer.write(buf) + } + BlobWriterState::Compressed { ref mut compr, ref mut hasher } => { + hasher.update(buf); + compr.write(buf) + } + } } fn flush(&mut self) -> Result<(), std::io::Error> { - let compr = self.compr.as_mut().expect("blob writer already finished"); - compr.flush() + match self.state { + BlobWriterState::Uncompressed { ref mut writer, .. } => { + writer.flush() + } + BlobWriterState::Compressed { ref mut compr, .. } => { + compr.flush() + } + } } }