Skip to content

Commit

Permalink
fix(python-sdk): fixes exception raised when accessing local mode user
Browse files Browse the repository at this point in the history
  • Loading branch information
jfeo committed Oct 30, 2024
1 parent a9456f2 commit 3663049
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 6 deletions.
9 changes: 8 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@ dev = [
"mkdocs==1.6.0",
"mkdocstrings==0.26.2",
"mkdocstrings-python==1.12.2",
"mkdocs_snippet_plugin==1.0.2"
"mkdocs_snippet_plugin==1.0.2",
# for testing functionality related to frameworks
"flask",
"fastapi",
"marimo",
"streamlit",
"dash",
"panel",
]

[project.scripts]
Expand Down
6 changes: 3 additions & 3 deletions python/src/numerous/frameworks/marimo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
class MarimoCookieGetter:
def cookies(self) -> dict[str, Any]:
"""Get the cookies associated with the current request."""
cookies = marimo.cookies()
if is_local_mode():
# Update the cookies on the marimo server
user_session.set_user_info_cookie(marimo.cookies(), local_user)
return {key: str(val) for key, val in marimo.cookies().items()}
user_session.set_user_info_cookie(cookies, local_user)
return cookies


def get_session() -> user_session.Session:
Expand Down
4 changes: 2 additions & 2 deletions python/src/numerous/user_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def cookies(self) -> dict[str, Any]:


def encode_user_info(user_id: str, name: str) -> str:
user_info_json = json.dumps({"user_id": user_id, "name": name})
return base64.b64encode(user_info_json.encode("utf-8")).decode("utf-8")
user_info_json = json.dumps({"user_id": user_id, "user_full_name": name})
return base64.b64encode(user_info_json.encode()).decode()


def set_user_info_cookie(cookies: dict[str, str], user: User) -> None:
Expand Down
83 changes: 83 additions & 0 deletions python/tests/test_frameworks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from unittest.mock import Mock

import pytest

from numerous import local


@pytest.fixture(autouse=True)
def _ensure_local_mode(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.delenv("NUMEROUS_API_URL", raising=False)


def test_marimo_get_session_in_local_mode_returns_expected_local_user(
tmp_path_factory: pytest.TempPathFactory,
) -> None:
# patch cookies
from numerous.experimental.marimo._cookies.cookies import use_cookie_storage
from numerous.experimental.marimo._cookies.files import FileCookieStorage

path = tmp_path_factory.mktemp("marimo-cookies")
use_cookie_storage(FileCookieStorage(path, lambda: "test-ident"))

from numerous.frameworks.marimo import get_session

session = get_session()

assert session.user == local.local_user


def test_streamlit_get_session_in_local_mode_returns_expected_local_user() -> None:
from numerous.frameworks.streamlit import get_session

session = get_session()

assert session.user == local.local_user


def test_fastapi_get_session_in_local_mode_returns_expected_local_user() -> None:
from fastapi import Request

from numerous.frameworks.fastapi import get_session

session = get_session(Request(scope={"type": "http", "headers": {}}))

assert session.user == local.local_user


def test_flask_get_session_in_local_mode_returns_expected_local_user() -> None:
from flask import Flask

from numerous.frameworks.flask import get_session

with Flask("test_app").test_request_context():
session = get_session()

assert session.user == local.local_user


def test_dash_get_session_in_local_mode_returns_expected_local_user() -> None:
import dash

from numerous.frameworks.dash import get_session

app = dash.Dash()
with app.server.test_request_context():
session = get_session()

assert session.user == local.local_user


def test_panel_get_session_in_local_mode_returns_expected_local_user() -> None:
from bokeh.document import Document
from panel.io.state import set_curdoc

from numerous.frameworks.panel import get_session

mock_doc = Mock(Document)
mock_doc.session_context.request.cookies = {}

with set_curdoc(mock_doc):
session = get_session()

assert session.user == local.local_user

0 comments on commit 3663049

Please sign in to comment.