Skip to content

Commit

Permalink
override charge fee way when charge storage fee
Browse files Browse the repository at this point in the history
  • Loading branch information
wangjj9219 committed Dec 17, 2023
1 parent 225df86 commit 18456aa
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 61 deletions.
152 changes: 115 additions & 37 deletions modules/transaction-payment/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,13 @@ pub mod module {
pub const RESERVE_ID: ReserveIdentifier = ReserveIdentifier::TransactionPayment;
pub const DEPOSIT_ID: ReserveIdentifier = ReserveIdentifier::TransactionPaymentDeposit;

#[derive(Encode, Decode, Clone, PartialEq, RuntimeDebug, TypeInfo)]
pub enum ChargeFeeWay {
FeeCurrency(CurrencyId),
FeePath(Vec<CurrencyId>),
FeeAggregatedPath(Vec<AggregatedSwapPath<CurrencyId>>),
}

#[pallet::config]
pub trait Config: frame_system::Config {
type RuntimeEvent: From<Event<Self>> + IsType<<Self as frame_system::Config>::RuntimeEvent>;
Expand Down Expand Up @@ -494,7 +501,15 @@ pub mod module {
#[pallet::getter(fn swap_balance_threshold)]
pub type SwapBalanceThreshold<T: Config> = StorageMap<_, Twox64Concat, CurrencyId, Balance, ValueQuery>;

/// The override charge fee way.
///
/// OverrideChargeFeeWay: ChargeFeeWay
#[pallet::storage]
#[pallet::getter(fn override_charge_fee_way)]
pub type OverrideChargeFeeWay<T: Config> = StorageValue<_, ChargeFeeWay, OptionQuery>;

#[pallet::pallet]
#[pallet::without_storage_info]
pub struct Pallet<T>(_);

#[pallet::hooks]
Expand Down Expand Up @@ -842,6 +857,63 @@ where
}
}

fn charge_fee_path(
who: &T::AccountId,
fee: PalletBalanceOf<T>,
fee_swap_path: &[CurrencyId],
) -> Result<(T::AccountId, Balance), DispatchError> {
let custom_fee_surplus = T::CustomFeeSurplus::get().mul_ceil(fee);
T::Swap::swap_by_path(
who,
fee_swap_path,
SwapLimit::ExactTarget(Balance::MAX, fee.saturating_add(custom_fee_surplus)),
)
.map(|_| (who.clone(), custom_fee_surplus))
}

fn charge_fee_aggregated_path(
who: &T::AccountId,
fee: PalletBalanceOf<T>,
fee_aggregated_path: &[AggregatedSwapPath<CurrencyId>],
) -> Result<(T::AccountId, Balance), DispatchError> {
let custom_fee_surplus = T::CustomFeeSurplus::get().mul_ceil(fee);
T::Swap::swap_by_aggregated_path(
who,
fee_aggregated_path,
SwapLimit::ExactTarget(Balance::MAX, fee.saturating_add(custom_fee_surplus)),
)
.map(|_| (who.clone(), custom_fee_surplus))
}

fn charge_fee_currency(
who: &T::AccountId,
fee: PalletBalanceOf<T>,
fee_currency_id: CurrencyId,
) -> Result<(T::AccountId, Balance), DispatchError> {
let alternative_fee_surplus = T::AlternativeFeeSurplus::get().mul_ceil(fee);
let custom_fee_surplus = T::CustomFeeSurplus::get().mul_ceil(fee);

let (fee_amount, fee_surplus) = if T::DefaultFeeTokens::get().contains(&fee_currency_id) {
(fee.saturating_add(alternative_fee_surplus), alternative_fee_surplus)
} else {
(fee.saturating_add(custom_fee_surplus), custom_fee_surplus)
};

if TokenExchangeRate::<T>::contains_key(fee_currency_id) {
// token in charge fee pool should have `TokenExchangeRate` info.
Self::swap_from_pool_or_dex(who, fee_amount, fee_currency_id).map(|_| (who.clone(), fee_surplus))
} else {
// `supply_currency_id` not in charge fee pool, direct swap.
T::Swap::swap(
who,
fee_currency_id,
T::NativeCurrencyId::get(),
SwapLimit::ExactTarget(Balance::MAX, fee_amount),
)
.map(|_| (who.clone(), fee_surplus))
}
}

/// Determine the fee and surplus that should be withdraw from user. There are three kind call:
/// - TransactionPayment::with_fee_currency: swap with tx fee pool if token is enable charge fee
/// pool, else swap with dex.
Expand All @@ -855,68 +927,50 @@ where
) -> Result<(T::AccountId, Balance), DispatchError> {
match call.is_sub_type() {
Some(Call::with_fee_path { fee_swap_path, .. }) => {
// pre check before set OverrideChargeFeeWay
ensure!(
fee_swap_path.len() > 1
fee_swap_path.len() <= T::TradingPathLimit::get() as usize
&& fee_swap_path.len() > 1
&& fee_swap_path.first() != Some(&T::NativeCurrencyId::get())
&& fee_swap_path.last() == Some(&T::NativeCurrencyId::get()),
Error::<T>::InvalidSwapPath
);

// put in storage after check
OverrideChargeFeeWay::<T>::put(ChargeFeeWay::FeePath(fee_swap_path.clone()));

let fee = Self::check_native_is_not_enough(who, fee, reason).map_or_else(|| fee, |amount| amount);
let custom_fee_surplus = T::CustomFeeSurplus::get().mul_ceil(fee);
T::Swap::swap_by_path(
who,
fee_swap_path,
SwapLimit::ExactTarget(Balance::MAX, fee.saturating_add(custom_fee_surplus)),
)
.map(|_| (who.clone(), custom_fee_surplus))
Self::charge_fee_path(who, fee, fee_swap_path)
}
Some(Call::with_fee_aggregated_path {
fee_aggregated_path, ..
}) => {
let last_should_be_dex = fee_aggregated_path.last();
match last_should_be_dex {
Some(AggregatedSwapPath::<CurrencyId>::Dex(fee_swap_path)) => {
// pre check before set OverrideChargeFeeWay
ensure!(
fee_swap_path.len() > 1
&& fee_swap_path.first() != Some(&T::NativeCurrencyId::get())
fee_swap_path.len() <= T::TradingPathLimit::get() as usize
&& fee_swap_path.len() > 1 && fee_swap_path.first() != Some(&T::NativeCurrencyId::get())
&& fee_swap_path.last() == Some(&T::NativeCurrencyId::get()),
Error::<T>::InvalidSwapPath
);

// put in storage after check
OverrideChargeFeeWay::<T>::put(ChargeFeeWay::FeeAggregatedPath(fee_aggregated_path.clone()));

let fee =
Self::check_native_is_not_enough(who, fee, reason).map_or_else(|| fee, |amount| amount);
let custom_fee_surplus = T::CustomFeeSurplus::get().mul_ceil(fee);
T::Swap::swap_by_aggregated_path(
who,
fee_aggregated_path,
SwapLimit::ExactTarget(Balance::MAX, fee.saturating_add(custom_fee_surplus)),
)
.map(|_| (who.clone(), custom_fee_surplus))
Self::charge_fee_aggregated_path(who, fee, fee_aggregated_path)
}
_ => Err(Error::<T>::InvalidSwapPath.into()),
}
}
Some(Call::with_fee_currency { currency_id, .. }) => {
OverrideChargeFeeWay::<T>::put(ChargeFeeWay::FeeCurrency(*currency_id));

let fee = Self::check_native_is_not_enough(who, fee, reason).map_or_else(|| fee, |amount| amount);
let alternative_fee_surplus = T::AlternativeFeeSurplus::get().mul_ceil(fee);
let custom_fee_surplus = T::CustomFeeSurplus::get().mul_ceil(fee);
let (fee_amount, fee_surplus) = if T::DefaultFeeTokens::get().contains(currency_id) {
(fee.saturating_add(alternative_fee_surplus), alternative_fee_surplus)
} else {
(fee.saturating_add(custom_fee_surplus), custom_fee_surplus)
};
if TokenExchangeRate::<T>::contains_key(currency_id) {
// token in charge fee pool should have `TokenExchangeRate` info.
Self::swap_from_pool_or_dex(who, fee_amount, *currency_id).map(|_| (who.clone(), fee_surplus))
} else {
// `supply_currency_id` not in charge fee pool, direct swap.
T::Swap::swap(
who,
*currency_id,
T::NativeCurrencyId::get(),
SwapLimit::ExactTarget(Balance::MAX, fee.saturating_add(custom_fee_surplus)),
)
.map(|_| (who.clone(), custom_fee_surplus))
}
Self::charge_fee_currency(who, fee, *currency_id)
}
_ => Self::native_then_alternative_or_default(who, fee, reason).map(|surplus| (who.clone(), surplus)),
}
Expand Down Expand Up @@ -944,6 +998,27 @@ where
) -> Result<Balance, DispatchError> {
if let Some(amount) = Self::check_native_is_not_enough(who, fee, reason) {
// native asset is not enough

// if override charge fee way, charge fee by the config firstly.
match OverrideChargeFeeWay::<T>::get() {
Some(ChargeFeeWay::FeeCurrency(fee_currency_id)) => {
if let Ok((_, surplus)) = Self::charge_fee_currency(who, amount, fee_currency_id) {
return Ok(surplus);
}
}
Some(ChargeFeeWay::FeePath(fee_path)) => {
if let Ok((_, surplus)) = Self::charge_fee_path(who, amount, &fee_path) {
return Ok(surplus);
}
}
Some(ChargeFeeWay::FeeAggregatedPath(fee_aggregated_path)) => {
if let Ok((_, surplus)) = Self::charge_fee_aggregated_path(who, amount, &fee_aggregated_path) {
return Ok(surplus);
}
}
None => {}
}

let fee_surplus = T::AlternativeFeeSurplus::get().mul_ceil(fee);
let fee_amount = fee_surplus.saturating_add(amount);
let custom_fee_surplus = T::CustomFeeSurplus::get().mul_ceil(fee);
Expand Down Expand Up @@ -1417,6 +1492,9 @@ where
// distribute fee
<T as Config>::OnTransactionPayment::on_unbalanceds(Some(fee).into_iter().chain(Some(tip)));

// reset OverrideChargeFeeWay
OverrideChargeFeeWay::<T>::kill();

Pallet::<T>::deposit_event(Event::<T>::TransactionFeePaid {
who,
actual_fee,
Expand Down
69 changes: 55 additions & 14 deletions modules/transaction-payment/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,12 @@ fn pre_post_dispatch_and_refund_with_fee_currency_call(token: CurrencyId, surplu
assert_eq!(pre.2, Some(pallet_balances::NegativeImbalance::new(fee_surplus)));
assert_eq!(pre.3, fee_surplus);

// with_fee_currency will set OverrideChargeFeeWay when pre_dispatch
assert_eq!(
OverrideChargeFeeWay::<Runtime>::get(),
Some(ChargeFeeWay::FeeCurrency(token))
);

let token_transfer = token_rate.saturating_mul_int(fee_surplus);
System::assert_has_event(crate::mock::RuntimeEvent::Tokens(orml_tokens::Event::Transfer {
currency_id: token,
Expand Down Expand Up @@ -422,6 +428,9 @@ fn pre_post_dispatch_and_refund_with_fee_currency_call(token: CurrencyId, surplu
&Ok(())
));

// always clear OverrideChargeFeeWay when post_dispatch
assert_eq!(OverrideChargeFeeWay::<Runtime>::get(), None);

let refund = 200; // 1000 - 800
let refund_surplus = surplus_percent.mul_ceil(refund);
let actual_surplus = surplus - refund_surplus;
Expand Down Expand Up @@ -465,6 +474,13 @@ fn pre_post_dispatch_and_refund_with_fee_currency_call(token: CurrencyId, surplu
.unwrap();
assert_eq!(pre.2, Some(pallet_balances::NegativeImbalance::new(fee_surplus)));
assert_eq!(pre.3, fee_surplus);

// with_fee_currency will set OverrideChargeFeeWay when pre_dispatch
assert_eq!(
OverrideChargeFeeWay::<Runtime>::get(),
Some(ChargeFeeWay::FeeCurrency(token))
);

System::assert_has_event(crate::mock::RuntimeEvent::Tokens(orml_tokens::Event::Transfer {
currency_id: token,
from: CHARLIE,
Expand All @@ -489,6 +505,10 @@ fn pre_post_dispatch_and_refund_with_fee_currency_call(token: CurrencyId, surplu
500,
&Ok(())
));

// always clear OverrideChargeFeeWay when post_dispatch
assert_eq!(OverrideChargeFeeWay::<Runtime>::get(), None);

assert_eq!(
Currencies::free_balance(ACA, &CHARLIE),
aca_init + refund + refund_surplus
Expand Down Expand Up @@ -2133,19 +2153,27 @@ fn with_fee_call_validation_works() {
.one_hundred_thousand_for_alice_n_charlie()
.build()
.execute_with(|| {
assert_eq!(OverrideChargeFeeWay::<Runtime>::get(), None);
// dex swap not enabled, validate failed.
// with_fee_currency test
for token in vec![DOT, AUSD] {
assert_noop!(
assert_eq!(
ChargeTransactionPayment::<Runtime>::from(0).pre_dispatch(
&ALICE,
&with_fee_currency_call(token),
&INFO,
500
),
TransactionValidityError::Invalid(InvalidTransaction::Payment)
Err(TransactionValidityError::Invalid(InvalidTransaction::Payment))
);

// pre_dispatch will set OverrideChargeFeeWay and it's not transactional
assert_eq!(
OverrideChargeFeeWay::<Runtime>::get(),
Some(ChargeFeeWay::FeeCurrency(token))
);
}

assert_ok!(TransactionPayment::with_fee_currency(
RuntimeOrigin::signed(ALICE),
DOT,
Expand All @@ -2156,14 +2184,20 @@ fn with_fee_call_validation_works() {

// with_fee_path test
for path in vec![vec![DOT, AUSD, ACA], vec![AUSD, ACA]] {
assert_noop!(
assert_eq!(
ChargeTransactionPayment::<Runtime>::from(0).pre_dispatch(
&ALICE,
&with_fee_path_call(path),
&with_fee_path_call(path.clone()),
&INFO,
500
),
TransactionValidityError::Invalid(InvalidTransaction::Payment)
Err(TransactionValidityError::Invalid(InvalidTransaction::Payment))
);

// pre_dispatch will set OverrideChargeFeeWay and it's not transactional
assert_eq!(
OverrideChargeFeeWay::<Runtime>::get(),
Some(ChargeFeeWay::FeePath(path))
);
}
assert_ok!(TransactionPayment::with_fee_currency(
Expand All @@ -2176,34 +2210,41 @@ fn with_fee_call_validation_works() {

// with_fee_aggregated_path
let aggregated_path = vec![AggregatedSwapPath::Dex(vec![DOT, AUSD])];
assert_noop!(
assert_eq!(
ChargeTransactionPayment::<Runtime>::from(0).pre_dispatch(
&ALICE,
&with_fee_aggregated_path_by_call(aggregated_path),
&with_fee_aggregated_path_by_call(aggregated_path.clone()),
&INFO,
500
),
TransactionValidityError::Invalid(InvalidTransaction::Payment)
Err(TransactionValidityError::Invalid(InvalidTransaction::Payment))
);

let aggregated_path = vec![AggregatedSwapPath::Dex(vec![DOT, ACA])];
assert_noop!(
assert_eq!(
ChargeTransactionPayment::<Runtime>::from(0).pre_dispatch(
&ALICE,
&with_fee_aggregated_path_by_call(aggregated_path),
&with_fee_aggregated_path_by_call(aggregated_path.clone()),
&INFO,
500
),
TransactionValidityError::Invalid(InvalidTransaction::Payment)
Err(TransactionValidityError::Invalid(InvalidTransaction::Payment))
);
// pre_dispatch will set OverrideChargeFeeWay and it's not transactional
assert_eq!(
OverrideChargeFeeWay::<Runtime>::get(),
Some(ChargeFeeWay::FeeAggregatedPath(aggregated_path))
);

let aggregated_path = vec![AggregatedSwapPath::Taiga(0, 0, 0)];
assert_noop!(
assert_eq!(
ChargeTransactionPayment::<Runtime>::from(0).pre_dispatch(
&ALICE,
&with_fee_aggregated_path_by_call(aggregated_path),
&with_fee_aggregated_path_by_call(aggregated_path.clone()),
&INFO,
500
),
TransactionValidityError::Invalid(InvalidTransaction::Payment)
Err(TransactionValidityError::Invalid(InvalidTransaction::Payment))
);
});
}
4 changes: 2 additions & 2 deletions runtime/integration-tests/src/payment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -475,14 +475,14 @@ fn with_fee_call_works(
RuntimeEvent::Dex(module_dex::Event::Swap { .. })
)));
// Bob don't have any USD currency.
assert_noop!(
assert_eq!(
<module_transaction_payment::ChargeTransactionPayment::<Runtime>>::from(0).validate(
&AccountId::from(BOB),
&with_fee_currency_call(USD_CURRENCY),
&INFO,
50
),
TransactionValidityError::Invalid(InvalidTransaction::Payment)
Err(TransactionValidityError::Invalid(InvalidTransaction::Payment))
);

// Charlie have USD currency.
Expand Down
Loading

0 comments on commit 18456aa

Please sign in to comment.