Skip to content

Commit 1af482d

Browse files
authored
Add annotation for getting field as Rust Option (#403)
* Add annotation for getting field as Rust Option Given struct Test { field @0 :Text $Rust.getOption; } you get getters like so assert_eq!(struct_with.get_field(), Some("foo")); assert_eq!(struct_without.get_field(), None)); The setters are unchanged to match the Rust convention. Fixes #249. * Fix typo * Support asOption on AnyPointer * use is_pointer_field_null instead of get_pointer_field(...).is_null * Fix bug in StructReader::is_pointer_field_null * add ::core prefix * Make builder getters return Option<Builder> * Rename to Rust.option * Fix: forgot to rename annotation in test
1 parent 977b3db commit 1af482d

File tree

6 files changed

+274
-45
lines changed

6 files changed

+274
-45
lines changed

capnp/src/private/layout.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3479,6 +3479,15 @@ impl<'a> StructReader<'a> {
34793479
}
34803480
}
34813481

3482+
#[inline]
3483+
pub fn is_pointer_field_null(&self, ptr_index: WirePointerCount) -> bool {
3484+
if ptr_index < self.pointer_count as WirePointerCount {
3485+
unsafe { (*self.pointers.add(ptr_index)).is_null() }
3486+
} else {
3487+
true
3488+
}
3489+
}
3490+
34823491
pub fn total_size(&self) -> Result<MessageSize> {
34833492
let mut result = MessageSize {
34843493
word_count: u64::from(wire_helpers::round_bits_up_to_words(u64::from(

capnpc/rust.capnp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,26 @@ annotation parentModule @0xabee386cd1450364 (file) :Text;
2626
# }
2727
# }
2828
# }
29+
30+
annotation option @0xabfef22c4ee1964e (field) :Void;
31+
# Make the generated getters return Option<T> instead of T. Supported on
32+
# pointer types (e.g. structs, lists, and blobs).
33+
#
34+
# Capnp pointer types are nullable. Normally get_field will return the default
35+
# value if the field isn't set. With this annotation you get Some(...) when
36+
# the field is set and None when it isn't.
37+
#
38+
# Given
39+
#
40+
# struct Test {
41+
# field @0 :Text $Rust.option;
42+
# }
43+
#
44+
# you get getters like so
45+
#
46+
# assert_eq!(struct_with.get_field(), Some("foo"));
47+
# assert_eq!(struct_without.get_field(), None));
48+
#
49+
# The setters are unchanged to match the Rust convention.
50+
#
51+
# Note: Support for this annotation on interfaces isn't implemented yet.

capnpc/src/codegen.rs

Lines changed: 115 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,7 @@ fn module_name(camel_case: &str) -> String {
489489
// Annotation IDs, as defined in rust.capnp.
490490
const NAME_ANNOTATION_ID: u64 = 0xc2fe4c6d100166d0;
491491
const PARENT_MODULE_ANNOTATION_ID: u64 = 0xabee386cd1450364;
492+
const OPTION_ANNOTATION_ID: u64 = 0xabfef22c4ee1964e;
492493

493494
fn name_annotation_value(annotation: schema_capnp::annotation::Reader) -> capnp::Result<&str> {
494495
if let schema_capnp::value::Text(t) = annotation.get_value()?.which()? {
@@ -553,6 +554,32 @@ fn capnp_name_to_rust_name(capnp_name: &str, name_kind: NameKind) -> String {
553554
}
554555
}
555556

557+
fn is_option_field(field: schema_capnp::field::Reader) -> capnp::Result<bool> {
558+
use capnp::schema_capnp::*;
559+
560+
let enabled = field
561+
.get_annotations()?
562+
.iter()
563+
.any(|a| a.get_id() == OPTION_ANNOTATION_ID);
564+
565+
if enabled {
566+
let supported = match field.which()? {
567+
field::Which::Group(_) => false,
568+
field::Which::Slot(field) => {
569+
let ty = field.get_type()?;
570+
ty.is_pointer()? && !matches!(ty.which()?, type_::Interface(_))
571+
}
572+
};
573+
if !supported {
574+
return Err(capnp::Error::failed(
575+
"$Rust.option annotation only supported on pointer fields (support for optional interfaces isn't implemented yet)".to_string(),
576+
));
577+
}
578+
}
579+
580+
Ok(enabled)
581+
}
582+
556583
fn prim_default(value: &schema_capnp::value::Reader) -> ::capnp::Result<Option<String>> {
557584
use capnp::schema_capnp::value;
558585
match value.which()? {
@@ -665,37 +692,46 @@ pub fn getter_text(
665692
offset: usize,
666693
default: T,
667694
zero: T,
668-
) -> FormattedText {
695+
) -> String {
669696
if default == zero {
670-
Line(format!("self.{member}.get_data_field::<{typ}>({offset})"))
697+
format!("self.{member}.get_data_field::<{typ}>({offset})")
671698
} else {
672-
Line(format!(
673-
"self.{member}.get_data_field_mask::<{typ}>({offset}, {default})"
674-
))
699+
format!("self.{member}.get_data_field_mask::<{typ}>({offset}, {default})")
675700
}
676701
}
677702

678703
let raw_type = reg_field.get_type()?;
679-
let typ = raw_type.type_string(ctx, module)?;
704+
let inner_type = raw_type.type_string(ctx, module)?;
680705
let default_value = reg_field.get_default_value()?;
681706
let default = default_value.which()?;
682707
let default_name = format!(
683708
"DEFAULT_{}",
684709
snake_to_upper_case(&camel_to_snake_case(get_field_name(*field)?))
685710
);
711+
let should_get_option = is_option_field(*field)?;
686712

687-
let mut result_type = match raw_type.which()? {
688-
type_::Enum(_) => fmt!(ctx, "::core::result::Result<{typ},{capnp}::NotInSchema>"),
689-
type_::AnyPointer(_) if !raw_type.is_parameter()? => typ.clone(),
690-
type_::Interface(_) => {
713+
let typ = if should_get_option {
714+
format!("Option<{}>", inner_type)
715+
} else {
716+
inner_type
717+
};
718+
719+
let (is_fallible, mut result_type) = match raw_type.which()? {
720+
type_::Enum(_) => (
721+
true,
722+
fmt!(ctx, "::core::result::Result<{typ},{capnp}::NotInSchema>"),
723+
),
724+
type_::AnyPointer(_) if !raw_type.is_parameter()? => (false, typ.clone()),
725+
type_::Interface(_) => (
726+
true,
691727
fmt!(
692728
ctx,
693729
"{capnp}::Result<{}>",
694730
raw_type.type_string(ctx, Leaf::Client)?
695-
)
696-
}
697-
_ if raw_type.is_prim()? => typ.clone(),
698-
_ => fmt!(ctx, "{capnp}::Result<{typ}>"),
731+
),
732+
),
733+
_ if raw_type.is_prim()? => (false, typ.clone()),
734+
_ => (true, fmt!(ctx, "{capnp}::Result<{typ}>")),
699735
};
700736

701737
if is_fn {
@@ -706,80 +742,114 @@ pub fn getter_text(
706742
}
707743
}
708744

709-
let getter_code = match (raw_type.which()?, default) {
745+
let getter_fragment = match (raw_type.which()?, default) {
710746
(type_::Void(()), value::Void(())) => {
711747
if is_fn {
712-
Line("".to_string())
748+
"".to_string()
713749
} else {
714-
Line("()".to_string())
750+
"()".to_string()
715751
}
716-
},
752+
}
717753
(type_::Bool(()), value::Bool(b)) => {
718754
if b {
719-
Line(format!("self.{member}.get_bool_field_mask({offset}, true)"))
755+
format!("self.{member}.get_bool_field_mask({offset}, true)")
720756
} else {
721-
Line(format!("self.{member}.get_bool_field({offset})"))
757+
format!("self.{member}.get_bool_field({offset})")
722758
}
723759
}
724760
(type_::Int8(()), value::Int8(i)) => primitive_case(&typ, &member, offset, i, 0),
725761
(type_::Int16(()), value::Int16(i)) => primitive_case(&typ, &member, offset, i, 0),
726762
(type_::Int32(()), value::Int32(i)) => primitive_case(&typ, &member, offset, i, 0),
727763
(type_::Int64(()), value::Int64(i)) => primitive_case(&typ, &member, offset, i, 0),
728764
(type_::Uint8(()), value::Uint8(i)) => primitive_case(&typ, &member, offset, i, 0),
729-
(type_::Uint16(()), value::Uint16(i)) => primitive_case(&typ, &member, offset, i, 0),
730-
(type_::Uint32(()), value::Uint32(i)) => primitive_case(&typ, &member, offset, i, 0),
731-
(type_::Uint64(()), value::Uint64(i)) => primitive_case(&typ, &member, offset, i, 0),
732-
(type_::Float32(()), value::Float32(f)) =>
733-
primitive_case(&typ, &member, offset, f.to_bits(), 0),
734-
(type_::Float64(()), value::Float64(f)) =>
735-
primitive_case(&typ, &member, offset, f.to_bits(), 0),
765+
(type_::Uint16(()), value::Uint16(i)) => {
766+
primitive_case(&typ, &member, offset, i, 0)
767+
}
768+
(type_::Uint32(()), value::Uint32(i)) => {
769+
primitive_case(&typ, &member, offset, i, 0)
770+
}
771+
(type_::Uint64(()), value::Uint64(i)) => {
772+
primitive_case(&typ, &member, offset, i, 0)
773+
}
774+
(type_::Float32(()), value::Float32(f)) => {
775+
primitive_case(&typ, &member, offset, f.to_bits(), 0)
776+
}
777+
(type_::Float64(()), value::Float64(f)) => {
778+
primitive_case(&typ, &member, offset, f.to_bits(), 0)
779+
}
736780
(type_::Enum(_), value::Enum(d)) => {
737781
if d == 0 {
738-
Line(format!("::core::convert::TryInto::try_into(self.{member}.get_data_field::<u16>({offset}))"))
782+
format!("::core::convert::TryInto::try_into(self.{member}.get_data_field::<u16>({offset}))")
739783
} else {
740-
Line(
741-
format!(
742-
"::core::convert::TryInto::try_into(self.{member}.get_data_field_mask::<u16>({offset}, {d}))"))
784+
format!(
785+
"::core::convert::TryInto::try_into(self.{member}.get_data_field_mask::<u16>({offset}, {d}))")
743786
}
744787
}
745788

746-
(type_::Text(()), value::Text(_)) |
747-
(type_::Data(()), value::Data(_)) |
748-
(type_::List(_), value::List(_)) |
749-
(type_::Struct(_), value::Struct(_)) => {
789+
(type_::Text(()), value::Text(_))
790+
| (type_::Data(()), value::Data(_))
791+
| (type_::List(_), value::List(_))
792+
| (type_::Struct(_), value::Struct(_)) => {
750793
let default = if reg_field.get_had_explicit_default() {
751794
default_decl = Some(crate::pointer_constants::word_array_declaration(
752795
ctx,
753796
&default_name,
754797
::capnp::raw::get_struct_pointer_section(default_value).get(0),
755-
crate::pointer_constants::WordArrayDeclarationOptions {public: true})?);
756-
format!("Some(&_private::{default_name}[..])")
798+
crate::pointer_constants::WordArrayDeclarationOptions { public: true },
799+
)?);
800+
format!("::core::option::Option::Some(&_private::{default_name}[..])")
757801
} else {
758802
"::core::option::Option::None".to_string()
759803
};
760804

761805
if is_reader {
762-
Line(fmt!(ctx,
763-
"{capnp}::traits::FromPointerReader::get_from_pointer(&self.{member}.get_pointer_field({offset}), {default})"))
806+
fmt!(ctx,
807+
"{capnp}::traits::FromPointerReader::get_from_pointer(&self.{member}.get_pointer_field({offset}), {default})")
764808
} else {
765-
Line(fmt!(ctx,"{capnp}::traits::FromPointerBuilder::get_from_pointer(self.{member}.get_pointer_field({offset}), {default})"))
809+
fmt!(ctx,"{capnp}::traits::FromPointerBuilder::get_from_pointer(self.{member}.get_pointer_field({offset}), {default})")
766810
}
767811
}
768812

769813
(type_::Interface(_), value::Interface(_)) => {
770-
Line(fmt!(ctx,"match self.{member}.get_pointer_field({offset}).get_capability() {{ ::core::result::Result::Ok(c) => ::core::result::Result::Ok({capnp}::capability::FromClientHook::new(c)), ::core::result::Result::Err(e) => ::core::result::Result::Err(e)}}"))
814+
fmt!(ctx,"match self.{member}.get_pointer_field({offset}).get_capability() {{ ::core::result::Result::Ok(c) => ::core::result::Result::Ok({capnp}::capability::FromClientHook::new(c)), ::core::result::Result::Err(e) => ::core::result::Result::Err(e)}}")
771815
}
772816
(type_::AnyPointer(_), value::AnyPointer(_)) => {
773817
if !raw_type.is_parameter()? {
774-
Line(fmt!(ctx,"{capnp}::any_pointer::{module_string}::new(self.{member}.get_pointer_field({offset}))"))
818+
fmt!(ctx,"{capnp}::any_pointer::{module_string}::new(self.{member}.get_pointer_field({offset}))")
775819
} else if is_reader {
776-
Line(fmt!(ctx,"{capnp}::traits::FromPointerReader::get_from_pointer(&self.{member}.get_pointer_field({offset}), ::core::option::Option::None)"))
820+
fmt!(ctx,"{capnp}::traits::FromPointerReader::get_from_pointer(&self.{member}.get_pointer_field({offset}), ::core::option::Option::None)")
777821
} else {
778-
Line(fmt!(ctx,"{capnp}::traits::FromPointerBuilder::get_from_pointer(self.{member}.get_pointer_field({offset}), ::core::option::Option::None)"))
822+
fmt!(ctx,"{capnp}::traits::FromPointerBuilder::get_from_pointer(self.{member}.get_pointer_field({offset}), ::core::option::Option::None)")
779823
}
780824
}
781825
_ => return Err(Error::failed("default value was of wrong type".to_string())),
782826
};
827+
828+
let getter_code = if should_get_option {
829+
Branch(vec![
830+
Line(format!(
831+
"if self.{member}.is_pointer_field_null({offset}) {{"
832+
)),
833+
Indent(Box::new(Line(
834+
if is_fallible {
835+
"core::result::Result::Ok(core::option::Option::None)"
836+
} else {
837+
"::core::option::Option::None"
838+
}
839+
.to_string(),
840+
))),
841+
Line("} else {".to_string()),
842+
Indent(Box::new(Line(if is_fallible {
843+
format!("{getter_fragment}.map(::core::option::Option::Some)")
844+
} else {
845+
format!("::core::option::Option::Some({getter_fragment})")
846+
}))),
847+
Line("}".to_string()),
848+
])
849+
} else {
850+
Line(getter_fragment)
851+
};
852+
783853
Ok((result_type, getter_code, default_decl))
784854
}
785855
}
@@ -2454,7 +2524,7 @@ fn generate_node(
24542524
)));
24552525

24562526
client_impl_interior.push(Indent(Box::new(Line(format!(
2457-
"self.client.new_call(_private::TYPE_ID, {ordinal}, None)"
2527+
"self.client.new_call(_private::TYPE_ID, {ordinal}, ::core::option::Option::None)"
24582528
)))));
24592529
client_impl_interior.push(Line("}".to_string()));
24602530

capnpc/src/codegen_types.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ pub trait RustNodeInfo {
9797
// this is a collection of helpers acting on a "Type" (someplace where a Type is used, not defined)
9898
pub trait RustTypeInfo {
9999
fn is_prim(&self) -> Result<bool, Error>;
100+
fn is_pointer(&self) -> Result<bool, Error>;
100101
fn is_parameter(&self) -> Result<bool, Error>;
101102
fn is_branded(&self) -> Result<bool, Error>;
102103
fn type_string(&self, ctx: &GeneratorContext, module: Leaf) -> Result<String, Error>;
@@ -330,6 +331,19 @@ impl<'a> RustTypeInfo for type_::Reader<'a> {
330331
_ => Ok(false),
331332
}
332333
}
334+
335+
#[inline(always)]
336+
fn is_pointer(&self) -> Result<bool, Error> {
337+
Ok(matches!(
338+
self.which()?,
339+
type_::Text(())
340+
| type_::Data(())
341+
| type_::List(_)
342+
| type_::Struct(_)
343+
| type_::Interface(_)
344+
| type_::AnyPointer(_)
345+
))
346+
}
333347
}
334348

335349
///

capnpc/test/test.capnp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,20 @@ struct TestNewUnionVersion {
446446
}
447447
}
448448

449+
struct TestFieldGetOption {
450+
text @0 :Text $Rust.option;
451+
data @1 :Data $Rust.option;
452+
list @2 :List(UInt8) $Rust.option;
453+
emptyStruct @3 :EmptyStruct $Rust.option;
454+
simpleStruct @4 :SimpleStruct $Rust.option;
455+
any @5 :AnyPointer $Rust.option;
456+
457+
struct EmptyStruct {}
458+
struct SimpleStruct {
459+
field @0 :Text $Rust.option;
460+
}
461+
}
462+
449463
struct TestGenerics(Foo, Bar) {
450464
foo @0 :Foo;
451465
bar @1 :Bar;

0 commit comments

Comments
 (0)