Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New ports allocation system #699

Merged
merged 18 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/usr/bin/env python3

#
# Copyright (C) 2024 Nethesis S.r.l.
# SPDX-License-Identifier: GPL-3.0-or-later
#

import agent
import os

try:
agent.deallocate_ports("tcp", os.environ['MODULE_ID'] + "_rsync")
except:
pass
60 changes: 60 additions & 0 deletions core/imageroot/usr/local/agent/pypkg/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,3 +607,63 @@ def get_bound_domain_list(rdb, module_id=None):
return rval.split()
else:
return []

def allocate_ports(ports_number: int, protocol: str, module_id: str=""):
"""
Allocate a range of ports for a given module,
if it is already allocated it is deallocated first.

:param ports_number: Number of consecutive ports required.
:param protocol: Protocol type ('tcp' or 'udp').
:param module_id: Name of the module requesting the ports.
Parameter is optional, if not provided, default value is environment variable MODULE_ID.
:return: A tuple (start_port, end_port) if allocation is successful, None otherwise.
"""

if module_id == "":
module_id = os.environ['MODULE_ID']

node_id = os.environ['NODE_ID']
response = agent.tasks.run(
agent_id=f'node/{node_id}',
action='allocate-ports',
data={
'ports': ports_number,
'module_id': module_id,
'protocol': protocol
}
)

if response['exit_code'] != 0:
raise Exception(f"{response['error']}")

return response['output']


def deallocate_ports(protocol: str, module_id: str=""):
"""
Deallocate the ports for a given module.

:param protocol: Protocol type ('tcp' or 'udp').
:param module_id: Name of the module whose ports are to be deallocated.
Parameter is optional, if not provided, default value is environment variable MODULE_ID.
:return: A tuple (start_port, end_port) if deallocation is successful, None otherwise.
"""

if module_id == "":
module_id = os.environ['MODULE_ID']

node_id = os.environ['NODE_ID']
response = agent.tasks.run(
agent_id=f'node/{node_id}',
action='deallocate-ports',
data={
'module_id': module_id,
'protocol': protocol
}
)

if response['exit_code'] != 0:
raise Exception(f"{response['error']}")

return response['output']
182 changes: 182 additions & 0 deletions core/imageroot/usr/local/agent/pypkg/node/ports_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#
# Copyright (C) 2024 Nethesis S.r.l.
# SPDX-License-Identifier: GPL-3.0-or-later
#

import sqlite3

class PortError(Exception):
"""Base class for all port-related exceptions."""
pass

class PortRangeExceededError(PortError):
"""Exception raised when the port range is exceeded."""
def __init__(self, message="Ports range max exceeded!"):
self.message = message
super().__init__(self.message)

class StorageError(PortError):
"""Exception raised when a database error occurs."""
def __init__(self, message="Database operation failed."):
self.message = message
super().__init__(self.message)

class ModuleNotFoundError(PortError):
"""Exception raised when a module is not found for deallocation."""
def __init__(self, module_name, message=None):
self.module_name = module_name
if message is None:
message = f"Module '{module_name}' not found."
self.message = message
super().__init__(self.message)

class InvalidPortRequestError(PortError):
"""Exception raised when the requested number of ports is invalid."""
def __init__(self, message="The number of required ports must be at least 1."):
self.message = message
super().__init__(self.message)

def create_tables(cursor: sqlite3.Cursor):
# Create TCP table if it doesn't exist
cursor.execute("""
CREATE TABLE IF NOT EXISTS TCP_PORTS (
start INT NOT NULL,
end INT NOT NULL,
module CHAR(255) NOT NULL
);
""")

# Create UDP table if it doesn't exist
cursor.execute("""
CREATE TABLE IF NOT EXISTS UDP_PORTS (
start INT NOT NULL,
end INT NOT NULL,
module CHAR(255) NOT NULL
);
""")

def is_port_used(ports_used, port_to_check):
for port in ports_used:
if port_to_check in range(port[0], port[1] + 1):
return True
return False

def allocate_ports(required_ports: int, module_name: str, protocol: str):
"""
Allocate a range of ports for a given module,
if it is already allocated it is deallocated first.

:param required_ports: Number of consecutive ports required.
:param module_name: Name of the module requesting the ports.
:param protocol: Protocol type ('tcp' or 'udp').
:return: A tuple (start_port, end_port) if allocation is successful, None otherwise.
"""
if required_ports < 1:
raise InvalidPortRequestError() # Raise error if requested ports are less than 1

range_start = 20000
range_end = 45000

try:
with sqlite3.connect('./ports.sqlite', isolation_level='EXCLUSIVE', timeout=30) as database:
cursor = database.cursor()
create_tables(cursor) # Ensure the tables exist

# Fetch used ports based on protocol
if protocol == 'tcp':
cursor.execute("SELECT start,end,module FROM TCP_PORTS ORDER BY start;")
elif protocol == 'udp':
cursor.execute("SELECT start,end,module FROM UDP_PORTS ORDER BY start;")
ports_used = cursor.fetchall()

# If the module already has an assigned range, deallocate it first
if any(module_name == range[2] for range in ports_used):
deallocate_ports(module_name, protocol)
# Reload the used ports after deallocation
if protocol == 'tcp':
cursor.execute("SELECT start,end,module FROM TCP_PORTS ORDER BY start;")
elif protocol == 'udp':
cursor.execute("SELECT start,end,module FROM UDP_PORTS ORDER BY start;")
ports_used = cursor.fetchall()

if len(ports_used) == 0:
write_range(range_start, range_start + required_ports - 1, module_name, protocol, database)
return (range_start, range_start + required_ports - 1)

while range_start <= range_end:
# Check if the current port is within an already used range
for port_range in ports_used:
for index in range(required_ports):
if is_port_used(ports_used, range_start+index):
range_start = port_range[1] + 1 # Move to the next available port range
break
if index == required_ports-1:
write_range(range_start, range_start + required_ports - 1, module_name, protocol, database)
return (range_start, range_start + required_ports - 1)
else:
raise PortRangeExceededError()
except sqlite3.Error as e:
raise StorageError(f"Database error: {e}") from e # Raise custom database error

def deallocate_ports(module_name: str, protocol: str):
"""
Deallocate the ports for a given module.

:param module_name: Name of the module whose ports are to be deallocated.
:param protocol: Protocol type ('tcp' or 'udp').
:return: A tuple (start_port, end_port) if deallocation is successful, None otherwise.
"""
try:
with sqlite3.connect('./ports.sqlite', isolation_level='EXCLUSIVE', timeout=30) as database:
cursor = database.cursor()
create_tables(cursor) # Ensure the tables exist

# Fetch the port range for the given module and protocol
if protocol == 'tcp':
cursor.execute("SELECT start,end,module FROM TCP_PORTS WHERE module=?;", (module_name,))
elif protocol == 'udp':
cursor.execute("SELECT start,end,module FROM UDP_PORTS WHERE module=?;", (module_name,))
ports_deallocated = cursor.fetchall()

if ports_deallocated:
# Delete the allocated port range for the module
if protocol == 'tcp':
cursor.execute("DELETE FROM TCP_PORTS WHERE module=?;", (module_name,))
elif protocol == 'udp':
cursor.execute("DELETE FROM UDP_PORTS WHERE module=?;", (module_name,))
database.commit()
return (ports_deallocated[0][0], ports_deallocated[0][1])
else:
raise ModuleNotFoundError(module_name) # Raise error if the module is not found

except sqlite3.Error as e:
raise StorageError(f"Database error: {e}") from e # Raise custom database error

def write_range(start: int, end: int, module: str, protocol: str, database: sqlite3.Connection=None):
"""
Write a port range for a module directly to the database.

:param start: Starting port number.
:param end: Ending port number.
:param module: Name of the module.
:param protocol: Protocol type ('tcp' or 'udp').
"""
try:
if database is None:
database = sqlite3.connect('./ports.sqlite', isolation_level='EXCLUSIVE', timeout=30)

with database:
cursor = database.cursor()
create_tables(cursor) # Ensure the tables exist

# Insert the port range into the appropriate table based on protocol
if protocol == 'tcp':
cursor.execute("INSERT INTO TCP_PORTS (start, end, module) VALUES (?, ?, ?);",
(start, end, module))
elif protocol == 'udp':
cursor.execute("INSERT INTO UDP_PORTS (start, end, module) VALUES (?, ?, ?);",
(start, end, module))
database.commit()

except sqlite3.Error as e:
raise StorageError(f"Database error: {e}") from e # Raise custom database error
Original file line number Diff line number Diff line change
Expand Up @@ -31,34 +31,6 @@ import os
import re
import uuid

def allocate_tcp_ports_range(node_id, module_environment, size):
"""Allocate in "node_id" a TCP port range of the given "size" for "module_id"
"""
global rdb
agent.assert_exp(size > 0)

seq = rdb.incrby(f'node/{int(node_id)}/tcp_ports_sequence', size)
agent.assert_exp(int(seq) > 0)
module_environment['TCP_PORT'] = f'{seq - size}' # Always set the first port
if size > 1: # Multiple ports: always set the ports range variable
module_environment['TCP_PORTS_RANGE'] = f'{seq - size}-{seq - 1}'
if size <= 8: # Few ports: set also a comma-separated list of ports variable
module_environment['TCP_PORTS'] = ','.join(str(port) for port in range(seq-size, seq))

def allocate_udp_ports_range(node_id, module_environment, size):
"""Allocate in "node_id" a UDP port range of the given "size" for "module_id"
"""
global rdb
agent.assert_exp(size > 0)

seq = rdb.incrby(f'node/{int(node_id)}/udp_ports_sequence', size)
agent.assert_exp(int(seq) > 0)
module_environment['UDP_PORT'] = f'{seq - size}' # Always set the first port
if size > 1: # Multiple ports: always set the ports range variable
module_environment['UDP_PORTS_RANGE'] = f'{seq - size}-{seq - 1}'
if size <= 8: # Few ports: set also a comma-separated list of ports variable
module_environment['UDP_PORTS'] = ','.join(str(port) for port in range(seq-size, seq))

request = json.load(sys.stdin)
node_id = int(request['node'])
agent.assert_exp(node_id > 0)
Expand Down Expand Up @@ -146,14 +118,6 @@ module_environment = {
'MODULE_UUID': str(uuid.uuid4())
}

# Allocate TCP ports
if tcp_ports_demand > 0:
allocate_tcp_ports_range(node_id, module_environment, tcp_ports_demand)

# Allocate UDP ports
if udp_ports_demand > 0:
allocate_udp_ports_range(node_id, module_environment, udp_ports_demand)

# Set the "default_instance" keys for cluster and node, if module_id is the first instance of image
for kdefault_instance in [f'cluster/default_instance/{image_id}', f'node/{node_id}/default_instance/{image_id}']:
default_instance = rdb.get(kdefault_instance)
Expand All @@ -174,6 +138,8 @@ add_module_result = agent.tasks.run(
"module_id": module_id,
"is_rootfull": is_rootfull,
"environment": module_environment,
"tcp_ports_demand": tcp_ports_demand,
"udp_ports_demand": udp_ports_demand,
},
endpoint="redis://cluster-leader",
progress_callback=agent.get_progress_callback(34,66),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,6 @@ agent.assert_exp(rdb.hset(f'node/{node_id}/vpn', mapping={
for flag in flags:
rdb.sadd(f'node/{node_id}/flags', flag)

# Initialize the node ports sequence
agent.assert_exp(rdb.set(f'node/{node_id}/tcp_ports_sequence', 20000) is True)
agent.assert_exp(rdb.set(f'node/{node_id}/udp_ports_sequence', 20000) is True)

#
# Create redis acls for the node agent
#
Expand Down Expand Up @@ -168,6 +164,9 @@ cluster.grants.grant(rdb, "remove-custom-zone", f'node/{node_id}', "tunadm")
cluster.grants.grant(rdb, "add-tun", f'node/{node_id}', "tunadm")
cluster.grants.grant(rdb, "remove-tun", f'node/{node_id}', "tunadm")

cluster.grants.grant(rdb, "allocate-ports", f'node/{node_id}', "portsadm")
cluster.grants.grant(rdb, "deallocate-ports", f'node/{node_id}', "portsadm")

# Grant on cascade the owner role on the new node, to users with the owner
# role on cluster
for userk in rdb.scan_iter('roles/*'):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,19 @@ add_module_result = agent.tasks.run("cluster", "add-module",
agent.assert_exp(add_module_result['exit_code'] == 0) # add-module is successful

dmid = add_module_result['output']['module_id'] # Destination module ID
rsyncd_port = int(rdb.incrby(f'node/{node_id}/tcp_ports_sequence', 1)) # Allocate a TCP port for rsyncd
allocated_range = agent.tasks.run(
agent_id=f'node/{node_id}',
action="allocate-ports",
data={
'ports': 1,
'module_id': dmid + '_rsync',
'protocol': 'tcp'
},
endpoint="redis://cluster-leader",
progress_callback=agent.get_progress_callback(26,40),
)
agent.assert_exp(allocated_range['output'][0] == allocated_range['output'][1])
rsyncd_port = allocated_range['output'][0] # Allocate a TCP port for rsyncd
agent.assert_exp(rsyncd_port > 0) # valid destination port number

# Rootfull modules require a volume name remapping:
Expand Down Expand Up @@ -103,7 +115,7 @@ client_task = {
# Send and receive tasks run in parallel until both finish
clone_errors = agent.tasks.runp_brief([server_task, client_task],
endpoint="redis://cluster-leader",
progress_callback=agent.get_progress_callback(26, 94),
progress_callback=agent.get_progress_callback(41, 90),
)

if clone_errors > 0:
Expand All @@ -122,10 +134,23 @@ if replace:
"preserve_data": False
},
endpoint="redis://cluster-leader",
progress_callback=agent.get_progress_callback(95, 98),
progress_callback=agent.get_progress_callback(91, 94),
)
if remove_retval['exit_code'] != 0:
print(f"Removal of module/{smid} has failed!")
sys.exit(1)

# Deallocate rsync port
deallocated_range = agent.tasks.run(
agent_id=f'node/{node_id}',
action="deallocate-ports",
data={
'module_id': dmid + '_rsync',
'protocol': 'tcp'
},
endpoint="redis://cluster-leader",
progress_callback=agent.get_progress_callback(96,99),
)
agent.assert_exp(allocated_range['output'] == deallocated_range['output'])

json.dump(add_module_result['output'], fp=sys.stdout)
Loading
Loading