diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index f146984b3..b19c93739 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -6,11 +6,13 @@ from starlette._utils import collapse_excgroups from starlette.requests import ClientDisconnect, Request -from starlette.responses import AsyncContentStream, Response +from starlette.responses import Response from starlette.types import ASGIApp, Message, Receive, Scope, Send RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] DispatchFunction = typing.Callable[[Request, RequestResponseEndpoint], typing.Awaitable[Response]] +BodyStreamGenerator = typing.AsyncGenerator[typing.Union[bytes, typing.MutableMapping[str, typing.Any]], None] +AsyncContentStream = typing.AsyncIterable[typing.Union[str, bytes, memoryview, typing.MutableMapping[str, typing.Any]]] T = typing.TypeVar("T") @@ -156,8 +158,11 @@ async def coro() -> None: assert message["type"] == "http.response.start" - async def body_stream() -> typing.AsyncGenerator[bytes, None]: + async def body_stream() -> BodyStreamGenerator: async for message in recv_stream: + if message["type"] == "http.response.pathsend": + yield message + break assert message["type"] == "http.response.body" body = message.get("body", b"") if body: @@ -212,10 +217,17 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: } ) + should_close_body = True async for chunk in self.body_iterator: + if isinstance(chunk, dict): + # We got an ASGI message which is not response body (eg: pathsend) + should_close_body = False + await send(chunk) + continue await send({"type": "http.response.body", "body": chunk, "more_body": True}) - await send({"type": "http.response.body", "body": b"", "more_body": False}) + if should_close_body: + await send({"type": "http.response.body", "body": b"", "more_body": False}) if self.background: await self.background() diff --git a/starlette/middleware/gzip.py b/starlette/middleware/gzip.py index b677063da..1bfea9d4c 100644 --- a/starlette/middleware/gzip.py +++ b/starlette/middleware/gzip.py @@ -102,6 +102,10 @@ async def send_with_gzip(self, message: Message) -> None: self.gzip_buffer.truncate() await self.send(message) + elif message_type == "http.response.pathsend": + # Don't apply GZip to pathsend responses + await self.send(self.initial_message) + await self.send(message) async def unattached_send(message: Message) -> typing.NoReturn: diff --git a/starlette/responses.py b/starlette/responses.py index 31874f655..62cad0802 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -336,6 +336,8 @@ def set_stat_headers(self, stat_result: os.stat_result) -> None: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: send_header_only: bool = scope["method"].upper() == "HEAD" + send_pathsend: bool = "http.response.pathsend" in scope.get("extensions", {}) + if self.stat_result is None: try: stat_result = await anyio.to_thread.run_sync(os.stat, self.path) @@ -354,7 +356,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: http_if_range = headers.get("if-range") if http_range is None or (http_if_range is not None and not self._should_use_range(http_if_range)): - await self._handle_simple(send, send_header_only) + await self._handle_simple(send, send_header_only, send_pathsend) else: try: ranges = self._parse_range_header(http_range, stat_result.st_size) @@ -373,10 +375,12 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if self.background is not None: await self.background() - async def _handle_simple(self, send: Send, send_header_only: bool) -> None: + async def _handle_simple(self, send: Send, send_header_only: bool, send_pathsend: bool) -> None: await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers}) if send_header_only: await send({"type": "http.response.body", "body": b"", "more_body": False}) + elif send_pathsend: + await send({"type": "http.response.pathsend", "path": str(self.path)}) else: async with await anyio.open_file(self.path, mode="rb") as file: more_body = True diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 7232cfd18..6618233b9 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -3,6 +3,7 @@ import contextvars from collections.abc import AsyncGenerator, AsyncIterator, Generator from contextlib import AsyncExitStack +from pathlib import Path from typing import Any import anyio @@ -14,7 +15,7 @@ from starlette.middleware import Middleware, _MiddlewareFactory from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import ClientDisconnect, Request -from starlette.responses import PlainTextResponse, Response, StreamingResponse +from starlette.responses import FileResponse, PlainTextResponse, Response, StreamingResponse from starlette.routing import Route, WebSocketRoute from starlette.testclient import TestClient from starlette.types import ASGIApp, Message, Receive, Scope, Send @@ -1154,3 +1155,47 @@ async def send(message: Message) -> None: {"type": "http.response.body", "body": b"good!", "more_body": True}, {"type": "http.response.body", "body": b"", "more_body": False}, ] + + +@pytest.mark.anyio +async def test_asgi_pathsend_events(tmpdir: Path) -> None: + path = tmpdir / "example.txt" + with path.open("w") as file: + file.write("") + + response_complete = anyio.Event() + events: list[Message] = [] + + async def endpoint_with_pathsend(_: Request) -> FileResponse: + return FileResponse(path) + + async def passthrough(request: Request, call_next: RequestResponseEndpoint) -> Response: + return await call_next(request) + + app = Starlette( + middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)], + routes=[Route("/", endpoint_with_pathsend)], + ) + + scope = { + "type": "http", + "version": "3", + "method": "GET", + "path": "/", + "headers": [], + "extensions": {"http.response.pathsend": {}}, + } + + async def receive() -> Message: + raise NotImplementedError("Should not be called!") # pragma: no cover + + async def send(message: Message) -> None: + events.append(message) + if message["type"] == "http.response.pathsend": + response_complete.set() + + await app(scope, receive, send) + + assert len(events) == 2 + assert events[0]["type"] == "http.response.start" + assert events[1]["type"] == "http.response.pathsend" diff --git a/tests/middleware/test_gzip.py b/tests/middleware/test_gzip.py index b20a7cb84..4dd79af28 100644 --- a/tests/middleware/test_gzip.py +++ b/tests/middleware/test_gzip.py @@ -1,9 +1,21 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.gzip import GZipMiddleware from starlette.requests import Request -from starlette.responses import ContentStream, PlainTextResponse, StreamingResponse +from starlette.responses import ( + ContentStream, + FileResponse, + PlainTextResponse, + StreamingResponse, +) from starlette.routing import Route +from starlette.types import Message from tests.types import TestClientFactory @@ -104,3 +116,42 @@ async def generator(bytes: bytes, count: int) -> ContentStream: assert response.text == "x" * 4000 assert response.headers["Content-Encoding"] == "text" assert "Content-Length" not in response.headers + + +@pytest.mark.anyio +async def test_gzip_ignored_for_pathsend_responses(tmpdir: Path) -> None: + path = tmpdir / "example.txt" + with path.open("w") as file: + file.write("") + + events: list[Message] = [] + + async def endpoint_with_pathsend(request: Request) -> FileResponse: + _ = await request.body() + return FileResponse(path) + + app = Starlette( + routes=[Route("/", endpoint=endpoint_with_pathsend)], + middleware=[Middleware(GZipMiddleware)], + ) + + scope = { + "type": "http", + "version": "3", + "method": "GET", + "path": "/", + "headers": [(b"accept-encoding", b"gzip, text")], + "extensions": {"http.response.pathsend": {}}, + } + + async def receive() -> Message: + return {"type": "http.request", "body": b"", "more_body": False} + + async def send(message: Message) -> None: + events.append(message) + + await app(scope, receive, send) + + assert len(events) == 2 + assert events[0]["type"] == "http.response.start" + assert events[1]["type"] == "http.response.pathsend" diff --git a/tests/test_responses.py b/tests/test_responses.py index d5ed83499..f019db807 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -353,6 +353,38 @@ def test_file_response_with_range_header(tmp_path: Path, test_client_factory: Te assert response.headers["content-range"] == f"bytes 0-4/{len(content)}" +@pytest.mark.anyio +async def test_file_response_with_pathsend(tmpdir: Path) -> None: + path = tmpdir / "xyz" + content = b"" * 1000 + with open(path, "wb") as file: + file.write(content) + + app = FileResponse(path=path, filename="example.png") + + async def receive() -> Message: # type: ignore[empty-body] + ... # pragma: no cover + + async def send(message: Message) -> None: + if message["type"] == "http.response.start": + assert message["status"] == status.HTTP_200_OK + headers = Headers(raw=message["headers"]) + assert headers["content-type"] == "image/png" + assert "content-length" in headers + assert "content-disposition" in headers + assert "last-modified" in headers + assert "etag" in headers + elif message["type"] == "http.response.pathsend": + assert message["path"] == str(path) + + # Since the TestClient doesn't support `pathsend`, we need to test this directly. + await app( + {"type": "http", "method": "get", "headers": [], "extensions": {"http.response.pathsend": {}}}, + receive, + send, + ) + + def test_set_cookie(test_client_factory: TestClientFactory, monkeypatch: pytest.MonkeyPatch) -> None: # Mock time used as a reference for `Expires` by stdlib `SimpleCookie`. mocked_now = dt.datetime(2037, 1, 22, 12, 0, 0, tzinfo=dt.timezone.utc)