diff --git a/src/agentscope/agents/rpc_agent.py b/src/agentscope/agents/rpc_agent.py index b7dd62d63..615d74962 100644 --- a/src/agentscope/agents/rpc_agent.py +++ b/src/agentscope/agents/rpc_agent.py @@ -4,6 +4,7 @@ from multiprocessing import Process from multiprocessing import Event from multiprocessing.synchronize import Event as EventClass +from multiprocessing import Pipe import socket import threading import time @@ -80,7 +81,7 @@ def __init__( use_memory: bool = True, memory_config: Optional[dict] = None, host: str = "localhost", - port: int = 80, + port: int = 12000, max_pool_size: int = 100, max_timeout_seconds: int = 1800, launch_server: bool = True, @@ -110,7 +111,7 @@ def __init__( The config of memory. host (`str`, defaults to "localhost"): Hostname of the rpc agent server. - port (`int`, defaults to `80`): + port (`int`, defaults to `None`): Port of the rpc agent server. max_pool_size (`int`, defaults to `100`): The max number of task results that the server can @@ -159,11 +160,9 @@ def __init__( max_timeout_seconds=max_timeout_seconds, local_mode=local_mode, ) - self.port = self.server_launcher.port self.client = None if not lazy_launch: - self.server_launcher.launch() - self.client = RpcAgentClient(host=self.host, port=self.port) + self._launch_server() # is_servicer is True only in the rpc server process if is_servicer: self.result_pool = ExpiringDict( @@ -180,6 +179,13 @@ def __init__( if not launch_server and not is_servicer: self.client = RpcAgentClient(host=self.host, port=self.port) + def _launch_server(self) -> None: + """Launch a rpc server and update the port and the client + """ + self.server_launcher.launch() + self.port = self.server_launcher.port + self.client = RpcAgentClient(host=self.host, port=self.port) + def get_task_id(self) -> int: """Get the auto-increment task id.""" with self.task_id_lock: @@ -294,8 +300,7 @@ def observe(self, x: Union[dict, Sequence[dict]]) -> None: The input to be observed. """ if self.client is None: - self.server_launcher.launch() - self.client = RpcAgentClient(host=self.host, port=self.port) + self._launch_server() self.client.call_func( func_name="_observe", value=serialize(x), # type: ignore [arg-type] @@ -312,8 +317,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> dict: if x is not None: assert isinstance(x, MessageBase) if self.client is None: - self.server_launcher.launch() - self.client = RpcAgentClient(host=self.host, port=self.port) + self._launch_server() res_msg = self.client.call_func( func_name="_call", value=x.serialize() if x is not None else "", @@ -340,6 +344,7 @@ def setup_rcp_agent_server( servicer_class: Type[RpcAgentServicer], start_event: EventClass = None, stop_event: EventClass = None, + pipe: int = None, max_workers: int = 4, local_mode: bool = True, init_settings: dict = None, @@ -358,6 +363,8 @@ def setup_rcp_agent_server( stop_event (`EventClass`, defaults to `None`): The stop Event instance used to determine whether the child process has been stopped. + pipe (`int`, defaults to `None`): + A pipe instance used to pass the actual port of the server. max_workers (`int`, defaults to `4`): max worker number of grpc server. local_mode (`bool`, defaults to `None`): @@ -368,18 +375,35 @@ def setup_rcp_agent_server( if init_settings is not None: init_process(**init_settings) - server = grpc.server(futures.ThreadPoolExecutor(max_workers=max_workers)) - servicer = servicer_class(**kwargs) - add_RpcAgentServicer_to_server(servicer, server) - if local_mode: - server.add_insecure_port(f"localhost:{port}") - else: - server.add_insecure_port(f"0.0.0.0:{port}") - server.start() + while True: + try: + port = check_port(port) + logger.info( + f"Starting rpc server [{servicer_class.__name__}] at port" + f" [{port}]...", + ) + server = grpc.server( + futures.ThreadPoolExecutor(max_workers=max_workers), + ) + kwargs['port'] = port + servicer = servicer_class(**kwargs) + add_RpcAgentServicer_to_server(servicer, server) + if local_mode: + server.add_insecure_port(f"localhost:{port}") + else: + server.add_insecure_port(f"0.0.0.0:{port}") + server.start() + break + except OSError: + logger.warning( + f"Failed to start rpc server at port [{port}], " + f"try another port", + ) logger.info( f"rpc server [{servicer_class.__name__}] at port [{port}] started " "successfully", ) + pipe.send(port) start_event.set() stop_event.wait() logger.info( @@ -392,6 +416,42 @@ def setup_rcp_agent_server( ) +def find_available_port() -> int: + """Get an unoccupied socket port number.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def check_port(port: Optional[int] = None) -> int: + """Check if the port is available. + + Args: + port (`int`): + the port number being checked. + + Returns: + `int`: the port number that passed the check. If the port is found + to be occupied, an available port number will be automatically + returned. + """ + if port is None: + new_port = find_available_port() + logger.warning( + "gRpc server port is not provided, automatically select " + f"[{new_port}] as the port number.", + ) + return new_port + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + if s.connect_ex(("localhost", port)) == 0: + new_port = find_available_port() + logger.warning( + f"Port [{port}] is occupied, use [{new_port}] instead", + ) + return new_port + return port + + class RpcAgentServerLauncher: """Launcher of rpc agent server.""" @@ -453,55 +513,19 @@ def __init__( self.memory_config = memory_config self.agent_class = agent_class self.host = host - self.port = self.check_port(port) + self.port = check_port(port) self.max_pool_size = max_pool_size self.max_timeout_seconds = max_timeout_seconds self.local_model = local_mode self.server = None self.stop_event = None + self.parent_con = None self.kwargs = kwargs - def find_available_port(self) -> int: - """Get an unoccupied socket port number.""" - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) - return s.getsockname()[1] - - def check_port(self, port: int) -> int: - """Check if the port is available. - - Args: - port (`int`): - the port number being checked. - - Returns: - `int`: the port number that passed the check. If the port is found - to be occupied, an available port number will be automatically - returned. - """ - if port is None: - new_port = self.find_available_port() - logger.warning( - "gRpc server port is not provided, automatically select " - f"[{new_port}] as the port number.", - ) - return new_port - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - if s.connect_ex(("localhost", port)) == 0: - new_port = self.find_available_port() - logger.warning( - f"Port [{port}] is occupied, use [{new_port}] instead", - ) - return new_port - return port - def launch(self) -> None: """launch a local rpc agent server.""" self.stop_event = Event() - logger.info( - f"Starting rpc server [{self.agent_class.__name__}] at port" - f" [{self.port}]...", - ) + self.parent_con, child_con = Pipe() start_event = Event() server_process = Process( target=setup_rcp_agent_server, @@ -510,6 +534,7 @@ def launch(self) -> None: "servicer_class": self.agent_class, "start_event": start_event, "stop_event": self.stop_event, + "pipe": child_con, "local_mode": self.local_model, "init_settings": _INIT_SETTINGS, "kwargs": { @@ -531,6 +556,7 @@ def launch(self) -> None: }, ) server_process.start() + self.port = self.parent_con.recv() start_event.wait() self.server = server_process diff --git a/tests/rpc_agent_test.py b/tests/rpc_agent_test.py index ce23609df..6c2e2978d 100644 --- a/tests/rpc_agent_test.py +++ b/tests/rpc_agent_test.py @@ -284,23 +284,14 @@ def test_mix_rpc_agent_and_local_agent(self) -> None: def test_msghub_compatibility(self) -> None: """test compatibility with msghub""" - port1 = 12001 - port2 = 12002 - port3 = 12003 agent_a = DemoRpcAgentWithMemory( name="a", - lazy_launch=False, - port=port1, ) agent_b = DemoRpcAgentWithMemory( name="b", - lazy_launch=False, - port=port2, ) agent_c = DemoRpcAgentWithMemory( name="c", - lazy_launch=False, - port=port3, ) participants = [agent_a, agent_b, agent_c] annonuncement_msgs = [