Skip to content

Commit

Permalink
More reliable shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac committed Oct 29, 2024
1 parent 83c2213 commit 57db287
Showing 1 changed file with 35 additions and 15 deletions.
50 changes: 35 additions & 15 deletions edb/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,13 @@ def __init__(
user_agent: str = "EdgeDB",
stat_callback: Optional[StatCallback] = None,
):
self._client = Http(limit)
self._fd = self._client._fd
self._task = None
self._client = None
self._limit = limit
self._skip_reads = 0
self._loop: Optional[asyncio.AbstractEventLoop] = asyncio.get_running_loop()
self._loop: Optional[asyncio.AbstractEventLoop] = (
asyncio.get_running_loop()
)
self._task = None
self._streaming: dict[int, asyncio.Queue[Any]] = {}
self._next_id = 0
Expand All @@ -91,15 +93,33 @@ def close(self) -> None:
self._task.cancel()
self._task = None
self._loop = None
self._client = None

def _ensure_task(self):
if self._loop is None:
raise Exception("HttpClient was closed")
if self._task is None:
self._task = self._loop.create_task(self._boot(self._loop))
self._client = Http(self._limit)
self._task = self._loop.create_task(
self._boot(self._loop, self._client._fd)
)

def _ensure_client(self):
if self._client is None:
raise Exception("HttpClient was closed")
return self._client

def _safe_close(self, id):
if self._client is not None:
self._client._close(id)

def _safe_ack(self, id):
if self._client is not None:
self._client._ack_sse(id)

def _update_limit(self, limit: int):
self._client._update_limit(limit)
if self._client is not None and limit != self._limit:
self._client._update_limit(limit)

def _process_headers(self, headers: HeaderType) -> list[tuple[str, str]]:
if headers is None:
Expand Down Expand Up @@ -173,7 +193,7 @@ async def request(
self._requests[id] = asyncio.Future()
start_time = time.time()
try:
self._client._request(id, path, method, data, headers_list)
self._ensure_client()._request(id, path, method, data, headers_list)
resp = await self._requests[id]
if self._stat_callback:
status_code, body, headers = resp
Expand Down Expand Up @@ -225,9 +245,6 @@ async def stream_sse(
data: bytes | str | dict[str, str] | None = None,
json: Any | None = None,
) -> Response | ResponseSSE:
"""Create a streaming request. Note that there is no backpressure on the
SSE events, so the called must continuously iterate on the response to
avoid excessive memory use."""
self._ensure_task()
path = self._process_path(path)
headers_list = self._process_headers(headers)
Expand All @@ -239,7 +256,9 @@ async def stream_sse(
self._requests[id] = asyncio.Future()
start_time = time.time()
try:
self._client._request_sse(id, path, method, data, headers_list)
self._ensure_client()._request_sse(
id, path, method, data, headers_list
)
resp = await self._requests[id]
if self._stat_callback:
if id in self._streaming:
Expand All @@ -265,22 +284,23 @@ async def stream_sse(
)
if id in self._streaming:
# Valid to call multiple times
cancel = lambda: self._client._close(id)
cancel = lambda: self._safe_close(id)
# Acknowledge SSE message (for backpressure)
ack = lambda: self._client._ack_sse(id)
ack = lambda: self._safe_ack(id)
return ResponseSSE.from_tuple(
resp, self._streaming[id], cancel, ack
)
return Response.from_tuple(resp)
finally:
del self._requests[id]

async def _boot(self, loop: asyncio.AbstractEventLoop) -> None:
async def _boot(self, loop: asyncio.AbstractEventLoop, fd: int) -> None:
logger.info(f"HTTP client initialized, user_agent={self._user_agent}")
reader = asyncio.StreamReader(loop=loop)
reader_protocol = asyncio.StreamReaderProtocol(reader)
fd = os.fdopen(self._client._fd, 'rb')
transport, _ = await loop.connect_read_pipe(lambda: reader_protocol, fd)
transport, _ = await loop.connect_read_pipe(
lambda: reader_protocol, os.fdopen(fd, 'rb')
)
try:
while len(await reader.read(1)) == 1:
if not self._client or not self._task:
Expand Down

0 comments on commit 57db287

Please sign in to comment.