Skip to content

Commit

Permalink
Enable type checking in edb.server.protocol. (#7489)
Browse files Browse the repository at this point in the history
Additionally:
- Move/rename `auth_ext.data.Config` to `auth_ext.config.OAuthProviderConfig`
    - Change to inherit `auth_ext.config.ProviderConfig`
    - Change to `dataclass`
- Add `ai_ext.ProviderConfig` and `ai_ext.ApiStyle`
- Add `auth.scram.Session`
  • Loading branch information
dnwpark authored Jun 27, 2024
1 parent 2328483 commit e25404e
Show file tree
Hide file tree
Showing 33 changed files with 567 additions and 268 deletions.
1 change: 1 addition & 0 deletions edb/server/dbview/dbview.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class Database:
dbver: int
db_config: Config
extensions: set[str]
user_config_spec: config.Spec

@property
def server(self) -> server.Server:
Expand Down
2 changes: 1 addition & 1 deletion edb/server/protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from __future__ import annotations

from . import protocol # type: ignore
from . import protocol

HttpProtocol = protocol.HttpProtocol

Expand Down
57 changes: 41 additions & 16 deletions edb/server/protocol/ai_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import (
cast,
Any,
AsyncIterator,
ClassVar,
Expand All @@ -42,6 +43,7 @@
from edb import errors
from edb.common import asyncutil
from edb.common import debug
from edb.common import enum as s_enum
from edb.common import markup
from edb.common import uuidgen

Expand Down Expand Up @@ -101,6 +103,21 @@ class BadRequestError(AIExtError):
http_status = http.HTTPStatus.BAD_REQUEST


class ApiStyle(s_enum.StrEnum):
OpenAI = 'OpenAI'
Anthropic = 'Anthropic'


@dataclass
class ProviderConfig:
name: str
display_name: str
api_url: str
client_id: str
secret: str
api_style: ApiStyle


def start_extension(
tenant: srv_tenant.Tenant,
dbname: str,
Expand Down Expand Up @@ -565,7 +582,7 @@ async def _update_embeddings_in_db(


async def _generate_embeddings(
provider,
provider: ProviderConfig,
model_name: str,
inputs: list[str],
shortening: Optional[int],
Expand All @@ -578,7 +595,7 @@ async def _generate_embeddings(
f"of {provider.name!r} for {len(inputs)} object{suf}"
)

if provider.api_style == "OpenAI":
if provider.api_style == ApiStyle.OpenAI:
return await _generate_openai_embeddings(
provider, model_name, inputs, shortening,
)
Expand All @@ -590,7 +607,7 @@ async def _generate_embeddings(


async def _generate_openai_embeddings(
provider,
provider: ProviderConfig,
model_name: str,
inputs: list[str],
shortening: Optional[int],
Expand Down Expand Up @@ -676,9 +693,9 @@ async def _start_chat(
protocol: protocol.HttpProtocol,
request: protocol.HttpRequest,
response: protocol.HttpResponse,
provider,
provider: ProviderConfig,
model_name: str,
messages: list[dict],
messages: list[dict[str, Any]],
stream: bool,
) -> None:
if provider.api_style == "OpenAI":
Expand Down Expand Up @@ -725,7 +742,7 @@ async def _start_openai_like_chat(
response: protocol.HttpResponse,
client: httpx.AsyncClient,
model_name: str,
messages: list[dict],
messages: list[dict[str, Any]],
stream: bool,
) -> None:
if stream:
Expand Down Expand Up @@ -847,9 +864,9 @@ async def _start_openai_chat(
protocol: protocol.HttpProtocol,
request: protocol.HttpRequest,
response: protocol.HttpResponse,
provider,
provider: ProviderConfig,
model_name: str,
messages: list[dict],
messages: list[dict[str, Any]],
stream: bool,
) -> None:
headers = {
Expand Down Expand Up @@ -879,9 +896,9 @@ async def _start_anthropic_chat(
protocol: protocol.HttpProtocol,
request: protocol.HttpRequest,
response: protocol.HttpResponse,
provider,
provider: ProviderConfig,
model_name: str,
messages: list[dict],
messages: list[dict[str, Any]],
stream: bool,
) -> None:
headers = {
Expand Down Expand Up @@ -1010,7 +1027,7 @@ async def handle_request(
db: dbview.Database,
args: list[str],
tenant: srv_tenant.Tenant,
):
) -> None:
if len(args) != 1 or args[0] not in {"rag", "embeddings"}:
response.body = b'Unknown path'
response.status = http.HTTPStatus.NOT_FOUND
Expand Down Expand Up @@ -1239,7 +1256,7 @@ async def _handle_rag_request(
"messages": [],
}

messages: dict[str, list[dict]] = {}
messages: dict[str, list[dict[str, Any]]] = {}
for message in prompt["messages"]:
if message["participant_role"] == "User":
content = message["content"].format(
Expand Down Expand Up @@ -1356,7 +1373,7 @@ async def _edgeql_query_json(
except Exception as iex:
raise iex from None
else:
return content
return cast(list[Any], content)


async def _db_error(
Expand Down Expand Up @@ -1394,12 +1411,20 @@ async def _db_error(
def _get_provider_config(
db: dbview.Database,
provider_name: str,
) -> Any:
) -> ProviderConfig:
cfg = db.lookup_config("ext::ai::Config::providers")

for provider in cfg:
if provider.name == provider_name:
return provider
provider = cast(ProviderConfig, provider)
return ProviderConfig(
name=provider.name,
display_name=provider.display_name,
api_url=provider.api_url,
client_id=provider.client_id,
secret=provider.secret,
api_style=provider.api_style,
)
else:
raise ConfigurationError(
f"provider {provider_name!r} has not been configured"
Expand Down Expand Up @@ -1447,7 +1472,7 @@ async def _get_model_provider(
elif len(models) > 1:
raise InternalError("multiple models defined as requested model")

return models[0]["provider"]
return cast(str, models[0]["provider"])


async def _generate_embeddings_for_type(
Expand Down
20 changes: 18 additions & 2 deletions edb/server/protocol/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#


from __future__ import annotations
from typing import Type, TYPE_CHECKING
import http
import json

Expand All @@ -26,8 +28,17 @@

from . import scram

if TYPE_CHECKING:
from edb.server import tenant as edbtenant
from edb.server.protocol import protocol

async def handle_request(request, response, path_parts, tenant):

async def handle_request(
request: protocol.HttpRequest,
response: protocol.HttpResponse,
path_parts: list[str],
tenant: edbtenant.Tenant,
) -> None:
try:
if path_parts == ["token"]:
if not request.authorization:
Expand Down Expand Up @@ -68,7 +79,12 @@ async def handle_request(request, response, path_parts, tenant):
)


def _response_error(response, status, message, ex_type):
def _response_error(
response: protocol.HttpResponse,
status: http.HTTPStatus,
message: str,
ex_type: Type[errors.EdgeDBError],
) -> None:
err_dct = {
"message": message,
"type": str(ex_type.__name__),
Expand Down
65 changes: 49 additions & 16 deletions edb/server/protocol/auth/scram.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# limitations under the License.
#

from __future__ import annotations
from typing import NamedTuple, Optional, TYPE_CHECKING
import base64
import collections
import hashlib
Expand All @@ -29,13 +31,36 @@
from edb.common import markup
from edb.common import secretkey

if TYPE_CHECKING:
from edb.server import tenant as edbtenant
from edb.server.protocol import protocol

SESSION_TIMEOUT = 30
SESSION_HIGH_WATER_MARK = SESSION_TIMEOUT * 10
sessions: collections.OrderedDict[str, tuple] = collections.OrderedDict()

SESSION_TIMEOUT: float = 30
SESSION_HIGH_WATER_MARK: float = SESSION_TIMEOUT * 10

def handle_request(scheme, auth_str, response, tenant):

class Session(NamedTuple):
time: float
client_nonce: str
server_nonce: str
client_first_bare: bytes
cb_flag: bool
server_first: bytes
verifier: scram.SCRAMVerifier
mock_auth: bool
username: str


sessions: collections.OrderedDict[str, Session] = collections.OrderedDict()


def handle_request(
scheme: str,
auth_str: str,
response: protocol.HttpResponse,
tenant: edbtenant.Tenant,
) -> None:
server = tenant.server
if scheme != "SCRAM-SHA-256":
response.body = (
Expand Down Expand Up @@ -64,19 +89,24 @@ def handle_request(scheme, auth_str, response, tenant):
response.close_connection = True
return

if not server.get_jws_key().has_private:
if not server.get_jws_key().has_private: # type: ignore[union-attr]
response.body = b"Server doesn't support HTTP SCRAM authentication"
response.status = http.HTTPStatus.FORBIDDEN
response.close_connection = True
return

if sid is None:
try:
bare_offset: int
cb_flag: bool
authzid: Optional[bytes]
username_bytes: bytes
client_nonce: str
(
bare_offset,
cb_flag,
authzid,
username,
username_bytes,
client_nonce,
) = scram.parse_client_first_message(data)
except ValueError as ex:
Expand All @@ -87,7 +117,7 @@ def handle_request(scheme, auth_str, response, tenant):
response.close_connection = True
return

username = username.decode("utf-8")
username = username_bytes.decode("utf-8")
client_first_bare = data[bare_offset:]

if isinstance(cb_flag, str):
Expand Down Expand Up @@ -120,16 +150,16 @@ def handle_request(scheme, auth_str, response, tenant):
response.custom_headers["WWW-Authenticate"] = "SCRAM-SHA-256"
return

server_nonce = scram.generate_nonce()
server_first = scram.build_server_first_message(
server_nonce: str = scram.generate_nonce()
server_first: bytes = scram.build_server_first_message(
server_nonce, client_nonce, verifier.salt, verifier.iterations
).encode("utf-8")

if len(sessions) > SESSION_HIGH_WATER_MARK:
while sessions:
key, value = sessions.popitem(last=False)
if value[0] + SESSION_TIMEOUT > time.monotonic():
sessions[key] = value
key, session = sessions.popitem(last=False)
if session.time + SESSION_TIMEOUT > time.monotonic():
sessions[key] = session
sessions.move_to_end(key, last=False)
break

Expand All @@ -139,7 +169,7 @@ def handle_request(scheme, auth_str, response, tenant):
.rstrip("=")
)
assert sid not in sessions
sessions[sid] = (
sessions[sid] = Session(
time.monotonic(),
client_nonce,
server_nonce,
Expand All @@ -151,11 +181,11 @@ def handle_request(scheme, auth_str, response, tenant):
username,
)

server_first = base64.b64encode(server_first).decode("ascii")
server_first_str = base64.b64encode(server_first).decode("ascii")
response.status = http.HTTPStatus.UNAUTHORIZED
response.custom_headers[
"WWW-Authenticate"
] = f"SCRAM-SHA-256 sid={sid}, data={server_first}"
] = f"SCRAM-SHA-256 sid={sid}, data={server_first_str}"

else:
session = sessions.pop(sid)
Expand Down Expand Up @@ -255,7 +285,10 @@ def handle_request(scheme, auth_str, response, tenant):
] = f"sid={sid}, data={server_final}"


def get_scram_verifier(user, tenant):
def get_scram_verifier(
user: str,
tenant: edbtenant.Tenant,
) -> tuple[scram.SCRAMVerifier, bool]:
roles = tenant.get_roles()

rolerec = roles.get(user)
Expand Down
3 changes: 2 additions & 1 deletion edb/server/protocol/auth_ext/apple.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# limitations under the License.
#

from typing import Any
import uuid
import urllib.parse

Expand All @@ -24,7 +25,7 @@


class AppleProvider(base.OpenIDProvider):
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(
"apple",
"https://appleid.apple.com",
Expand Down
Loading

0 comments on commit e25404e

Please sign in to comment.