From bc1571533edc7a7baa34635f6c0136353a844fca Mon Sep 17 00:00:00 2001 From: link2xt Date: Sat, 5 Oct 2024 04:57:42 +0000 Subject: [PATCH] feat: smooth progress bar for backup transfer --- Cargo.lock | 1 + Cargo.toml | 1 + src/blob.rs | 4 -- src/imex.rs | 168 +++++++++++++++++++++++++++++++++++-------- src/imex/transfer.rs | 4 +- 5 files changed, 144 insertions(+), 34 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e8ac81203d..d7012809d5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1315,6 +1315,7 @@ dependencies = [ "parking_lot", "percent-encoding", "pgp", + "pin-project", "pretty_assertions", "proptest", "qrcodegen", diff --git a/Cargo.toml b/Cargo.toml index 056070a00c..7ec378b43f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -77,6 +77,7 @@ once_cell = { workspace = true } parking_lot = "0.12" percent-encoding = "2.3" pgp = { version = "0.13.2", default-features = false } +pin-project = "1" qrcodegen = "1.7.0" quick-xml = "0.36" quoted_printable = "0.5" diff --git a/src/blob.rs b/src/blob.rs index 320cc011e9..6c0c032093 100644 --- a/src/blob.rs +++ b/src/blob.rs @@ -666,10 +666,6 @@ impl<'a> BlobDirContents<'a> { pub(crate) fn iter(&self) -> BlobDirIter<'_> { BlobDirIter::new(self.context, self.inner.iter()) } - - pub(crate) fn len(&self) -> usize { - self.inner.len() - } } /// A iterator over all the [`BlobObject`]s in the blobdir. diff --git a/src/imex.rs b/src/imex.rs index 89273a7a50..a461bfdd65 100644 --- a/src/imex.rs +++ b/src/imex.rs @@ -2,13 +2,16 @@ use std::ffi::OsStr; use std::path::{Path, PathBuf}; +use std::pin::Pin; use ::pgp::types::KeyTrait; use anyhow::{bail, ensure, format_err, Context as _, Result}; use futures::TryStreamExt; use futures_lite::FutureExt; +use pin_project::pin_project; use tokio::fs::{self, File}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio_tar::Archive; use crate::blob::BlobDirContents; @@ -212,7 +215,7 @@ async fn imex_inner( path.display() ); ensure!(context.sql.is_open().await, "Database not opened."); - context.emit_event(EventType::ImexProgress(10)); + context.emit_event(EventType::ImexProgress(1)); if what == ImexMode::ExportBackup || what == ImexMode::ExportSelfKeys { // before we export anything, make sure the private key exists @@ -294,12 +297,68 @@ pub(crate) async fn import_backup_stream( .0 } +#[pin_project] +struct ProgressReader { + #[pin] + inner: R, + + #[pin] + read: usize, + + #[pin] + file_size: usize, + + #[pin] + last_progress: usize, + + #[pin] + context: Context, +} + +impl ProgressReader { + fn new(r: R, context: Context, file_size: u64) -> Self { + Self { + inner: r, + read: 0, + file_size: file_size as usize, + last_progress: 1, + context, + } + } +} + +impl AsyncRead for ProgressReader +where + R: AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> std::task::Poll> { + let mut this = self.project(); + let before = buf.filled().len(); + let res = this.inner.poll_read(cx, buf); + if let std::task::Poll::Ready(Ok(())) = res { + *this.read = this.read.saturating_add(buf.filled().len() - before); + + let progress = std::cmp::min(1000 * *this.read / *this.file_size, 999); + if progress > *this.last_progress { + this.context.emit_event(EventType::ImexProgress(progress)); + *this.last_progress = progress; + } + } + res + } +} + async fn import_backup_stream_inner( context: &Context, backup_file: R, file_size: u64, passphrase: String, ) -> (Result<()>,) { + let backup_file = ProgressReader::new(backup_file, context.clone(), file_size); let mut archive = Archive::new(backup_file); let mut entries = match archive.entries() { @@ -307,29 +366,12 @@ async fn import_backup_stream_inner( Err(e) => return (Err(e).context("Failed to get archive entries"),), }; let mut blobs = Vec::new(); - // We already emitted ImexProgress(10) above - let mut last_progress = 10; - const PROGRESS_MIGRATIONS: u128 = 999; - let mut total_size: u64 = 0; let mut res: Result<()> = loop { let mut f = match entries.try_next().await { Ok(Some(f)) => f, Ok(None) => break Ok(()), Err(e) => break Err(e).context("Failed to get next entry"), }; - total_size += match f.header().entry_size() { - Ok(size) => size, - Err(e) => break Err(e).context("Failed to get entry size"), - }; - let max = PROGRESS_MIGRATIONS - 1; - let progress = std::cmp::min( - max * u128::from(total_size) / std::cmp::max(u128::from(file_size), 1), - max, - ); - if progress > last_progress { - context.emit_event(EventType::ImexProgress(progress as usize)); - last_progress = progress; - } let path = match f.path() { Ok(path) => path.to_path_buf(), @@ -379,7 +421,7 @@ async fn import_backup_stream_inner( .log_err(context) .ok(); if res.is_ok() { - context.emit_event(EventType::ImexProgress(PROGRESS_MIGRATIONS as usize)); + context.emit_event(EventType::ImexProgress(999)); res = context.sql.run_migrations(context).await; } if res.is_ok() { @@ -452,7 +494,14 @@ async fn export_backup(context: &Context, dir: &Path, passphrase: String) -> Res let file = File::create(&temp_path).await?; let blobdir = BlobDirContents::new(context).await?; - export_backup_stream(context, &temp_db_path, blobdir, file) + + let mut file_size = 0; + file_size += temp_db_path.metadata()?.len(); + for blob in blobdir.iter() { + file_size += blob.to_abs_path().metadata()?.len() + } + + export_backup_stream(context, &temp_db_path, blobdir, file, file_size) .await .context("Exporting backup to file failed")?; fs::rename(temp_path, &dest_path).await?; @@ -460,33 +509,96 @@ async fn export_backup(context: &Context, dir: &Path, passphrase: String) -> Res Ok(()) } +#[pin_project] +struct ProgressWriter { + #[pin] + inner: W, + + #[pin] + wrote: usize, + + #[pin] + file_size: usize, + + #[pin] + last_progress: usize, + + #[pin] + context: Context, +} + +impl ProgressWriter { + fn new(w: W, context: Context, file_size: u64) -> Self { + Self { + inner: w, + wrote: 0, + file_size: file_size as usize, + last_progress: 1, + context, + } + } +} + +impl AsyncWrite for ProgressWriter +where + W: AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + let mut this = self.project(); + let res = this.inner.poll_write(cx, buf); + if let std::task::Poll::Ready(Ok(wrote)) = res { + *this.wrote = this.wrote.saturating_add(wrote); + + let progress = std::cmp::min(1000 * *this.wrote / *this.file_size, 999); + if progress > *this.last_progress { + this.context.emit_event(EventType::ImexProgress(progress)); + *this.last_progress = progress; + } + } + res + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.project().inner.poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.project().inner.poll_shutdown(cx) + } +} + /// Exports the database and blobs into a stream. pub(crate) async fn export_backup_stream<'a, W>( context: &'a Context, temp_db_path: &Path, blobdir: BlobDirContents<'a>, writer: W, + file_size: u64, ) -> Result<()> where W: tokio::io::AsyncWrite + tokio::io::AsyncWriteExt + Unpin + Send + 'static, { + let writer = ProgressWriter::new(writer, context.clone(), file_size); let mut builder = tokio_tar::Builder::new(writer); builder .append_path_with_name(temp_db_path, DBFILE_BACKUP_NAME) .await?; - let mut last_progress = 10; - - for (i, blob) in blobdir.iter().enumerate() { + for blob in blobdir.iter() { let mut file = File::open(blob.to_abs_path()).await?; let path_in_archive = PathBuf::from(BLOBS_BACKUP_NAME).join(blob.as_name()); builder.append_file(path_in_archive, &mut file).await?; - let progress = std::cmp::min(1000 * i / blobdir.len(), 999); - if progress > last_progress { - context.emit_event(EventType::ImexProgress(progress)); - last_progress = progress; - } } builder.finish().await?; diff --git a/src/imex/transfer.rs b/src/imex/transfer.rs index c50c028e17..bfc79d47e5 100644 --- a/src/imex/transfer.rs +++ b/src/imex/transfer.rs @@ -124,7 +124,7 @@ impl BackupProvider { export_database(context, &dbfile, passphrase, time()) .await .context("Database export failed")?; - context.emit_event(EventType::ImexProgress(300)); + context.emit_event(EventType::ImexProgress(1)); let drop_token = CancellationToken::new(); let handle = { @@ -190,7 +190,7 @@ impl BackupProvider { send_stream.write_all(&file_size.to_be_bytes()).await?; - export_backup_stream(&context, &dbfile, blobdir, send_stream) + export_backup_stream(&context, &dbfile, blobdir, send_stream, file_size) .await .context("Failed to write backup into QUIC stream")?; info!(context, "Finished writing backup into QUIC stream.");