diff --git a/near/nep141-locker/src/lib.rs b/near/nep141-locker/src/lib.rs index 5789a0ad..7efeada2 100644 --- a/near/nep141-locker/src/lib.rs +++ b/near/nep141-locker/src/lib.rs @@ -469,7 +469,7 @@ impl Contract { &mut self, #[serializer(borsh)] storage_deposit_args: StorageDepositArgs, #[serializer(borsh)] predecessor_account_id: AccountId, - #[serializer(borsh)] native_fee_recipient: OmniAddress, + #[serializer(borsh)] native_fee_recipient: Option, ) -> PromiseOrValue { let Ok(ProverResult::InitTransfer(init_transfer)) = Self::decode_prover_result(0) else { env::panic_str("Invalid proof message") @@ -485,13 +485,22 @@ impl Contract { let mut required_balance; if let OmniAddress::Near(recipient) = &transfer_message.recipient { - required_balance = self.add_fin_transfer( - &transfer_message.get_transfer_id(), - &Some(NativeFee { + let native_fee = if transfer_message.fee.native_fee.0 != 0 { + let recipient = native_fee_recipient.sdk_expect("ERR_FEE_RECIPIENT_NOT_SET"); + require!( + transfer_message.get_origin_chain() == recipient.get_chain(), + "ERR_WRONG_FEE_RECIPIENT_CHAIN" + ); + Some(NativeFee { amount: transfer_message.fee.native_fee, - recipient: native_fee_recipient, - }), - ); + recipient, + }) + } else { + None + }; + + required_balance = + self.add_fin_transfer(&transfer_message.get_transfer_id(), &native_fee); let recipient: NearRecipient = recipient.parse().sdk_expect("Failed to parse recipient"); @@ -616,7 +625,7 @@ impl Contract { #[payable] pub fn claim_fee_callback( &mut self, - #[serializer(borsh)] native_fee_recipient: OmniAddress, + #[serializer(borsh)] native_fee_recipient: Option, #[serializer(borsh)] predecessor_account_id: AccountId, #[callback_result] #[serializer(borsh)] @@ -640,6 +649,12 @@ impl Contract { let fee = message.amount.0 - fin_transfer.amount.0; if message.fee.native_fee.0 != 0 { + let native_fee_recipient = native_fee_recipient.sdk_expect("ERR_FEE_RECIPIENT_NOT_SET"); + require!( + message.get_origin_chain() == native_fee_recipient.get_chain(), + "ERR_WRONG_FEE_RECIPIENT_CHAIN" + ); + if message.get_origin_chain() == ChainKind::Near { let OmniAddress::Near(recipient) = &native_fee_recipient else { env::panic_str("ERR_WRONG_CHAIN_KIND") @@ -667,7 +682,6 @@ impl Contract { env::log_str( &Nep141LockerEvent::ClaimFeeEvent { transfer_message: message, - native_fee_recipient, } .to_log_string(), ); diff --git a/near/omni-types/src/locker_args.rs b/near/omni-types/src/locker_args.rs index 6c02a516..1278c26a 100644 --- a/near/omni-types/src/locker_args.rs +++ b/near/omni-types/src/locker_args.rs @@ -13,7 +13,7 @@ pub struct StorageDepositArgs { #[derive(BorshDeserialize, BorshSerialize, Clone)] pub struct FinTransferArgs { pub chain_kind: ChainKind, - pub native_fee_recipient: OmniAddress, + pub native_fee_recipient: Option, pub storage_deposit_args: StorageDepositArgs, pub prover_args: Vec, } @@ -22,7 +22,7 @@ pub struct FinTransferArgs { pub struct ClaimFeeArgs { pub chain_kind: ChainKind, pub prover_args: Vec, - pub native_fee_recipient: OmniAddress, + pub native_fee_recipient: Option, } #[derive(BorshDeserialize, BorshSerialize, Clone)] diff --git a/near/omni-types/src/near_events.rs b/near/omni-types/src/near_events.rs index 5489bc07..08672983 100644 --- a/near/omni-types/src/near_events.rs +++ b/near/omni-types/src/near_events.rs @@ -33,7 +33,6 @@ pub enum Nep141LockerEvent { }, ClaimFeeEvent { transfer_message: TransferMessage, - native_fee_recipient: OmniAddress, }, }