Skip to content

Commit

Permalink
fix(prost-build): Remove derived(Copy) on boxed fields (#1157)
Browse files Browse the repository at this point in the history
* fix(prost-build): Remove `derived(Copy)` on boxed fields

* Add regression test
  • Loading branch information
ldm0 authored Sep 20, 2024
1 parent 8d4cac5 commit fb977f4
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 8 deletions.
2 changes: 1 addition & 1 deletion prost-build/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,7 @@ impl Config {
let mut modules = HashMap::new();
let mut packages = HashMap::new();

let message_graph = MessageGraph::new(requests.iter().map(|x| &x.1));
let message_graph = MessageGraph::new(requests.iter().map(|x| &x.1), self.boxed.clone());
let extern_paths = ExternPaths::new(&self.extern_paths, self.prost_types)
.map_err(|error| Error::new(ErrorKind::InvalidInput, error))?;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub struct Foo {
#[prost(string, tag="1")]
pub foo: ::prost::alloc::string::String,
}
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Bar {
#[prost(message, optional, boxed, tag="1")]
pub qux: ::core::option::Option<::prost::alloc::boxed::Box<Qux>>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub struct Foo {
#[prost(string, tag = "1")]
pub foo: ::prost::alloc::string::String,
}
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Bar {
#[prost(message, optional, boxed, tag = "1")]
pub qux: ::core::option::Option<::prost::alloc::boxed::Box<Qux>>,
Expand Down
28 changes: 24 additions & 4 deletions prost-build/src/message_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,28 @@ use prost_types::{
DescriptorProto, FieldDescriptorProto, FileDescriptorProto,
};

use crate::path::PathMap;

/// `MessageGraph` builds a graph of messages whose edges correspond to nesting.
/// The goal is to recognize when message types are recursively nested, so
/// that fields can be boxed when necessary.
pub struct MessageGraph {
index: HashMap<String, NodeIndex>,
graph: Graph<String, ()>,
messages: HashMap<String, DescriptorProto>,
boxed: PathMap<()>,
}

impl MessageGraph {
pub fn new<'a>(files: impl Iterator<Item = &'a FileDescriptorProto>) -> MessageGraph {
pub(crate) fn new<'a>(
files: impl Iterator<Item = &'a FileDescriptorProto>,
boxed: PathMap<()>,
) -> MessageGraph {
let mut msg_graph = MessageGraph {
index: HashMap::new(),
graph: Graph::new(),
messages: HashMap::new(),
boxed,
};

for file in files {
Expand Down Expand Up @@ -74,6 +81,11 @@ impl MessageGraph {
}
}

/// Try get a message descriptor from current message graph
pub fn get_message(&self, message: &str) -> Option<&DescriptorProto> {
self.messages.get(message)
}

/// Returns true if message type `inner` is nested in message type `outer`.
pub fn is_nested(&self, outer: &str, inner: &str) -> bool {
let outer = match self.index.get(outer) {
Expand All @@ -91,8 +103,9 @@ impl MessageGraph {
/// Returns `true` if this message can automatically derive Copy trait.
pub fn can_message_derive_copy(&self, fq_message_name: &str) -> bool {
assert_eq!(".", &fq_message_name[..1]);
let msg = self.messages.get(fq_message_name).unwrap();
msg.field
self.get_message(fq_message_name)
.unwrap()
.field
.iter()
.all(|field| self.can_field_derive_copy(fq_message_name, field))
}
Expand All @@ -105,10 +118,17 @@ impl MessageGraph {
) -> bool {
assert_eq!(".", &fq_message_name[..1]);

// repeated field cannot derive Copy
if field.label() == Label::Repeated {
false
} else if field.r#type() == Type::Message {
if self.is_nested(field.type_name(), fq_message_name) {
// nested and boxed messages cannot derive Copy
if self.is_nested(field.type_name(), fq_message_name)
|| self
.boxed
.get_first_field(fq_message_name, field.name())
.is_some()
{
false
} else {
self.can_message_derive_copy(field.type_name())
Expand Down
2 changes: 1 addition & 1 deletion prost-build/src/path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use std::iter;

/// Maps a fully-qualified Protobuf path to a value using path matchers.
#[derive(Debug, Default)]
#[derive(Clone, Debug, Default)]
pub(crate) struct PathMap<T> {
// insertion order might actually matter (to avoid warning about legacy-derive-helpers)
// see: https://doc.rust-lang.org/rustc/lints/listing/warn-by-default.html#legacy-derive-helpers
Expand Down
10 changes: 10 additions & 0 deletions tests/src/boxed_field.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
syntax = "proto3";

package boxed_field;

message Foo {
Bar bar = 1;
}

message Bar {
}
5 changes: 5 additions & 0 deletions tests/src/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,11 @@ fn main() {
.compile_protos(&[src.join("type_names.proto")], includes)
.unwrap();

prost_build::Config::new()
.boxed("Foo.bar")
.compile_protos(&[src.join("boxed_field.proto")], includes)
.unwrap();

// Check that attempting to compile a .proto without a package declaration does not result in an error.
config
.compile_protos(&[src.join("no_package.proto")], includes)
Expand Down
4 changes: 4 additions & 0 deletions tests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ pub mod invalid {
}
}

pub mod boxed_field {
include!(concat!(env!("OUT_DIR"), "/boxed_field.rs"));
}

pub mod default_string_escape {
include!(concat!(env!("OUT_DIR"), "/default_string_escape.rs"));
}
Expand Down

0 comments on commit fb977f4

Please sign in to comment.