diff --git a/starlette/testclient.py b/starlette/testclient.py index 08d03fa5c..b9d5485c0 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -400,7 +400,6 @@ def receive_json(self, mode: str = "text") -> typing.Any: class TestClient(requests.Session): __test__ = False # For pytest to not discover this up. task: "Future[None]" - portal: typing.Optional[anyio.abc.BlockingPortal] = None def __init__( self, @@ -410,6 +409,7 @@ def __init__( root_path: str = "", backend: str = "asyncio", backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None, + portal: typing.Optional[anyio.abc.BlockingPortal] = None, ) -> None: super().__init__() self.async_backend = _AsyncBackend( @@ -434,6 +434,7 @@ def __init__( self.headers.update({"user-agent": "testclient"}) self.app = asgi_app self.base_url = base_url + self.portal = portal @contextlib.contextmanager def _portal_factory( @@ -506,13 +507,16 @@ def websocket_connect( def __enter__(self) -> "TestClient": with contextlib.ExitStack() as stack: - self.portal = portal = stack.enter_context( - anyio.start_blocking_portal(**self.async_backend) - ) + if self.portal: + portal = self.portal + else: + portal = self.portal = stack.enter_context( + anyio.start_blocking_portal(**self.async_backend) + ) - @stack.callback - def reset_portal() -> None: - self.portal = None + @stack.callback + def reset_portal() -> None: + self.portal = None self.stream_send = StapledObjectStream( *anyio.create_memory_object_stream(math.inf) diff --git a/tests/test_testclient.py b/tests/test_testclient.py index 8c0666789..055749191 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -71,6 +71,31 @@ def homepage(request): assert response.json() == {"mock": "example"} +def test_use_multiple_testclients_with_asyncio(test_client_factory, anyio_backend_name): + if anyio_backend_name != "asyncio": + return + service = Starlette() + + @service.on_event("startup") + async def save_event_loop(): + service.state.event_loop = asyncio.get_running_loop() + + @service.route("/") + async def handle_request(request): + await service.state.event_loop.create_task(asyncio.sleep(0)) + return JSONResponse({}) + + with test_client_factory(service) as parent_client: + client_2 = test_client_factory(service, portal=parent_client.portal) + assert client_2.get("/").json() == {} + assert parent_client.get("/").json() == {} + + client_3 = test_client_factory(service, portal=parent_client.portal) + with client_3: + assert client_3.get("/").json() == {} + assert parent_client.get("/").json() == {} + + def test_use_testclient_as_contextmanager(test_client_factory, anyio_backend_name): """ This test asserts a number of properties that are important for an