Skip to content

Commit

Permalink
Add subsampling for points and images
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurBrussee committed Nov 19, 2024
1 parent aad8d79 commit edd52c5
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 121 deletions.
19 changes: 15 additions & 4 deletions crates/brush-dataset/src/formats/colmap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use brush_render::{
Backend,
};
use brush_train::scene::SceneView;
use glam::Vec3;
use tokio_stream::StreamExt;

fn read_views(
Expand Down Expand Up @@ -114,7 +115,11 @@ pub(crate) fn load_dataset<B: Backend>(
load_args: &LoadDatasetArgs,
device: &B::Device,
) -> Result<(DataStream<Splats<B>>, DataStream<Dataset>)> {
let handles = read_views(archive.clone(), load_args)?;
let mut handles = read_views(archive.clone(), load_args)?;

if let Some(subsample) = load_args.subsample_frames {
handles = handles.into_iter().step_by(subsample as usize).collect();
}

let mut train_views = vec![];
let mut eval_views = vec![];
Expand Down Expand Up @@ -167,9 +172,8 @@ pub(crate) fn load_dataset<B: Backend>(
if !points_data.is_empty() {
log::info!("Starting from colmap points {}", points_data.len());

let positions = points_data.values().map(|p| p.xyz).collect();

let colors = points_data
let mut positions: Vec<Vec3> = points_data.values().map(|p| p.xyz).collect();
let mut colors: Vec<f32> = points_data
.values()
.flat_map(|p| {
[
Expand All @@ -180,6 +184,13 @@ pub(crate) fn load_dataset<B: Backend>(
})
.collect();

// Other dataloaders handle subsampling in the ply import. Here just
// do it manually, maybe nice to unify at some point.
if let Some(subsample) = load_args.subsample_points {
positions = positions.into_iter().step_by(subsample as usize).collect();
colors = colors.into_iter().step_by(subsample as usize * 3).collect();
}

let init_ply = Splats::from_raw(positions, None, None, Some(colors), None, &device);
emitter.emit(init_ply).await;
}
Expand Down
6 changes: 5 additions & 1 deletion crates/brush-dataset/src/formats/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ pub fn load_dataset<B: Backend>(
let init_stream = if let Ok(path) = init_path {
let ply_data = archive.read_bytes_at_path(&path)?;
log::info!("Using {path:?} as initial point cloud.");
let splat_stream = load_splat_from_ply(Cursor::new(ply_data), device.clone());
let splat_stream = load_splat_from_ply(
Cursor::new(ply_data),
load_args.subsample_points,
device.clone(),
);
Box::pin(splat_stream)
} else {
streams.0
Expand Down
28 changes: 20 additions & 8 deletions crates/brush-dataset/src/formats/nerfstudio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,28 +180,35 @@ pub fn read_dataset<B: Backend>(

let transforms_path = archive.find_with_extension(".json", "_train")?;
let train_scene: JsonScene = serde_json::from_reader(archive.file_at_path(&transforms_path)?)?;
let train_handles = read_transforms_file(
let mut train_handles = read_transforms_file(
train_scene.clone(),
transforms_path.clone(),
archive.clone(),
load_args,
)?;

let load_args = load_args.clone();
let mut archive_move = archive.clone();
if let Some(subsample) = load_args.subsample_frames {
train_handles = train_handles
.into_iter()
.step_by(subsample as usize)
.collect();
}

let load_args_clone = load_args.clone();
let mut archive_clone = archive.clone();

let transforms_path_clone = transforms_path.clone();

let dataset_stream = try_fn_stream(|emitter| async move {
let mut train_views = vec![];
let mut eval_views = vec![];

let eval_trans_path = archive_move.find_with_extension(".json", "_val")?;
let eval_trans_path = archive_clone.find_with_extension(".json", "_val")?;

// If a seperate eval file is specified, read it.
let val_stream = if eval_trans_path != transforms_path_clone {
let val_scene = serde_json::from_reader(archive_move.file_at_path(&eval_trans_path)?)?;
read_transforms_file(val_scene, eval_trans_path, archive_move, &load_args).ok()
let val_scene = serde_json::from_reader(archive_clone.file_at_path(&eval_trans_path)?)?;
read_transforms_file(val_scene, eval_trans_path, archive_clone, &load_args_clone).ok()
} else {
None
};
Expand All @@ -215,7 +222,7 @@ pub fn read_dataset<B: Backend>(

let mut i = 0;
while let Some(view) = train_handles.next().await {
if let Some(eval_period) = load_args.eval_split_every {
if let Some(eval_period) = load_args_clone.eval_split_every {
// Include extra eval images only when the dataset doesn't have them.
if i % eval_period == 0 && val_stream.is_some() {
eval_views.push(view?);
Expand Down Expand Up @@ -248,14 +255,19 @@ pub fn read_dataset<B: Backend>(
});

let device = device.clone();
let load_args = load_args.clone();

let splat_stream = try_fn_stream(|emitter| async move {
if let Some(init) = train_scene.ply_file_path {
let init_path = transforms_path.parent().unwrap().join(init);
let ply_data = archive.read_bytes_at_path(&init_path);

if let Ok(ply_data) = ply_data {
let splat_stream = load_splat_from_ply(Cursor::new(ply_data), device.clone());
let splat_stream = load_splat_from_ply(
Cursor::new(ply_data),
load_args.subsample_points,
device.clone(),
);

let mut splat_stream = std::pin::pin!(splat_stream);

Expand Down
2 changes: 2 additions & 0 deletions crates/brush-dataset/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ pub struct LoadDatasetArgs {
pub max_frames: Option<usize>,
pub max_resolution: Option<u32>,
pub eval_split_every: Option<usize>,
pub subsample_frames: Option<u32>,
pub subsample_points: Option<u32>,
}

#[derive(Clone)]
Expand Down
47 changes: 28 additions & 19 deletions crates/brush-dataset/src/splat_import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ fn interleave_coeffs(sh_dc: [f32; 3], sh_rest: &[f32]) -> Vec<f32> {

pub fn load_splat_from_ply<T: AsyncRead + Unpin + 'static, B: Backend>(
reader: T,
subsample_points: Option<u32>,
device: B::Device,
) -> impl Stream<Item = Result<Splats<B>>> + 'static {
// set up a reader, in this case a file.
Expand Down Expand Up @@ -173,6 +174,25 @@ pub fn load_splat_from_ply<T: AsyncRead + Unpin + 'static, B: Backend>(
let mut ascii_line = String::new();

for i in 0..element.count {
// Ocassionally yield.
if i % 500 == 0 {
tokio::task::yield_now().await;
}

// Occasionally send some updated splats.
if i % update_every == update_every - 1 {
let splats = Splats::from_raw(
means.clone(),
rotation.clone(),
scales.clone(),
sh_coeffs.clone(),
opacity.clone(),
&device,
);

emitter.emit(splats).await;
}

let splat = match header.encoding {
ply_rs::ply::Encoding::Ascii => {
reader.read_line(&mut ascii_line).await?;
Expand All @@ -192,6 +212,14 @@ pub fn load_splat_from_ply<T: AsyncRead + Unpin + 'static, B: Backend>(
}
};

// Doing this after first reading and parsing the points is quite wasteful, but
// we do need to advance the reader.
if let Some(subsample) = subsample_points {
if i % subsample as usize != 0 {
continue;
}
}

means.push(splat.means);
if let Some(scales) = scales.as_mut() {
scales.push(splat.scale);
Expand All @@ -207,25 +235,6 @@ pub fn load_splat_from_ply<T: AsyncRead + Unpin + 'static, B: Backend>(
interleave_coeffs(splat.sh_dc, &splat.sh_coeffs_rest);
sh_coeffs.extend(sh_coeffs_interleaved);
}

// Occasionally send some updated splats.
if i % update_every == update_every - 1 {
let splats = Splats::from_raw(
means.clone(),
rotation.clone(),
scales.clone(),
sh_coeffs.clone(),
opacity.clone(),
&device,
);

emitter.emit(splats).await;
}

// Ocassionally yield.
if i % 500 == 0 {
tokio::task::yield_now().await;
}
}

let splats = Splats::from_raw(means, rotation, scales, sh_coeffs, opacity, &device);
Expand Down
Loading

0 comments on commit edd52c5

Please sign in to comment.