diff --git a/docs/docs/comment_dsl.mdx b/docs/docs/comment_dsl.mdx index 994b8aa..4195069 100644 --- a/docs/docs/comment_dsl.mdx +++ b/docs/docs/comment_dsl.mdx @@ -53,6 +53,16 @@ script = [ With code like `foo = uint` this creates an alias e.g. `pub type Foo = u64;` in rust. When we use `foo = uint ; @newtype` it instead creates a `pub struct Foo(u64);`. +`@newtype` can also optionally specify a getter function e.g. `foo = uint ; @newtype custom_getter` will generate: + +```rust +impl Foo { + pub fn custom_getter(&self) -> u64 { + self.0 + } +} +``` + ## @no_alias ```cddl diff --git a/src/comment_ast.rs b/src/comment_ast.rs index d9756b3..f52c556 100644 --- a/src/comment_ast.rs +++ b/src/comment_ast.rs @@ -9,7 +9,8 @@ use nom::{ #[derive(Clone, Default, Debug, PartialEq)] pub struct RuleMetadata { pub name: Option, - pub is_newtype: bool, + /// None = not newtype, Some(Some) = generate getter, Some(None) = no getter + pub newtype: Option>, pub no_alias: bool, pub used_as_key: bool, pub custom_json: bool, @@ -18,56 +19,46 @@ pub struct RuleMetadata { pub comment: Option, } -pub fn merge_metadata(r1: &RuleMetadata, r2: &RuleMetadata) -> RuleMetadata { - let merged = RuleMetadata { - name: match (r1.name.as_ref(), r2.name.as_ref()) { - (Some(val1), Some(val2)) => { - panic!("Key \"name\" specified twice: {:?} {:?}", val1, val2) - } - (val @ Some(_), _) => val.cloned(), - (_, val) => val.cloned(), - }, - is_newtype: r1.is_newtype || r2.is_newtype, - no_alias: r1.no_alias || r2.no_alias, - used_as_key: r1.used_as_key || r2.used_as_key, - custom_json: r1.custom_json || r2.custom_json, - custom_serialize: match (r1.custom_serialize.as_ref(), r2.custom_serialize.as_ref()) { +macro_rules! merge_metadata_fields { + ($lhs:expr, $rhs:expr, $field_name:literal) => { + match ($lhs.as_ref(), $rhs.as_ref()) { (Some(val1), Some(val2)) => { panic!( - "Key \"custom_serialize\" specified twice: {:?} {:?}", + concat!("Key \"", $field_name, "\" specified twice: {:?} {:?}"), val1, val2 ) } (val @ Some(_), _) => val.cloned(), (_, val) => val.cloned(), - }, - custom_deserialize: match ( - r1.custom_deserialize.as_ref(), - r2.custom_deserialize.as_ref(), - ) { - (Some(val1), Some(val2)) => { - panic!( - "Key \"custom_deserialize\" specified twice: {:?} {:?}", - val1, val2 - ) - } - (val @ Some(_), _) => val.cloned(), - (_, val) => val.cloned(), - }, - comment: match (r1.comment.as_ref(), r2.comment.as_ref()) { - (Some(val1), Some(val2)) => { - panic!("Key \"comment\" specified twice: {:?} {:?}", val1, val2) - } - (val @ Some(_), _) => val.cloned(), - (_, val) => val.cloned(), - }, + } + }; +} + +pub fn merge_metadata(r1: &RuleMetadata, r2: &RuleMetadata) -> RuleMetadata { + let merged = RuleMetadata { + name: merge_metadata_fields!(r1.name, r2.name, "name"), + newtype: merge_metadata_fields!(r1.newtype, r2.newtype, "newtype"), + no_alias: r1.no_alias || r2.no_alias, + used_as_key: r1.used_as_key || r2.used_as_key, + custom_json: r1.custom_json || r2.custom_json, + custom_serialize: merge_metadata_fields!( + r1.custom_serialize, + r2.custom_serialize, + "custom_serialize" + ), + custom_deserialize: merge_metadata_fields!( + r1.custom_deserialize, + r2.custom_deserialize, + "custom_deserialize" + ), + comment: merge_metadata_fields!(r1.comment, r2.comment, "comment"), }; merged.verify(); merged } enum ParseResult { - NewType, + NewType(Option), Name(String), DontGenAlias, UsedAsKey, @@ -77,21 +68,30 @@ enum ParseResult { Comment(String), } +macro_rules! merge_parse_fields { + ($base:expr, $new:expr, $field_name:literal) => { + match $base.as_ref() { + Some(old) => { + panic!( + concat!("Key \"", $field_name, "\" specified twice: {:?} {:?}"), + old, $new + ) + } + None => { + $base = Some($new.to_owned()); + } + } + }; +} + impl RuleMetadata { fn from_parse_results(results: &[ParseResult]) -> RuleMetadata { let mut base = RuleMetadata::default(); for result in results { match result { - ParseResult::Name(name) => match base.name.as_ref() { - Some(old_name) => { - panic!("Key \"name\" specified twice: {:?} {:?}", old_name, name) - } - None => { - base.name = Some(name.to_string()); - } - }, - ParseResult::NewType => { - base.is_newtype = true; + ParseResult::Name(name) => merge_parse_fields!(base.name, name, "name"), + ParseResult::NewType(newtype) => { + merge_parse_fields!(base.newtype, newtype, "newtype") } ParseResult::DontGenAlias => { base.no_alias = true; @@ -104,39 +104,16 @@ impl RuleMetadata { base.custom_json = true; } ParseResult::CustomSerialize(custom_serialize) => { - match base.custom_serialize.as_ref() { - Some(old) => { - panic!( - "Key \"custom_serialize\" specified twice: {:?} {:?}", - old, custom_serialize - ) - } - None => { - base.custom_serialize = Some(custom_serialize.to_string()); - } - } + merge_parse_fields!(base.custom_serialize, custom_serialize, "custom_serialize") } - ParseResult::CustomDeserialize(custom_deserialize) => { - match base.custom_deserialize.as_ref() { - Some(old) => { - panic!( - "Key \"custom_deserialize\" specified twice: {:?} {:?}", - old, custom_deserialize - ) - } - None => { - base.custom_deserialize = Some(custom_deserialize.to_string()); - } - } + ParseResult::CustomDeserialize(custom_deserialize) => merge_parse_fields!( + base.custom_deserialize, + custom_deserialize, + "custom_deserialize" + ), + ParseResult::Comment(comment) => { + merge_parse_fields!(base.comment, comment, "comment") } - ParseResult::Comment(comment) => match base.comment.as_ref() { - Some(old) => { - panic!("Key \"comment\" specified twice: {:?} {:?}", old, comment) - } - None => { - base.comment = Some(comment.to_string()); - } - }, } } base.verify(); @@ -144,7 +121,7 @@ impl RuleMetadata { } fn verify(&self) { - if self.is_newtype && self.no_alias { + if self.newtype.is_some() && self.no_alias { // this would make no sense anyway as with newtype we're already not making an alias panic!("cannot use both @newtype and @no_alias on the same alias"); } @@ -161,8 +138,16 @@ fn tag_name(input: &str) -> IResult<&str, ParseResult> { fn tag_newtype(input: &str) -> IResult<&str, ParseResult> { let (input, _) = tag("@newtype")(input)?; - - Ok((input, ParseResult::NewType)) + // to get around type annotations + fn parse_newtype(input: &str) -> IResult<&str, ParseResult> { + let (input, _) = take_while(char::is_whitespace)(input)?; + let (input, getter) = take_while1(|ch| !char::is_whitespace(ch) && ch != '@')(input)?; + Ok((input, ParseResult::NewType(Some(getter.trim().to_owned())))) + } + match parse_newtype(input) { + Ok(ret) => Ok(ret), + Err(_) => Ok((input.trim_start(), ParseResult::NewType(None))), + } } fn tag_no_alias(input: &str) -> IResult<&str, ParseResult> { @@ -261,7 +246,7 @@ fn parse_comment_name() { "", RuleMetadata { name: Some("foo".to_string()), - is_newtype: false, + newtype: None, no_alias: false, used_as_key: false, custom_json: false, @@ -281,7 +266,7 @@ fn parse_comment_newtype() { "", RuleMetadata { name: None, - is_newtype: true, + newtype: Some(None), no_alias: false, used_as_key: false, custom_json: false, @@ -293,6 +278,46 @@ fn parse_comment_newtype() { ); } +#[test] +fn parse_comment_newtype_getter_before() { + assert_eq!( + rule_metadata("@newtype custom_getter @used_as_key"), + Ok(( + "", + RuleMetadata { + name: None, + newtype: Some(Some("custom_getter".to_owned())), + no_alias: false, + used_as_key: true, + custom_json: false, + custom_serialize: None, + custom_deserialize: None, + comment: None, + } + )) + ); +} + +#[test] +fn parse_comment_newtype_getter_after() { + assert_eq!( + rule_metadata("@used_as_key @newtype custom_getter"), + Ok(( + "", + RuleMetadata { + name: None, + newtype: Some(Some("custom_getter".to_owned())), + no_alias: false, + used_as_key: true, + custom_json: false, + custom_serialize: None, + custom_deserialize: None, + comment: None, + } + )) + ); +} + #[test] fn parse_comment_newtype_and_name() { assert_eq!( @@ -301,7 +326,7 @@ fn parse_comment_newtype_and_name() { "", RuleMetadata { name: Some("foo".to_string()), - is_newtype: true, + newtype: Some(None), no_alias: false, used_as_key: false, custom_json: false, @@ -321,7 +346,7 @@ fn parse_comment_newtype_and_name_and_used_as_key() { "", RuleMetadata { name: Some("foo".to_string()), - is_newtype: true, + newtype: Some(None), no_alias: false, used_as_key: true, custom_json: false, @@ -341,7 +366,7 @@ fn parse_comment_used_as_key() { "", RuleMetadata { name: None, - is_newtype: false, + newtype: None, no_alias: false, used_as_key: true, custom_json: false, @@ -361,7 +386,7 @@ fn parse_comment_newtype_and_name_inverse() { "", RuleMetadata { name: Some("foo".to_string()), - is_newtype: true, + newtype: Some(None), no_alias: false, used_as_key: false, custom_json: false, @@ -381,7 +406,7 @@ fn parse_comment_name_noalias() { "", RuleMetadata { name: Some("foo".to_string()), - is_newtype: false, + newtype: None, no_alias: true, used_as_key: false, custom_json: false, @@ -401,7 +426,7 @@ fn parse_comment_newtype_and_custom_json() { "", RuleMetadata { name: None, - is_newtype: true, + newtype: Some(None), no_alias: false, used_as_key: false, custom_json: true, @@ -427,7 +452,7 @@ fn parse_comment_custom_serialize_deserialize() { "", RuleMetadata { name: None, - is_newtype: false, + newtype: None, no_alias: false, used_as_key: false, custom_json: false, @@ -448,7 +473,7 @@ fn parse_comment_all_except_no_alias() { "", RuleMetadata { name: Some("baz".to_string()), - is_newtype: true, + newtype: Some(None), no_alias: false, used_as_key: true, custom_json: true, diff --git a/src/generation.rs b/src/generation.rs index fe1ca88..d785fcc 100644 --- a/src/generation.rs +++ b/src/generation.rs @@ -7560,12 +7560,14 @@ fn generate_wrapper_struct( ToWasmBoundaryOperations::format(ops.into_iter()) )); } - let mut get = codegen::Function::new("get"); - get.vis("pub") - .arg_ref_self() - .ret(field_type.for_wasm_return(types)) - .line(field_type.to_wasm_boundary(types, "self.0.get()", false)); - wrapper.s_impl.push_fn(get); + if let Some(Some(getter)) = struct_config.newtype_getter.as_ref() { + let mut get = codegen::Function::new(getter); + get.vis("pub") + .arg_ref_self() + .ret(field_type.for_wasm_return(types)) + .line(field_type.to_wasm_boundary(types, &format!("self.0.{getter}()"), false)); + wrapper.s_impl.push_fn(get); + } wrapper.push(gen_scope, types); } @@ -7734,29 +7736,28 @@ fn generate_wrapper_struct( } Some(enc_fields) } else { - s.tuple_field( - Some("pub".to_string()), - field_type.for_rust_member(types, false, cli), - ); + s.tuple_field(None, field_type.for_rust_member(types, false, cli)); None }; // TODO: is there a way to know if the encoding object is also copyable? if field_type.is_copy(types) && !cli.preserve_encodings { s.derive("Copy"); } - let mut get = codegen::Function::new("get"); - get.vis("pub").arg_ref_self(); - if field_type.is_copy(types) { - get.ret(field_type.for_rust_member(types, false, cli)) - .line(field_type.clone_if_not_copy(types, self_var)); - } else { - get.ret(format!( - "&{}", - field_type.for_rust_member(types, false, cli) - )) - .line(format!("&{self_var}")); + if let Some(Some(getter)) = struct_config.newtype_getter.as_ref() { + let mut get = codegen::Function::new(getter); + get.vis("pub").arg_ref_self(); + if field_type.is_copy(types) { + get.ret(field_type.for_rust_member(types, false, cli)) + .line(field_type.clone_if_not_copy(types, self_var)); + } else { + get.ret(format!( + "&{}", + field_type.for_rust_member(types, false, cli) + )) + .line(format!("&{self_var}")); + } + s_impl.push_fn(get); } - s_impl.push_fn(get); let mut ser_func = make_serialization_function("serialize", cli); let mut ser_impl = make_serialization_impl(type_name.as_ref(), cli); gen_scope.generate_serialize( diff --git a/src/intermediate.rs b/src/intermediate.rs index d348e5a..1265a4b 100644 --- a/src/intermediate.rs +++ b/src/intermediate.rs @@ -2308,6 +2308,7 @@ pub struct RustStructConfig { pub custom_serialize: Option, pub custom_deserialize: Option, pub doc: Option, + pub newtype_getter: Option>, } impl From> for RustStructConfig { @@ -2318,6 +2319,7 @@ impl From> for RustStructConfig { custom_serialize: rule_metadata.custom_serialize.clone(), custom_deserialize: rule_metadata.custom_deserialize.clone(), doc: rule_metadata.comment.clone(), + newtype_getter: rule_metadata.newtype.clone(), }, None => Self::default(), } diff --git a/src/parsing.rs b/src/parsing.rs index aba1efc..2dfc457 100644 --- a/src/parsing.rs +++ b/src/parsing.rs @@ -517,7 +517,9 @@ fn parse_type( min_max.1, ident_to_primitive(&cddl_ident).unwrap(), ); - if ranged_type.config.bounds.is_some() || rule_metadata.is_newtype { + if ranged_type.config.bounds.is_some() + || rule_metadata.newtype.is_some() + { // without bounds since passed in other param ranged_type.config.bounds = None; // has non-rust-primitive matching bounds @@ -603,7 +605,7 @@ fn parse_type( )) } None => { - if rule_metadata.is_newtype { + if rule_metadata.newtype.is_some() { types.register_rust_struct( parent_visitor, RustStruct::new_wrapper( @@ -1469,7 +1471,7 @@ fn parse_group_choice( }; let rust_struct = match parse_group_type(types, parent_visitor, group_choice, rep, cli) { GroupParsingType::HomogenousArray(element_type) => { - if rule_metadata.is_newtype { + if rule_metadata.newtype.is_some() { // generate newtype over array RustStruct::new_wrapper( name.clone(), @@ -1484,7 +1486,7 @@ fn parse_group_choice( } } GroupParsingType::HomogenousMap(key_type, value_type) => { - if rule_metadata.is_newtype { + if rule_metadata.newtype.is_some() { // generate newtype over map RustStruct::new_wrapper( name.clone(), @@ -1506,7 +1508,7 @@ fn parse_group_choice( } GroupParsingType::Heterogenous | GroupParsingType::WrappedBasicGroup(_) => { assert!( - !rule_metadata.is_newtype, + rule_metadata.newtype.is_none(), "Can only use @newtype on primtives + heterogenious arrays/maps" ); // Heterogenous map or array with defined key/value pairs in the cddl like a struct @@ -1551,7 +1553,7 @@ pub fn parse_group( if generic_params.is_some() { todo!("{}: generic group choices not supported", name); } - assert!(!parent_rule_metadata.is_newtype); + assert!(parent_rule_metadata.newtype.is_none()); // Generate Enum object that is not exposed to wasm, since wasm can't expose // fully featured rust enums via wasm_bindgen diff --git a/tests/comment-dsl/input.cddl b/tests/comment-dsl/input.cddl index 186aa23..32ba935 100644 --- a/tests/comment-dsl/input.cddl +++ b/tests/comment-dsl/input.cddl @@ -21,7 +21,7 @@ typechoice = / [1, bytes] ; @name case_2 -protocol_magic = uint ; @newtype +protocol_magic = uint ; @newtype get typechoice_variants = text ; @name case_1 diff --git a/tests/core/input.cddl b/tests/core/input.cddl index 7fa0a8c..4228663 100644 --- a/tests/core/input.cddl +++ b/tests/core/input.cddl @@ -187,6 +187,7 @@ top_level_single_elem = [uint] wrapper_table = { * uint => uint } ; @newtype wrapper_list = [ * uint ] ; @newtype +wrapper_int = uint ; @newtype custom_getter overlapping_inlined = [ ; @name one diff --git a/tests/core/tests.rs b/tests/core/tests.rs index a026cde..fa27ddf 100644 --- a/tests/core/tests.rs +++ b/tests/core/tests.rs @@ -575,6 +575,12 @@ mod tests { deser_test(&from_bytes); } + #[test] + fn wrapper_getter() { + let x = WrapperInt::new(128); + assert_eq!(128, x.custom_getter()); + } + #[test] fn docs() { use std::str::FromStr;