Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(prost-build): Generate less boxed if nested type is boxed manually #1160

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion prost-build/src/code_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1075,7 +1075,7 @@ impl<'a> CodeGenerator<'a> {
&& (fd_type == Type::Message || fd_type == Type::Group)
&& self
.message_graph
.is_nested(field.type_name(), fq_message_name)
.is_directly_nested(field.type_name(), fq_message_name)
{
return true;
}
Expand Down
142 changes: 129 additions & 13 deletions prost-build/src/message_graph.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};

use petgraph::algo::has_path_connecting;
use petgraph::graph::NodeIndex;
use petgraph::Graph;
use petgraph::visit::{EdgeRef, VisitMap};
use petgraph::{Direction, Graph};

use prost_types::{
field_descriptor_proto::{Label, Type},
Expand All @@ -15,9 +15,13 @@ use crate::path::PathMap;
/// The goal is to recognize when message types are recursively nested, so
/// that fields can be boxed when necessary.
pub struct MessageGraph {
/// Map<fq type name, graph node index>
index: HashMap<String, NodeIndex>,
graph: Graph<String, ()>,
/// Graph with fq type name as node, field name as edge
graph: Graph<String, String>,
/// Map<fq type name, DescriptorProto>
messages: HashMap<String, DescriptorProto>,
/// Manually boxed fields
boxed: PathMap<()>,
}

Expand Down Expand Up @@ -71,7 +75,8 @@ impl MessageGraph {
for field in &msg.field {
if field.r#type() == Type::Message && field.label() != Label::Repeated {
let field_index = self.get_or_insert_index(field.type_name.clone().unwrap());
self.graph.add_edge(msg_index, field_index, ());
self.graph
.add_edge(msg_index, field_index, field.name.clone().unwrap());
}
}
self.messages.insert(msg_name.clone(), msg.clone());
Expand All @@ -86,8 +91,9 @@ impl MessageGraph {
self.messages.get(message)
}

/// Returns true if message type `inner` is nested in message type `outer`.
pub fn is_nested(&self, outer: &str, inner: &str) -> bool {
/// Returns true if message type `inner` is nested in message type `outer`,
/// and no field edge in the chain of dependencies is manually boxed.
pub fn is_directly_nested(&self, outer: &str, inner: &str) -> bool {
let outer = match self.index.get(outer) {
Some(outer) => *outer,
None => return false,
Expand All @@ -97,7 +103,12 @@ impl MessageGraph {
None => return false,
};

has_path_connecting(&self.graph, outer, inner, None)
// Check if `inner` is nested in `outer` and ensure that all edge fields are not boxed manually.
is_connected_with_edge_filter(&self.graph, outer, inner, |node, field_name| {
self.boxed
.get_first_field(&self.graph[node], field_name)
.is_none()
})
}

/// Returns `true` if this message can automatically derive Copy trait.
Expand All @@ -123,11 +134,11 @@ impl MessageGraph {
false
} else if field.r#type() == Type::Message {
// nested and boxed messages cannot derive Copy
if self.is_nested(field.type_name(), fq_message_name)
|| self
.boxed
.get_first_field(fq_message_name, field.name())
.is_some()
if self
.boxed
.get_first_field(fq_message_name, field.name())
.is_some()
|| self.is_directly_nested(field.type_name(), fq_message_name)
{
false
} else {
Expand All @@ -154,3 +165,108 @@ impl MessageGraph {
}
}
}

/// Check two nodes is connected with edge filter
fn is_connected_with_edge_filter<F, N, E>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very nifty, but it is not the core strength of prost. I understand what the function is supposed to do, but I don't understand how it works. This should either be documented better or move to petgraph library.

graph: &Graph<N, E>,
start: NodeIndex,
end: NodeIndex,
mut is_good_edge: F,
) -> bool
where
F: FnMut(NodeIndex, &E) -> bool,
{
fn visitor<F, N, E>(
graph: &Graph<N, E>,
start: NodeIndex,
end: NodeIndex,
is_good_edge: &mut F,
visited: &mut HashSet<NodeIndex>,
) -> bool
where
F: FnMut(NodeIndex, &E) -> bool,
{
if start == end {
return true;
}
visited.visit(start);
for edge in graph.edges_directed(start, Direction::Outgoing) {
// if the edge doesn't pass the filter, skip it
if !is_good_edge(start, edge.weight()) {
continue;
}
let target = edge.target();
if visited.is_visited(&target) {
continue;
}
if visitor(graph, target, end, is_good_edge, visited) {
return true;
}
}
false
}
let mut visited = HashSet::new();
visitor(graph, start, end, &mut is_good_edge, &mut visited)
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_connected() {
let mut graph = Graph::new();
let n1 = graph.add_node(1);
let n2 = graph.add_node(2);
let n3 = graph.add_node(3);
let n4 = graph.add_node(4);
let n5 = graph.add_node(5);
let n6 = graph.add_node(6);
let n7 = graph.add_node(7);
let n8 = graph.add_node(8);
graph.add_edge(n1, n2, 1.);
graph.add_edge(n2, n3, 2.);
graph.add_edge(n3, n4, 3.);
graph.add_edge(n4, n5, 4.);
graph.add_edge(n5, n6, 5.);
graph.add_edge(n6, n7, 6.);
graph.add_edge(n7, n8, 7.);
graph.add_edge(n8, n1, 8.);
assert!(is_connected_with_edge_filter(&graph, n2, n1, |_, edge| {
dbg!(edge);
true
}),);
assert!(is_connected_with_edge_filter(&graph, n2, n1, |_, edge| {
dbg!(edge);
edge < &8.5
}),);
assert!(!is_connected_with_edge_filter(&graph, n2, n1, |_, edge| {
dbg!(edge);
edge < &7.5
}),);
}

#[test]
fn test_connected_multi_circle() {
let mut graph = Graph::new();
let n0 = graph.add_node(0);
let n1 = graph.add_node(1);
let n2 = graph.add_node(2);
let n3 = graph.add_node(3);
let n4 = graph.add_node(4);
graph.add_edge(n0, n1, 0.);
graph.add_edge(n1, n2, 1.);
graph.add_edge(n2, n3, 2.);
graph.add_edge(n3, n0, 3.);
graph.add_edge(n1, n4, 1.5);
graph.add_edge(n4, n0, 2.5);
assert!(is_connected_with_edge_filter(&graph, n1, n0, |_, edge| {
dbg!(edge);
edge < &2.8
}),);
assert!(!is_connected_with_edge_filter(&graph, n1, n0, |_, edge| {
dbg!(edge);
edge < &2.1
}),);
}
}
13 changes: 13 additions & 0 deletions tests/src/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,19 @@ fn main() {

std::fs::create_dir_all(&out_path).unwrap();

prost_build::Config::new()
.out_dir(src.join("nesting_complex/boxed"))
.boxed("Foo.bar")
.boxed("BazB.baz_c")
.boxed("BakC.bak_d")
.compile_protos(&[src.join("nesting_complex.proto")], includes)
.unwrap();

prost_build::Config::new()
.out_dir(src.join("nesting_complex/"))
.compile_protos(&[src.join("nesting_complex.proto")], includes)
.unwrap();

prost_build::Config::new()
.bytes(["."])
.out_dir(out_path)
Expand Down
8 changes: 8 additions & 0 deletions tests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ pub mod proto3 {
}
}

pub mod nesting_complex_boxed {
include!("nesting_complex/boxed/nesting_complex.rs");
}

pub mod nesting_complex {
include!("nesting_complex/nesting_complex.rs");
}

pub mod invalid {
pub mod doctest {
include!(concat!(env!("OUT_DIR"), "/invalid.doctest.rs"));
Expand Down
47 changes: 47 additions & 0 deletions tests/src/nesting_complex.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
syntax = "proto2";

package nesting_complex;

// ----- Directly nested
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to have more documentation explaining how this structure is nested.

message Foo {
optional Bar bar = 1;
}

message Bar {
optional Foo foo = 1;
}

// ----- Transitively nested
message BazA {
optional BazB baz_b = 1;
}

message BazB {
optional BazC baz_c = 1;
}

message BazC {
optional BazA baz_a = 1;
}

// ----- Transitively nested in two chain
message BakA {
optional BakB bak_b = 1;
}

message BakB {
optional BakC bak_c = 1;
optional BakE bak_e = 2;
}

message BakC {
optional BakD bak_d = 1;
}

message BakD {
optional BakA bak_a = 1;
}

message BakE {
optional BakA bak_a = 1;
}
56 changes: 56 additions & 0 deletions tests/src/nesting_complex/boxed/nesting_complex.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// This file is @generated by prost-build.
/// ----- Directly nested
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Foo {
#[prost(message, optional, boxed, tag = "1")]
pub bar: ::core::option::Option<::prost::alloc::boxed::Box<Bar>>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Bar {
#[prost(message, optional, tag = "1")]
pub foo: ::core::option::Option<Foo>,
}
/// ----- Transitively nested
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BazA {
#[prost(message, optional, tag = "1")]
pub baz_b: ::core::option::Option<BazB>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BazB {
#[prost(message, optional, boxed, tag = "1")]
pub baz_c: ::core::option::Option<::prost::alloc::boxed::Box<BazC>>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BazC {
#[prost(message, optional, tag = "1")]
pub baz_a: ::core::option::Option<BazA>,
}
/// ----- Transitively nested in two chain
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BakA {
#[prost(message, optional, boxed, tag = "1")]
pub bak_b: ::core::option::Option<::prost::alloc::boxed::Box<BakB>>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BakB {
#[prost(message, optional, tag = "1")]
pub bak_c: ::core::option::Option<BakC>,
#[prost(message, optional, boxed, tag = "2")]
pub bak_e: ::core::option::Option<::prost::alloc::boxed::Box<BakE>>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BakC {
#[prost(message, optional, boxed, tag = "1")]
pub bak_d: ::core::option::Option<::prost::alloc::boxed::Box<BakD>>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BakD {
#[prost(message, optional, tag = "1")]
pub bak_a: ::core::option::Option<BakA>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BakE {
#[prost(message, optional, boxed, tag = "1")]
pub bak_a: ::core::option::Option<::prost::alloc::boxed::Box<BakA>>,
}
56 changes: 56 additions & 0 deletions tests/src/nesting_complex/nesting_complex.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// This file is @generated by prost-build.
/// ----- Directly nested
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Foo {
#[prost(message, optional, boxed, tag = "1")]
pub bar: ::core::option::Option<::prost::alloc::boxed::Box<Bar>>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Bar {
#[prost(message, optional, boxed, tag = "1")]
pub foo: ::core::option::Option<::prost::alloc::boxed::Box<Foo>>,
}
/// ----- Transitively nested
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BazA {
#[prost(message, optional, boxed, tag = "1")]
pub baz_b: ::core::option::Option<::prost::alloc::boxed::Box<BazB>>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BazB {
#[prost(message, optional, boxed, tag = "1")]
pub baz_c: ::core::option::Option<::prost::alloc::boxed::Box<BazC>>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BazC {
#[prost(message, optional, boxed, tag = "1")]
pub baz_a: ::core::option::Option<::prost::alloc::boxed::Box<BazA>>,
}
/// ----- Transitively nested in two chain
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BakA {
#[prost(message, optional, boxed, tag = "1")]
pub bak_b: ::core::option::Option<::prost::alloc::boxed::Box<BakB>>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BakB {
#[prost(message, optional, boxed, tag = "1")]
pub bak_c: ::core::option::Option<::prost::alloc::boxed::Box<BakC>>,
#[prost(message, optional, boxed, tag = "2")]
pub bak_e: ::core::option::Option<::prost::alloc::boxed::Box<BakE>>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BakC {
#[prost(message, optional, boxed, tag = "1")]
pub bak_d: ::core::option::Option<::prost::alloc::boxed::Box<BakD>>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BakD {
#[prost(message, optional, boxed, tag = "1")]
pub bak_a: ::core::option::Option<::prost::alloc::boxed::Box<BakA>>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BakE {
#[prost(message, optional, boxed, tag = "1")]
pub bak_a: ::core::option::Option<::prost::alloc::boxed::Box<BakA>>,
}