Skip to content

Commit

Permalink
feat(python-sdk): user sessions for frameworks (#47)
Browse files Browse the repository at this point in the history
Adds framework specific functions to access a user `Session` object, which represents a session of an
authenticated Numerous user accessing the app. Through the `Session` cookies, user information and
a specific user collection can be accessed.

---------

Co-authored-by: Jens Feodor Nielsen <[email protected]>
  • Loading branch information
Lasse-numerous and jfeo authored Oct 29, 2024
1 parent 17fbca5 commit 4cc900a
Show file tree
Hide file tree
Showing 17 changed files with 460 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def add_marimo_cookie_middleware(
cookies = FileCookieStorage(cookies_dir, session_ident or _ident)
use_cookie_storage(cookies)

@app.middleware("http") # type: ignore[misc]
@app.middleware("http") # type: ignore[misc, unused-ignore]
async def middleware(
request: Request,
call_next: t.Callable[[Request], t.Awaitable[Response]],
Expand Down
4 changes: 2 additions & 2 deletions python/src/numerous/experimental/marimo/_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"""

from typing import Any, Dict, Type, TypeVar, Union
from typing import Any, Type, TypeVar, Union

import marimo as mo
from marimo._runtime.state import State as MoState
Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(
self,
default: Union[str, float, None] = None,
annotation: Union[type, None] = None,
**kwargs: Dict[str, Any],
**kwargs: dict[str, Any],
) -> None:
"""
Field with a state that can be used in a Marimo app.
Expand Down
6 changes: 3 additions & 3 deletions python/src/numerous/experimental/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
- Field: Class representing a field in the model.
"""

from typing import Any, Dict, Tuple, Union
from typing import Any, Tuple, Union

from pydantic import BaseModel as PydanticBaseModel
from pydantic import Field as PydanticField
Expand Down Expand Up @@ -94,7 +94,7 @@ class BaseModel(_ModelInterface):
"""

def __init__(self, **kwargs: Dict[str, Any]) -> None:
def __init__(self, **kwargs: dict[str, Any]) -> None:
"""
Initialize a model object with the given fields.
Expand Down Expand Up @@ -195,7 +195,7 @@ def __init__(
self,
default: Union[str, float, None] = None,
annotation: Union[type, None] = None,
**kwargs: Dict[str, Any],
**kwargs: dict[str, Any],
) -> None:
"""
Initialize a Field object.
Expand Down
1 change: 1 addition & 0 deletions python/src/numerous/frameworks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Module for integrating Numerous with various frameworks."""
19 changes: 19 additions & 0 deletions python/src/numerous/frameworks/dash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Module for integrating Numerous with Dash."""

from numerous import user_session
from numerous.frameworks.flask import FlaskCookieGetter


class DashCookieGetter(FlaskCookieGetter):
pass


def get_session() -> user_session.Session:
"""
Get the session for the current user.
Returns:
Session: The session for the current user.
"""
return user_session.Session(cg=DashCookieGetter())
30 changes: 30 additions & 0 deletions python/src/numerous/frameworks/fastapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Module for integrating Numerous with FastAPI."""

from fastapi import Request

from numerous import user_session
from numerous.local import is_local_mode, local_user


class FastAPICookieGetter:
def __init__(self, request: Request) -> None:
self.request = request

def cookies(self) -> dict[str, str]:
"""Get the cookies associated with the current request."""
if is_local_mode():
# Update the cookies on the fastapi server
user_session.set_user_info_cookie(self.request.cookies, local_user)

return {str(key): str(val) for key, val in self.request.cookies.items()}


def get_session(request: Request) -> user_session.Session:
"""
Get the session for the current user.
Returns:
Session: The session for the current user.
"""
return user_session.Session(cg=FastAPICookieGetter(request))
27 changes: 27 additions & 0 deletions python/src/numerous/frameworks/flask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Module for integrating Numerous with Flask."""

from flask import request

from numerous import user_session
from numerous.local import is_local_mode, local_user


class FlaskCookieGetter:
def cookies(self) -> dict[str, str]:
"""Get the cookies associated with the current request."""
cookies = {key: str(val) for key, val in request.cookies.items()}
if is_local_mode():
# Update the cookies on the flask server
user_session.set_user_info_cookie(cookies, local_user)
return cookies


def get_session() -> user_session.Session:
"""
Get the session for the current user.
Returns:
Session: The session for the current user.
"""
return user_session.Session(cg=FlaskCookieGetter())
27 changes: 27 additions & 0 deletions python/src/numerous/frameworks/marimo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Module for integrating Numerous with Marimo."""

from typing import Any

from numerous import user_session
from numerous.experimental import marimo
from numerous.local import is_local_mode, local_user


class MarimoCookieGetter:
def cookies(self) -> dict[str, Any]:
"""Get the cookies associated with the current request."""
if is_local_mode():
# Update the cookies on the marimo server
user_session.set_user_info_cookie(marimo.cookies(), local_user)
return {key: str(val) for key, val in marimo.cookies().items()}


def get_session() -> user_session.Session:
"""
Get the session for the current user.
Returns:
Session: The session for the current user.
"""
return user_session.Session(cg=MarimoCookieGetter())
30 changes: 30 additions & 0 deletions python/src/numerous/frameworks/panel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Module for integrating Numerous with Panel."""

import panel as pn

from numerous import user_session
from numerous.local import is_local_mode, local_user


class PanelCookieGetter:
@staticmethod
def cookies() -> dict[str, str]:
"""Get the cookies associated with the current request."""
if is_local_mode():
# Add the user info to the cookies on panel server
user_session.set_user_info_cookie(pn.state.cookies, local_user)

if pn.state.curdoc and pn.state.curdoc.session_context:
return {key: str(val) for key, val in pn.state.cookies.items()}
return {}


def get_session() -> user_session.Session:
"""
Get the session for the current user.
Returns:
Session: The session for the current user.
"""
return user_session.Session(cg=PanelCookieGetter())
27 changes: 27 additions & 0 deletions python/src/numerous/frameworks/streamlit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Module for integrating Numerous with Streamlit."""

import streamlit as st

from numerous import user_session
from numerous.local import is_local_mode, local_user


class StreamlitCookieGetter:
def cookies(self) -> dict[str, str]:
"""Get the cookies associated with the current request."""
cookies = {key: str(val) for key, val in st.context.cookies.items()}
if is_local_mode():
# Update the cookies on the streamlit server
user_session.set_user_info_cookie(cookies, local_user)
return cookies


def get_session() -> user_session.Session:
"""
Get the session for the current user.
Returns:
Session: The session for the current user.
"""
return user_session.Session(cg=StreamlitCookieGetter())
56 changes: 28 additions & 28 deletions python/src/numerous/generated/graphql/async_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ class AsyncBaseClient:
def __init__(
self,
url: str = "",
headers: Optional[Dict[str, str]] = None,
headers: Optional[dict[str, str]] = None,
http_client: Optional[httpx.AsyncClient] = None,
ws_url: str = "",
ws_headers: Optional[Dict[str, Any]] = None,
ws_headers: Optional[dict[str, Any]] = None,
ws_origin: Optional[str] = None,
ws_connection_init_payload: Optional[Dict[str, Any]] = None,
ws_connection_init_payload: Optional[dict[str, Any]] = None,
) -> None:
self.url = url
self.headers = headers
Expand Down Expand Up @@ -96,7 +96,7 @@ async def execute(
self,
query: str,
operation_name: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
variables: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> httpx.Response:
processed_variables, files, files_map = self._process_variables(variables)
Expand All @@ -118,7 +118,7 @@ async def execute(
**kwargs,
)

def get_data(self, response: httpx.Response) -> Dict[str, Any]:
def get_data(self, response: httpx.Response) -> dict[str, Any]:
if not response.is_success:
raise GraphQLClientHttpError(
status_code=response.status_code, response=response
Expand All @@ -142,19 +142,19 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]:
errors_dicts=errors, data=data
)

return cast(Dict[str, Any], data)
return cast(dict[str, Any], data)

async def execute_ws(
self,
query: str,
operation_name: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
variables: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
) -> AsyncIterator[dict[str, Any]]:
headers = self.ws_headers.copy()
headers.update(kwargs.get("extra_headers", {}))

merged_kwargs: Dict[str, Any] = {"origin": self.ws_origin}
merged_kwargs: dict[str, Any] = {"origin": self.ws_origin}
merged_kwargs.update(kwargs)
merged_kwargs["extra_headers"] = headers

Expand Down Expand Up @@ -185,9 +185,9 @@ async def execute_ws(
yield data

def _process_variables(
self, variables: Optional[Dict[str, Any]]
self, variables: Optional[dict[str, Any]]
) -> Tuple[
Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]]
dict[str, Any], dict[str, Tuple[str, IO[bytes], str]], dict[str, List[str]]
]:
if not variables:
return {}, {}, {}
Expand All @@ -196,8 +196,8 @@ def _process_variables(
return self._get_files_from_variables(serializable_variables)

def _convert_dict_to_json_serializable(
self, dict_: Dict[str, Any]
) -> Dict[str, Any]:
self, dict_: dict[str, Any]
) -> dict[str, Any]:
return {
key: self._convert_value(value)
for key, value in dict_.items()
Expand All @@ -212,11 +212,11 @@ def _convert_value(self, value: Any) -> Any:
return value

def _get_files_from_variables(
self, variables: Dict[str, Any]
self, variables: dict[str, Any]
) -> Tuple[
Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]]
dict[str, Any], dict[str, Tuple[str, IO[bytes], str]], dict[str, List[str]]
]:
files_map: Dict[str, List[str]] = {}
files_map: dict[str, List[str]] = {}
files_list: List[Upload] = []

def separate_files(path: str, obj: Any) -> Any:
Expand Down Expand Up @@ -247,7 +247,7 @@ def separate_files(path: str, obj: Any) -> Any:
return obj

nulled_variables = separate_files("variables", variables)
files: Dict[str, Tuple[str, IO[bytes], str]] = {
files: dict[str, Tuple[str, IO[bytes], str]] = {
str(i): (file_.filename, cast(IO[bytes], file_.content), file_.content_type)
for i, file_ in enumerate(files_list)
}
Expand All @@ -257,9 +257,9 @@ async def _execute_multipart(
self,
query: str,
operation_name: Optional[str],
variables: Dict[str, Any],
files: Dict[str, Tuple[str, IO[bytes], str]],
files_map: Dict[str, List[str]],
variables: dict[str, Any],
files: dict[str, Tuple[str, IO[bytes], str]],
files_map: dict[str, List[str]],
**kwargs: Any,
) -> httpx.Response:
data = {
Expand All @@ -282,13 +282,13 @@ async def _execute_json(
self,
query: str,
operation_name: Optional[str],
variables: Dict[str, Any],
variables: dict[str, Any],
**kwargs: Any,
) -> httpx.Response:
headers: Dict[str, str] = {"Content-Type": "application/json"}
headers: dict[str, str] = {"Content-Type": "application/json"}
headers.update(kwargs.get("headers", {}))

merged_kwargs: Dict[str, Any] = kwargs.copy()
merged_kwargs: dict[str, Any] = kwargs.copy()
merged_kwargs["headers"] = headers

return await self.http_client.post(
Expand All @@ -305,7 +305,7 @@ async def _execute_json(
)

async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None:
payload: Dict[str, Any] = {
payload: dict[str, Any] = {
"type": GraphQLTransportWSMessageType.CONNECTION_INIT.value
}
if self.ws_connection_init_payload:
Expand All @@ -318,9 +318,9 @@ async def _send_subscribe(
operation_id: str,
query: str,
operation_name: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
variables: Optional[dict[str, Any]] = None,
) -> None:
payload: Dict[str, Any] = {
payload: dict[str, Any] = {
"id": operation_id,
"type": GraphQLTransportWSMessageType.SUBSCRIBE.value,
"payload": {"query": query, "operationName": operation_name},
Expand All @@ -336,7 +336,7 @@ async def _handle_ws_message(
message: Data,
websocket: WebSocketClientProtocol,
expected_type: Optional[GraphQLTransportWSMessageType] = None,
) -> Optional[Dict[str, Any]]:
) -> Optional[dict[str, Any]]:
try:
message_dict = json.loads(message)
except json.JSONDecodeError as exc:
Expand All @@ -356,7 +356,7 @@ async def _handle_ws_message(
if type_ == GraphQLTransportWSMessageType.NEXT:
if "data" not in payload:
raise GraphQLClientInvalidMessageFormat(message=message)
return cast(Dict[str, Any], payload["data"])
return cast(dict[str, Any], payload["data"])

if type_ == GraphQLTransportWSMessageType.COMPLETE:
await websocket.close()
Expand Down
Loading

0 comments on commit 4cc900a

Please sign in to comment.