Skip to content

Commit

Permalink
Fix clawback sender resync issue (#15853)
Browse files Browse the repository at this point in the history
* Fix clawback sender resync issue

* enhance unit test

* Fix CI

* Add comment

* Fix unit test & revert changes

* Remove prints

* Resolve comments

* Resolve comments

* Remove check
  • Loading branch information
ytx1991 authored Aug 9, 2023
1 parent 0de10d5 commit 17659b8
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 62 deletions.
2 changes: 2 additions & 0 deletions chia/wallet/wallet_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,8 @@ async def _await_closed(self, shutting_down: bool = True) -> None:
await proxy.close()
await asyncio.sleep(0.5) # https://docs.aiohttp.org/en/stable/client_advanced.html#graceful-shutdown
self.wallet_peers = None
self.race_cache = {}
self.race_cache_hashes = []
self._balance_cache = {}

def _set_state_changed_callback(self, callback: StateChangedProtocol) -> None:
Expand Down
19 changes: 18 additions & 1 deletion chia/wallet/wallet_state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1419,6 +1419,8 @@ async def _add_coin_states(
fee = 0

to_puzzle_hash = None
coin_spend: Optional[CoinSpend] = None
clawback_metadata: Optional[ClawbackMetadata] = None
# Find coin that doesn't belong to us
amount = 0
for coin in additions:
Expand All @@ -1428,6 +1430,18 @@ async def _add_coin_states(
if derivation_record is None: # not change
to_puzzle_hash = coin.puzzle_hash
amount += coin.amount
if coin_spend is None:
# To prevent unnecessary fetch, we only fetch once,
# if there is a child coin that is not owned by the wallet.
coin_spend = await fetch_coin_spend_for_coin_state(coin_state, peer)
# Check if the parent coin is a Clawback coin
puzzle: Program = coin_spend.puzzle_reveal.to_program()
solution: Program = coin_spend.solution.to_program()
uncurried = uncurry_puzzle(puzzle)
clawback_metadata = match_clawback_puzzle(uncurried, puzzle, solution)
if clawback_metadata is not None:
# Add the Clawback coin as the interested coin for the sender
await self.add_interested_coin_ids([coin.name()])
elif wallet_identifier.type == WalletType.CAT:
# We subscribe to change for CATs since they didn't hint previously
await self.add_interested_coin_ids([coin.name()])
Expand Down Expand Up @@ -1746,8 +1760,11 @@ async def coin_added(

parent_coin_record: Optional[WalletCoinRecord] = await self.coin_store.get_coin_record(coin.parent_coin_info)
change = parent_coin_record is not None and wallet_type.value == parent_coin_record.wallet_type
# If the coin is from a Clawback spent, we want to add the INCOMING_TX,
# no matter if there is another TX updated.
clawback = parent_coin_record is not None and parent_coin_record.coin_type == CoinType.CLAWBACK

if coinbase or not coin_confirmed_transaction and not change:
if coinbase or clawback or not coin_confirmed_transaction and not change:
tx_record = TransactionRecord(
confirmed_at_height=uint32(height),
created_at_time=await self.wallet_node.get_timestamp_for_height(height),
Expand Down
213 changes: 152 additions & 61 deletions tests/wallet/test_wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from chia.types.coin_spend import compute_additions
from chia.types.peer_info import PeerInfo
from chia.util.bech32m import encode_puzzle_hash
from chia.util.config import load_config
from chia.util.ints import uint16, uint32, uint64
from chia.wallet.derive_keys import master_sk_to_wallet_sk
from chia.wallet.payment import Payment
Expand Down Expand Up @@ -762,95 +761,187 @@ async def test_clawback_resync(
num_blocks = 1
full_nodes, wallets, _ = two_wallet_nodes
full_node_api = full_nodes[0]
server_1 = full_node_api.full_node.server
wallet_node, server_2 = wallets[0]
wallet_node_2, server_3 = wallets[1]
wallet = wallet_node.wallet_state_manager.main_wallet
api_0 = WalletRpcApi(wallet_node)
full_node_server = full_node_api.full_node.server
wallet_node_1, wallet_server_1 = wallets[0]
wallet_node_2, wallet_server_2 = wallets[1]
wallet_1 = wallet_node_1.wallet_state_manager.main_wallet
wallet_2 = wallet_node_2.wallet_state_manager.main_wallet
api_1 = WalletRpcApi(wallet_node_1)
if trusted:
wallet_node.config["trusted_peers"] = {server_1.node_id.hex(): server_1.node_id.hex()}
wallet_node_2.config["trusted_peers"] = {server_1.node_id.hex(): server_1.node_id.hex()}
wallet_node_1.config["trusted_peers"] = {full_node_server.node_id.hex(): full_node_server.node_id.hex()}
wallet_node_2.config["trusted_peers"] = {full_node_server.node_id.hex(): full_node_server.node_id.hex()}
else:
wallet_node.config["trusted_peers"] = {}
wallet_node_1.config["trusted_peers"] = {}
wallet_node_2.config["trusted_peers"] = {}

await server_2.start_client(PeerInfo(self_hostname, uint16(server_1._port)), None)
await server_3.start_client(PeerInfo(self_hostname, uint16(server_1._port)), None)
expected_confirmed_balance = await full_node_api.farm_blocks_to_wallet(count=num_blocks, wallet=wallet)
normal_puzhash = await wallet.get_new_puzzlehash()
await wallet_server_1.start_client(PeerInfo(self_hostname, uint16(full_node_server._port)), None)
await wallet_server_2.start_client(PeerInfo(self_hostname, uint16(full_node_server._port)), None)
expected_confirmed_balance = await full_node_api.farm_blocks_to_wallet(count=num_blocks, wallet=wallet_1)
wallet_1_puzhash = await wallet_1.get_new_puzzlehash()
wallet_2_puzhash = await wallet_2.get_new_puzzlehash()

# Transfer to normal wallet
tx = await wallet.generate_signed_transaction(
tx1 = await wallet_1.generate_signed_transaction(
uint64(500),
normal_puzhash,
wallet_2_puzhash,
uint64(0),
puzzle_decorator_override=[{"decorator": "CLAWBACK", "clawback_timelock": 5}],
)
clawback_coin_id = tx.additions[0].name()
assert tx.spend_bundle is not None
await wallet.push_transaction(tx)
await full_node_api.wait_transaction_records_entered_mempool(records=[tx])
expected_confirmed_balance += await full_node_api.farm_blocks_to_wallet(count=num_blocks, wallet=wallet)

clawback_coin_id_1 = tx1.additions[0].name()
assert tx1.spend_bundle is not None
await wallet_1.push_transaction(tx1)
await full_node_api.wait_transaction_records_entered_mempool(records=[tx1])
await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(bytes32(b"\00" * 32)))
# Check merkle coins
await time_out_assert(
20, wallet_node.wallet_state_manager.coin_store.count_small_unspent, 1, 1000, CoinType.CLAWBACK
20, wallet_node_1.wallet_state_manager.coin_store.count_small_unspent, 1, 1000, CoinType.CLAWBACK
)
assert await wallet.get_confirmed_balance() == 3999999999500
await time_out_assert(
20, wallet_node_2.wallet_state_manager.coin_store.count_small_unspent, 1, 1000, CoinType.CLAWBACK
)
tx2 = await wallet_1.generate_signed_transaction(
uint64(700),
wallet_1_puzhash,
uint64(0),
puzzle_decorator_override=[{"decorator": "CLAWBACK", "clawback_timelock": 5}],
)
clawback_coin_id_2 = tx2.additions[0].name()
assert tx2.spend_bundle is not None
await wallet_1.push_transaction(tx2)
await full_node_api.wait_transaction_records_entered_mempool(records=[tx2])
await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(bytes32(b"\00" * 32)))
# Check merkle coins
await time_out_assert(
20, wallet_node_1.wallet_state_manager.coin_store.count_small_unspent, 2, 1000, CoinType.CLAWBACK
)
await time_out_assert(
20, wallet_node_2.wallet_state_manager.coin_store.count_small_unspent, 1, 1000, CoinType.CLAWBACK
)
assert await wallet_1.get_confirmed_balance() == 1999999998800
assert await wallet_2.get_confirmed_balance() == 0
await asyncio.sleep(10)
# clawback merkle coin
resp = await api_0.spend_clawback_coins(
dict({"coin_ids": [normal_puzhash.hex(), clawback_coin_id.hex()], "fee": 0})
)
resp = await api_1.spend_clawback_coins(dict({"coin_ids": [clawback_coin_id_1.hex()], "fee": 0}))
json.dumps(resp)
assert resp["success"]
assert len(resp["transaction_ids"]) == 1
expected_confirmed_balance += await full_node_api.farm_blocks_to_wallet(count=num_blocks, wallet=wallet)
resp = await api_1.spend_clawback_coins(dict({"coin_ids": [clawback_coin_id_2.hex()], "fee": 0}))
assert resp["success"]
assert len(resp["transaction_ids"]) == 1
expected_confirmed_balance += await full_node_api.farm_blocks_to_wallet(count=num_blocks, wallet=wallet_1)
await time_out_assert(
20, wallet_node.wallet_state_manager.coin_store.count_small_unspent, 0, 1000, CoinType.CLAWBACK
20, wallet_node_1.wallet_state_manager.coin_store.count_small_unspent, 0, 1000, CoinType.CLAWBACK
)
await time_out_assert(
20, wallet_node_2.wallet_state_manager.coin_store.count_small_unspent, 0, 1000, CoinType.CLAWBACK
)

assert len(await wallet_node_1.wallet_state_manager.coin_store.get_all_unspent_coins()) == 6
assert len(await wallet_node_2.wallet_state_manager.coin_store.get_all_unspent_coins()) == 0
before_txs: Dict[str, Dict[TransactionType, int]] = {"sender": {}, "recipient": {}}
before_txs["sender"][
TransactionType.INCOMING_CLAWBACK_SEND
] = await wallet_node_1.wallet_state_manager.tx_store.get_transaction_count_for_wallet(
1, type_filter=TransactionTypeFilter.include([TransactionType.INCOMING_CLAWBACK_SEND])
)
before_txs["sender"][
TransactionType.OUTGOING_CLAWBACK
] = await wallet_node_1.wallet_state_manager.tx_store.get_transaction_count_for_wallet(
1, type_filter=TransactionTypeFilter.include([TransactionType.OUTGOING_CLAWBACK])
)
before_txs["sender"][
TransactionType.OUTGOING_TX
] = await wallet_node_1.wallet_state_manager.tx_store.get_transaction_count_for_wallet(
1, type_filter=TransactionTypeFilter.include([TransactionType.OUTGOING_TX])
)
before_txs["sender"][
TransactionType.INCOMING_TX
] = await wallet_node_1.wallet_state_manager.tx_store.get_transaction_count_for_wallet(
1, type_filter=TransactionTypeFilter.include([TransactionType.INCOMING_TX])
)
before_txs["sender"][
TransactionType.COINBASE_REWARD
] = await wallet_node_1.wallet_state_manager.tx_store.get_transaction_count_for_wallet(
1, type_filter=TransactionTypeFilter.include([TransactionType.COINBASE_REWARD])
)
before_txs["recipient"][
TransactionType.INCOMING_CLAWBACK_RECEIVE
] = await wallet_node_2.wallet_state_manager.tx_store.get_transaction_count_for_wallet(
1, type_filter=TransactionTypeFilter.include([TransactionType.INCOMING_CLAWBACK_RECEIVE])
)
# Resync start
wallet_node_1._close()
await wallet_node_1._await_closed()
wallet_node_2._close()
await wallet_node_2._await_closed()
# set flag to reset wallet sync data on start
await api_0.set_wallet_resync_on_startup({"enable": True})
fingerprint = wallet_node.logged_in_fingerprint
assert wallet_node._wallet_state_manager
# 2 reward coins, 1 clawbacked coin
assert len(await wallet_node._wallet_state_manager.coin_store.get_all_unspent_coins()) == 7
# standard wallet
assert len(await wallet_node.wallet_state_manager.user_store.get_all_wallet_info_entries()) == 1
before_txs = await wallet_node.wallet_state_manager.tx_store.get_all_transactions()
# Delete tx records
await wallet_node.wallet_state_manager.tx_store.rollback_to_block(0)
wallet_node._close()
await wallet_node._await_closed()
config = load_config(wallet_node.root_path, "config.yaml")
# check that flag was set in config file
assert config["wallet"]["reset_sync_for_fingerprint"] == fingerprint
new_config = wallet_node.config.copy()
new_config["reset_sync_for_fingerprint"] = config["wallet"]["reset_sync_for_fingerprint"]
new_config["database_path"] = "wallet/db/blockchain_wallet_v2_test_CHALLENGE_KEY.sqlite"
wallet_node_2.config = new_config
wallet_node_2.root_path = wallet_node.root_path
wallet_node_2.local_keychain = wallet_node.local_keychain
wallet_node_1.config["database_path"] = "wallet/db/blockchain_wallet_v2_test1_CHALLENGE_KEY.sqlite"
wallet_node_2.config["database_path"] = "wallet/db/blockchain_wallet_v2_test2_CHALLENGE_KEY.sqlite"

# use second node to start the same wallet, reusing config and db
await wallet_node_2._start_with_fingerprint(fingerprint)
assert wallet_node_2._wallet_state_manager
await server_3.start_client(PeerInfo(self_hostname, uint16(server_1._port)), None)
await wallet_node_1._start()
await wallet_server_1.start_client(PeerInfo(self_hostname, uint16(full_node_server._port)), None)
await wallet_node_2._start()
await wallet_server_2.start_client(PeerInfo(self_hostname, uint16(full_node_server._port)), None)
await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(bytes32(b"\00" * 32)))
await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node_1, timeout=20)
await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node_2, timeout=20)
after_txs = await wallet_node_2.wallet_state_manager.tx_store.get_all_transactions()
# transactions should be the same
assert len(after_txs) == len(before_txs)
after_txs: Dict[str, Dict[TransactionType, int]] = {"sender": {}, "recipient": {}}
after_txs["sender"][
TransactionType.INCOMING_CLAWBACK_SEND
] = await wallet_node_1.wallet_state_manager.tx_store.get_transaction_count_for_wallet(
1, type_filter=TransactionTypeFilter.include([TransactionType.INCOMING_CLAWBACK_SEND])
)
after_txs["sender"][
TransactionType.OUTGOING_CLAWBACK
] = await wallet_node_1.wallet_state_manager.tx_store.get_transaction_count_for_wallet(
1, type_filter=TransactionTypeFilter.include([TransactionType.OUTGOING_CLAWBACK])
)
after_txs["sender"][
TransactionType.OUTGOING_TX
] = await wallet_node_1.wallet_state_manager.tx_store.get_transaction_count_for_wallet(
1, type_filter=TransactionTypeFilter.include([TransactionType.OUTGOING_TX])
)
after_txs["sender"][
TransactionType.INCOMING_TX
] = await wallet_node_1.wallet_state_manager.tx_store.get_transaction_count_for_wallet(
1, type_filter=TransactionTypeFilter.include([TransactionType.INCOMING_TX])
)
after_txs["sender"][
TransactionType.COINBASE_REWARD
] = await wallet_node_1.wallet_state_manager.tx_store.get_transaction_count_for_wallet(
1, type_filter=TransactionTypeFilter.include([TransactionType.COINBASE_REWARD])
)
after_txs["recipient"][
TransactionType.INCOMING_CLAWBACK_RECEIVE
] = await wallet_node_2.wallet_state_manager.tx_store.get_transaction_count_for_wallet(
1, type_filter=TransactionTypeFilter.include([TransactionType.INCOMING_CLAWBACK_RECEIVE])
)
# Check clawback
clawback_tx = await wallet_node_2.wallet_state_manager.tx_store.get_transaction_record(clawback_coin_id)
assert clawback_tx is not None
assert clawback_tx.confirmed
outgoing_clawback_txs = await wallet_node_2.wallet_state_manager.tx_store.get_transactions_between(
clawback_tx_1 = await wallet_node_1.wallet_state_manager.tx_store.get_transaction_record(clawback_coin_id_1)
clawback_tx_2 = await wallet_node_1.wallet_state_manager.tx_store.get_transaction_record(clawback_coin_id_2)
assert clawback_tx_1 is not None
assert clawback_tx_1.confirmed
assert clawback_tx_2 is not None
assert clawback_tx_2.confirmed
outgoing_clawback_txs = await wallet_node_1.wallet_state_manager.tx_store.get_transactions_between(
1, 0, 100, type_filter=TransactionTypeFilter.include([TransactionType.OUTGOING_CLAWBACK])
)
assert len(outgoing_clawback_txs) == 1
assert len(outgoing_clawback_txs) == 2
assert outgoing_clawback_txs[0].confirmed
assert outgoing_clawback_txs[1].confirmed

# transactions should be the same

assert (
before_txs["sender"][TransactionType.OUTGOING_CLAWBACK]
== after_txs["sender"][TransactionType.OUTGOING_CLAWBACK]
)
assert before_txs["sender"] == after_txs["sender"]
assert before_txs["recipient"] == after_txs["recipient"]

# Check unspent coins
assert len(await wallet_node_2._wallet_state_manager.coin_store.get_all_unspent_coins()) == 7
assert len(await wallet_node_1.wallet_state_manager.coin_store.get_all_unspent_coins()) == 6

@pytest.mark.parametrize(
"trusted",
Expand Down

0 comments on commit 17659b8

Please sign in to comment.