From edd52c58010a2fbfdd8cabb872322a478839820e Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Tue, 19 Nov 2024 16:00:54 +0000 Subject: [PATCH] Add subsampling for points and images --- crates/brush-dataset/src/formats/colmap.rs | 19 +- crates/brush-dataset/src/formats/mod.rs | 6 +- .../brush-dataset/src/formats/nerfstudio.rs | 28 ++- crates/brush-dataset/src/lib.rs | 2 + crates/brush-dataset/src/splat_import.rs | 47 ++-- crates/brush-viewer/src/panels/load_data.rs | 209 ++++++++++-------- crates/brush-viewer/src/viewer.rs | 4 +- 7 files changed, 194 insertions(+), 121 deletions(-) diff --git a/crates/brush-dataset/src/formats/colmap.rs b/crates/brush-dataset/src/formats/colmap.rs index 5c2be1b0..ee2f3ce1 100644 --- a/crates/brush-dataset/src/formats/colmap.rs +++ b/crates/brush-dataset/src/formats/colmap.rs @@ -11,6 +11,7 @@ use brush_render::{ Backend, }; use brush_train::scene::SceneView; +use glam::Vec3; use tokio_stream::StreamExt; fn read_views( @@ -114,7 +115,11 @@ pub(crate) fn load_dataset( load_args: &LoadDatasetArgs, device: &B::Device, ) -> Result<(DataStream>, DataStream)> { - let handles = read_views(archive.clone(), load_args)?; + let mut handles = read_views(archive.clone(), load_args)?; + + if let Some(subsample) = load_args.subsample_frames { + handles = handles.into_iter().step_by(subsample as usize).collect(); + } let mut train_views = vec![]; let mut eval_views = vec![]; @@ -167,9 +172,8 @@ pub(crate) fn load_dataset( if !points_data.is_empty() { log::info!("Starting from colmap points {}", points_data.len()); - let positions = points_data.values().map(|p| p.xyz).collect(); - - let colors = points_data + let mut positions: Vec = points_data.values().map(|p| p.xyz).collect(); + let mut colors: Vec = points_data .values() .flat_map(|p| { [ @@ -180,6 +184,13 @@ pub(crate) fn load_dataset( }) .collect(); + // Other dataloaders handle subsampling in the ply import. Here just + // do it manually, maybe nice to unify at some point. + if let Some(subsample) = load_args.subsample_points { + positions = positions.into_iter().step_by(subsample as usize).collect(); + colors = colors.into_iter().step_by(subsample as usize * 3).collect(); + } + let init_ply = Splats::from_raw(positions, None, None, Some(colors), None, &device); emitter.emit(init_ply).await; } diff --git a/crates/brush-dataset/src/formats/mod.rs b/crates/brush-dataset/src/formats/mod.rs index 230cbf9a..020aeffd 100644 --- a/crates/brush-dataset/src/formats/mod.rs +++ b/crates/brush-dataset/src/formats/mod.rs @@ -28,7 +28,11 @@ pub fn load_dataset( let init_stream = if let Ok(path) = init_path { let ply_data = archive.read_bytes_at_path(&path)?; log::info!("Using {path:?} as initial point cloud."); - let splat_stream = load_splat_from_ply(Cursor::new(ply_data), device.clone()); + let splat_stream = load_splat_from_ply( + Cursor::new(ply_data), + load_args.subsample_points, + device.clone(), + ); Box::pin(splat_stream) } else { streams.0 diff --git a/crates/brush-dataset/src/formats/nerfstudio.rs b/crates/brush-dataset/src/formats/nerfstudio.rs index 6f26e2be..a59e0c3b 100644 --- a/crates/brush-dataset/src/formats/nerfstudio.rs +++ b/crates/brush-dataset/src/formats/nerfstudio.rs @@ -180,15 +180,22 @@ pub fn read_dataset( let transforms_path = archive.find_with_extension(".json", "_train")?; let train_scene: JsonScene = serde_json::from_reader(archive.file_at_path(&transforms_path)?)?; - let train_handles = read_transforms_file( + let mut train_handles = read_transforms_file( train_scene.clone(), transforms_path.clone(), archive.clone(), load_args, )?; - let load_args = load_args.clone(); - let mut archive_move = archive.clone(); + if let Some(subsample) = load_args.subsample_frames { + train_handles = train_handles + .into_iter() + .step_by(subsample as usize) + .collect(); + } + + let load_args_clone = load_args.clone(); + let mut archive_clone = archive.clone(); let transforms_path_clone = transforms_path.clone(); @@ -196,12 +203,12 @@ pub fn read_dataset( let mut train_views = vec![]; let mut eval_views = vec![]; - let eval_trans_path = archive_move.find_with_extension(".json", "_val")?; + let eval_trans_path = archive_clone.find_with_extension(".json", "_val")?; // If a seperate eval file is specified, read it. let val_stream = if eval_trans_path != transforms_path_clone { - let val_scene = serde_json::from_reader(archive_move.file_at_path(&eval_trans_path)?)?; - read_transforms_file(val_scene, eval_trans_path, archive_move, &load_args).ok() + let val_scene = serde_json::from_reader(archive_clone.file_at_path(&eval_trans_path)?)?; + read_transforms_file(val_scene, eval_trans_path, archive_clone, &load_args_clone).ok() } else { None }; @@ -215,7 +222,7 @@ pub fn read_dataset( let mut i = 0; while let Some(view) = train_handles.next().await { - if let Some(eval_period) = load_args.eval_split_every { + if let Some(eval_period) = load_args_clone.eval_split_every { // Include extra eval images only when the dataset doesn't have them. if i % eval_period == 0 && val_stream.is_some() { eval_views.push(view?); @@ -248,6 +255,7 @@ pub fn read_dataset( }); let device = device.clone(); + let load_args = load_args.clone(); let splat_stream = try_fn_stream(|emitter| async move { if let Some(init) = train_scene.ply_file_path { @@ -255,7 +263,11 @@ pub fn read_dataset( let ply_data = archive.read_bytes_at_path(&init_path); if let Ok(ply_data) = ply_data { - let splat_stream = load_splat_from_ply(Cursor::new(ply_data), device.clone()); + let splat_stream = load_splat_from_ply( + Cursor::new(ply_data), + load_args.subsample_points, + device.clone(), + ); let mut splat_stream = std::pin::pin!(splat_stream); diff --git a/crates/brush-dataset/src/lib.rs b/crates/brush-dataset/src/lib.rs index 74c4e228..5a182ee6 100644 --- a/crates/brush-dataset/src/lib.rs +++ b/crates/brush-dataset/src/lib.rs @@ -22,6 +22,8 @@ pub struct LoadDatasetArgs { pub max_frames: Option, pub max_resolution: Option, pub eval_split_every: Option, + pub subsample_frames: Option, + pub subsample_points: Option, } #[derive(Clone)] diff --git a/crates/brush-dataset/src/splat_import.rs b/crates/brush-dataset/src/splat_import.rs index 1ecd5330..9e23e819 100644 --- a/crates/brush-dataset/src/splat_import.rs +++ b/crates/brush-dataset/src/splat_import.rs @@ -125,6 +125,7 @@ fn interleave_coeffs(sh_dc: [f32; 3], sh_rest: &[f32]) -> Vec { pub fn load_splat_from_ply( reader: T, + subsample_points: Option, device: B::Device, ) -> impl Stream>> + 'static { // set up a reader, in this case a file. @@ -173,6 +174,25 @@ pub fn load_splat_from_ply( let mut ascii_line = String::new(); for i in 0..element.count { + // Ocassionally yield. + if i % 500 == 0 { + tokio::task::yield_now().await; + } + + // Occasionally send some updated splats. + if i % update_every == update_every - 1 { + let splats = Splats::from_raw( + means.clone(), + rotation.clone(), + scales.clone(), + sh_coeffs.clone(), + opacity.clone(), + &device, + ); + + emitter.emit(splats).await; + } + let splat = match header.encoding { ply_rs::ply::Encoding::Ascii => { reader.read_line(&mut ascii_line).await?; @@ -192,6 +212,14 @@ pub fn load_splat_from_ply( } }; + // Doing this after first reading and parsing the points is quite wasteful, but + // we do need to advance the reader. + if let Some(subsample) = subsample_points { + if i % subsample as usize != 0 { + continue; + } + } + means.push(splat.means); if let Some(scales) = scales.as_mut() { scales.push(splat.scale); @@ -207,25 +235,6 @@ pub fn load_splat_from_ply( interleave_coeffs(splat.sh_dc, &splat.sh_coeffs_rest); sh_coeffs.extend(sh_coeffs_interleaved); } - - // Occasionally send some updated splats. - if i % update_every == update_every - 1 { - let splats = Splats::from_raw( - means.clone(), - rotation.clone(), - scales.clone(), - sh_coeffs.clone(), - opacity.clone(), - &device, - ); - - emitter.emit(splats).await; - } - - // Ocassionally yield. - if i % 500 == 0 { - tokio::task::yield_now().await; - } } let splats = Splats::from_raw(means, rotation, scales, sh_coeffs, opacity, &device); diff --git a/crates/brush-viewer/src/panels/load_data.rs b/crates/brush-viewer/src/panels/load_data.rs index 1e9d0773..c655af15 100644 --- a/crates/brush-viewer/src/panels/load_data.rs +++ b/crates/brush-viewer/src/panels/load_data.rs @@ -9,9 +9,8 @@ enum Quality { } pub(crate) struct LoadDataPanel { - max_train_resolution: Option, - max_frames: Option, - eval_split_every: Option, + load_args: LoadDatasetArgs, + sh_degree: u32, quality: Quality, proxy: bool, @@ -23,9 +22,13 @@ impl LoadDataPanel { Self { // Super high resolutions are a bit sketchy. Limit to at least // some size. - max_train_resolution: Some(1920), - max_frames: None, - eval_split_every: None, + load_args: LoadDatasetArgs { + max_frames: None, + max_resolution: Some(1920), + eval_split_every: None, + subsample_frames: None, + subsample_points: None, + }, sh_degree: 3, quality: Quality::Normal, proxy: false, @@ -40,113 +43,143 @@ impl ViewerPanel for LoadDataPanel { } fn ui(&mut self, ui: &mut egui::Ui, context: &mut ViewerContext) { - ui.label("Select a .ply to visualize, or a .zip with training data."); - - let file = ui.button("Load file").clicked(); + egui::ScrollArea::vertical().show(ui, |ui| { - ui.add_space(10.0); + ui.label("Select a .ply to visualize, or a .zip with training data."); - ui.checkbox(&mut self.proxy, "Proxy proxy.brush-splat.workers.dev/") - .on_hover_text("File hosting services often don't allow client-side requests. Using a proxy can solve this. In particular this makes google drive share links work!"); + let file = ui.button("Load file").clicked(); - ui.text_edit_singleline(&mut self.url); + ui.add_space(10.0); - let url = ui.button("Load URL").clicked(); + ui.checkbox(&mut self.proxy, "Proxy proxy.brush-splat.workers.dev/") + .on_hover_text("File hosting services often don't allow client-side requests. Using a proxy can solve this. In particular this makes google drive share links work!"); - ui.add_space(10.0); + ui.text_edit_singleline(&mut self.url); - if file || url { - let load_data_args = LoadDatasetArgs { - max_frames: self.max_frames, - max_resolution: self.max_train_resolution, - eval_split_every: self.eval_split_every, - }; - let load_init_args = LoadInitArgs { - sh_degree: self.sh_degree, - }; + let url = ui.button("Load URL").clicked(); - let mut config = TrainConfig::default(); - if matches!(self.quality, Quality::Low) { - config = config - .with_densify_grad_thresh(0.00035) - .with_refine_every(200); - } + ui.add_space(10.0); - let source = if file { - crate::viewer::DataSource::PickFile - } else { - let url = if !self.proxy { - self.url.to_string() - } else { - format!("https://proxy.brush-splat.workers.dev/{}", self.url) + if file || url { + let load_init_args = LoadInitArgs { + sh_degree: self.sh_degree, }; - crate::viewer::DataSource::Url(url) - }; - context.start_data_load(source, load_data_args, load_init_args, config); - } - ui.add_space(10.0); - ui.heading("Train settings"); + let mut config = TrainConfig::default(); + if matches!(self.quality, Quality::Low) { + config = config + .with_densify_grad_thresh(0.00035) + .with_refine_every(200); + } - ui.label("Spherical Harmonics Degree:"); - ui.add(Slider::new(&mut self.sh_degree, 0..=4)); + let source = if file { + crate::viewer::DataSource::PickFile + } else { + let url = if !self.proxy { + self.url.to_string() + } else { + format!("https://proxy.brush-splat.workers.dev/{}", self.url) + }; + crate::viewer::DataSource::Url(url) + }; + context.start_data_load(source, self.load_args.clone(), load_init_args, config); + } - ui.horizontal(|ui| { - ui.label("Quality:"); + ui.add_space(10.0); + ui.heading("Train settings"); + + ui.label("Spherical Harmonics Degree:"); + ui.add(Slider::new(&mut self.sh_degree, 0..=4)); + + ui.horizontal(|ui| { + ui.label("Quality:"); + if ui + .selectable_label(matches!(self.quality, Quality::Low), "Low") + .clicked() + { + self.quality = Quality::Low; + } + if ui + .selectable_label(matches!(self.quality, Quality::Normal), "Normal") + .clicked() + { + self.quality = Quality::Normal; + } + }); + + let mut limit_res = self.load_args.max_resolution.is_some(); if ui - .selectable_label(matches!(self.quality, Quality::Low), "Low") + .checkbox(&mut limit_res, "Limit training resolution") .clicked() { - self.quality = Quality::Low; + self.load_args.max_resolution = if limit_res { Some(800) } else { None }; + } + + if let Some(target_res) = self.load_args.max_resolution.as_mut() { + ui.add(Slider::new(target_res, 32..=2048)); } + + let mut limit_frames = self.load_args.max_frames.is_some(); + if ui.checkbox(&mut limit_frames, "Limit max frames").clicked() { + self.load_args.max_frames = if limit_frames { Some(32) } else { None }; + } + + if let Some(max_frames) = self.load_args.max_frames.as_mut() { + ui.add(Slider::new(max_frames, 1..=256)); + } + + let mut use_eval_split = self.load_args.eval_split_every.is_some(); if ui - .selectable_label(matches!(self.quality, Quality::Normal), "Normal") + .checkbox(&mut use_eval_split, "Split dataset for evaluation") .clicked() { - self.quality = Quality::Normal; + self.load_args.eval_split_every = if use_eval_split { Some(8) } else { None }; } - }); - - let mut limit_res = self.max_train_resolution.is_some(); - if ui - .checkbox(&mut limit_res, "Limit training resolution") - .clicked() - { - self.max_train_resolution = if limit_res { Some(800) } else { None }; - } - if let Some(target_res) = self.max_train_resolution.as_mut() { - ui.add(Slider::new(target_res, 32..=2048)); - } + if let Some(eval_split) = self.load_args.eval_split_every.as_mut() { + ui.add( + Slider::new(eval_split, 2..=32) + .prefix("1 out of ") + .suffix(" frames"), + ); + } - let mut limit_frames = self.max_frames.is_some(); - if ui.checkbox(&mut limit_frames, "Limit max frames").clicked() { - self.max_frames = if limit_frames { Some(32) } else { None }; - } + let mut use_frame_subsample = self.load_args.subsample_frames.is_some(); + if ui + .checkbox(&mut use_frame_subsample, "Subsample frames") + .clicked() + { + self.load_args.subsample_frames = if use_frame_subsample { Some(2) } else { None }; + } - if let Some(max_frames) = self.max_frames.as_mut() { - ui.add(Slider::new(max_frames, 1..=256)); - } + if let Some(subsample_frames) = self.load_args.subsample_frames.as_mut() { + ui.add( + Slider::new(subsample_frames, 2..=32) + .prefix("Keep 1 out of ") + .suffix(" frames"), + ); + } - let mut use_eval_split = self.eval_split_every.is_some(); - if ui - .checkbox(&mut use_eval_split, "Split dataset for evaluation") - .clicked() - { - self.eval_split_every = if use_eval_split { Some(8) } else { None }; - } + let mut use_point_subsample = self.load_args.subsample_points.is_some(); + if ui + .checkbox(&mut use_point_subsample, "Subsample points") + .clicked() + { + self.load_args.subsample_points = if use_point_subsample { Some(2) } else { None }; + } - if let Some(eval_split) = self.eval_split_every.as_mut() { - ui.add( - Slider::new(eval_split, 2..=32) - .prefix("1 out of ") - .suffix(" frames"), - ); - } + if let Some(subsample_points) = self.load_args.subsample_points.as_mut() { + ui.add( + Slider::new(subsample_points, 2..=32) + .prefix("Keep 1 out of ") + .suffix(" points"), + ); + } - #[cfg(not(target_family = "wasm"))] - if ui.input(|r| r.key_pressed(egui::Key::Escape)) { - ui.ctx().send_viewport_cmd(egui::ViewportCommand::Close); - } + #[cfg(not(target_family = "wasm"))] + if ui.input(|r| r.key_pressed(egui::Key::Escape)) { + ui.ctx().send_viewport_cmd(egui::ViewportCommand::Close); + } + }); } } diff --git a/crates/brush-viewer/src/viewer.rs b/crates/brush-viewer/src/viewer.rs index 73bb5bc3..dde6dd94 100644 --- a/crates/brush-viewer/src/viewer.rs +++ b/crates/brush-viewer/src/viewer.rs @@ -119,7 +119,9 @@ fn process_loop( let _ = emitter .emit(ViewerMessage::StartLoading { training: false }) .await; - let splat_stream = splat_import::load_splat_from_ply(data, device.clone()); + + let subsample = None; // Subsampling a trained ply doesn't really make sense. + let splat_stream = splat_import::load_splat_from_ply(data, subsample, device.clone()); let mut splat_stream = std::pin::pin!(splat_stream); while let Some(splats) = splat_stream.next().await {