Skip to content

Commit

Permalink
[EVM] prevent duplicate txs from getting inserted (#196)
Browse files Browse the repository at this point in the history
* prevent duplicates in mempool

* use timestamp in priority queue
  • Loading branch information
stevenlanders authored and udpatil committed Apr 16, 2024
1 parent cbf2ae7 commit 7295161
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 32 deletions.
5 changes: 2 additions & 3 deletions internal/consensus/mempool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func checkTxsRange(ctx context.Context, t *testing.T, cs *State, start, end int)
var rCode uint32
err := assertMempool(t, cs.txNotifier).CheckTx(ctx, txBytes, func(r *abci.ResponseCheckTx) { rCode = r.Code }, mempool.TxInfo{})
require.NoError(t, err, "error after checkTx")
require.Equal(t, code.CodeTypeOK, rCode, "checkTx code is error, txBytes %X", txBytes)
require.Equal(t, code.CodeTypeOK, rCode, "checkTx code is error, txBytes %X, index=%d", txBytes, i)
}
}

Expand All @@ -166,7 +166,7 @@ func TestMempoolTxConcurrentWithCommit(t *testing.T) {
require.NoError(t, err)
newBlockHeaderCh := subscribe(ctx, t, cs.eventBus, types.EventQueryNewBlockHeader)

const numTxs int64 = 100
const numTxs int64 = 50
go checkTxsRange(ctx, t, cs, 0, int(numTxs))

startTestRound(ctx, cs, cs.roundState.Height(), cs.roundState.Round())
Expand Down Expand Up @@ -331,7 +331,6 @@ func txAsUint64(tx []byte) uint64 {
func (app *CounterApplication) Commit(context.Context) (*abci.ResponseCommit, error) {
app.mu.Lock()
defer app.mu.Unlock()

app.mempoolTxCount = app.txCount
return &abci.ResponseCommit{}, nil
}
Expand Down
32 changes: 22 additions & 10 deletions internal/mempool/mempool.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,11 @@ func (txmp *TxMempool) CheckTx(
return nil
}

func (txmp *TxMempool) isInMempool(tx types.Tx) bool {
existingTx := txmp.txStore.GetTxByHash(tx.Key())
return existingTx != nil && !existingTx.removed
}

func (txmp *TxMempool) RemoveTxByKey(txKey types.TxKey) error {
txmp.Lock()
defer txmp.Unlock()
Expand Down Expand Up @@ -635,15 +640,17 @@ func (txmp *TxMempool) addNewTransaction(wtx *WrappedTx, res *abci.ResponseCheck
txmp.metrics.Size.Set(float64(txmp.Size()))
txmp.metrics.PendingSize.Set(float64(txmp.PendingSize()))

txmp.insertTx(wtx)
txmp.logger.Debug(
"inserted good transaction",
"priority", wtx.priority,
"tx", fmt.Sprintf("%X", wtx.tx.Hash()),
"height", txmp.height,
"num_txs", txmp.Size(),
)
txmp.notifyTxsAvailable()
if txmp.insertTx(wtx) {
txmp.logger.Debug(
"inserted good transaction",
"priority", wtx.priority,
"tx", fmt.Sprintf("%X", wtx.tx.Hash()),
"height", txmp.height,
"num_txs", txmp.Size(),
)
txmp.notifyTxsAvailable()
}

return nil
}

Expand Down Expand Up @@ -809,7 +816,11 @@ func (txmp *TxMempool) canAddTx(wtx *WrappedTx) error {
return nil
}

func (txmp *TxMempool) insertTx(wtx *WrappedTx) {
func (txmp *TxMempool) insertTx(wtx *WrappedTx) bool {
if txmp.isInMempool(wtx.tx) {
return false
}

txmp.txStore.SetTx(wtx)
txmp.priorityIndex.PushTx(wtx)
txmp.heightIndex.Insert(wtx)
Expand All @@ -822,6 +833,7 @@ func (txmp *TxMempool) insertTx(wtx *WrappedTx) {
wtx.gossipEl = gossipEl

atomic.AddInt64(&txmp.sizeBytes, int64(wtx.Size()))
return true
}

func (txmp *TxMempool) removeTx(wtx *WrappedTx, removeFromCache bool) {
Expand Down
90 changes: 71 additions & 19 deletions internal/mempool/priority_queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func binarySearch(queue []*WrappedTx, tx *WrappedTx) int {
low, high := 0, len(queue)
for low < high {
mid := low + (high-low)/2
if queue[mid].evmNonce <= tx.evmNonce {
if queue[mid].IsBefore(tx) {
low = mid + 1
} else {
high = mid
Expand Down Expand Up @@ -118,11 +118,6 @@ func (pq *TxPriorityQueue) removeQueuedEvmTxUnsafe(tx *WrappedTx) {
pq.evmQueue[tx.evmAddress] = append(queue[:i], queue[i+1:]...)
if len(pq.evmQueue[tx.evmAddress]) == 0 {
delete(pq.evmQueue, tx.evmAddress)
} else {
// only if removing the first item, then push next onto queue
if i == 0 {
heap.Push(pq, pq.evmQueue[tx.evmAddress][0])
}
}
break
}
Expand All @@ -132,7 +127,7 @@ func (pq *TxPriorityQueue) removeQueuedEvmTxUnsafe(tx *WrappedTx) {

func (pq *TxPriorityQueue) findTxIndexUnsafe(tx *WrappedTx) (int, bool) {
for i, t := range pq.txs {
if t == tx {
if t.tx.Key() == tx.tx.Key() {
return i, true
}
}
Expand All @@ -146,9 +141,13 @@ func (pq *TxPriorityQueue) RemoveTx(tx *WrappedTx) {

if idx, ok := pq.findTxIndexUnsafe(tx); ok {
heap.Remove(pq, idx)
}

if tx.isEVM {
if tx.isEVM {
pq.removeQueuedEvmTxUnsafe(tx)
if len(pq.evmQueue[tx.evmAddress]) > 0 {
heap.Push(pq, pq.evmQueue[tx.evmAddress][0])
}
}
} else if tx.isEVM {
pq.removeQueuedEvmTxUnsafe(tx)
}
}
Expand All @@ -159,36 +158,53 @@ func (pq *TxPriorityQueue) pushTxUnsafe(tx *WrappedTx) {
return
}

// if there aren't other waiting txs, init and return
queue, exists := pq.evmQueue[tx.evmAddress]
if !exists {
pq.evmQueue[tx.evmAddress] = []*WrappedTx{tx}
heap.Push(pq, tx)
return
}

// this item is on the heap at the moment
first := queue[0]
if tx.evmNonce < first.evmNonce {

// the queue's first item (and ONLY the first item) must be on the heap
// if this tx is before the first item, then we need to remove the first
// item from the heap
if tx.IsBefore(first) {
if idx, ok := pq.findTxIndexUnsafe(first); ok {
heap.Remove(pq, idx)
}
heap.Push(pq, tx)
}

pq.evmQueue[tx.evmAddress] = insertToEVMQueue(queue, tx, binarySearch(queue, tx))

}

// These are available if we need to test the invariant checks
// these can be used to troubleshoot invariant violations
//func (pq *TxPriorityQueue) checkInvariants(msg string) {
//
// uniqHashes := make(map[string]bool)
// for _, tx := range pq.txs {
// for idx, tx := range pq.txs {
// if tx == nil {
// pq.print()
// panic(fmt.Sprintf("DEBUG PRINT: found nil item on heap: idx=%d\n", idx))
// }
// if tx.tx == nil {
// pq.print()
// panic(fmt.Sprintf("DEBUG PRINT: found nil tx.tx on heap: idx=%d\n", idx))
// }
// if _, ok := uniqHashes[fmt.Sprintf("%x", tx.tx.Key())]; ok {
// pq.print()
// panic(fmt.Sprintf("INVARIANT (%s): duplicate hash=%x in heap", msg, tx.tx.Key()))
// }
// uniqHashes[fmt.Sprintf("%x", tx.tx.Key())] = true
//
// //if _, ok := pq.keys[tx.tx.Key()]; !ok {
// // pq.print()
// // panic(fmt.Sprintf("INVARIANT (%s): tx in heap but not in keys hash=%x", msg, tx.tx.Key()))
// //}
//
// if tx.isEVM {
// if queue, ok := pq.evmQueue[tx.evmAddress]; ok {
// if queue[0].tx.Key() != tx.tx.Key() {
Expand All @@ -213,6 +229,10 @@ func (pq *TxPriorityQueue) pushTxUnsafe(tx *WrappedTx) {
// panic(fmt.Sprintf("INVARIANT (%s): did not find tx[0] hash=%x nonce=%d in heap", msg, tx.tx.Key(), tx.evmNonce))
// }
// }
// //if _, ok := pq.keys[tx.tx.Key()]; !ok {
// // pq.print()
// // panic(fmt.Sprintf("INVARIANT (%s): tx in heap but not in keys hash=%x", msg, tx.tx.Key()))
// //}
// if _, ok := hashes[fmt.Sprintf("%x", tx.tx.Key())]; ok {
// pq.print()
// panic(fmt.Sprintf("INVARIANT (%s): duplicate hash=%x in queue nonce=%d", msg, tx.tx.Key(), tx.evmNonce))
Expand All @@ -224,13 +244,31 @@ func (pq *TxPriorityQueue) pushTxUnsafe(tx *WrappedTx) {

// for debugging situations where invariant violations occur
//func (pq *TxPriorityQueue) print() {
// fmt.Println("PRINT PRIORITY QUEUE ****************** ")
// for _, tx := range pq.txs {
// fmt.Printf("DEBUG PRINT: heap: nonce=%d, hash=%x\n", tx.evmNonce, tx.tx.Key())
// if tx == nil {
// fmt.Printf("DEBUG PRINT: heap (nil): nonce=?, hash=?\n")
// continue
// }
// if tx.tx == nil {
// fmt.Printf("DEBUG PRINT: heap (%s): nonce=%d, tx.tx is nil \n", tx.evmAddress, tx.evmNonce)
// continue
// }
// fmt.Printf("DEBUG PRINT: heap (%s): nonce=%d, hash=%x, time=%d\n", tx.evmAddress, tx.evmNonce, tx.tx.Key(), tx.timestamp.UnixNano())
// }
//
// for _, queue := range pq.evmQueue {
// for addr, queue := range pq.evmQueue {
// for idx, tx := range queue {
// fmt.Printf("DEBUG PRINT: evmQueue[%d]: nonce=%d, hash=%x\n", idx, tx.evmNonce, tx.tx.Key())
// if tx == nil {
// fmt.Printf("DEBUG PRINT: found nil item on evmQueue(%s): idx=%d\n", addr, idx)
// continue
// }
// if tx.tx == nil {
// fmt.Printf("DEBUG PRINT: found nil tx.tx on evmQueue(%s): idx=%d\n", addr, idx)
// continue
// }
//
// fmt.Printf("DEBUG PRINT: evmQueue(%s)[%d]: nonce=%d, hash=%x, time=%d\n", tx.evmAddress, idx, tx.evmNonce, tx.tx.Key(), tx.timestamp.UnixNano())
// }
// }
//}
Expand All @@ -239,33 +277,47 @@ func (pq *TxPriorityQueue) pushTxUnsafe(tx *WrappedTx) {
func (pq *TxPriorityQueue) PushTx(tx *WrappedTx) {
pq.mtx.Lock()
defer pq.mtx.Unlock()

pq.pushTxUnsafe(tx)
}

func (pq *TxPriorityQueue) popTxUnsafe() *WrappedTx {
if len(pq.txs) == 0 {
return nil
}

// remove the first item from the heap
x := heap.Pop(pq)
if x == nil {
return nil
}

tx := x.(*WrappedTx)

// non-evm transactions do not have txs waiting on a nonce
if !tx.isEVM {
return tx
}

// evm transactions can have txs waiting on this nonce
// if there are any, we should replace the heap with the next nonce
// for the address

// remove the first item from the evmQueue
pq.removeQueuedEvmTxUnsafe(tx)

// if there is a next item, now it can be added to the heap
if len(pq.evmQueue[tx.evmAddress]) > 0 {
heap.Push(pq, pq.evmQueue[tx.evmAddress][0])
}

return tx
}

// PopTx removes the top priority transaction from the queue. It is thread safe.
func (pq *TxPriorityQueue) PopTx() *WrappedTx {
pq.mtx.Lock()
defer pq.mtx.Unlock()

return pq.popTxUnsafe()
}

Expand Down
1 change: 1 addition & 0 deletions internal/mempool/priority_queue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ func TestTxPriorityQueue(t *testing.T) {
pq.PushTx(&WrappedTx{
priority: 1000,
timestamp: now,
tx: []byte(fmt.Sprintf("%d", time.Now().UnixNano())),
})
require.Equal(t, 1001, pq.NumTxs())

Expand Down
6 changes: 6 additions & 0 deletions internal/mempool/tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ type WrappedTx struct {
isEVM bool
}

// IsBefore returns true if the WrappedTx is before the given WrappedTx
// this applies to EVM transactions only
func (wtx *WrappedTx) IsBefore(tx *WrappedTx) bool {
return wtx.evmNonce < tx.evmNonce || (wtx.evmNonce == tx.evmNonce && wtx.timestamp.Before(tx.timestamp))
}

func (wtx *WrappedTx) Size() int {
return len(wtx.tx)
}
Expand Down

0 comments on commit 7295161

Please sign in to comment.