diff --git a/prost-build/src/config.rs b/prost-build/src/config.rs index af57497f1..a696e404b 100644 --- a/prost-build/src/config.rs +++ b/prost-build/src/config.rs @@ -826,8 +826,8 @@ impl Config { self.write_includes( modules.keys().collect(), &mut file, - 0, if target_is_env { None } else { Some(&target) }, + &file_names, )?; file.flush()?; } @@ -955,67 +955,58 @@ impl Config { self.compile_fds(file_descriptor_set) } - fn write_includes( + pub(crate) fn write_includes( &self, - mut entries: Vec<&Module>, - outfile: &mut fs::File, - depth: usize, + mut modules: Vec<&Module>, + outfile: &mut impl Write, basepath: Option<&PathBuf>, - ) -> Result { - let mut written = 0; - entries.sort(); - - while !entries.is_empty() { - let modident = entries[0].part(depth); - let matching: Vec<&Module> = entries - .iter() - .filter(|&v| v.part(depth) == modident) - .copied() - .collect(); - { - // Will NLL sort this mess out? - let _temp = entries - .drain(..) - .filter(|&v| v.part(depth) != modident) - .collect(); - entries = _temp; + file_names: &HashMap, + ) -> Result<()> { + modules.sort(); + + let mut stack = Vec::new(); + + for module in modules { + while !module.starts_with(&stack) { + stack.pop(); + self.write_line(outfile, stack.len(), "}")?; } - self.write_line(outfile, depth, &format!("pub mod {} {{", modident))?; - let subwritten = self.write_includes( - matching - .iter() - .filter(|v| v.len() > depth + 1) - .copied() - .collect(), - outfile, - depth + 1, - basepath, - )?; - written += subwritten; - if subwritten != matching.len() { - let modname = matching[0].to_partial_file_name(..=depth); - if basepath.is_some() { - self.write_line( - outfile, - depth + 1, - &format!("include!(\"{}.rs\");", modname), - )?; - } else { - self.write_line( - outfile, - depth + 1, - &format!("include!(concat!(env!(\"OUT_DIR\"), \"/{}.rs\"));", modname), - )?; - } - written += 1; + while stack.len() < module.len() { + self.write_line( + outfile, + stack.len(), + &format!("pub mod {} {{", module.part(stack.len())), + )?; + stack.push(module.part(stack.len()).to_owned()); } + let file_name = file_names + .get(module) + .expect("every module should have a filename"); + + if basepath.is_some() { + self.write_line( + outfile, + stack.len(), + &format!("include!(\"{}\");", file_name), + )?; + } else { + self.write_line( + outfile, + stack.len(), + &format!("include!(concat!(env!(\"OUT_DIR\"), \"/{}\"));", file_name), + )?; + } + } + + for depth in (0..stack.len()).rev() { self.write_line(outfile, depth, "}")?; } - Ok(written) + + Ok(()) } - fn write_line(&self, outfile: &mut fs::File, depth: usize, line: &str) -> Result<()> { + fn write_line(&self, outfile: &mut impl Write, depth: usize, line: &str) -> Result<()> { outfile.write_all(format!("{}{}\n", (" ").to_owned().repeat(depth), line).as_bytes()) } diff --git a/prost-build/src/fixtures/write_includes/_.includes.rs b/prost-build/src/fixtures/write_includes/_.includes.rs new file mode 100644 index 000000000..99b555635 --- /dev/null +++ b/prost-build/src/fixtures/write_includes/_.includes.rs @@ -0,0 +1,23 @@ +include!(concat!(env!("OUT_DIR"), "/_.default.rs")); +pub mod bar { + include!(concat!(env!("OUT_DIR"), "/bar.rs")); +} +pub mod foo { + include!(concat!(env!("OUT_DIR"), "/foo.rs")); + pub mod bar { + include!(concat!(env!("OUT_DIR"), "/foo.bar.rs")); + pub mod a { + pub mod b { + pub mod c { + include!(concat!(env!("OUT_DIR"), "/foo.bar.a.b.c.rs")); + } + } + } + pub mod baz { + include!(concat!(env!("OUT_DIR"), "/foo.bar.baz.rs")); + } + pub mod qux { + include!(concat!(env!("OUT_DIR"), "/foo.bar.qux.rs")); + } + } +} diff --git a/prost-build/src/lib.rs b/prost-build/src/lib.rs index a3c62d60b..304f82e4b 100644 --- a/prost-build/src/lib.rs +++ b/prost-build/src/lib.rs @@ -530,4 +530,32 @@ mod tests { f.read_to_string(&mut content).unwrap(); content } + + #[test] + fn write_includes() { + let modules = [ + Module::from_protobuf_package_name("foo.bar.baz"), + Module::from_protobuf_package_name(""), + Module::from_protobuf_package_name("foo.bar"), + Module::from_protobuf_package_name("bar"), + Module::from_protobuf_package_name("foo"), + Module::from_protobuf_package_name("foo.bar.qux"), + Module::from_protobuf_package_name("foo.bar.a.b.c"), + ]; + + let file_names = modules + .iter() + .map(|m| (m.clone(), m.to_file_name_or("_.default"))) + .collect(); + + let mut buf = Vec::new(); + Config::new() + .default_package_filename("_.default") + .write_includes(modules.iter().collect(), &mut buf, None, &file_names) + .unwrap(); + let expected = + read_all_content("src/fixtures/write_includes/_.includes.rs").replace("\r\n", "\n"); + let actual = String::from_utf8(buf).unwrap(); + assert_eq!(expected, actual); + } } diff --git a/prost-build/src/module.rs b/prost-build/src/module.rs index 21cab1163..02715c16e 100644 --- a/prost-build/src/module.rs +++ b/prost-build/src/module.rs @@ -1,5 +1,4 @@ use std::fmt; -use std::ops::RangeToInclusive; use crate::ident::to_snake; @@ -40,6 +39,15 @@ impl Module { self.components.iter().map(|s| s.as_str()) } + #[must_use] + #[inline(always)] + pub(crate) fn starts_with(&self, needle: &[String]) -> bool + where + String: PartialEq, + { + self.components.starts_with(needle) + } + /// Format the module path into a filename for generated Rust code. /// /// If the module path is empty, `default` is used to provide the root of the filename. @@ -65,10 +73,6 @@ impl Module { self.components.is_empty() } - pub(crate) fn to_partial_file_name(&self, range: RangeToInclusive) -> String { - self.components[range].join(".") - } - pub(crate) fn part(&self, idx: usize) -> &str { self.components[idx].as_str() } diff --git a/tests/src/build.rs b/tests/src/build.rs index 7b2c0bc5e..0403d2603 100644 --- a/tests/src/build.rs +++ b/tests/src/build.rs @@ -178,6 +178,7 @@ fn main() { no_root_packages_config .out_dir(&no_root_packages) .default_package_filename("__.default") + .include_file("__.include.rs") .compile_protos( &[src.join("no_root_packages/widget_factory.proto")], &[src.join("no_root_packages")], diff --git a/tests/src/no_root_packages/mod.rs b/tests/src/no_root_packages/mod.rs index 2025eb983..7ef3b4c4e 100644 --- a/tests/src/no_root_packages/mod.rs +++ b/tests/src/no_root_packages/mod.rs @@ -16,6 +16,10 @@ pub mod widget { } } +pub mod generated_include { + include!(concat!(env!("OUT_DIR"), "/no_root_packages/__.include.rs")); +} + #[test] fn test() { use prost::Message; @@ -44,3 +48,32 @@ fn test() { widget_factory.gizmo_inner = Some(gizmo::gizmo::Inner {}); assert_eq!(14, widget_factory.encoded_len()); } + +#[test] +fn generated_include() { + use prost::Message; + + let mut widget_factory = generated_include::widget::factory::WidgetFactory::default(); + assert_eq!(0, widget_factory.encoded_len()); + + widget_factory.inner = Some(generated_include::widget::factory::widget_factory::Inner {}); + assert_eq!(2, widget_factory.encoded_len()); + + widget_factory.root = Some(generated_include::Root {}); + assert_eq!(4, widget_factory.encoded_len()); + + widget_factory.root_inner = Some(generated_include::root::Inner {}); + assert_eq!(6, widget_factory.encoded_len()); + + widget_factory.widget = Some(generated_include::widget::Widget {}); + assert_eq!(8, widget_factory.encoded_len()); + + widget_factory.widget_inner = Some(generated_include::widget::widget::Inner {}); + assert_eq!(10, widget_factory.encoded_len()); + + widget_factory.gizmo = Some(generated_include::gizmo::Gizmo {}); + assert_eq!(12, widget_factory.encoded_len()); + + widget_factory.gizmo_inner = Some(generated_include::gizmo::gizmo::Inner {}); + assert_eq!(14, widget_factory.encoded_len()); +}