Skip to content

Commit

Permalink
Migrate OAuth HTTP code to new worker
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac committed Oct 17, 2024
1 parent 8319bbe commit c28cfa7
Show file tree
Hide file tree
Showing 9 changed files with 302 additions and 129 deletions.
191 changes: 191 additions & 0 deletions edb/server/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2024-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import annotations

from typing import (
Tuple,
Any,
Mapping,
Optional,
)

import asyncio
import dataclasses
import logging
import os
import json as json_lib
import urllib.parse

from edb.server._http import Http

logger = logging.getLogger("edb.server")


class HttpClient:
def __init__(self, limit: int):
self._client = Http(limit)
self._fd = self._client._fd
self._task = None
self._skip_reads = 0
self._loop = asyncio.get_running_loop()
self._task = self._loop.create_task(self._boot(self._loop))
self._next_id = 0
self._requests: dict[int, asyncio.Future] = {}

def __del__(self) -> None:
if self._task:
self._task.cancel()
self._task = None

def _update_limit(self, limit: int):
self._client._update_limit(limit)

async def request(
self,
*,
method: str,
url: str,
content: bytes | None,
headers: list[tuple[str, str]] | None,
) -> tuple[int, bytes, dict[str, str]]:
if content is None:
content = bytes()
if headers is None:
headers = []
id = self._next_id
self._next_id += 1
self._requests[id] = asyncio.Future()
try:
self._client._request(id, url, method, content, headers)
resp = await self._requests[id]
return resp
finally:
del self._requests[id]

async def get(
self, path: str, *, headers: dict[str, str] | None = None
) -> Response:
headers_list = [(k, v) for k, v in headers.items()] if headers else None
result = await self.request(
method="GET", url=path, content=None, headers=headers_list
)
return Response.from_tuple(result)

async def post(
self,
path: str,
*,
headers: dict[str, str] | None = None,
data: bytes | str | dict[str, str] | None = None,
json: Any | None = None,
) -> Response:
if json is not None:
data = json_lib.dumps(json).encode('utf-8')
headers = headers or {}
headers['Content-Type'] = 'application/json'
elif isinstance(data, str):
data = data.encode('utf-8')
elif isinstance(data, dict):
data = urllib.parse.urlencode(data).encode('utf-8')
headers = headers or {}
headers['Content-Type'] = 'application/x-www-form-urlencoded'

headers_list = [(k, v) for k, v in headers.items()] if headers else None
result = await self.request(
method="POST", url=path, content=data, headers=headers_list
)
return Response.from_tuple(result)

async def _boot(self, loop: asyncio.AbstractEventLoop) -> None:
logger.info("Python-side HTTP client booted")
reader = asyncio.StreamReader(loop=loop)
reader_protocol = asyncio.StreamReaderProtocol(reader)
fd = os.fdopen(self._client._fd, 'rb')
transport, _ = await loop.connect_read_pipe(lambda: reader_protocol, fd)
try:
while len(await reader.read(1)) == 1:
if not self._client or not self._task:
break
if self._skip_reads > 0:
self._skip_reads -= 1
continue
msg = self._client._read()
if not msg:
break
self._process_message(msg)
finally:
transport.close()

def _process_message(self, msg):
msg_type, id, data = msg

if id in self._requests:
if msg_type == 1:
self._requests[id].set_result(data)
elif msg_type == 0:
self._requests[id].set_exception(Exception(data))


class CaseInsensitiveDict(dict):
def __init__(self, data: Optional[list[Tuple[str, str]]] = None):
super().__init__()
if data:
for k, v in data:
self[k.lower()] = v

def __setitem__(self, key: str, value: str):
super().__setitem__(key.lower(), value)

def __getitem__(self, key: str):
return super().__getitem__(key.lower())

def get(self, key: str, default=None):
return super().get(key.lower(), default)

def update(self, *args, **kwargs: str) -> None:
if args:
data = args[0]
if isinstance(data, Mapping):
for key, value in data.items():
self[key] = value
else:
for key, value in data:
self[key] = value
for key, value in kwargs.items():
self[key] = value


@dataclasses.dataclass
class Response:
status_code: int
body: bytes
headers: CaseInsensitiveDict

@classmethod
def from_tuple(cls, data: Tuple[int, bytes, dict[str, str]]):
status_code, body, headers_list = data
headers = CaseInsensitiveDict([(k, v) for k, v in headers_list.items()])
return cls(status_code, body, headers)

def json(self):
return json_lib.loads(self.body.decode('utf-8'))

@property
def text(self) -> str:
return self.body.decode('utf-8')
74 changes: 1 addition & 73 deletions edb/server/net_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@
import asyncio
import logging
import base64
import os

from edb.ir import statypes
from edb.server import defines
from edb.server.protocol import execute
from edb.server._http import Http
from edb.server.http import HttpClient
from edb.common import retryloop
from . import dbview

Expand Down Expand Up @@ -93,77 +92,6 @@ async def _http_task(tenant: edbtenant.Tenant, http_client) -> None:
)


class HttpClient:
def __init__(self, limit: int):
self._client = Http(limit)
self._fd = self._client._fd
self._task = None
self._skip_reads = 0
self._loop = asyncio.get_running_loop()
self._task = self._loop.create_task(self._boot(self._loop))
self._next_id = 0
self._requests: dict[int, asyncio.Future] = {}

def __del__(self) -> None:
if self._task:
self._task.cancel()
self._task = None

def _update_limit(self, limit: int):
self._client._update_limit(limit)

async def request(
self,
*,
method: str,
url: str,
content: bytes | None,
headers: list[tuple[str, str]] | None,
):
if content is None:
content = bytes()
if headers is None:
headers = []
id = self._next_id
self._next_id += 1
self._requests[id] = asyncio.Future()
try:
self._client._request(id, url, method, content, headers)
resp = await self._requests[id]
return resp
finally:
del self._requests[id]

async def _boot(self, loop: asyncio.AbstractEventLoop) -> None:
logger.info("Python-side HTTP client booted")
reader = asyncio.StreamReader(loop=loop)
reader_protocol = asyncio.StreamReaderProtocol(reader)
fd = os.fdopen(self._client._fd, 'rb')
transport, _ = await loop.connect_read_pipe(lambda: reader_protocol, fd)
try:
while len(await reader.read(1)) == 1:
if not self._client or not self._task:
break
if self._skip_reads > 0:
self._skip_reads -= 1
continue
msg = self._client._read()
if not msg:
break
self._process_message(msg)
finally:
transport.close()

def _process_message(self, msg):
msg_type, id, data = msg

if id in self._requests:
if msg_type == 1:
self._requests[id].set_result(data)
elif msg_type == 0:
self._requests[id].set_exception(Exception(data))


def create_http(tenant: edbtenant.Tenant):
net_http_max_connections = tenant._server.config_lookup(
'net_http_max_connections', tenant.get_sys_config()
Expand Down
2 changes: 1 addition & 1 deletion edb/server/protocol/auth_ext/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
client_secret: str,
*,
additional_scope: str | None,
http_factory: Callable[..., http_client.HttpClient],
http_factory: Callable[..., http_client.AuthHttpClient],
):
self.name = name
self.issuer_url = issuer_url
Expand Down
39 changes: 36 additions & 3 deletions edb/server/protocol/auth_ext/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,15 @@
import mimetypes
import uuid

from typing import Any, Optional, Tuple, FrozenSet, cast, TYPE_CHECKING
from typing import (
Any,
Optional,
Tuple,
FrozenSet,
cast,
TYPE_CHECKING,
Callable,
)

import aiosmtplib
from jwcrypto import jwk, jwt
Expand Down Expand Up @@ -80,6 +88,27 @@ def __init__(
self.tenant = tenant
self.test_mode = tenant.server.in_test_mode()

def _get_url_munger(
self, request: protocol.HttpRequest
) -> Callable[[str], str] | None:
"""
Returns a callable that can be used to modify the base URL
when making requests to the OAuth provider.
This is used to redirect requests to the test OAuth provider
when running in test mode.
"""
if not self.test_mode:
return None
test_url = (
request.params[b'oauth-test-server'].decode()
if (request.params and b'oauth-test-server' in request.params)
else None
)
if test_url:
return lambda path: f"{test_url}{urllib.parse.quote(path)}"
return None

async def handle_request(
self,
request: protocol.HttpRequest,
Expand Down Expand Up @@ -270,7 +299,10 @@ async def handle_authorize(
query, "challenge", fallback_keys=["code_challenge"]
)
oauth_client = oauth.Client(
db=self.db, provider_name=provider_name, base_url=self.test_url
db=self.db,
provider_name=provider_name,
url_munger=self._get_url_munger(request),
http_client=self.tenant.get_http_client(),
)
await pkce.create(self.db, challenge)
authorize_url = await oauth_client.get_authorize_url(
Expand Down Expand Up @@ -369,7 +401,8 @@ async def handle_callback(
oauth_client = oauth.Client(
db=self.db,
provider_name=provider_name,
base_url=self.test_url,
url_munger=self._get_url_munger(request),
http_client=self.tenant.get_http_client(),
)
(
identity,
Expand Down
Loading

0 comments on commit c28cfa7

Please sign in to comment.