diff --git a/mlos_bench/mlos_bench/event_loop_context.py b/mlos_bench/mlos_bench/event_loop_context.py index 32b778dd582..6fcc51e0c35 100644 --- a/mlos_bench/mlos_bench/event_loop_context.py +++ b/mlos_bench/mlos_bench/event_loop_context.py @@ -19,6 +19,12 @@ else: from typing_extensions import TypeAlias +CoroReturnType = TypeVar('CoroReturnType') # pylint: disable=invalid-name +if sys.version_info >= (3, 9): + FutureReturnType: TypeAlias = Future[CoroReturnType] +else: + FutureReturnType: TypeAlias = Future + class EventLoopContext: """ @@ -78,12 +84,6 @@ def exit(self) -> None: self._event_loop = None self._event_loop_thread = None - CoroReturnType = TypeVar('CoroReturnType') - if sys.version_info >= (3, 9): - FutureReturnType: TypeAlias = Future[CoroReturnType] - else: - FutureReturnType: TypeAlias = Future - def run_coroutine(self, coro: Coroutine[Any, Any, CoroReturnType]) -> FutureReturnType: """ Runs the given coroutine in the background event loop thread and diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py index f67e71b7225..4ba1a52246a 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py @@ -50,6 +50,8 @@ def __init__(self, methods : Union[Dict[str, Callable], List[Callable], None] New methods to register with the service. """ + # Same methods are also provided by the AzureVMService class + # pylint: disable=duplicate-code super().__init__( config, global_config, parent, self.merge_methods(methods, [ diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py index 21db889bb59..2ffecd4885a 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py @@ -8,27 +8,18 @@ from abc import ABCMeta from asyncio import Event as CoroEvent, Lock as CoroLock -from concurrent.futures import Future from types import TracebackType -from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional, Tuple, Type, Union from threading import current_thread import logging import os -import sys import asyncssh - from asyncssh.connection import SSHClientConnection from mlos_bench.services.base_service import Service -from mlos_bench.event_loop_context import EventLoopContext - -if sys.version_info >= (3, 10): - from typing import TypeAlias -else: - from typing_extensions import TypeAlias - +from mlos_bench.event_loop_context import EventLoopContext, CoroReturnType, FutureReturnType _LOG = logging.getLogger(__name__) @@ -284,12 +275,6 @@ def clear_client_cache(cls) -> None: """ cls._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.cleanup() - CoroReturnType = TypeVar('CoroReturnType') - if sys.version_info >= (3, 9): - FutureReturnType: TypeAlias = Future[CoroReturnType] - else: - FutureReturnType: TypeAlias = Future - def _run_coroutine(self, coro: Coroutine[Any, Any, CoroReturnType]) -> FutureReturnType: """ Runs the given coroutine in the background event loop thread. diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py index 84a61e3ef18..97ef6eba3d1 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py @@ -35,15 +35,13 @@ @pytest.fixture(scope="session") def ssh_test_server_hostname() -> str: """Returns the local hostname to use to connect to the test ssh server.""" - if sys.platform == 'win32': - # Docker (Desktop) for Windows (WSL2) uses a special networking magic - # to refer to the host machine when exposing ports. - return 'localhost' - # On Linux, if we're running in a docker container, we can use the - # --add-host (extra_hosts in docker-compose.yml) to refer to the host IP. - if resolve_host_name(HOST_DOCKER_NAME): + if sys.platform != 'win32' and resolve_host_name(HOST_DOCKER_NAME): + # On Linux, if we're running in a docker container, we can use the + # --add-host (extra_hosts in docker-compose.yml) to refer to the host IP. return HOST_DOCKER_NAME - # Otherwise, assume we're executing directly inside conda on the host. + # Docker (Desktop) for Windows (WSL2) uses a special networking magic + # to refer to the host machine as `localhost` when exposing ports. + # In all other cases, assume we're executing directly inside conda on the host. return 'localhost' diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py index 14fc015be3f..d3b5f3c24fe 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py @@ -96,7 +96,7 @@ def test_ssh_service_context_handler() -> None: ssh_fileshare_service._run_coroutine(asyncio.sleep(0.1)) # The background thread should remain running since we have another context still open. - assert isinstance(SshService._EVENT_LOOP_CONTEXT._event_loop_thread, Thread) # type: ignore[unreachable] + assert isinstance(SshService._EVENT_LOOP_CONTEXT._event_loop_thread, Thread) assert SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE is not None