diff --git a/crates/brush-dataset/src/formats/nerfstudio.rs b/crates/brush-dataset/src/formats/nerfstudio.rs index f2e2d269..6f26e2be 100644 --- a/crates/brush-dataset/src/formats/nerfstudio.rs +++ b/crates/brush-dataset/src/formats/nerfstudio.rs @@ -1,6 +1,7 @@ use super::DatasetZip; use super::LoadDatasetArgs; use crate::splat_import::load_splat_from_ply; +use crate::stream_fut_parallel; use crate::{clamp_img_to_max_size, DataStream, Dataset}; use anyhow::Context; use anyhow::Result; @@ -209,26 +210,34 @@ pub fn read_dataset( // Not entirely sure yet if we want to report stats on both test // and eval, atm this skips "transforms_test.json" even if it's there. - for (i, handle) in train_handles.into_iter().enumerate() { + let train_handles = stream_fut_parallel(train_handles); + let mut train_handles = std::pin::pin!(train_handles); + + let mut i = 0; + while let Some(view) = train_handles.next().await { if let Some(eval_period) = load_args.eval_split_every { // Include extra eval images only when the dataset doesn't have them. if i % eval_period == 0 && val_stream.is_some() { - eval_views.push(handle.await?); + eval_views.push(view?); } else { - train_views.push(handle.await?); + train_views.push(view?); } } else { - train_views.push(handle.await?); + train_views.push(view?); } emitter .emit(Dataset::from_views(train_views.clone(), eval_views.clone())) .await; + + i += 1; } if let Some(val_stream) = val_stream { - for handle in val_stream { - eval_views.push(handle.await?); + let val_handles = stream_fut_parallel(val_stream); + let mut val_handles = std::pin::pin!(val_handles); + while let Some(view) = val_handles.next().await { + eval_views.push(view?); emitter .emit(Dataset::from_views(train_views.clone(), eval_views.clone())) .await; diff --git a/crates/brush-dataset/src/lib.rs b/crates/brush-dataset/src/lib.rs index b204ac87..74c4e228 100644 --- a/crates/brush-dataset/src/lib.rs +++ b/crates/brush-dataset/src/lib.rs @@ -88,6 +88,8 @@ pub(crate) fn stream_fut_parallel( .get() }; + log::info!("Loading steam with {parallel} threads"); + let mut futures = futures; fn_stream(|emitter| async move { while !futures.is_empty() {