From 247b706f941ce68ef407d9865952dc5ac113660b Mon Sep 17 00:00:00 2001 From: j75689 Date: Wed, 29 Nov 2023 15:55:25 +0800 Subject: [PATCH] fix: incorrect behavior for refund swap --- plugins/tokens/plugin.go | 13 ++++++------- plugins/tokens/swap/handler.go | 26 ++++++++++++++------------ 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/plugins/tokens/plugin.go b/plugins/tokens/plugin.go index 53ddd44a7..543931da1 100644 --- a/plugins/tokens/plugin.go +++ b/plugins/tokens/plugin.go @@ -82,20 +82,19 @@ func EndBlocker(ctx sdk.Context, timelockKeeper timelock.Keeper, swapKeeper swap i++ } - iterator = swapKeeper.GetSwapIterator(ctx) - defer iterator.Close() + swapIterator := swapKeeper.GetSwapIterator(ctx) + defer swapIterator.Close() i = 0 - for ; iterator.Valid(); iterator.Next() { + for ; swapIterator.Valid(); swapIterator.Next() { if i >= MaxUnlockItems { break } var automaticSwap swap.AtomicSwap swapKeeper.CDC().MustUnmarshalBinaryBare(iterator.Value(), &automaticSwap) swapID := iterator.Key()[len(swap.HashKey):] - result := swap.HandleClaimHashTimerLockedTransferAfterBCFusion(ctx, swapKeeper, swap.ClaimHTLTMsg{ - From: automaticSwap.From, - SwapID: swapID, - RandomNumber: automaticSwap.RandomNumber, + result := swap.HandleRefundHashTimerLockedTransferAfterBCFusion(ctx, swapKeeper, swap.RefundHTLTMsg{ + From: automaticSwap.From, + SwapID: swapID, }) if !result.IsOK() { logger.Error("Refound error", "swapId", swapID) diff --git a/plugins/tokens/swap/handler.go b/plugins/tokens/swap/handler.go index a07347e5b..4ac933448 100644 --- a/plugins/tokens/swap/handler.go +++ b/plugins/tokens/swap/handler.go @@ -22,9 +22,9 @@ func NewHandler(kp Keeper) sdk.Handler { } return handleDepositHashTimerLockedTransfer(ctx, kp, msg) case ClaimHTLTMsg: - return handleClaimHashTimerLockedTransfer(ctx, kp, msg, false) + return handleClaimHashTimerLockedTransfer(ctx, kp, msg) case RefundHTLTMsg: - return handleRefundHashTimerLockedTransfer(ctx, kp, msg) + return handleRefundHashTimerLockedTransfer(ctx, kp, msg, false) default: errMsg := fmt.Sprintf("unrecognized message type: %T", msg) return sdk.ErrUnknownRequest(errMsg).Result() @@ -107,11 +107,7 @@ func handleDepositHashTimerLockedTransfer(ctx sdk.Context, kp Keeper, msg Deposi } -func HandleClaimHashTimerLockedTransferAfterBCFusion(ctx sdk.Context, kp Keeper, msg ClaimHTLTMsg) sdk.Result { - return handleClaimHashTimerLockedTransfer(ctx, kp, msg, true) -} - -func handleClaimHashTimerLockedTransfer(ctx sdk.Context, kp Keeper, msg ClaimHTLTMsg, isBCFusionRefund bool) sdk.Result { +func handleClaimHashTimerLockedTransfer(ctx sdk.Context, kp Keeper, msg ClaimHTLTMsg) sdk.Result { swap := kp.GetSwap(ctx, msg.SwapID) if swap == nil { return ErrNonExistSwapID(fmt.Sprintf("No matched swap with swapID %v", msg.SwapID)).Result() @@ -119,11 +115,11 @@ func handleClaimHashTimerLockedTransfer(ctx sdk.Context, kp Keeper, msg ClaimHTL if swap.Status != Open { return ErrUnexpectedSwapStatus(fmt.Sprintf("Expected swap status is Open, actually it is %s", swap.Status.String())).Result() } - if !isBCFusionRefund && swap.ExpireHeight <= ctx.BlockHeight() { + if swap.ExpireHeight <= ctx.BlockHeight() { return ErrClaimExpiredSwap(fmt.Sprintf("Current block height is %d, the swap expire height(%d) is passed", ctx.BlockHeight(), swap.ExpireHeight)).Result() } - if !isBCFusionRefund && !bytes.Equal(CalculateRandomHash(msg.RandomNumber, swap.Timestamp), swap.RandomNumberHash) { + if !bytes.Equal(CalculateRandomHash(msg.RandomNumber, swap.Timestamp), swap.RandomNumberHash) { return ErrMismatchedRandomNumber("Mismatched random number").Result() } @@ -168,7 +164,11 @@ func handleClaimHashTimerLockedTransfer(ctx sdk.Context, kp Keeper, msg ClaimHTL return sdk.Result{Tags: tags} } -func handleRefundHashTimerLockedTransfer(ctx sdk.Context, kp Keeper, msg RefundHTLTMsg) sdk.Result { +func HandleRefundHashTimerLockedTransferAfterBCFusion(ctx sdk.Context, kp Keeper, msg RefundHTLTMsg) sdk.Result { + return handleRefundHashTimerLockedTransfer(ctx, kp, msg, true) +} + +func handleRefundHashTimerLockedTransfer(ctx sdk.Context, kp Keeper, msg RefundHTLTMsg, isBCFusionRefund bool) sdk.Result { swap := kp.GetSwap(ctx, msg.SwapID) if swap == nil { return ErrNonExistSwapID(fmt.Sprintf("No matched swap with swapID %v", msg.SwapID)).Result() @@ -176,8 +176,10 @@ func handleRefundHashTimerLockedTransfer(ctx sdk.Context, kp Keeper, msg RefundH if swap.Status != Open { return ErrUnexpectedSwapStatus(fmt.Sprintf("Expected swap status is Open, actually it is %s", swap.Status.String())).Result() } - if ctx.BlockHeight() < swap.ExpireHeight { - return ErrRefundUnexpiredSwap(fmt.Sprintf("Current block height is %d, the expire height (%d) is still not reached", ctx.BlockHeight(), swap.ExpireHeight)).Result() + if !isBCFusionRefund { + if ctx.BlockHeight() < swap.ExpireHeight { + return ErrRefundUnexpiredSwap(fmt.Sprintf("Current block height is %d, the expire height (%d) is still not reached", ctx.BlockHeight(), swap.ExpireHeight)).Result() + } } tags := sdk.EmptyTags()