Skip to content

Commit

Permalink
Merge pull request #1244 from egbertbouman/tunnel_refactoring_pt1
Browse files Browse the repository at this point in the history
Added direction argument to encrypt_str/decrypto_str
  • Loading branch information
egbertbouman authored Nov 14, 2023
2 parents 138f1dd + 67bc083 commit 3ea9d89
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 131 deletions.
6 changes: 3 additions & 3 deletions ipv8/dht/community.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def on_ping_response(self, peer: Peer, payload: PingResponsePayload, data: bytes
When receive a response to our ping, update the node's metrics.
"""
if not self.request_cache.has('ping', payload.identifier):
self.logger.error('Got ping-response with unknown identifier, dropping packet')
self.logger.warning('Got ping-response with unknown identifier, dropping packet')
return

self.logger.debug('Got ping-response from %s', peer.address)
Expand Down Expand Up @@ -550,7 +550,7 @@ def on_store_response(self, peer: Peer, payload: StoreResponsePayload) -> None:
We got confirmation of storage.
"""
if not self.request_cache.has('store', payload.identifier):
self.logger.error('Got store-response with unknown identifier, dropping packet')
self.logger.warning('Got store-response with unknown identifier, dropping packet')
return

self.logger.debug('Got store-response from %s', peer.address)
Expand Down Expand Up @@ -691,7 +691,7 @@ def on_find_response(self, peer: Peer, payload: FindResponsePayload) -> None:
We got a response for our find requests.
"""
if not self.request_cache.has('find', payload.identifier):
self.logger.error('Got find-response with unknown identifier, dropping packet')
self.logger.warning('Got find-response with unknown identifier, dropping packet')
return

self.logger.debug('Got find-response from %s', peer.address)
Expand Down
4 changes: 2 additions & 2 deletions ipv8/dht/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def on_store_peer_response(self, peer: Peer, payload: StorePeerResponsePayload)
When a peer signals storage is complete, pop it from our cache.
"""
if not self.request_cache.has('store-peer', payload.identifier):
self.logger.error('Got store-peer-response with unknown identifier, dropping packet')
self.logger.warning('Got store-peer-response with unknown identifier, dropping packet')
return

self.logger.debug('Got store-peer-response from %s', peer.address)
Expand Down Expand Up @@ -228,7 +228,7 @@ def on_connect_peer_response(self, peer: Peer, payload: ConnectPeerResponsePaylo
Handle responses of peers that performed punctures for us.
"""
if not self.request_cache.has('connect-peer', payload.identifier):
self.logger.error('Got connect-peer-response with unknown identifier, dropping packet')
self.logger.warning('Got connect-peer-response with unknown identifier, dropping packet')
return

self.logger.debug('Got connect-peer-response from %s', peer.address)
Expand Down
84 changes: 37 additions & 47 deletions ipv8/messaging/anonymization/community.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from ...community import Community, CommunitySettings
from ...keyvault.private.libnaclkey import LibNaCLSK
from ...lazy_community import lazy_wrapper
from ...messaging.payload_headers import BinMemberAuthenticationPayload, GlobalTimeDistributionPayload
from ...requestcache import RequestCache
from ...taskmanager import task
from ...types import Address
Expand All @@ -30,6 +29,7 @@
from collections.abc import Collection

from ...dht.provider import DHTCommunityProvider
from ...messaging.payload_headers import GlobalTimeDistributionPayload
from ..lazy_payload import VariablePayloadWID
from ..payload import (
IntroductionRequestPayload,
Expand Down Expand Up @@ -147,7 +147,6 @@ def __init__(self, settings: TunnelSettings) -> None:
self.add_cell_handler(TestResponsePayload, self.on_test_response)

self.circuits: dict[int, Circuit] = {}
self.directions: dict[int, int | None] = {}
self.relay_from_to: dict[int, RelayRoute] = {}
self.relay_session_keys: dict[int, SessionKeys] = {}
self.exit_sockets: dict[int, TunnelExitSocket] = {}
Expand Down Expand Up @@ -359,21 +358,21 @@ def create_circuit(self, goal_hops: int, ctype: str = CIRCUIT_TYPE_DATA,

return circuit

def send_initial_create(self, circuit: Circuit, candidate_list: list[Peer], max_tries: int) -> None:
def send_initial_create(self, circuit: Circuit, candidate_peers: list[Peer], max_tries: int) -> None:
"""
Attempt to establish the first hop in a Circuit.
"""
if self.request_cache.has(RetryRequestCache, circuit.circuit_id):
self.request_cache.pop(RetryRequestCache, circuit.circuit_id)
self.logger.info("Retrying first hop for circuit %d", circuit.circuit_id)

first_hop = random.choice(candidate_list)
alt_first_hops = [c for c in candidate_list if c != first_hop]
first_hop = random.choice(candidate_peers)
alt_first_hops = [c for c in candidate_peers if c != first_hop]

circuit.unverified_hop = Hop(first_hop, flags=self.candidates.get(first_hop))
circuit.unverified_hop.dh_secret, circuit.unverified_hop.dh_first_part = self.crypto.generate_diffie_secret()

self.logger.info("Adding first hop %s:%d to circuit %d", *((*first_hop.address, circuit.circuit_id)))
self.logger.info("Adding first hop %s:%d to circuit %d", *(*first_hop.address, circuit.circuit_id))

cache = RetryRequestCache(self, circuit, alt_first_hops, max_tries - 1,
self.send_initial_create, self.settings.next_hop_timeout)
Expand Down Expand Up @@ -412,9 +411,6 @@ async def remove_circuit(self, circuit_id: int, additional_info: str = '', remov
if circuit:
self.logger.info("Removed circuit %d %s", circuit_id, additional_info)

# Clean up the directions dictionary
self.directions.pop(circuit_id, None)

@task
async def remove_relay(self, circuit_id: int, additional_info: str = '', remove_now: bool = False,
destroy: bool = False) -> RelayRoute | None:
Expand All @@ -432,7 +428,6 @@ async def remove_relay(self, circuit_id: int, additional_info: str = '', remove_

relay = self.relay_from_to.pop(circuit_id, None)
self.relay_session_keys.pop(circuit_id, None)
self.directions.pop(circuit_id, None)

return relay

Expand Down Expand Up @@ -538,9 +533,7 @@ def send_destroy(self, peer: Address | Peer, circuit_id: int, reason: int) -> No
"""
Send a destroy message directly to the given peer.
"""
auth = BinMemberAuthenticationPayload(self.my_peer.public_key.key_to_bin())
payload = DestroyPayload(circuit_id, reason)
packet = self._ez_pack(self._prefix, DestroyPayload.msg_id, [auth, payload])
packet = self.ezr_pack(DestroyPayload.msg_id, DestroyPayload(circuit_id, reason))
self.send_packet(peer, packet)

def relay_cell(self, cell: CellPayload) -> None:
Expand All @@ -561,11 +554,12 @@ def relay_cell(self, cell: CellPayload) -> None:
cell.encrypt(self.crypto, relay_session_keys=self.relay_session_keys[next_relay.circuit_id])
cell.relay_early = False
else:
direction = self.directions[cell.circuit_id]
if direction == ORIGINATOR:
cell.encrypt(self.crypto, relay_session_keys=self.relay_session_keys[cell.circuit_id])
elif direction == EXIT_NODE:
direction = next_relay.direction
if direction == FORWARD:
cell.decrypt(self.crypto, relay_session_keys=self.relay_session_keys[cell.circuit_id])
elif direction == BACKWARD:
cell.encrypt(self.crypto, relay_session_keys=self.relay_session_keys[cell.circuit_id])

except CryptoException as e:
self.logger.warning(str(e))
return
Expand All @@ -592,23 +586,21 @@ def _ours_on_created_extended(self, circuit: Circuit, payload: CreatedPayload |
circuit.add_hop(hop)

if circuit.state == CIRCUIT_STATE_EXTENDING:
candidate_list_enc = payload.candidate_list_enc
candidate_list_bin = self.crypto.decrypt_str(candidate_list_enc,
session_keys.key_backward,
session_keys.salt_backward)
candidate_list, _ = self.serializer.unpack('varlenH-list', candidate_list_bin)
candidates_enc = payload.candidates_enc
candidates_bin = self.crypto.decrypt_str(candidates_enc, session_keys, FORWARD)
candidates, _ = self.serializer.unpack('varlenH-list', candidates_bin)

cache = self.request_cache.pop(RetryRequestCache, circuit.circuit_id)
self.send_extend(circuit, cast(List[bytes], candidate_list), cache.max_tries if cache else 1)
self.send_extend(circuit, cast(List[bytes], candidates), cache.max_tries if cache else 1)

elif circuit.state == CIRCUIT_STATE_READY:
self.request_cache.pop(RetryRequestCache, circuit.circuit_id)

def send_extend(self, circuit: Circuit, candidate_list: list[bytes], max_tries: int) -> None:
def send_extend(self, circuit: Circuit, candidates: list[bytes], max_tries: int) -> None:
"""
Extend a circuit by choosing one of the given candidates.
"""
ignore_candidates = [hop.node_public_key for hop in circuit.hops] + [self.my_peer.public_key.key_to_bin()]
ignore_candidates = [hop.public_key_bin for hop in circuit.hops] + [self.my_peer.public_key.key_to_bin()]
if circuit.required_exit:
ignore_candidates.append(circuit.required_exit.public_key.key_to_bin())

Expand All @@ -621,15 +613,15 @@ def send_extend(self, circuit: Circuit, candidate_list: list[bytes], max_tries:
else:
# The next candidate is chosen from the returned list of possible candidates
for ignore_candidate in ignore_candidates:
if ignore_candidate in candidate_list:
candidate_list.remove(ignore_candidate)
if ignore_candidate in candidates:
candidates.remove(ignore_candidate)

for i in range(len(candidate_list) - 1, -1, -1):
public_key = self.crypto.key_from_public_bin(candidate_list[i])
for i in range(len(candidates) - 1, -1, -1):
public_key = self.crypto.key_from_public_bin(candidates[i])
if not self.crypto.is_key_compatible(public_key):
candidate_list.pop(i)
candidates.pop(i)

extend_hop_public_bin = next(iter(candidate_list), b'')
extend_hop_public_bin = next(iter(candidates), b'')
extend_hop_addr = ('0.0.0.0', 0)

if extend_hop_public_bin:
Expand All @@ -646,7 +638,7 @@ def send_extend(self, circuit: Circuit, candidate_list: list[bytes], max_tries:

# Only retry if we are allowed to use another node
if not become_exit or not circuit.required_exit:
alt_candidates = [c for c in candidate_list if c != extend_hop_public_bin]
alt_candidates = [c for c in candidates if c != extend_hop_public_bin]
else:
alt_candidates = []

Expand All @@ -656,7 +648,7 @@ def send_extend(self, circuit: Circuit, candidate_list: list[bytes], max_tries:

self.send_cell(cast(Peer, circuit.peer), ExtendPayload(circuit.circuit_id,
cache.packet_identifier,
circuit.unverified_hop.node_public_key,
circuit.unverified_hop.public_key_bin,
circuit.unverified_hop.dh_first_part,
extend_hop_addr))

Expand Down Expand Up @@ -784,7 +776,6 @@ def join_circuit(self, create_payload: CreatePayload, previous_node_address: Add
"""
circuit_id = create_payload.circuit_id

self.directions[circuit_id] = EXIT_NODE
self.logger.info('We joined circuit %d with neighbour %s', circuit_id, previous_node_address)

shared_secret, key, auth = self.crypto.generate_diffie_shared_secret(create_payload.key)
Expand All @@ -798,11 +789,9 @@ def join_circuit(self, create_payload: CreatePayload, previous_node_address: Add
self.request_cache.add(CreatedRequestCache(self, circuit_id, peer, peers_keys, self.settings.unstable_timeout))
self.exit_sockets[circuit_id] = TunnelExitSocket(circuit_id, peer, self)

candidate_list_bin = self.serializer.pack('varlenH-list', list(peers_keys.keys()))
candidate_list_enc = self.crypto.encrypt_str(candidate_list_bin,
*self.crypto.get_session_keys(self.relay_session_keys[circuit_id],
EXIT_NODE))
self.send_cell(peer, CreatedPayload(circuit_id, create_payload.identifier, key, auth, candidate_list_enc))
candidates_bin = self.serializer.pack('varlenH-list', list(peers_keys.keys()))
candidates_enc = self.crypto.encrypt_str(candidates_bin, self.relay_session_keys[circuit_id], FORWARD)
self.send_cell(peer, CreatedPayload(circuit_id, create_payload.identifier, key, auth, candidates_enc))

@unpack_cell(CreatePayload)
async def on_create(self, source_address: Address, payload: CreatePayload, _: int | None) -> None:
Expand All @@ -828,23 +817,24 @@ def on_created(self, source_address: Address, payload: CreatedPayload, _: int |
Callback for when another peer signals that they have joined our circuit.
"""
circuit_id = payload.circuit_id
self.directions[circuit_id] = ORIGINATOR

if self.request_cache.has(CreateRequestCache, payload.identifier):
request = self.request_cache.pop(CreateRequestCache, payload.identifier)

self.logger.info("Got CREATED message forward as EXTENDED to origin.")

self.relay_from_to[request.to_circuit_id] = relay = RelayRoute(request.from_circuit_id, request.peer)
self.relay_from_to[request.from_circuit_id] = RelayRoute(request.to_circuit_id, request.to_peer)
self.relay_session_keys[request.to_circuit_id] = self.relay_session_keys[request.from_circuit_id]
from_circuit_id = request.from_circuit_id
to_circuit_id = request.to_circuit_id

self.relay_from_to[to_circuit_id] = relay = RelayRoute(from_circuit_id, request.peer, BACKWARD)
self.relay_from_to[from_circuit_id] = RelayRoute(to_circuit_id, request.to_peer, FORWARD)
self.relay_session_keys[to_circuit_id] = self.relay_session_keys[from_circuit_id]

self.directions[request.from_circuit_id] = EXIT_NODE
self.remove_exit_socket(request.from_circuit_id)

self.send_cell(relay.peer,
ExtendedPayload(relay.circuit_id, request.extend_identifier,
payload.key, payload.auth, payload.candidate_list_enc))
payload.key, payload.auth, payload.candidates_enc))
return

cache = self.request_cache.get(RetryRequestCache, circuit_id)
Expand All @@ -854,7 +844,7 @@ def on_created(self, source_address: Address, payload: CreatedPayload, _: int |
circuit = self.circuits[circuit_id]
self._ours_on_created_extended(circuit, payload)
else:
self.logger.warning("Received unexpected created for circuit %d", payload.circuit_id)
self.logger.warning("Received unexpected created for circuit %d", circuit_id)

@unpack_cell(ExtendPayload)
async def on_extend(self, source_address: Address, payload: ExtendPayload, _: int | None) -> None:
Expand Down Expand Up @@ -1128,7 +1118,7 @@ def on_test_response(self, source_address: Address, data: bytes, circuit_id: int
self.logger.error("Dropping test-response with unknown circuit_id")
return
if not self.request_cache.has(TestRequestCache, payload.identifier):
self.logger.error("Dropping unexpected test-response")
self.logger.warning("Dropping unexpected test-response")
return

self.logger.debug("Got test-response (%d) from %s", circuit_id, source_address)
Expand Down
Loading

0 comments on commit 3ea9d89

Please sign in to comment.