Skip to content

Commit

Permalink
Paralellize nerfstudio as well
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurBrussee committed Nov 19, 2024
1 parent ba7b920 commit 89d61c9
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
21 changes: 15 additions & 6 deletions crates/brush-dataset/src/formats/nerfstudio.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::DatasetZip;
use super::LoadDatasetArgs;
use crate::splat_import::load_splat_from_ply;
use crate::stream_fut_parallel;
use crate::{clamp_img_to_max_size, DataStream, Dataset};
use anyhow::Context;
use anyhow::Result;
Expand Down Expand Up @@ -209,26 +210,34 @@ pub fn read_dataset<B: Backend>(
// Not entirely sure yet if we want to report stats on both test
// and eval, atm this skips "transforms_test.json" even if it's there.

for (i, handle) in train_handles.into_iter().enumerate() {
let train_handles = stream_fut_parallel(train_handles);
let mut train_handles = std::pin::pin!(train_handles);

let mut i = 0;
while let Some(view) = train_handles.next().await {
if let Some(eval_period) = load_args.eval_split_every {
// Include extra eval images only when the dataset doesn't have them.
if i % eval_period == 0 && val_stream.is_some() {
eval_views.push(handle.await?);
eval_views.push(view?);
} else {
train_views.push(handle.await?);
train_views.push(view?);
}
} else {
train_views.push(handle.await?);
train_views.push(view?);
}

emitter
.emit(Dataset::from_views(train_views.clone(), eval_views.clone()))
.await;

i += 1;
}

if let Some(val_stream) = val_stream {
for handle in val_stream {
eval_views.push(handle.await?);
let val_handles = stream_fut_parallel(val_stream);
let mut val_handles = std::pin::pin!(val_handles);
while let Some(view) = val_handles.next().await {
eval_views.push(view?);
emitter
.emit(Dataset::from_views(train_views.clone(), eval_views.clone()))
.await;
Expand Down
2 changes: 2 additions & 0 deletions crates/brush-dataset/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ pub(crate) fn stream_fut_parallel<T: Send + 'static>(
.get()
};

log::info!("Loading steam with {parallel} threads");

let mut futures = futures;
fn_stream(|emitter| async move {
while !futures.is_empty() {
Expand Down

0 comments on commit 89d61c9

Please sign in to comment.