Skip to content

Commit

Permalink
Merge pull request #699 from NethServer/feature-6974-1
Browse files Browse the repository at this point in the history
New ports allocation system

Refs NethServer/dev#6974
  • Loading branch information
DavidePrincipi authored Oct 7, 2024
2 parents aa8ba9d + ec7e2b7 commit 1e9dcda
Show file tree
Hide file tree
Showing 24 changed files with 648 additions and 67 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/usr/bin/env python3

#
# Copyright (C) 2024 Nethesis S.r.l.
# SPDX-License-Identifier: GPL-3.0-or-later
#

import agent
import os

try:
agent.deallocate_ports("tcp", os.environ['MODULE_ID'] + "_rsync")
except:
pass
60 changes: 60 additions & 0 deletions core/imageroot/usr/local/agent/pypkg/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,3 +607,63 @@ def get_bound_domain_list(rdb, module_id=None):
return rval.split()
else:
return []

def allocate_ports(ports_number: int, protocol: str, module_id: str=""):
"""
Allocate a range of ports for a given module,
if it is already allocated it is deallocated first.
:param ports_number: Number of consecutive ports required.
:param protocol: Protocol type ('tcp' or 'udp').
:param module_id: Name of the module requesting the ports.
Parameter is optional, if not provided, default value is environment variable MODULE_ID.
:return: A tuple (start_port, end_port) if allocation is successful, None otherwise.
"""

if module_id == "":
module_id = os.environ['MODULE_ID']

node_id = os.environ['NODE_ID']
response = agent.tasks.run(
agent_id=f'node/{node_id}',
action='allocate-ports',
data={
'ports': ports_number,
'module_id': module_id,
'protocol': protocol
}
)

if response['exit_code'] != 0:
raise Exception(f"{response['error']}")

return response['output']


def deallocate_ports(protocol: str, module_id: str=""):
"""
Deallocate the ports for a given module.
:param protocol: Protocol type ('tcp' or 'udp').
:param module_id: Name of the module whose ports are to be deallocated.
Parameter is optional, if not provided, default value is environment variable MODULE_ID.
:return: A tuple (start_port, end_port) if deallocation is successful, None otherwise.
"""

if module_id == "":
module_id = os.environ['MODULE_ID']

node_id = os.environ['NODE_ID']
response = agent.tasks.run(
agent_id=f'node/{node_id}',
action='deallocate-ports',
data={
'module_id': module_id,
'protocol': protocol
}
)

if response['exit_code'] != 0:
raise Exception(f"{response['error']}")

return response['output']
182 changes: 182 additions & 0 deletions core/imageroot/usr/local/agent/pypkg/node/ports_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#
# Copyright (C) 2024 Nethesis S.r.l.
# SPDX-License-Identifier: GPL-3.0-or-later
#

import sqlite3

class PortError(Exception):
"""Base class for all port-related exceptions."""
pass

class PortRangeExceededError(PortError):
"""Exception raised when the port range is exceeded."""
def __init__(self, message="Ports range max exceeded!"):
self.message = message
super().__init__(self.message)

class StorageError(PortError):
"""Exception raised when a database error occurs."""
def __init__(self, message="Database operation failed."):
self.message = message
super().__init__(self.message)

class ModuleNotFoundError(PortError):
"""Exception raised when a module is not found for deallocation."""
def __init__(self, module_name, message=None):
self.module_name = module_name
if message is None:
message = f"Module '{module_name}' not found."
self.message = message
super().__init__(self.message)

class InvalidPortRequestError(PortError):
"""Exception raised when the requested number of ports is invalid."""
def __init__(self, message="The number of required ports must be at least 1."):
self.message = message
super().__init__(self.message)

def create_tables(cursor: sqlite3.Cursor):
# Create TCP table if it doesn't exist
cursor.execute("""
CREATE TABLE IF NOT EXISTS TCP_PORTS (
start INT NOT NULL,
end INT NOT NULL,
module CHAR(255) NOT NULL
);
""")

# Create UDP table if it doesn't exist
cursor.execute("""
CREATE TABLE IF NOT EXISTS UDP_PORTS (
start INT NOT NULL,
end INT NOT NULL,
module CHAR(255) NOT NULL
);
""")

def is_port_used(ports_used, port_to_check):
for port in ports_used:
if port_to_check in range(port[0], port[1] + 1):
return True
return False

def allocate_ports(required_ports: int, module_name: str, protocol: str):
"""
Allocate a range of ports for a given module,
if it is already allocated it is deallocated first.
:param required_ports: Number of consecutive ports required.
:param module_name: Name of the module requesting the ports.
:param protocol: Protocol type ('tcp' or 'udp').
:return: A tuple (start_port, end_port) if allocation is successful, None otherwise.
"""
if required_ports < 1:
raise InvalidPortRequestError() # Raise error if requested ports are less than 1

range_start = 20000
range_end = 45000

try:
with sqlite3.connect('./ports.sqlite', isolation_level='EXCLUSIVE', timeout=30) as database:
cursor = database.cursor()
create_tables(cursor) # Ensure the tables exist

# Fetch used ports based on protocol
if protocol == 'tcp':
cursor.execute("SELECT start,end,module FROM TCP_PORTS ORDER BY start;")
elif protocol == 'udp':
cursor.execute("SELECT start,end,module FROM UDP_PORTS ORDER BY start;")
ports_used = cursor.fetchall()

# If the module already has an assigned range, deallocate it first
if any(module_name == range[2] for range in ports_used):
deallocate_ports(module_name, protocol)
# Reload the used ports after deallocation
if protocol == 'tcp':
cursor.execute("SELECT start,end,module FROM TCP_PORTS ORDER BY start;")
elif protocol == 'udp':
cursor.execute("SELECT start,end,module FROM UDP_PORTS ORDER BY start;")
ports_used = cursor.fetchall()

if len(ports_used) == 0:
write_range(range_start, range_start + required_ports - 1, module_name, protocol, database)
return (range_start, range_start + required_ports - 1)

while range_start <= range_end:
# Check if the current port is within an already used range
for port_range in ports_used:
for index in range(required_ports):
if is_port_used(ports_used, range_start+index):
range_start = port_range[1] + 1 # Move to the next available port range
break
if index == required_ports-1:
write_range(range_start, range_start + required_ports - 1, module_name, protocol, database)
return (range_start, range_start + required_ports - 1)
else:
raise PortRangeExceededError()
except sqlite3.Error as e:
raise StorageError(f"Database error: {e}") from e # Raise custom database error

def deallocate_ports(module_name: str, protocol: str):
"""
Deallocate the ports for a given module.
:param module_name: Name of the module whose ports are to be deallocated.
:param protocol: Protocol type ('tcp' or 'udp').
:return: A tuple (start_port, end_port) if deallocation is successful, None otherwise.
"""
try:
with sqlite3.connect('./ports.sqlite', isolation_level='EXCLUSIVE', timeout=30) as database:
cursor = database.cursor()
create_tables(cursor) # Ensure the tables exist

# Fetch the port range for the given module and protocol
if protocol == 'tcp':
cursor.execute("SELECT start,end,module FROM TCP_PORTS WHERE module=?;", (module_name,))
elif protocol == 'udp':
cursor.execute("SELECT start,end,module FROM UDP_PORTS WHERE module=?;", (module_name,))
ports_deallocated = cursor.fetchall()

if ports_deallocated:
# Delete the allocated port range for the module
if protocol == 'tcp':
cursor.execute("DELETE FROM TCP_PORTS WHERE module=?;", (module_name,))
elif protocol == 'udp':
cursor.execute("DELETE FROM UDP_PORTS WHERE module=?;", (module_name,))
database.commit()
return (ports_deallocated[0][0], ports_deallocated[0][1])
else:
raise ModuleNotFoundError(module_name) # Raise error if the module is not found

except sqlite3.Error as e:
raise StorageError(f"Database error: {e}") from e # Raise custom database error

def write_range(start: int, end: int, module: str, protocol: str, database: sqlite3.Connection=None):
"""
Write a port range for a module directly to the database.
:param start: Starting port number.
:param end: Ending port number.
:param module: Name of the module.
:param protocol: Protocol type ('tcp' or 'udp').
"""
try:
if database is None:
database = sqlite3.connect('./ports.sqlite', isolation_level='EXCLUSIVE', timeout=30)

with database:
cursor = database.cursor()
create_tables(cursor) # Ensure the tables exist

# Insert the port range into the appropriate table based on protocol
if protocol == 'tcp':
cursor.execute("INSERT INTO TCP_PORTS (start, end, module) VALUES (?, ?, ?);",
(start, end, module))
elif protocol == 'udp':
cursor.execute("INSERT INTO UDP_PORTS (start, end, module) VALUES (?, ?, ?);",
(start, end, module))
database.commit()

except sqlite3.Error as e:
raise StorageError(f"Database error: {e}") from e # Raise custom database error
Original file line number Diff line number Diff line change
Expand Up @@ -31,34 +31,6 @@ import os
import re
import uuid

def allocate_tcp_ports_range(node_id, module_environment, size):
"""Allocate in "node_id" a TCP port range of the given "size" for "module_id"
"""
global rdb
agent.assert_exp(size > 0)

seq = rdb.incrby(f'node/{int(node_id)}/tcp_ports_sequence', size)
agent.assert_exp(int(seq) > 0)
module_environment['TCP_PORT'] = f'{seq - size}' # Always set the first port
if size > 1: # Multiple ports: always set the ports range variable
module_environment['TCP_PORTS_RANGE'] = f'{seq - size}-{seq - 1}'
if size <= 8: # Few ports: set also a comma-separated list of ports variable
module_environment['TCP_PORTS'] = ','.join(str(port) for port in range(seq-size, seq))

def allocate_udp_ports_range(node_id, module_environment, size):
"""Allocate in "node_id" a UDP port range of the given "size" for "module_id"
"""
global rdb
agent.assert_exp(size > 0)

seq = rdb.incrby(f'node/{int(node_id)}/udp_ports_sequence', size)
agent.assert_exp(int(seq) > 0)
module_environment['UDP_PORT'] = f'{seq - size}' # Always set the first port
if size > 1: # Multiple ports: always set the ports range variable
module_environment['UDP_PORTS_RANGE'] = f'{seq - size}-{seq - 1}'
if size <= 8: # Few ports: set also a comma-separated list of ports variable
module_environment['UDP_PORTS'] = ','.join(str(port) for port in range(seq-size, seq))

request = json.load(sys.stdin)
node_id = int(request['node'])
agent.assert_exp(node_id > 0)
Expand Down Expand Up @@ -146,14 +118,6 @@ module_environment = {
'MODULE_UUID': str(uuid.uuid4())
}

# Allocate TCP ports
if tcp_ports_demand > 0:
allocate_tcp_ports_range(node_id, module_environment, tcp_ports_demand)

# Allocate UDP ports
if udp_ports_demand > 0:
allocate_udp_ports_range(node_id, module_environment, udp_ports_demand)

# Set the "default_instance" keys for cluster and node, if module_id is the first instance of image
for kdefault_instance in [f'cluster/default_instance/{image_id}', f'node/{node_id}/default_instance/{image_id}']:
default_instance = rdb.get(kdefault_instance)
Expand All @@ -174,6 +138,8 @@ add_module_result = agent.tasks.run(
"module_id": module_id,
"is_rootfull": is_rootfull,
"environment": module_environment,
"tcp_ports_demand": tcp_ports_demand,
"udp_ports_demand": udp_ports_demand,
},
endpoint="redis://cluster-leader",
progress_callback=agent.get_progress_callback(34,66),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,6 @@ agent.assert_exp(rdb.hset(f'node/{node_id}/vpn', mapping={
for flag in flags:
rdb.sadd(f'node/{node_id}/flags', flag)

# Initialize the node ports sequence
agent.assert_exp(rdb.set(f'node/{node_id}/tcp_ports_sequence', 20000) is True)
agent.assert_exp(rdb.set(f'node/{node_id}/udp_ports_sequence', 20000) is True)

#
# Create redis acls for the node agent
#
Expand Down Expand Up @@ -168,6 +164,9 @@ cluster.grants.grant(rdb, "remove-custom-zone", f'node/{node_id}', "tunadm")
cluster.grants.grant(rdb, "add-tun", f'node/{node_id}', "tunadm")
cluster.grants.grant(rdb, "remove-tun", f'node/{node_id}', "tunadm")

cluster.grants.grant(rdb, "allocate-ports", f'node/{node_id}', "portsadm")
cluster.grants.grant(rdb, "deallocate-ports", f'node/{node_id}', "portsadm")

# Grant on cascade the owner role on the new node, to users with the owner
# role on cluster
for userk in rdb.scan_iter('roles/*'):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,19 @@ add_module_result = agent.tasks.run("cluster", "add-module",
agent.assert_exp(add_module_result['exit_code'] == 0) # add-module is successful

dmid = add_module_result['output']['module_id'] # Destination module ID
rsyncd_port = int(rdb.incrby(f'node/{node_id}/tcp_ports_sequence', 1)) # Allocate a TCP port for rsyncd
allocated_range = agent.tasks.run(
agent_id=f'node/{node_id}',
action="allocate-ports",
data={
'ports': 1,
'module_id': dmid + '_rsync',
'protocol': 'tcp'
},
endpoint="redis://cluster-leader",
progress_callback=agent.get_progress_callback(26,40),
)
agent.assert_exp(allocated_range['output'][0] == allocated_range['output'][1])
rsyncd_port = allocated_range['output'][0] # Allocate a TCP port for rsyncd
agent.assert_exp(rsyncd_port > 0) # valid destination port number

# Rootfull modules require a volume name remapping:
Expand Down Expand Up @@ -103,7 +115,7 @@ client_task = {
# Send and receive tasks run in parallel until both finish
clone_errors = agent.tasks.runp_brief([server_task, client_task],
endpoint="redis://cluster-leader",
progress_callback=agent.get_progress_callback(26, 94),
progress_callback=agent.get_progress_callback(41, 90),
)

if clone_errors > 0:
Expand All @@ -122,10 +134,23 @@ if replace:
"preserve_data": False
},
endpoint="redis://cluster-leader",
progress_callback=agent.get_progress_callback(95, 98),
progress_callback=agent.get_progress_callback(91, 94),
)
if remove_retval['exit_code'] != 0:
print(f"Removal of module/{smid} has failed!")
sys.exit(1)

# Deallocate rsync port
deallocated_range = agent.tasks.run(
agent_id=f'node/{node_id}',
action="deallocate-ports",
data={
'module_id': dmid + '_rsync',
'protocol': 'tcp'
},
endpoint="redis://cluster-leader",
progress_callback=agent.get_progress_callback(96,99),
)
agent.assert_exp(allocated_range['output'] == deallocated_range['output'])

json.dump(add_module_result['output'], fp=sys.stdout)
Loading

0 comments on commit 1e9dcda

Please sign in to comment.