diff --git a/pynumaflow/mapper/multiproc_server.py b/pynumaflow/mapper/multiproc_server.py index 14277d9c..d14fde93 100644 --- a/pynumaflow/mapper/multiproc_server.py +++ b/pynumaflow/mapper/multiproc_server.py @@ -4,6 +4,7 @@ import os import socket from concurrent import futures +from collections.abc import Iterator import grpc from google.protobuf import empty_pb2 as _empty_pb2 @@ -66,7 +67,6 @@ class MultiProcMapper(map_pb2_grpc.MapServicer): "__map_handler", "_max_message_size", "_server_options", - "_sock_path", "_process_count", "_threads_per_proc", ) @@ -85,9 +85,12 @@ def __init__( ("grpc.so_reuseport", 1), ("grpc.so_reuseaddr", 1), ] - self._process_count = int(os.getenv("NUM_CPU_MULTIPROC") or os.cpu_count()) - # Setting the max process count to 2 * CPU count - self._process_count = min(self._process_count, 2 * os.cpu_count()) + # Set the number of processes to be spawned to the number of CPUs or + # the value of the env var NUM_CPU_MULTIPROC defined by the user + # Setting the max value to 2 * CPU count + self._process_count = min( + int(os.getenv("NUM_CPU_MULTIPROC", str(os.cpu_count()))), 2 * os.cpu_count() + ) self._threads_per_proc = int(os.getenv("MAX_THREADS", "4")) def MapFn( @@ -151,7 +154,7 @@ def _run_server(self, bind_address: str) -> None: server.wait_for_termination() @contextlib.contextmanager - def _reserve_port(self, port_num: int) -> int: + def _reserve_port(self, port_num: int) -> Iterator[int]: """Find and reserve a port for all subprocesses to use.""" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) @@ -165,14 +168,16 @@ def _reserve_port(self, port_num: int) -> int: def start(self) -> None: """ - Start N grpc servers in different processes where N = CPU Count + Start N grpc servers in different processes where N = The number of CPUs or the + value of the env var NUM_CPU_MULTIPROC defined by the user. The max value + is set to 2 * CPU count. Each server will be bound to a different port, and we will create equal number of workers to handle each server. On the client side there will be same number of connections as the number of servers. """ workers = [] server_ports = [] - for i in range(self._process_count): + for _ in range(self._process_count): # Find a port to bind to for each server, thus sending the port number = 0 # to the _reserve_port function so that kernel can find and return a free port with self._reserve_port(0) as port: @@ -187,7 +192,7 @@ def start(self) -> None: server_ports.append(port) # Convert the available ports to a comma separated string - ports = ",".join([str(p) for p in server_ports]) + ports = ",".join(map(str, server_ports)) serv_info = ServerInfo( protocol=Protocol.TCP, diff --git a/pynumaflow/sourcetransformer/multiproc_server.py b/pynumaflow/sourcetransformer/multiproc_server.py index a3035506..7aa58e9d 100644 --- a/pynumaflow/sourcetransformer/multiproc_server.py +++ b/pynumaflow/sourcetransformer/multiproc_server.py @@ -4,6 +4,7 @@ import os import socket from concurrent import futures +from collections.abc import Iterator import grpc from google.protobuf import empty_pb2 as _empty_pb2 @@ -13,7 +14,7 @@ from pynumaflow._constants import ( MAX_MESSAGE_SIZE, ) -from pynumaflow._constants import MULTIPROC_MAP_SOCK_PORT, MULTIPROC_MAP_SOCK_ADDR +from pynumaflow._constants import MULTIPROC_MAP_SOCK_ADDR from pynumaflow.exceptions import SocketError from pynumaflow.info.server import ( get_sdk_version, @@ -46,10 +47,7 @@ class MultiProcSourceTransformer(transform_pb2_grpc.SourceTransformServicer): Args: handler: Function callable following the type signature of SourceTransformCallable - sock_path: Path to the TCP Socket max_message_size: The max message size in bytes the server can receive and send - max_threads: The max number of threads to be spawned; - defaults to number of processors x4 Example invocation: >>> from typing import Iterator @@ -70,7 +68,6 @@ class MultiProcSourceTransformer(transform_pb2_grpc.SourceTransformServicer): def __init__( self, handler: SourceTransformCallable, - sock_path=MULTIPROC_MAP_SOCK_PORT, max_message_size=MAX_MESSAGE_SIZE, ): self.__transform_handler: SourceTransformCallable = handler @@ -82,11 +79,13 @@ def __init__( ("grpc.so_reuseport", 1), ("grpc.so_reuseaddr", 1), ] - self._sock_path = sock_path - self._process_count = int( - os.getenv("NUM_CPU_MULTIPROC") or os.getenv("NUMAFLOW_CPU_LIMIT", 1) + # Set the number of processes to be spawned to the number of CPUs or the value + # of the env var NUM_CPU_MULTIPROC defined by the user + # Setting the max value to 2 * CPU count + self._process_count = min( + int(os.getenv("NUM_CPU_MULTIPROC", str(os.cpu_count()))), 2 * os.cpu_count() ) - self._thread_concurrency = int(os.getenv("MAX_THREADS", 0)) or (self._process_count * 4) + self._threads_per_proc = int(os.getenv("MAX_THREADS", "4")) def SourceTransformFn( self, request: transform_pb2.SourceTransformRequest, context: NumaflowServicerContext @@ -142,53 +141,66 @@ def _run_server(self, bind_address): _LOGGER.info("Starting new server.") server = grpc.server( futures.ThreadPoolExecutor( - max_workers=self._thread_concurrency, + max_workers=self._threads_per_proc, ), options=self._server_options, ) transform_pb2_grpc.add_SourceTransformServicer_to_server(self, server) server.add_insecure_port(bind_address) server.start() - serv_info = ServerInfo( - protocol=Protocol.TCP, - language=Language.PYTHON, - version=get_sdk_version(), - metadata=get_metadata_env(envs=METADATA_ENVS), - ) - # Overwrite the CPU_LIMIT metadata using user input - serv_info.metadata["CPU_LIMIT"] = str(self._process_count) - info_server_write(server_info=serv_info, info_file=SERVER_INFO_FILE_PATH) - _LOGGER.info("GRPC Multi-Processor Server listening on: %s %d", bind_address, os.getpid()) server.wait_for_termination() @contextlib.contextmanager - def _reserve_port(self) -> int: + def _reserve_port(self, port_num: int) -> Iterator[int]: """Find and reserve a port for all subprocesses to use.""" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) == 0: raise SocketError("Failed to set SO_REUSEADDR.") - if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0: - raise SocketError("Failed to set SO_REUSEPORT.") - sock.bind(("", self._sock_path)) try: + sock.bind(("", port_num)) yield sock.getsockname()[1] finally: sock.close() def start(self) -> None: - """Start N grpc servers in different processes where N = CPU Count""" - with self._reserve_port() as port: - bind_address = f"{MULTIPROC_MAP_SOCK_ADDR}:{port}" - workers = [] - for _ in range(self._process_count): + """ + Start N grpc servers in different processes where N = The number of CPUs or the + value of the env var NUM_CPU_MULTIPROC defined by the user. The max value + is set to 2 * CPU count. + Each server will be bound to a different port, and we will create equal number of + workers to handle each server. + On the client side there will be same number of connections as the number of servers. + """ + workers = [] + server_ports = [] + for _ in range(self._process_count): + # Find a port to bind to for each server, thus sending the port number = 0 + # to the _reserve_port function so that kernel can find and return a free port + with self._reserve_port(0) as port: + bind_address = f"{MULTIPROC_MAP_SOCK_ADDR}:{port}" + _LOGGER.info("Starting server on port: %s", port) # NOTE: It is imperative that the worker subprocesses be forked before # any gRPC servers start up. See # https://github.com/grpc/grpc/issues/16001 for more details. worker = multiprocessing.Process(target=self._run_server, args=(bind_address,)) worker.start() workers.append(worker) - for worker in workers: - worker.join() + server_ports.append(port) + + # Convert the available ports to a comma separated string + ports = ",".join(map(str, server_ports)) + + serv_info = ServerInfo( + protocol=Protocol.TCP, + language=Language.PYTHON, + version=get_sdk_version(), + metadata=get_metadata_env(envs=METADATA_ENVS), + ) + # Add the PORTS metadata using the available ports + serv_info.metadata["SERV_PORTS"] = ports + info_server_write(server_info=serv_info, info_file=SERVER_INFO_FILE_PATH) + + for worker in workers: + worker.join() diff --git a/tests/sourcetransform/test_multiproc.py b/tests/sourcetransform/test_multiproc.py index 8ca41aa8..25cca845 100644 --- a/tests/sourcetransform/test_multiproc.py +++ b/tests/sourcetransform/test_multiproc.py @@ -1,6 +1,7 @@ import os import unittest from unittest import mock +from unittest.mock import Mock, patch import grpc from google.protobuf import empty_pb2 as _empty_pb2 @@ -33,23 +34,27 @@ def setUp(self) -> None: @mockenv(NUM_CPU_MULTIPROC="3") def test_multiproc_init(self) -> None: server = MultiProcSourceTransformer(handler=transform_handler) - self.assertEqual(server._sock_path, 55551) self.assertEqual(server._process_count, 3) - @mockenv(NUMAFLOW_CPU_LIMIT="4") + @patch("os.cpu_count", Mock(return_value=4)) def test_multiproc_process_count(self) -> None: server = MultiProcSourceTransformer(handler=transform_handler) - self.assertEqual(server._sock_path, 55551) self.assertEqual(server._process_count, 4) + @patch("os.cpu_count", Mock(return_value=4)) + @mockenv(NUM_CPU_MULTIPROC="10") + def test_max_process_count(self) -> None: + server = MultiProcSourceTransformer(handler=transform_handler) + self.assertEqual(server._process_count, 8) + # To test the reuse property for the grpc servers which allow multiple # bindings to the same server def test_reuse_port(self): - serv_options = [("grpc.so_reuseport", 1), ("grpc.so_reuseaddr", 1)] + serv_options = [("grpc.so_reuseaddr", 1)] server = MultiProcSourceTransformer(handler=transform_handler) - with server._reserve_port() as port: + with server._reserve_port(0) as port: print(port) bind_address = f"localhost:{port}" server1 = grpc.server(thread_pool=None, options=serv_options)