From a63197278e3996dc51750764e4d43ea604f3af2c Mon Sep 17 00:00:00 2001 From: Jacob Rosborg Date: Sat, 29 Jul 2023 17:19:21 +0200 Subject: [PATCH 1/2] concurrent directory download --- Cargo.lock | 18 ++++- Cargo.toml | 2 + src/files.rs | 2 + src/files/download.rs | 157 +++++++++++++++++++++++++++++++++++------- src/files/ext.rs | 21 ++++++ src/files/list.rs | 1 + src/main.rs | 1 + 7 files changed, 176 insertions(+), 26 deletions(-) create mode 100644 src/files/ext.rs diff --git a/Cargo.lock b/Cargo.lock index ec7caaa..c2ba7b5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -427,6 +427,8 @@ dependencies = [ "tar", "tempfile", "tokio", + "tokio-stream", + "tokio-util", ] [[package]] @@ -1453,14 +1455,26 @@ dependencies = [ "webpki", ] +[[package]] +name = "tokio-stream" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-util" -version = "0.7.4" +version = "0.7.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bb2e075f03b3d66d8d8785356224ba688d2906a371015e225beeb65ca92c740" +checksum = "806fe8c2c87eccc8b3267cbae29ed3ab2d0bd37fca70ab622e46aaa9375ddb7d" dependencies = [ "bytes", "futures-core", + "futures-io", "futures-sink", "pin-project-lite", "tokio", diff --git a/Cargo.toml b/Cargo.toml index 9dd9bf5..2fca443 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,3 +26,5 @@ tabwriter = "1.2.1" tar = "0.4.38" tempfile = "3.3.0" tokio = { version = "1.23.0", features = ["full"] } +tokio-stream = "0.1" +tokio-util = { version = "0.7.8", features = ["io", "compat"] } diff --git a/src/files.rs b/src/files.rs index 89c968b..8ee82f0 100644 --- a/src/files.rs +++ b/src/files.rs @@ -2,6 +2,7 @@ pub mod copy; pub mod delete; pub mod download; pub mod export; +mod ext; pub mod generate_ids; pub mod import; pub mod info; @@ -16,6 +17,7 @@ pub use copy::copy; pub use delete::delete; pub use download::download; pub use export::export; +pub use ext::FileExtension; pub use generate_ids::generate_ids; pub use import::import; pub use info::info; diff --git a/src/files/download.rs b/src/files/download.rs index 0f08f60..6d3f3b8 100644 --- a/src/files/download.rs +++ b/src/files/download.rs @@ -2,11 +2,17 @@ use crate::common::drive_file; use crate::common::file_tree_drive; use crate::common::file_tree_drive::FileTreeDrive; use crate::common::hub_helper; -use crate::common::md5_writer::Md5Writer; + use crate::files; +use crate::files::list; +use crate::files::list::ListQuery; +use crate::files::FileExtension; use crate::hub::Hub; -use async_recursion::async_recursion; + +use futures::stream; use futures::stream::StreamExt; + +use futures::TryStreamExt; use google_drive3::hyper; use human_bytes::human_bytes; use std::error; @@ -18,13 +24,22 @@ use std::io; use std::io::BufReader; use std::io::Read; use std::io::Write; +use std::path::Path; use std::path::PathBuf; +use tokio_util::compat::FuturesAsyncReadCompatExt; +use tokio_util::io::InspectReader; + +use super::list::list_files; + +type GFile = google_drive3::api::File; + pub struct Config { pub file_id: String, pub existing_file_action: ExistingFileAction, pub follow_shortcuts: bool, pub download_directories: bool, + pub parallelisme: usize, pub destination: Destination, } @@ -71,7 +86,84 @@ pub enum ExistingFileAction { Overwrite, } -#[async_recursion] +pub async fn _download_file( + hub: &Hub, + file_path: impl AsRef, + file: &GFile, +) -> Result<(), Error> { + let file_id = file.id.as_ref().ok_or_else(|| Error::MissingFileName)?; + let body = download_file(&hub, file_id.as_str()) + .await + .map_err(Error::DownloadFile)?; + + let file_path = file_path.as_ref(); + + println!("Downloading file '{}'", file_path.display()); + save_body_to_file(body, &file_path, None).await?; + + Ok(()) +} + +pub async fn _download_dir(hub: &Hub, file: GFile, config: &Config) -> Result<(), Error> { + let root_path = config.canonical_destination_root()?; + let file_name = file.name.as_ref().ok_or_else(|| Error::MissingFileName)?; + let path = root_path.join(file_name.as_str()); + + stream::unfold(vec![(path, file)], |mut to_visit| async { + let (path, file) = to_visit.pop()?; + let file_id = file.id.as_ref()?; + let files = list_files( + &hub, + &list::ListFilesConfig { + query: ListQuery::FilesInFolder { + folder_id: file_id.clone(), + }, + order_by: Default::default(), + max_files: usize::MAX, + }, + ) + .await; + + let file_stream = match files { + Ok(files) => { + let (dirs, others): (Vec<_>, Vec<_>) = + files.into_iter().partition(|f| f.is_directory()); // TODO: drain filter + to_visit.extend( + dirs.into_iter() + .filter_map(|file| Some((path.join(file.name.as_ref()?), file))), + ); + stream::iter( + others + .into_iter() + .filter_map(move |file| Some((path.join(file.name.as_ref()?), file))), + ) + .map(Ok) + .left_stream() + } + Err(err) => stream::once(async { + Err(Error::CreateFileTree(file_tree_drive::Error::ListFiles( + err, + ))) + }) + .right_stream(), + }; + + Some((file_stream, to_visit)) + }) + .flatten() + .map(|file| async move { + match file { + Ok((path, file)) => _download_file(&hub, &path, &file).await, + Err(_err) => Err(Error::MissingFileName), // TODO: fix error + } + }) + .buffer_unordered(config.parallelisme) + .collect::>() + .await; + + Ok(()) +} + pub async fn download(config: Config) -> Result<(), Error> { let hub = hub_helper::get_hub().await.map_err(Error::Hub)?; @@ -84,19 +176,19 @@ pub async fn download(config: Config) -> Result<(), Error> { err_if_shortcut(&file, &config)?; if drive_file::is_shortcut(&file) { - let target_file_id = file.shortcut_details.and_then(|details| details.target_id); + // let target_file_id = file.shortcut_details.and_then(|details| details.target_id); - err_if_shortcut_target_is_missing(&target_file_id)?; + // err_if_shortcut_target_is_missing(&target_file_id)?; - download(Config { - file_id: target_file_id.unwrap_or_default(), - ..config - }) - .await?; + // download(Config { + // file_id: target_file_id.unwrap_or_default(), + // ..config + // }) + // .await?; } else if drive_file::is_directory(&file) { - download_directory(&hub, &file, &config).await?; + _download_dir(&hub, file, &config).await?; } else { - download_regular(&hub, &file, &config).await?; + // download_regular(&hub, &file, &config).await?; } Ok(()) @@ -294,28 +386,45 @@ impl Display for Error { // TODO: move to common pub async fn save_body_to_file( - mut body: hyper::Body, - file_path: &PathBuf, + body: hyper::Body, + file_path: impl AsRef, expected_md5: Option, ) -> Result<(), Error> { + let file_path = file_path.as_ref(); // Create temporary file + + tokio::fs::create_dir_all(file_path.parent().unwrap()) + .await + .map_err(|err| Error::CreateDirectory(file_path.to_path_buf(), err))?; + let tmp_file_path = file_path.with_extension("incomplete"); - let file = File::create(&tmp_file_path).map_err(Error::CreateFile)?; + let mut file = tokio::fs::File::create(&tmp_file_path) + .await + .map_err(Error::CreateFile)?; - // Wrap file in writer that calculates md5 - let mut writer = Md5Writer::new(file); + let mut md5 = md5::Context::new(); - // Read chunks from stream and write to file - while let Some(chunk_result) = body.next().await { - let chunk = chunk_result.map_err(Error::ReadChunk)?; - writer.write_all(&chunk).map_err(Error::WriteChunk)?; - } + let body = body + .into_stream() + .map(|result| { + result.map_err(|_error| std::io::Error::new(std::io::ErrorKind::Other, "Error!")) + }) + .into_async_read() + .compat(); + + let mut body = InspectReader::new(body, |bytes| md5.consume(&bytes)); + + tokio::io::copy(&mut body, &mut file) + .await + .map_err(|err| Error::WriteChunk(err))?; // Check md5 - err_if_md5_mismatch(expected_md5, writer.md5())?; + err_if_md5_mismatch(expected_md5, format!("{:x}", md5.compute()))?; // Rename temporary file to final file - fs::rename(&tmp_file_path, &file_path).map_err(Error::RenameFile) + tokio::fs::rename(&tmp_file_path, &file_path) + .await + .map_err(Error::RenameFile) } // TODO: move to common diff --git a/src/files/ext.rs b/src/files/ext.rs new file mode 100644 index 0000000..8bf67bd --- /dev/null +++ b/src/files/ext.rs @@ -0,0 +1,21 @@ +use crate::common::drive_file::{MIME_TYPE_DRIVE_FOLDER, MIME_TYPE_DRIVE_SHORTCUT}; + +pub trait FileExtension { + fn is_directory(&self) -> bool; + fn is_binary(&self) -> bool; + fn is_shortcut(&self) -> bool; +} + +impl FileExtension for google_drive3::api::File { + fn is_directory(&self) -> bool { + self.mime_type == Some(String::from(MIME_TYPE_DRIVE_FOLDER)) + } + + fn is_binary(&self) -> bool { + self.md5_checksum != None + } + + fn is_shortcut(&self) -> bool { + self.mime_type == Some(String::from(MIME_TYPE_DRIVE_SHORTCUT)) + } +} \ No newline at end of file diff --git a/src/files/list.rs b/src/files/list.rs index bc9bce0..d937c2f 100644 --- a/src/files/list.rs +++ b/src/files/list.rs @@ -72,6 +72,7 @@ pub async fn list(config: Config) -> Result<(), Error> { Ok(()) } +#[derive(Default)] pub struct ListFilesConfig { pub query: ListQuery, pub order_by: ListSortOrder, diff --git a/src/main.rs b/src/main.rs index b309066..cfb571a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -519,6 +519,7 @@ async fn main() { follow_shortcuts, download_directories: recursive, destination: dst, + parallelisme: 10, }) .await .unwrap_or_else(handle_error) From 44a1b61aefb7f1cff001142b3b0a7c8ef8993696 Mon Sep 17 00:00:00 2001 From: Jacob Rosborg Date: Sat, 29 Jul 2023 17:29:24 +0200 Subject: [PATCH 2/2] cli arg --concurent --- src/main.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/main.rs b/src/main.rs index cfb571a..6ac51ff 100644 --- a/src/main.rs +++ b/src/main.rs @@ -171,6 +171,9 @@ enum FileCommand { #[arg(long)] recursive: bool, + #[arg(long)] + concurent: usize, + /// Path where the file/directory should be downloaded to #[arg(long, value_name = "PATH")] destination: Option, @@ -498,6 +501,7 @@ async fn main() { recursive, destination, stdout, + concurent, } => { let existing_file_action = if overwrite { files::download::ExistingFileAction::Overwrite @@ -519,7 +523,7 @@ async fn main() { follow_shortcuts, download_directories: recursive, destination: dst, - parallelisme: 10, + parallelisme: concurent, }) .await .unwrap_or_else(handle_error)