Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: update multiproc to use explicit ports for connection #127

Merged
merged 12 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 43 additions & 28 deletions pynumaflow/mapper/multiproc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import socket
from concurrent import futures
from collections.abc import Iterator

import grpc
from google.protobuf import empty_pb2 as _empty_pb2
Expand All @@ -12,7 +13,7 @@
from pynumaflow._constants import (
MAX_MESSAGE_SIZE,
)
from pynumaflow._constants import MULTIPROC_MAP_SOCK_PORT, MULTIPROC_MAP_SOCK_ADDR
from pynumaflow._constants import MULTIPROC_MAP_SOCK_ADDR
from pynumaflow.exceptions import SocketError
from pynumaflow.mapper import Datum
from pynumaflow.mapper._dtypes import MapCallable
Expand Down Expand Up @@ -44,7 +45,6 @@ class MultiProcMapper(map_pb2_grpc.MapServicer):

Args:
handler: Function callable following the type signature of MapCallable
sock_path: Path to the TCP port to bind to
max_message_size: The max message size in bytes the server can receive and send

Example invocation:
Expand All @@ -67,15 +67,13 @@ class MultiProcMapper(map_pb2_grpc.MapServicer):
"__map_handler",
"_max_message_size",
"_server_options",
"_sock_path",
"_process_count",
"_threads_per_proc",
)

def __init__(
self,
handler: MapCallable,
sock_path=MULTIPROC_MAP_SOCK_PORT,
max_message_size=MAX_MESSAGE_SIZE,
):
self.__map_handler: MapCallable = handler
Expand All @@ -87,8 +85,12 @@ def __init__(
("grpc.so_reuseport", 1),
("grpc.so_reuseaddr", 1),
]
self._sock_path = sock_path
self._process_count = int(os.getenv("NUM_CPU_MULTIPROC") or os.cpu_count())
# Set the number of processes to be spawned to the number of CPUs or
# the value of the env var NUM_CPU_MULTIPROC defined by the user
# Setting the max value to 2 * CPU count
self._process_count = min(
int(os.getenv("NUM_CPU_MULTIPROC", str(os.cpu_count()))), 2 * os.cpu_count()
)
self._threads_per_proc = int(os.getenv("MAX_THREADS", "4"))

def MapFn(
Expand Down Expand Up @@ -148,46 +150,59 @@ def _run_server(self, bind_address: str) -> None:
map_pb2_grpc.add_MapServicer_to_server(self, server)
server.add_insecure_port(bind_address)
server.start()
serv_info = ServerInfo(
protocol=Protocol.TCP,
language=Language.PYTHON,
version=get_sdk_version(),
metadata=get_metadata_env(envs=METADATA_ENVS),
)
# Overwrite the CPU_LIMIT metadata using user input
serv_info.metadata["CPU_LIMIT"] = str(self._process_count)
info_server_write(server_info=serv_info, info_file=SERVER_INFO_FILE_PATH)

_LOGGER.info("GRPC Multi-Processor Server listening on: %s %d", bind_address, os.getpid())
server.wait_for_termination()

@contextlib.contextmanager
def _reserve_port(self) -> int:
def _reserve_port(self, port_num: int) -> Iterator[int]:
"""Find and reserve a port for all subprocesses to use."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) == 0:
raise SocketError("Failed to set SO_REUSEADDR.")
if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0:
raise SocketError("Failed to set SO_REUSEPORT.")
sock.bind(("", self._sock_path))
try:
sock.bind(("", port_num))
yield sock.getsockname()[1]
finally:
sock.close()

def start(self) -> None:
"""Start N grpc servers in different processes where N = CPU Count"""
with self._reserve_port() as port:
bind_address = f"{MULTIPROC_MAP_SOCK_ADDR}:{port}"
workers = []
for _ in range(self._process_count):
"""
Start N grpc servers in different processes where N = The number of CPUs or the
value of the env var NUM_CPU_MULTIPROC defined by the user. The max value
is set to 2 * CPU count.
Each server will be bound to a different port, and we will create equal number of
workers to handle each server.
On the client side there will be same number of connections as the number of servers.
"""
workers = []
server_ports = []
for _ in range(self._process_count):
# Find a port to bind to for each server, thus sending the port number = 0
# to the _reserve_port function so that kernel can find and return a free port
with self._reserve_port(0) as port:
bind_address = f"{MULTIPROC_MAP_SOCK_ADDR}:{port}"
_LOGGER.info("Starting server on port: %s", port)
# NOTE: It is imperative that the worker subprocesses be forked before
# any gRPC servers start up. See
# https://github.com/grpc/grpc/issues/16001 for more details.
worker = multiprocessing.Process(target=self._run_server, args=(bind_address,))
worker.start()
workers.append(worker)
for worker in workers:
worker.join()
server_ports.append(port)

# Convert the available ports to a comma separated string
ports = ",".join(map(str, server_ports))

serv_info = ServerInfo(
protocol=Protocol.TCP,
language=Language.PYTHON,
version=get_sdk_version(),
metadata=get_metadata_env(envs=METADATA_ENVS),
)
# Add the PORTS metadata using the available ports
serv_info.metadata["SERV_PORTS"] = ports
info_server_write(server_info=serv_info, info_file=SERVER_INFO_FILE_PATH)

for worker in workers:
worker.join()
76 changes: 44 additions & 32 deletions pynumaflow/sourcetransformer/multiproc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import socket
from concurrent import futures
from collections.abc import Iterator

import grpc
from google.protobuf import empty_pb2 as _empty_pb2
Expand All @@ -13,7 +14,7 @@
from pynumaflow._constants import (
MAX_MESSAGE_SIZE,
)
from pynumaflow._constants import MULTIPROC_MAP_SOCK_PORT, MULTIPROC_MAP_SOCK_ADDR
from pynumaflow._constants import MULTIPROC_MAP_SOCK_ADDR
from pynumaflow.exceptions import SocketError
from pynumaflow.info.server import (
get_sdk_version,
Expand Down Expand Up @@ -46,10 +47,7 @@ class MultiProcSourceTransformer(transform_pb2_grpc.SourceTransformServicer):
Args:

handler: Function callable following the type signature of SourceTransformCallable
sock_path: Path to the TCP Socket
max_message_size: The max message size in bytes the server can receive and send
max_threads: The max number of threads to be spawned;
defaults to number of processors x4

Example invocation:
>>> from typing import Iterator
Expand All @@ -70,7 +68,6 @@ class MultiProcSourceTransformer(transform_pb2_grpc.SourceTransformServicer):
def __init__(
self,
handler: SourceTransformCallable,
sock_path=MULTIPROC_MAP_SOCK_PORT,
max_message_size=MAX_MESSAGE_SIZE,
):
self.__transform_handler: SourceTransformCallable = handler
Expand All @@ -82,11 +79,13 @@ def __init__(
("grpc.so_reuseport", 1),
("grpc.so_reuseaddr", 1),
]
self._sock_path = sock_path
self._process_count = int(
os.getenv("NUM_CPU_MULTIPROC") or os.getenv("NUMAFLOW_CPU_LIMIT", 1)
# Set the number of processes to be spawned to the number of CPUs or the value
# of the env var NUM_CPU_MULTIPROC defined by the user
# Setting the max value to 2 * CPU count
self._process_count = min(
int(os.getenv("NUM_CPU_MULTIPROC", str(os.cpu_count()))), 2 * os.cpu_count()
)
self._thread_concurrency = int(os.getenv("MAX_THREADS", 0)) or (self._process_count * 4)
self._threads_per_proc = int(os.getenv("MAX_THREADS", "4"))

def SourceTransformFn(
self, request: transform_pb2.SourceTransformRequest, context: NumaflowServicerContext
Expand Down Expand Up @@ -142,53 +141,66 @@ def _run_server(self, bind_address):
_LOGGER.info("Starting new server.")
server = grpc.server(
futures.ThreadPoolExecutor(
max_workers=self._thread_concurrency,
max_workers=self._threads_per_proc,
),
options=self._server_options,
)
transform_pb2_grpc.add_SourceTransformServicer_to_server(self, server)
server.add_insecure_port(bind_address)
server.start()
serv_info = ServerInfo(
protocol=Protocol.TCP,
language=Language.PYTHON,
version=get_sdk_version(),
metadata=get_metadata_env(envs=METADATA_ENVS),
)
# Overwrite the CPU_LIMIT metadata using user input
serv_info.metadata["CPU_LIMIT"] = str(self._process_count)
info_server_write(server_info=serv_info, info_file=SERVER_INFO_FILE_PATH)

_LOGGER.info("GRPC Multi-Processor Server listening on: %s %d", bind_address, os.getpid())
server.wait_for_termination()

@contextlib.contextmanager
def _reserve_port(self) -> int:
def _reserve_port(self, port_num: int) -> Iterator[int]:
"""Find and reserve a port for all subprocesses to use."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) == 0:
raise SocketError("Failed to set SO_REUSEADDR.")
if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0:
raise SocketError("Failed to set SO_REUSEPORT.")
sock.bind(("", self._sock_path))
try:
sock.bind(("", port_num))
yield sock.getsockname()[1]
finally:
sock.close()

def start(self) -> None:
"""Start N grpc servers in different processes where N = CPU Count"""
with self._reserve_port() as port:
bind_address = f"{MULTIPROC_MAP_SOCK_ADDR}:{port}"
workers = []
for _ in range(self._process_count):
"""
Start N grpc servers in different processes where N = The number of CPUs or the
value of the env var NUM_CPU_MULTIPROC defined by the user. The max value
is set to 2 * CPU count.
Each server will be bound to a different port, and we will create equal number of
workers to handle each server.
On the client side there will be same number of connections as the number of servers.
"""
workers = []
server_ports = []
for _ in range(self._process_count):
# Find a port to bind to for each server, thus sending the port number = 0
# to the _reserve_port function so that kernel can find and return a free port
with self._reserve_port(0) as port:
bind_address = f"{MULTIPROC_MAP_SOCK_ADDR}:{port}"
_LOGGER.info("Starting server on port: %s", port)
# NOTE: It is imperative that the worker subprocesses be forked before
# any gRPC servers start up. See
# https://github.com/grpc/grpc/issues/16001 for more details.
worker = multiprocessing.Process(target=self._run_server, args=(bind_address,))
worker.start()
workers.append(worker)
for worker in workers:
worker.join()
server_ports.append(port)

# Convert the available ports to a comma separated string
ports = ",".join(map(str, server_ports))

serv_info = ServerInfo(
protocol=Protocol.TCP,
language=Language.PYTHON,
version=get_sdk_version(),
metadata=get_metadata_env(envs=METADATA_ENVS),
)
# Add the PORTS metadata using the available ports
serv_info.metadata["SERV_PORTS"] = ports
info_server_write(server_info=serv_info, info_file=SERVER_INFO_FILE_PATH)

for worker in workers:
worker.join()
12 changes: 8 additions & 4 deletions tests/map/test_multiproc_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,27 @@ def setUp(self) -> None:
@mockenv(NUM_CPU_MULTIPROC="3")
def test_multiproc_init(self) -> None:
server = MultiProcMapper(handler=map_handler)
self.assertEqual(server._sock_path, 55551)
self.assertEqual(server._process_count, 3)

@patch("os.cpu_count", Mock(return_value=4))
def test_multiproc_process_count(self) -> None:
server = MultiProcMapper(handler=map_handler)
self.assertEqual(server._sock_path, 55551)
self.assertEqual(server._process_count, 4)

@patch("os.cpu_count", Mock(return_value=4))
@mockenv(NUM_CPU_MULTIPROC="10")
def test_max_process_count(self) -> None:
server = MultiProcMapper(handler=map_handler)
self.assertEqual(server._process_count, 8)

# To test the reuse property for the grpc servers which allow multiple
# bindings to the same server
def test_reuse_port(self):
serv_options = [("grpc.so_reuseport", 1), ("grpc.so_reuseaddr", 1)]
serv_options = [("grpc.so_reuseaddr", 1)]

server = MultiProcMapper(handler=map_handler)

with server._reserve_port() as port:
with server._reserve_port(0) as port:
print(port)
bind_address = f"localhost:{port}"
server1 = grpc.server(thread_pool=None, options=serv_options)
Expand Down
15 changes: 10 additions & 5 deletions tests/sourcetransform/test_multiproc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import unittest
from unittest import mock
from unittest.mock import Mock, patch

import grpc
from google.protobuf import empty_pb2 as _empty_pb2
Expand Down Expand Up @@ -33,23 +34,27 @@ def setUp(self) -> None:
@mockenv(NUM_CPU_MULTIPROC="3")
def test_multiproc_init(self) -> None:
server = MultiProcSourceTransformer(handler=transform_handler)
self.assertEqual(server._sock_path, 55551)
self.assertEqual(server._process_count, 3)

@mockenv(NUMAFLOW_CPU_LIMIT="4")
@patch("os.cpu_count", Mock(return_value=4))
def test_multiproc_process_count(self) -> None:
server = MultiProcSourceTransformer(handler=transform_handler)
self.assertEqual(server._sock_path, 55551)
self.assertEqual(server._process_count, 4)

@patch("os.cpu_count", Mock(return_value=4))
@mockenv(NUM_CPU_MULTIPROC="10")
def test_max_process_count(self) -> None:
server = MultiProcSourceTransformer(handler=transform_handler)
self.assertEqual(server._process_count, 8)

# To test the reuse property for the grpc servers which allow multiple
# bindings to the same server
def test_reuse_port(self):
serv_options = [("grpc.so_reuseport", 1), ("grpc.so_reuseaddr", 1)]
serv_options = [("grpc.so_reuseaddr", 1)]

server = MultiProcSourceTransformer(handler=transform_handler)

with server._reserve_port() as port:
with server._reserve_port(0) as port:
print(port)
bind_address = f"localhost:{port}"
server1 = grpc.server(thread_pool=None, options=serv_options)
Expand Down