diff --git a/contracts/examples/adder/src/adder.rs b/contracts/examples/adder/src/adder.rs index 307f4112f5..ebb327d524 100644 --- a/contracts/examples/adder/src/adder.rs +++ b/contracts/examples/adder/src/adder.rs @@ -13,18 +13,18 @@ pub trait Adder { fn sum(&self) -> SingleValueMapper; #[init] - fn init(&self, initial_value: BigUint) { + fn init(&mut self, initial_value: BigUint) { self.sum().set(initial_value); } #[upgrade] - fn upgrade(&self, initial_value: BigUint) { + fn upgrade(&mut self, initial_value: BigUint) { self.init(initial_value); } /// Add desired amount to the storage variable. #[endpoint] - fn add(&self, value: BigUint) { + fn add(&mut self, value: BigUint) { self.sum().update(|sum| *sum += value); } } diff --git a/framework/derive/src/model/endpoint_mutability_metadata.rs b/framework/derive/src/model/endpoint_mutability_metadata.rs index 08da088f1b..4fa6eb175d 100644 --- a/framework/derive/src/model/endpoint_mutability_metadata.rs +++ b/framework/derive/src/model/endpoint_mutability_metadata.rs @@ -1,4 +1,4 @@ -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum EndpointMutabilityMetadata { Mutable, Readonly, @@ -19,4 +19,8 @@ impl EndpointMutabilityMetadata { }, } } + + pub fn is_mutable(&self) -> bool { + matches!(self, EndpointMutabilityMetadata::Mutable) + } } diff --git a/framework/derive/src/model/method.rs b/framework/derive/src/model/method.rs index 8a4e8ec5eb..6cc4be9d3e 100644 --- a/framework/derive/src/model/method.rs +++ b/framework/derive/src/model/method.rs @@ -12,6 +12,7 @@ pub enum AutoImpl { StorageClear { identifier: String }, ProxyGetter, } + #[derive(Clone, Debug)] pub enum MethodImpl { /// Implementation auto-generated by the framework. There can (obviously) be only one per method. @@ -31,6 +32,12 @@ impl MethodImpl { } } +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum MethodMutability { + Readonly, + Mutable, +} + /// Models any method argument from a contract, module or callable proxy trait. #[derive(Clone, Debug)] pub struct Method { @@ -39,6 +46,7 @@ pub struct Method { pub name: syn::Ident, pub generics: syn::Generics, pub unprocessed_attributes: Vec, + pub method_mutability: MethodMutability, pub method_args: Vec, pub title: Option, pub output_names: Vec, @@ -113,4 +121,21 @@ impl Method { PublicRole::Private => false, } } + + pub fn accepts_labels(&self) -> bool { + matches!( + self.public_role, + PublicRole::Init(_) + | PublicRole::Endpoint(_) + | PublicRole::CallbackPromise(_) + | PublicRole::Upgrade(_) + ) + } + + pub fn accepts_mut_self(&self) -> bool { + match &self.public_role { + PublicRole::Endpoint(endpoint_metadata) => endpoint_metadata.mutability.is_mutable(), + _ => true, + } + } } diff --git a/framework/derive/src/parse/argument_parse.rs b/framework/derive/src/parse/argument_parse.rs index 1eaf386214..a423d0d11a 100644 --- a/framework/derive/src/parse/argument_parse.rs +++ b/framework/derive/src/parse/argument_parse.rs @@ -1,37 +1,42 @@ use super::attributes::*; -use crate::model::{ArgMetadata, ArgPaymentMetadata, MethodArgument}; +use crate::model::{ArgMetadata, ArgPaymentMetadata, MethodArgument, MethodMutability}; -pub fn extract_method_args(m: &syn::TraitItemFn) -> Vec { +pub fn extract_method_args(m: &syn::TraitItemFn) -> (MethodMutability, Vec) { if m.sig.inputs.is_empty() { missing_self_panic(m); } - let mut receiver_processed = false; - m.sig + let mut receiver_mutability = Option::::None; + let arguments = m + .sig .inputs .iter() .filter_map(|arg| match arg { syn::FnArg::Receiver(selfref) => { - if selfref.mutability.is_some() || receiver_processed { - missing_self_panic(m); + if selfref.mutability.is_some() { + receiver_mutability = Some(MethodMutability::Mutable); + } else { + receiver_mutability = Some(MethodMutability::Readonly); } - receiver_processed = true; None }, syn::FnArg::Typed(pat_typed) => { - if !receiver_processed { + if receiver_mutability.is_none() { missing_self_panic(m); } Some(extract_method_arg(pat_typed)) }, }) - .collect() + .collect(); + + let method_mutability = receiver_mutability.unwrap_or_else(|| missing_self_panic(m)); + (method_mutability, arguments) } fn missing_self_panic(m: &syn::TraitItemFn) -> ! { panic!( - "Trait method `{}` must have `&self` as its first argument.", + "Trait method `{}` must have `&self` or `&mut self` as its first argument.", m.sig.ident ) } diff --git a/framework/derive/src/parse/method_parse.rs b/framework/derive/src/parse/method_parse.rs index b5fa90b54d..d28ad8c79b 100644 --- a/framework/derive/src/parse/method_parse.rs +++ b/framework/derive/src/parse/method_parse.rs @@ -1,4 +1,6 @@ -use crate::model::{Method, MethodImpl, MethodPayableMetadata, PublicRole, TraitProperties}; +use crate::model::{ + Method, MethodImpl, MethodMutability, MethodPayableMetadata, PublicRole, TraitProperties, +}; use super::{ attributes::extract_doc, @@ -25,7 +27,7 @@ pub struct MethodAttributesPass1 { } pub fn process_method(m: &syn::TraitItemFn, trait_attributes: &TraitProperties) -> Method { - let method_args = extract_method_args(m); + let (method_mutability, method_args) = extract_method_args(m); let implementation = if let Some(body) = m.default.clone() { MethodImpl::Explicit(body) @@ -55,6 +57,7 @@ pub fn process_method(m: &syn::TraitItemFn, trait_attributes: &TraitProperties) name: m.sig.ident.clone(), generics: m.sig.generics.clone(), unprocessed_attributes: Vec::new(), + method_mutability, method_args, title: None, output_names: Vec::new(), @@ -138,12 +141,17 @@ fn process_attribute_second_pass( } fn validate_method(method: &Method) { + let method_name = method.name.to_string(); + + if !method.accepts_mut_self() { + assert!( + method.method_mutability == MethodMutability::Readonly, + "Method '{method_name}' cannot take &mut self, since it is declared as readonly", + ); + } + assert!( - matches!( - method.public_role, - PublicRole::Init(_) | PublicRole::Endpoint(_) | PublicRole::CallbackPromise(_) | PublicRole::Upgrade(_) - ) || method.label_names.is_empty(), - "Labels can only be placed on endpoints, constructors, and promises callbacks. Method '{}' is neither.", - &method.name.to_string() + method.accepts_labels() || method.label_names.is_empty(), + "Labels can only be placed on endpoints, constructors, and promises callbacks. Method '{method_name}' is neither", ) }