Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue with FastAPI mounts. #1183

Merged
merged 8 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 0 additions & 29 deletions integration_tests/fastapi_mount_test.sh

This file was deleted.

1 change: 0 additions & 1 deletion integration_tests/run_integration_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
set -e # Fail if any of the commands below fail.

./integration_tests/start_server_test.sh
./integration_tests/fastapi_mount_test.sh

echo
echo "CLI integration tests passed."
Expand Down
3 changes: 1 addition & 2 deletions lilac/data/cluster_titling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential

from ..batch_utils import group_by_sorted_key_iter
from ..env import env
from ..schema import (
Item,
)
Expand All @@ -24,7 +25,6 @@
)
from ..tasks import TaskInfo
from ..utils import chunks, log
from ..env import env

_TOP_K_CENTRAL_DOCS = 7
_TOP_K_CENTRAL_TITLES = 20
Expand Down Expand Up @@ -189,7 +189,6 @@ def _openai_client() -> Any:
try:
import openai


except ImportError:
raise ImportError(
'Could not import the "openai" python package. '
Expand Down
6 changes: 5 additions & 1 deletion lilac/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,13 @@ def __init__(self, app: ASGIApp) -> None:
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Redirect trailing slashes to non-trailing slashes."""
url = URL(scope=scope).path

root_path = scope.get('root_path') or ''
ends_with_slash = (
url.endswith('/') and url != '/' and root_path and not url.startswith(root_path + '/api')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make this a unit test to enumerate the kinds of URLs that this does/doesn't accept?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

url.endswith('/')
and url != '/'
and url != f'{root_path}/'
and not url.startswith(root_path + '/api')
)

if scope['type'] == 'http' and ends_with_slash:
Expand Down
37 changes: 37 additions & 0 deletions lilac/server_mount_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Tests the FastAPI server can be mounted."""

from fastapi import FastAPI
from fastapi.testclient import TestClient

from .server import app as lilac_app

app = FastAPI()

MOUNT_POINT = '/lilac_sub'


@app.get('/')
def read_main() -> dict:
"""The main endpoint."""
return {'message': 'hello world'}


app.mount(MOUNT_POINT, lilac_app)
client = TestClient(app)


def test_mount_root() -> None:
response = client.get('/', allow_redirects=False)
assert response.status_code == 200
assert response.json() == {'message': 'hello world'}


def test_mount_slash_redirect() -> None:
response = client.get(f'{MOUNT_POINT}/auth_info/', allow_redirects=False)
assert response.status_code == 307
# We should redirect to the URL with slash removed.
assert response.headers['location'] == f'{MOUNT_POINT}/auth_info'

# Allow redirects to follow through.
response = client.get(f'{MOUNT_POINT}/auth_info/', allow_redirects=True)
assert response.status_code == 200
11 changes: 11 additions & 0 deletions lilac/server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,14 @@ def user() -> UserInfo:
),
auth_enabled=True,
)


def test_slash_redirect() -> None:
response = client.get('/auth_info/', allow_redirects=False)
assert response.status_code == 307
# We should redirect to the URL with slash removed.
assert response.headers['location'] == '/auth_info'

# Allow redirects to follow through.
response = client.get('/auth_info/', allow_redirects=True)
assert response.status_code == 200
Loading