Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix startup process twice #938

Merged
merged 1 commit into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 5 additions & 126 deletions core/cat/main.py
Original file line number Diff line number Diff line change
@@ -1,134 +1,13 @@
import uvicorn
import asyncio
from contextlib import asynccontextmanager
from scalar_fastapi import get_scalar_api_reference

from fastapi import FastAPI
from fastapi.routing import APIRoute
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware

from cat.log import log
from cat.env import get_env, fix_legacy_env_variables
from cat.routes import (
base,
auth,
users,
settings,
llm,
embedder,
auth_handler,
plugins,
upload,
websocket,
)
from cat.routes.memory.memory_router import memory_router
from cat.routes.static import admin, static
from cat.routes.openapi import get_openapi_configuration_function
from cat.looking_glass.cheshire_cat import CheshireCat


# TODO: take away in v2
fix_legacy_env_variables()


@asynccontextmanager
async def lifespan(app: FastAPI):
# ^._.^
#
# loads Cat and plugins
# Every endpoint can access the cat instance via request.app.state.ccat
# - Not using midlleware because I can't make it work with both http and websocket;
# - Not using Depends because it only supports callables (not instances)
# - Starlette allows this: https://www.starlette.io/applications/#storing-state-on-the-app-instance
app.state.ccat = CheshireCat()

# Dict of pseudo-sessions (key is the user_id)
app.state.strays = {}

# set a reference to asyncio event loop
app.state.event_loop = asyncio.get_running_loop()

# startup message with admin, public and swagger addresses
log.welcome()

yield


def custom_generate_unique_id(route: APIRoute):
return f"{route.name}"


# REST API
cheshire_cat_api = FastAPI(
lifespan=lifespan, generate_unique_id_function=custom_generate_unique_id,
docs_url=None, redoc_url=None, title="Cheshire-Cat API",
license_info={"name": "GPL-3", "url": "https://www.gnu.org/licenses/gpl-3.0.en.html"},
)

# Configures the CORS middleware for the FastAPI app
cors_allowed_origins_str = get_env("CCAT_CORS_ALLOWED_ORIGINS")
origins = cors_allowed_origins_str.split(",") if cors_allowed_origins_str else ["*"]
cheshire_cat_api.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

# Add routers to the middleware stack.
cheshire_cat_api.include_router(base.router, tags=["Home"])
cheshire_cat_api.include_router(auth.router, tags=["User Auth"], prefix="/auth")
cheshire_cat_api.include_router(users.router, tags=["Users"], prefix="/users")
cheshire_cat_api.include_router(settings.router, tags=["Settings"], prefix="/settings")
cheshire_cat_api.include_router(
llm.router, tags=["Large Language Model"], prefix="/llm"
)
cheshire_cat_api.include_router(embedder.router, tags=["Embedder"], prefix="/embedder")
cheshire_cat_api.include_router(plugins.router, tags=["Plugins"], prefix="/plugins")
cheshire_cat_api.include_router(memory_router, prefix="/memory")
cheshire_cat_api.include_router(
upload.router, tags=["Rabbit Hole"], prefix="/rabbithole"
)
cheshire_cat_api.include_router(
auth_handler.router, tags=["AuthHandler"], prefix="/auth_handler"
)
cheshire_cat_api.include_router(websocket.router, tags=["Websocket"])

# mount static files
# this cannot be done via fastapi.APIrouter:
# https://github.com/tiangolo/fastapi/discussions/9070

# admin single page app (static build)
admin.mount(cheshire_cat_api)
# static files (for plugins and other purposes)
static.mount(cheshire_cat_api)


# error handling
@cheshire_cat_api.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
return JSONResponse(
status_code=400,
content={"error": exc.errors()},
)


# openapi customization
cheshire_cat_api.openapi = get_openapi_configuration_function(cheshire_cat_api)

@cheshire_cat_api.get("/docs", include_in_schema=False)
async def scalar_docs():
return get_scalar_api_reference(
openapi_url=cheshire_cat_api.openapi_url,
title=cheshire_cat_api.title,
scalar_favicon_url="https://cheshirecat.ai/wp-content/uploads/2023/10/Logo-Cheshire-Cat.svg",
)

# RUN!
if __name__ == "__main__":

# TODO: take away in v2
fix_legacy_env_variables()

# debugging utilities, to deactivate put `DEBUG=false` in .env
debug_config = {}
if get_env("CCAT_DEBUG") == "true":
Expand All @@ -146,7 +25,7 @@ async def scalar_docs():
}

uvicorn.run(
"cat.main:cheshire_cat_api",
"cat.startup:cheshire_cat_api",
host="0.0.0.0",
port=80,
use_colors=True,
Expand Down
131 changes: 131 additions & 0 deletions core/cat/startup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import asyncio
from contextlib import asynccontextmanager
from scalar_fastapi import get_scalar_api_reference

from fastapi import FastAPI
from fastapi.routing import APIRoute
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware

from cat.log import log
from cat.env import get_env
from cat.routes import (
base,
auth,
users,
settings,
llm,
embedder,
auth_handler,
plugins,
upload,
websocket,
)
from cat.routes.memory.memory_router import memory_router
from cat.routes.static import admin, static
from cat.routes.openapi import get_openapi_configuration_function
from cat.looking_glass.cheshire_cat import CheshireCat


@asynccontextmanager
async def lifespan(app: FastAPI):

# ^._.^
#
# loads Cat and plugins
# Every endpoint can access the cat instance via request.app.state.ccat
# - Not using midlleware because I can't make it work with both http and websocket;
# - Not using Depends because it only supports callables (not instances)
# - Starlette allows this: https://www.starlette.io/applications/#storing-state-on-the-app-instance
app.state.ccat = CheshireCat()

# Dict of pseudo-sessions (key is the user_id)
app.state.strays = {}

# set a reference to asyncio event loop
app.state.event_loop = asyncio.get_running_loop()

# startup message with admin, public and swagger addresses
log.welcome()

yield


def custom_generate_unique_id(route: APIRoute):
return f"{route.name}"


# REST API
cheshire_cat_api = FastAPI(
lifespan=lifespan,
generate_unique_id_function=custom_generate_unique_id,
docs_url=None,
redoc_url=None,
title="Cheshire-Cat API",
license_info={
"name": "GPL-3",
"url": "https://www.gnu.org/licenses/gpl-3.0.en.html",
},
)

# Configures the CORS middleware for the FastAPI app
cors_allowed_origins_str = get_env("CCAT_CORS_ALLOWED_ORIGINS")
origins = cors_allowed_origins_str.split(",") if cors_allowed_origins_str else ["*"]
cheshire_cat_api.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

# Add routers to the middleware stack.
cheshire_cat_api.include_router(base.router, tags=["Home"])
cheshire_cat_api.include_router(auth.router, tags=["User Auth"], prefix="/auth")
cheshire_cat_api.include_router(users.router, tags=["Users"], prefix="/users")
cheshire_cat_api.include_router(settings.router, tags=["Settings"], prefix="/settings")
cheshire_cat_api.include_router(
llm.router, tags=["Large Language Model"], prefix="/llm"
)
cheshire_cat_api.include_router(embedder.router, tags=["Embedder"], prefix="/embedder")
cheshire_cat_api.include_router(plugins.router, tags=["Plugins"], prefix="/plugins")
cheshire_cat_api.include_router(memory_router, prefix="/memory")
cheshire_cat_api.include_router(
upload.router, tags=["Rabbit Hole"], prefix="/rabbithole"
)
cheshire_cat_api.include_router(
auth_handler.router, tags=["AuthHandler"], prefix="/auth_handler"
)
cheshire_cat_api.include_router(websocket.router, tags=["Websocket"])

# mount static files
# this cannot be done via fastapi.APIrouter:
# https://github.com/tiangolo/fastapi/discussions/9070

# admin single page app (static build)
admin.mount(cheshire_cat_api)
# static files (for plugins and other purposes)
static.mount(cheshire_cat_api)


# error handling
@cheshire_cat_api.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
return JSONResponse(
status_code=400,
content={"error": exc.errors()},
)


# openapi customization
cheshire_cat_api.openapi = get_openapi_configuration_function(cheshire_cat_api)


@cheshire_cat_api.get("/docs", include_in_schema=False)
async def scalar_docs():
return get_scalar_api_reference(
openapi_url=cheshire_cat_api.openapi_url,
title=cheshire_cat_api.title,
scalar_favicon_url="https://cheshirecat.ai/wp-content/uploads/2023/10/Logo-Cheshire-Cat.svg",
)
2 changes: 1 addition & 1 deletion core/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import cat.utils as utils
from cat.memory.vector_memory import VectorMemory
from cat.mad_hatter.plugin import Plugin
from cat.main import cheshire_cat_api
from cat.startup import cheshire_cat_api
from tests.utils import create_mock_plugin_zip

import time
Expand Down
Loading