diff --git a/examples/uncompress_iterator.rs b/examples/uncompress_iterator.rs index 73b16db..52d4ce1 100644 --- a/examples/uncompress_iterator.rs +++ b/examples/uncompress_iterator.rs @@ -16,7 +16,7 @@ fn main() -> compress_tools::Result<()> { let source = std::fs::File::open(cmd.source_path)?; - for content in ArchiveIterator::from_read(source)? { + for content in ArchiveIterator::from_read(source, None)? { if let ArchiveContents::StartOfEntry(name, stat) = content { println!("{name}: size={}", stat.st_size); } diff --git a/scripts/generate-ffi b/scripts/generate-ffi index 365dccd..3bfcbda 100755 --- a/scripts/generate-ffi +++ b/scripts/generate-ffi @@ -48,6 +48,7 @@ bindgen \ --whitelist-function "archive_read_data_block" \ --whitelist-function "archive_read_next_header" \ --whitelist-function "archive_read_open" \ + --whitelist-function "archive_read_add_passphrase" \ --whitelist-function "archive_write_disk_new" \ --whitelist-function "archive_write_disk_set_options" \ --whitelist-function "archive_write_disk_set_standard_lookup" \ diff --git a/src/ffi/generated.rs b/src/ffi/generated.rs index 50f04c4..92d3aa5 100644 --- a/src/ffi/generated.rs +++ b/src/ffi/generated.rs @@ -105,6 +105,12 @@ extern "C" { offset: *mut la_int64_t, ) -> ::std::os::raw::c_int; } +extern "C" { + pub(crate) fn archive_read_add_passphrase( + arg1: *mut archive, + arg2: *const ::std::os::raw::c_char, + ) -> ::std::os::raw::c_int; +} extern "C" { pub(crate) fn archive_read_close(arg1: *mut archive) -> ::std::os::raw::c_int; } diff --git a/src/iterator.rs b/src/iterator.rs index 7b34a39..dbf616a 100644 --- a/src/iterator.rs +++ b/src/iterator.rs @@ -44,6 +44,23 @@ pub enum ArchiveContents { /// The entry is processed on a return value of `true` and ignored on `false`. pub type EntryFilterCallbackFn = dyn Fn(&str, &libc::stat) -> bool; +pub struct ArchivePassword(CString); + +impl ArchivePassword { + pub fn extract(&self) -> *const i8 { + self.0.as_ptr() as *const i8 + } +} + +impl From for ArchivePassword +where + T: AsRef, +{ + fn from(s: T) -> Self { + Self(CString::new(s.as_ref()).unwrap()) + } +} + /// An iterator over the contents of an archive. #[allow(clippy::module_name_repetitions)] pub struct ArchiveIterator { @@ -119,6 +136,7 @@ impl ArchiveIterator { source: R, decode: DecodeCallback, filter: Option>, + password: Option, ) -> Result> where R: Read + Seek, @@ -132,6 +150,10 @@ impl ArchiveIterator { let archive_entry: *mut ffi::archive_entry = std::ptr::null_mut(); let archive_reader = ffi::archive_read_new(); + if let Some(password) = password { + ffi::archive_read_add_passphrase(archive_reader, password.extract()); + } + let res = (|| { archive_result( ffi::archive_read_support_filter_all(archive_reader), @@ -230,7 +252,7 @@ impl ArchiveIterator { where R: Read + Seek, { - Self::new(source, decode, None) + Self::new(source, decode, None, None) } /// Iterate over the contents of an archive, streaming the contents of each @@ -245,7 +267,7 @@ impl ArchiveIterator { /// /// let mut name = String::default(); /// let mut size = 0; - /// let mut iter = ArchiveIterator::from_read(file)?; + /// let mut iter = ArchiveIterator::from_read(file, None)?; /// /// for content in &mut iter { /// match content { @@ -265,11 +287,11 @@ impl ArchiveIterator { /// # Ok(()) /// # } /// ``` - pub fn from_read(source: R) -> Result> + pub fn from_read(source: R, password: Option) -> Result> where R: Read + Seek, { - Self::new(source, crate::decode_utf8, None) + Self::new(source, crate::decode_utf8, None, password) } /// Close the iterator, freeing up the associated resources. @@ -392,6 +414,7 @@ where source: R, decoder: DecodeCallback, filter: Option>, + password: Option, } /// A builder to generate an archive iterator over the contents of an @@ -430,6 +453,7 @@ where source, decoder: crate::decode_utf8, filter: None, + password: None, } } @@ -450,8 +474,14 @@ where self } + /// Set a custom password to decode content of archive entries. + pub fn with_password(mut self, password: ArchivePassword) -> ArchiveIteratorBuilder { + self.password = Some(password); + self + } + /// Finish the builder and generate the configured `ArchiveIterator`. pub fn build(self) -> Result> { - ArchiveIterator::new(self.source, self.decoder, self.filter) + ArchiveIterator::new(self.source, self.decoder, self.filter, self.password) } } diff --git a/src/lib.rs b/src/lib.rs index 16a4599..9627d98 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -60,7 +60,7 @@ pub mod tokio_support; use error::archive_result; pub use error::{Error, Result}; use io::{Seek, SeekFrom}; -pub use iterator::{ArchiveContents, ArchiveIterator, ArchiveIteratorBuilder}; +pub use iterator::{ArchiveContents, ArchiveIterator, ArchiveIteratorBuilder, ArchivePassword}; use std::{ ffi::{CStr, CString}, io::{self, Read, Write}, diff --git a/tests/fixtures/with-password.zip b/tests/fixtures/with-password.zip new file mode 100644 index 0000000..dce2093 Binary files /dev/null and b/tests/fixtures/with-password.zip differ diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 9c12d54..adfb382 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -641,8 +641,7 @@ fn iterate_zip_with_cjk_pathname() { #[test] fn iterate_truncated_archive() { let source = std::fs::File::open("tests/fixtures/truncated.log.gz").unwrap(); - - for content in ArchiveIterator::from_read(source).unwrap() { + for content in ArchiveIterator::from_read(source, None).unwrap() { if let ArchiveContents::Err(Error::Unknown) = content { return; } @@ -654,7 +653,7 @@ fn iterate_truncated_archive() { fn uncompress_bytes_helper(bytes: &[u8]) { let wrapper = Cursor::new(bytes); - for content in ArchiveIterator::from_read(wrapper).unwrap() { + for content in ArchiveIterator::from_read(wrapper, None).unwrap() { if let ArchiveContents::Err(Error::Unknown) = content { return; } @@ -805,3 +804,49 @@ fn iterate_archive_with_filter_path() { "filtered file list inside the archive did not match" ); } + +#[test] +fn iterate_archive_with_password() { + let source = std::fs::File::open("tests/fixtures/with-password.zip").unwrap(); + let source_password: ArchivePassword = "123".into(); + + let mut files_result: Vec = Vec::new(); + let mut current_file_content: Vec = vec![]; + let mut current_file_name = String::new(); + + let mut iter = ArchiveIteratorBuilder::new(source) + .with_password(source_password) + .filter(|name, _| name.ends_with(".txt")) + .build() + .unwrap(); + + for content in &mut iter { + match content { + ArchiveContents::StartOfEntry(name, _stat) => { + current_file_name = name; + } + ArchiveContents::DataChunk(dt) => { + current_file_content.extend(dt); + } + ArchiveContents::EndOfEntry => { + let content_raw = String::from_utf8(current_file_content.clone()).unwrap(); + current_file_content.clear(); + + let content = format!("{}={}", current_file_name, content_raw); + files_result.push(content); + } + _ => {} + } + } + + iter.close().unwrap(); + + assert_eq!(files_result.len(), 2); + assert_eq!( + files_result, + vec![ + "with-password/file1.txt=its encrypted file".to_string(), + "with-password/file2.txt=file 2 in archive encrypted!".to_string() + ] + ); +}