From 474f307cd2d59f05712369b9db1b8cb38254c353 Mon Sep 17 00:00:00 2001 From: Samuele Barzaghi Date: Fri, 1 Nov 2024 22:02:44 +0100 Subject: [PATCH 01/13] New "*endpoint" decorators for custom endpoint creations --- core/cat/looking_glass/cheshire_cat.py | 10 +- core/cat/mad_hatter/decorators/__init__.py | 3 +- core/cat/mad_hatter/decorators/endpoint.py | 102 ++++++++++++++++++ core/cat/mad_hatter/mad_hatter.py | 13 ++- core/cat/mad_hatter/plugin.py | 24 ++++- core/cat/routes/plugins.py | 1 + core/cat/startup.py | 5 +- core/tests/mad_hatter/test_endpoints.py | 68 ++++++++++++ core/tests/mocks/mock_plugin/mock_endpoint.py | 26 +++++ 9 files changed, 239 insertions(+), 13 deletions(-) create mode 100644 core/cat/mad_hatter/decorators/endpoint.py create mode 100644 core/tests/mad_hatter/test_endpoints.py create mode 100644 core/tests/mocks/mock_plugin/mock_endpoint.py diff --git a/core/cat/looking_glass/cheshire_cat.py b/core/cat/looking_glass/cheshire_cat.py index 06b336bb..a25abaa6 100644 --- a/core/cat/looking_glass/cheshire_cat.py +++ b/core/cat/looking_glass/cheshire_cat.py @@ -54,7 +54,7 @@ class CheshireCat: """ - def __init__(self): + def __init__(self, fastapi_app): """Cat initialization. At init time the Cat executes the bootstrap. @@ -62,14 +62,14 @@ def __init__(self): # bootstrap the Cat! ^._.^ - # load AuthHandler - self.load_auth() - # Start scheduling system self.white_rabbit = WhiteRabbit() # instantiate MadHatter (loads all plugins' hooks and tools) - self.mad_hatter = MadHatter() + self.mad_hatter = MadHatter(fastapi_app) + + # load AuthHandler + self.load_auth() # allows plugins to do something before cat components are loaded self.mad_hatter.execute_hook("before_cat_bootstrap", cat=self) diff --git a/core/cat/mad_hatter/decorators/__init__.py b/core/cat/mad_hatter/decorators/__init__.py index 375ab1f1..a73074c6 100644 --- a/core/cat/mad_hatter/decorators/__init__.py +++ b/core/cat/mad_hatter/decorators/__init__.py @@ -1,5 +1,6 @@ from cat.mad_hatter.decorators.tool import CatTool, tool from cat.mad_hatter.decorators.hook import CatHook, hook +from cat.mad_hatter.decorators.endpoint import CustomEndpoint, endpoint, get_endpoint, post_endpoint from cat.mad_hatter.decorators.plugin_decorator import CatPluginDecorator, plugin -__all__ = ["CatTool", "tool", "CatHook", "hook", "CatPluginDecorator", "plugin"] +__all__ = ["CatTool", "tool", "CatHook", "hook", "CustomEndpoint", "endpoint", "get_endpoint", "post_endpoint", "CatPluginDecorator", "plugin"] diff --git a/core/cat/mad_hatter/decorators/endpoint.py b/core/cat/mad_hatter/decorators/endpoint.py new file mode 100644 index 00000000..d911c8b8 --- /dev/null +++ b/core/cat/mad_hatter/decorators/endpoint.py @@ -0,0 +1,102 @@ +from typing import Callable +from fastapi import APIRouter + +cheshire_cat_api = None + +# class to represent a @endpoint +class CustomEndpoint: + def __init__(self, prefix: str, path: str, function: Callable, **kwargs): + self.prefix = prefix + self.path = path + self.function = function + self.name = self.prefix + self.path + + for k in kwargs: + setattr(self, k, kwargs[k]) + + def __repr__(self) -> str: + return f"CustomEndpoint(path={self.name})" + +# Called from madhatter to inject the fastapi app instance +def _init_endpoint_decorator(new_cheshire_cat_api): + global cheshire_cat_api + + cheshire_cat_api = new_cheshire_cat_api + +# @endpoint decorator. Any function in a plugin decorated by @endpoint will be exposed as FastAPI operation +def endpoint(path, methods, prefix="/custom_endpoints", tags=["custom_endpoints"], **kwargs) -> Callable: + """ + Define a custom API endpoint, parameters are the same as FastAPI path operation. + Examples: + .. code-block:: python + @endpoint(path="/hello", methods=["GET"]) + def my_endpoint() -> str: + return {"Hello":"Alice"} + """ + + global cheshire_cat_api + + def _make_endpoint(endpoint): + custom_endpoint = CustomEndpoint(prefix=prefix, path=path, function=endpoint, **kwargs) + + plugins_router = APIRouter() + plugins_router.add_api_route( + path=path, endpoint=endpoint, methods=methods, tags=tags, **kwargs + ) + + cheshire_cat_api.include_router(plugins_router, prefix=prefix) + + return custom_endpoint + + return _make_endpoint + +# @get_endpoint decorator. Any function in a plugin decorated by @endpoint will be exposed as FastAPI GET operation +def get_endpoint( + path, prefix="/custom_endpoints", response_model=None, tags=["custom_endpoints"], **kwargs +) -> Callable: + """ + Define a custom API endpoint for GET operation, parameters are the same as FastAPI path operation. + Examples: + .. code-block:: python + @get_endpoint(path="/hello") + def my_get_endpoint() -> str: + return {"Hello":"Alice"} + """ + + return endpoint( + path=path, + methods=["GET"], + prefix=prefix, + response_model=response_model, + tags=tags, + **kwargs, + ) + +# @post_endpoint decorator. Any function in a plugin decorated by @endpoint will be exposed as FastAPI POST operation +def post_endpoint( + path, prefix="/custom_endpoints", response_model=None, tags=["custom_endpoints"], **kwargs +) -> Callable: + + """ + Define a custom API endpoint for POST operation, parameters are the same as FastAPI path operation. + Examples: + .. code-block:: python + + from pydantic import BaseModel + + class Item(BaseModel): + name: str + description: str + + @post_endpoint(path="/hello") + def my_post_endpoint(item: Item) -> str: + return {"Hello": item.name, "Description": item.description} + """ + return endpoint( + path=path, + methods=["POST"], + prefix=prefix, + response_model=response_model, + tags=tags, + **kwargs, + ) diff --git a/core/cat/mad_hatter/mad_hatter.py b/core/cat/mad_hatter/mad_hatter.py index cbf9f9f6..04a8a172 100644 --- a/core/cat/mad_hatter/mad_hatter.py +++ b/core/cat/mad_hatter/mad_hatter.py @@ -18,6 +18,7 @@ from cat.mad_hatter.plugin import Plugin from cat.mad_hatter.decorators.hook import CatHook from cat.mad_hatter.decorators.tool import CatTool +from cat.mad_hatter.decorators.endpoint import CustomEndpoint, _init_endpoint_decorator from cat.experimental.form import CatForm @@ -29,12 +30,12 @@ @singleton class MadHatter: # loads and execute plugins - # - enter into the plugin folder and loads everthing + # - enter into the plugin folder and loads everything # that is decorated or named properly # - orders plugged in hooks by name and priority # - exposes functionality to the cat - def __init__(self): + def __init__(self, fastapi_app): self.plugins: Dict[str, Plugin] = {} # plugins dictionary self.hooks: Dict[ @@ -42,6 +43,9 @@ def __init__(self): ] = {} # dict of active plugins hooks ( hook_name -> [CatHook, CatHook, ...]) self.tools: List[CatTool] = [] # list of active plugins tools self.forms: List[CatForm] = [] # list of active plugins forms + self.endpoints: List[CustomEndpoint] = [] # list of active plugins endpoints + + _init_endpoint_decorator(fastapi_app) # the endpoint decorator need the fastapi instance self.active_plugins: List[str] = [] @@ -138,15 +142,18 @@ def sync_hooks_tools_and_forms(self): self.hooks = {} self.tools = [] self.forms = [] + self.endpoints = [] for _, plugin in self.plugins.items(): - # load hooks, tools and forms from active plugins + # load hooks, tools, forms and endpoints from active plugins if plugin.id in self.active_plugins: # cache tools self.tools += plugin.tools self.forms += plugin.forms + self.endpoints += plugin.endpoints + # cache hooks (indexed by hook name) for h in plugin.hooks: if h.name not in self.hooks.keys(): diff --git a/core/cat/mad_hatter/plugin.py b/core/cat/mad_hatter/plugin.py index dfeabd7b..528a314b 100644 --- a/core/cat/mad_hatter/plugin.py +++ b/core/cat/mad_hatter/plugin.py @@ -11,7 +11,7 @@ from pydantic import BaseModel, ValidationError from packaging.requirements import Requirement -from cat.mad_hatter.decorators import CatTool, CatHook, CatPluginDecorator +from cat.mad_hatter.decorators import CatTool, CatHook, CatPluginDecorator, CustomEndpoint from cat.experimental.form import CatForm from cat.utils import to_camel_case from cat.log import log @@ -59,6 +59,7 @@ def __init__(self, plugin_path: str): self._hooks: List[CatHook] = [] # list of plugin hooks self._tools: List[CatTool] = [] # list of plugin tools self._forms: List[CatForm] = [] # list of plugin forms + self._endpoints: List[CustomEndpoint] = [] # list of plugin endpoints # list of @plugin decorated functions overriding default plugin behaviour self._plugin_overrides = [] # TODO: make this a dictionary indexed by func name, for faster access @@ -98,6 +99,8 @@ def deactivate(self): self._hooks = [] self._tools = [] + #TODO : clear of forms is missing here? + self._endpoints = [] self._plugin_overrides = [] self._active = False @@ -295,6 +298,7 @@ def _load_decorated_functions(self): hooks = [] tools = [] forms = [] + endpoints = [] plugin_overrides = [] for py_file in self.py_files: @@ -309,6 +313,7 @@ def _load_decorated_functions(self): hooks += getmembers(plugin_module, self._is_cat_hook) tools += getmembers(plugin_module, self._is_cat_tool) forms += getmembers(plugin_module, self._is_cat_form) + endpoints += getmembers(plugin_module, self._is_custom_endpoint) plugin_overrides += getmembers( plugin_module, self._is_cat_plugin_override ) @@ -323,6 +328,7 @@ def _load_decorated_functions(self): self._hooks = list(map(self._clean_hook, hooks)) self._tools = list(map(self._clean_tool, tools)) self._forms = list(map(self._clean_form, forms)) + self._endpoints = list(map(self._clean_endpoint, endpoints)) self._plugin_overrides = list( map(self._clean_plugin_override, plugin_overrides) ) @@ -349,6 +355,12 @@ def _clean_form(self, form: CatForm): f = form[1] f.plugin_id = self._id return f + + def _clean_endpoint(self, endpoint: CustomEndpoint): + # getmembers returns a tuple + f = endpoint[1] + f.plugin_id = self._id + return f def _clean_plugin_override(self, plugin_override): # getmembers returns a tuple @@ -382,6 +394,12 @@ def _is_cat_tool(obj): def _is_cat_plugin_override(obj): return isinstance(obj, CatPluginDecorator) + # a plugin custom endpoint has to be decorated with @endpoint + # (which returns an instance of CustomEndpoint) + @staticmethod + def _is_custom_endpoint(obj): + return isinstance(obj, CustomEndpoint) + @property def path(self): return self._path @@ -409,3 +427,7 @@ def tools(self): @property def forms(self): return self._forms + + @property + def endpoints(self): + return self._endpoints diff --git a/core/cat/routes/plugins.py b/core/cat/routes/plugins.py index 3ab2f319..fc06848f 100644 --- a/core/cat/routes/plugins.py +++ b/core/cat/routes/plugins.py @@ -51,6 +51,7 @@ async def get_available_plugins( {"name": hook.name, "priority": hook.priority} for hook in p.hooks ] manifest["tools"] = [{"name": tool.name} for tool in p.tools] + manifest["endpoints"] = [{"name": endpoint.name} for endpoint in p.endpoints] # filter by query plugin_text = [str(field) for field in manifest.values()] diff --git a/core/cat/startup.py b/core/cat/startup.py index 5613528e..5620203b 100644 --- a/core/cat/startup.py +++ b/core/cat/startup.py @@ -27,7 +27,6 @@ from cat.routes.openapi import get_openapi_configuration_function from cat.looking_glass.cheshire_cat import CheshireCat - @asynccontextmanager async def lifespan(app: FastAPI): @@ -35,10 +34,10 @@ async def lifespan(app: FastAPI): # # loads Cat and plugins # Every endpoint can access the cat instance via request.app.state.ccat - # - Not using midlleware because I can't make it work with both http and websocket; + # - Not using middleware because I can't make it work with both http and websocket; # - Not using Depends because it only supports callables (not instances) # - Starlette allows this: https://www.starlette.io/applications/#storing-state-on-the-app-instance - app.state.ccat = CheshireCat() + app.state.ccat = CheshireCat(cheshire_cat_api) # Dict of pseudo-sessions (key is the user_id) app.state.strays = {} diff --git a/core/tests/mad_hatter/test_endpoints.py b/core/tests/mad_hatter/test_endpoints.py new file mode 100644 index 00000000..6a3d7b3d --- /dev/null +++ b/core/tests/mad_hatter/test_endpoints.py @@ -0,0 +1,68 @@ +import pytest + +from cat.mad_hatter.mad_hatter import MadHatter +from cat.mad_hatter.decorators import CustomEndpoint + +from tests.utils import create_mock_plugin_zip + +# this function will be run before each test function +@pytest.fixture +def mad_hatter(client): # client here injects the monkeypatched version of the cat + + # each test is given the mad_hatter instance (it's a singleton) + mad_hatter = MadHatter() + + # install plugin + new_plugin_zip_path = create_mock_plugin_zip(flat=True) + mad_hatter.install_plugin(new_plugin_zip_path) + + yield mad_hatter + + +def test_endpoints_discovery(mad_hatter): + mock_plugin_endpoints = mad_hatter.plugins["mock_plugin"].endpoints + + assert len(mock_plugin_endpoints) == 4 + + for h in mock_plugin_endpoints: + assert isinstance(h, CustomEndpoint) + assert h.plugin_id == "mock_plugin" + + +def test_endpoint_decorator(client): + + response = client.get("/custom_endpoints/endpoint") + + assert response.status_code == 200 + + assert response.json()["result"] == "endpoint default prefix" + + +def test_endpoint_decorator_prefix(client): + + response = client.get("/tests/endpoint") + + assert response.status_code == 200 + + assert response.json()["result"] == "endpoint prefix tests" + + +def test_get_endpoint(client): + + response = client.get("/tests/get") + + assert response.status_code == 200 + + assert response.json()["result"] == "ok" + assert response.json()["stray_user_id"] == "user" + + +def test_post_endpoint(client): + + payload = {"name": "the cat", "description" : "it's magic"} + response = client.post("/tests/post", json=payload) + + assert response.status_code == 200 + + assert response.json()["name"] == "the cat" + assert response.json()["description"] == "it's magic" \ No newline at end of file diff --git a/core/tests/mocks/mock_plugin/mock_endpoint.py b/core/tests/mocks/mock_plugin/mock_endpoint.py new file mode 100644 index 00000000..3aea7c8d --- /dev/null +++ b/core/tests/mocks/mock_plugin/mock_endpoint.py @@ -0,0 +1,26 @@ +from fastapi import Request, Depends +from pydantic import BaseModel + +from cat.mad_hatter.decorators import endpoint, get_endpoint, post_endpoint +from cat.auth.connection import HTTPAuth +from cat.auth.permissions import AuthPermission, AuthResource + +class Item(BaseModel): + name: str + description: str + +@endpoint(path="/endpoint", methods=["GET"], tags=["Tests"]) +def test_endpoint(): + return {"result":"endpoint default prefix"} + +@endpoint(path="/endpoint", prefix="/tests", methods=["GET"], tags=["Tests"]) +def test_endpoint_prefix(): + return {"result":"endpoint prefix tests"} + +@get_endpoint(path="/get", prefix="/tests", tags=["Tests"]) +def test_get(request: Request, stray=Depends(HTTPAuth(AuthResource.PLUGINS, AuthPermission.LIST))): + return {"result":"ok", "stray_user_id":stray.user_id} + +@post_endpoint(path="/post", prefix="/tests", tags=["Tests"]) +def test_post(item: Item) -> str: + return {"name": item.name, "description": item.description} \ No newline at end of file From b05452995f716c7e11d404c33640d26f5915ea36 Mon Sep 17 00:00:00 2001 From: Samuele Barzaghi Date: Sat, 2 Nov 2024 14:15:05 +0100 Subject: [PATCH 02/13] Changes: - one decorator but with dot syntax @endpoint.get @endpoint.post @endpoint.endpoint - default prefix now "custom-endpoints" (kebab-case) - add tests on tags --- core/cat/mad_hatter/decorators/__init__.py | 4 +- core/cat/mad_hatter/decorators/endpoint.py | 188 ++++++++++-------- core/cat/mad_hatter/mad_hatter.py | 4 +- core/cat/routes/plugins.py | 2 +- core/tests/mad_hatter/test_endpoints.py | 8 +- core/tests/mocks/mock_plugin/mock_endpoint.py | 10 +- 6 files changed, 127 insertions(+), 89 deletions(-) diff --git a/core/cat/mad_hatter/decorators/__init__.py b/core/cat/mad_hatter/decorators/__init__.py index a73074c6..9aa85693 100644 --- a/core/cat/mad_hatter/decorators/__init__.py +++ b/core/cat/mad_hatter/decorators/__init__.py @@ -1,6 +1,6 @@ from cat.mad_hatter.decorators.tool import CatTool, tool from cat.mad_hatter.decorators.hook import CatHook, hook -from cat.mad_hatter.decorators.endpoint import CustomEndpoint, endpoint, get_endpoint, post_endpoint +from cat.mad_hatter.decorators.endpoint import CustomEndpoint, endpoint from cat.mad_hatter.decorators.plugin_decorator import CatPluginDecorator, plugin -__all__ = ["CatTool", "tool", "CatHook", "hook", "CustomEndpoint", "endpoint", "get_endpoint", "post_endpoint", "CatPluginDecorator", "plugin"] +__all__ = ["CatTool", "tool", "CatHook", "hook", "CustomEndpoint", "endpoint", "CatPluginDecorator", "plugin"] diff --git a/core/cat/mad_hatter/decorators/endpoint.py b/core/cat/mad_hatter/decorators/endpoint.py index d911c8b8..7d8114b3 100644 --- a/core/cat/mad_hatter/decorators/endpoint.py +++ b/core/cat/mad_hatter/decorators/endpoint.py @@ -1,14 +1,14 @@ from typing import Callable from fastapi import APIRouter -cheshire_cat_api = None - # class to represent a @endpoint class CustomEndpoint: - def __init__(self, prefix: str, path: str, function: Callable, **kwargs): + def __init__(self, prefix: str, path: str, function: Callable, tags, **kwargs): self.prefix = prefix self.path = path self.function = function + self.tags = tags + self.name = self.prefix + self.path for k in kwargs: @@ -17,86 +17,118 @@ def __init__(self, prefix: str, path: str, function: Callable, **kwargs): def __repr__(self) -> str: return f"CustomEndpoint(path={self.name})" -# Called from madhatter to inject the fastapi app instance -def _init_endpoint_decorator(new_cheshire_cat_api): - global cheshire_cat_api - cheshire_cat_api = new_cheshire_cat_api +class Endpoint: -# @endpoint decorator. Any function in a plugin decorated by @endpoint will be exposed as FastAPI operation -def endpoint(path, methods, prefix="/custom_endpoints", tags=["custom_endpoints"], **kwargs) -> Callable: - """ - Define a custom API endpoint, parameters are the same as FastAPI path operation. - Examples: - .. code-block:: python - @endpoint(path="/hello", methods=["GET"]) - def my_endpoint() -> str: - return {"Hello":"Alice"} - """ + cheshire_cat_api = None - global cheshire_cat_api + default_prefix = "/custom-endpoints" + default_tags = ["Custom Endpoints"] - def _make_endpoint(endpoint): - custom_endpoint = CustomEndpoint(prefix=prefix, path=path, function=endpoint, **kwargs) + # Called from madhatter to inject the fastapi app instance + def _init_decorators(cls, new_cheshire_cat_api): + cls.cheshire_cat_api = new_cheshire_cat_api - plugins_router = APIRouter() - plugins_router.add_api_route( - path=path, endpoint=endpoint, methods=methods, tags=tags, **kwargs + # @endpoint decorator. Any function in a plugin decorated by @endpoint.endpoint will be exposed as FastAPI operation + def endpoint( + cls, + path, + methods, + prefix=default_prefix, + tags=default_tags, + **kwargs, + ) -> Callable: + """ + Define a custom API endpoint, parameters are the same as FastAPI path operation. + Examples: + .. code-block:: python + from cat.mad_hatter.decorators.endpoint import endpoint + + @endpoint.endpoint(path="/hello", methods=["GET"]) + def my_endpoint(): + return {"Hello":"Alice"} + """ + + def _make_endpoint(endpoint): + custom_endpoint = CustomEndpoint( + prefix=prefix, path=path, function=endpoint, tags=tags, **kwargs + ) + + plugins_router = APIRouter() + plugins_router.add_api_route( + path=path, endpoint=endpoint, methods=methods, tags=tags, **kwargs + ) + + cls.cheshire_cat_api.include_router(plugins_router, prefix=prefix) + + return custom_endpoint + + return _make_endpoint + + # @get_endpoint decorator. Any function in a plugin decorated by @endpoint.get will be exposed as FastAPI GET operation + def get( + cls, + path, + prefix=default_prefix, + response_model=None, + tags=default_tags, + **kwargs, + ) -> Callable: + """ + Define a custom API endpoint for GET operation, parameters are the same as FastAPI path operation. + Examples: + .. code-block:: python + from cat.mad_hatter.decorators.endpoint import endpoint + + @endpoint.get(path="/hello") + def my_get_endpoint(): + return {"Hello":"Alice"} + """ + + return cls.endpoint( + path=path, + methods=["GET"], + prefix=prefix, + response_model=response_model, + tags=tags, + **kwargs, ) - cheshire_cat_api.include_router(plugins_router, prefix=prefix) - - return custom_endpoint - - return _make_endpoint - -# @get_endpoint decorator. Any function in a plugin decorated by @endpoint will be exposed as FastAPI GET operation -def get_endpoint( - path, prefix="/custom_endpoints", response_model=None, tags=["custom_endpoints"], **kwargs -) -> Callable: - """ - Define a custom API endpoint for GET operation, parameters are the same as FastAPI path operation. - Examples: - .. code-block:: python - @get_endpoint(path="/hello") - def my_get_endpoint() -> str: - return {"Hello":"Alice"} - """ - - return endpoint( - path=path, - methods=["GET"], - prefix=prefix, - response_model=response_model, - tags=tags, + # @post_endpoint decorator. Any function in a plugin decorated by @endpoint.post will be exposed as FastAPI POST operation + def post( + cls, + path, + prefix=default_prefix, + response_model=None, + tags=default_tags, **kwargs, - ) - -# @post_endpoint decorator. Any function in a plugin decorated by @endpoint will be exposed as FastAPI POST operation -def post_endpoint( - path, prefix="/custom_endpoints", response_model=None, tags=["custom_endpoints"], **kwargs -) -> Callable: - - """ - Define a custom API endpoint for POST operation, parameters are the same as FastAPI path operation. - Examples: - .. code-block:: python - - from pydantic import BaseModel - - class Item(BaseModel): - name: str - description: str - - @post_endpoint(path="/hello") - def my_post_endpoint(item: Item) -> str: - return {"Hello": item.name, "Description": item.description} - """ - return endpoint( - path=path, - methods=["POST"], - prefix=prefix, - response_model=response_model, - tags=tags, - **kwargs, - ) + ) -> Callable: + """ + Define a custom API endpoint for POST operation, parameters are the same as FastAPI path operation. + Examples: + .. code-block:: python + + from cat.mad_hatter.decorators.endpoint import endpoint + from pydantic import BaseModel + + class Item(BaseModel): + name: str + description: str + + @endpoint.post(path="/hello") + def my_post_endpoint(item: Item): + return {"Hello": item.name, "Description": item.description} + """ + return cls.endpoint( + path=path, + methods=["POST"], + prefix=prefix, + response_model=response_model, + tags=tags, + **kwargs, + ) + +endpoint = None + +if not endpoint: + endpoint = Endpoint() diff --git a/core/cat/mad_hatter/mad_hatter.py b/core/cat/mad_hatter/mad_hatter.py index 04a8a172..84770352 100644 --- a/core/cat/mad_hatter/mad_hatter.py +++ b/core/cat/mad_hatter/mad_hatter.py @@ -18,7 +18,7 @@ from cat.mad_hatter.plugin import Plugin from cat.mad_hatter.decorators.hook import CatHook from cat.mad_hatter.decorators.tool import CatTool -from cat.mad_hatter.decorators.endpoint import CustomEndpoint, _init_endpoint_decorator +from cat.mad_hatter.decorators.endpoint import CustomEndpoint, endpoint from cat.experimental.form import CatForm @@ -45,7 +45,7 @@ def __init__(self, fastapi_app): self.forms: List[CatForm] = [] # list of active plugins forms self.endpoints: List[CustomEndpoint] = [] # list of active plugins endpoints - _init_endpoint_decorator(fastapi_app) # the endpoint decorator need the fastapi instance + endpoint._init_decorators(fastapi_app) # the endpoint decorators need the fastapi instance self.active_plugins: List[str] = [] diff --git a/core/cat/routes/plugins.py b/core/cat/routes/plugins.py index fc06848f..c8e7d6a9 100644 --- a/core/cat/routes/plugins.py +++ b/core/cat/routes/plugins.py @@ -51,7 +51,7 @@ async def get_available_plugins( {"name": hook.name, "priority": hook.priority} for hook in p.hooks ] manifest["tools"] = [{"name": tool.name} for tool in p.tools] - manifest["endpoints"] = [{"name": endpoint.name} for endpoint in p.endpoints] + manifest["endpoints"] = [{"name": endpoint.name, "tags": endpoint.tags} for endpoint in p.endpoints] # filter by query plugin_text = [str(field) for field in manifest.values()] diff --git a/core/tests/mad_hatter/test_endpoints.py b/core/tests/mad_hatter/test_endpoints.py index 6a3d7b3d..1bcdf825 100644 --- a/core/tests/mad_hatter/test_endpoints.py +++ b/core/tests/mad_hatter/test_endpoints.py @@ -28,10 +28,16 @@ def test_endpoints_discovery(mad_hatter): assert isinstance(h, CustomEndpoint) assert h.plugin_id == "mock_plugin" + if h.name == "/custom-endpoints/endpoint": + assert h.tags == ["Custom Endpoints"] + + if h.name == "/custom-endpoints/tests": + assert h.tags == ["Tests"] + def test_endpoint_decorator(client): - response = client.get("/custom_endpoints/endpoint") + response = client.get("/custom-endpoints/endpoint") assert response.status_code == 200 diff --git a/core/tests/mocks/mock_plugin/mock_endpoint.py b/core/tests/mocks/mock_plugin/mock_endpoint.py index 3aea7c8d..a34b9439 100644 --- a/core/tests/mocks/mock_plugin/mock_endpoint.py +++ b/core/tests/mocks/mock_plugin/mock_endpoint.py @@ -1,7 +1,7 @@ from fastapi import Request, Depends from pydantic import BaseModel -from cat.mad_hatter.decorators import endpoint, get_endpoint, post_endpoint +from cat.mad_hatter.decorators import endpoint from cat.auth.connection import HTTPAuth from cat.auth.permissions import AuthPermission, AuthResource @@ -9,18 +9,18 @@ class Item(BaseModel): name: str description: str -@endpoint(path="/endpoint", methods=["GET"], tags=["Tests"]) +@endpoint.endpoint(path="/endpoint", methods=["GET"]) def test_endpoint(): return {"result":"endpoint default prefix"} -@endpoint(path="/endpoint", prefix="/tests", methods=["GET"], tags=["Tests"]) +@endpoint.endpoint(path="/endpoint", prefix="/tests", methods=["GET"], tags=["Tests"]) def test_endpoint_prefix(): return {"result":"endpoint prefix tests"} -@get_endpoint(path="/get", prefix="/tests", tags=["Tests"]) +@endpoint.get(path="/get", prefix="/tests", tags=["Tests"]) def test_get(request: Request, stray=Depends(HTTPAuth(AuthResource.PLUGINS, AuthPermission.LIST))): return {"result":"ok", "stray_user_id":stray.user_id} -@post_endpoint(path="/post", prefix="/tests", tags=["Tests"]) +@endpoint.post(path="/post", prefix="/tests", tags=["Tests"]) def test_post(item: Item) -> str: return {"name": item.name, "description": item.description} \ No newline at end of file From b1cfea538048fc46577369aa9fc416fdd71d5bd7 Mon Sep 17 00:00:00 2001 From: Samuele Barzaghi Date: Sat, 2 Nov 2024 14:25:24 +0100 Subject: [PATCH 03/13] Fixed docstring --- core/cat/mad_hatter/decorators/endpoint.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/cat/mad_hatter/decorators/endpoint.py b/core/cat/mad_hatter/decorators/endpoint.py index 7d8114b3..52538317 100644 --- a/core/cat/mad_hatter/decorators/endpoint.py +++ b/core/cat/mad_hatter/decorators/endpoint.py @@ -42,7 +42,7 @@ def endpoint( Define a custom API endpoint, parameters are the same as FastAPI path operation. Examples: .. code-block:: python - from cat.mad_hatter.decorators.endpoint import endpoint + from cat.mad_hatter.decorators import endpoint @endpoint.endpoint(path="/hello", methods=["GET"]) def my_endpoint(): @@ -78,7 +78,7 @@ def get( Define a custom API endpoint for GET operation, parameters are the same as FastAPI path operation. Examples: .. code-block:: python - from cat.mad_hatter.decorators.endpoint import endpoint + from cat.mad_hatter.decorators import endpoint @endpoint.get(path="/hello") def my_get_endpoint(): @@ -108,7 +108,7 @@ def post( Examples: .. code-block:: python - from cat.mad_hatter.decorators.endpoint import endpoint + from cat.mad_hatter.decorators import endpoint from pydantic import BaseModel class Item(BaseModel): From 002b939cbe07274079a3df7a501548238db61f77 Mon Sep 17 00:00:00 2001 From: Samuele Barzaghi Date: Sat, 2 Nov 2024 14:53:25 +0100 Subject: [PATCH 04/13] Fix doc page not updated on plugin activation How to reproduce: - Launch the Cat - Goto /docs page - Activate a plugin with a custom endpoint - The custom endpoint can be used - Refresh page /docs - The new custom endpoint is not there Why: The openapi schema is cached by get_openapi_configuration_function Solution: Each time a function is decorated with `endpoint`, we flush the cache --- core/cat/mad_hatter/decorators/endpoint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/core/cat/mad_hatter/decorators/endpoint.py b/core/cat/mad_hatter/decorators/endpoint.py index 52538317..d101f1cd 100644 --- a/core/cat/mad_hatter/decorators/endpoint.py +++ b/core/cat/mad_hatter/decorators/endpoint.py @@ -60,6 +60,7 @@ def _make_endpoint(endpoint): ) cls.cheshire_cat_api.include_router(plugins_router, prefix=prefix) + cls.cheshire_cat_api.openapi_schema = None # Flush the cache of openapi schema return custom_endpoint From c878a4ca8d58b0c25d386e389afdbcf3cb6483ee Mon Sep 17 00:00:00 2001 From: Samuele Barzaghi Date: Sat, 2 Nov 2024 15:58:32 +0100 Subject: [PATCH 05/13] Implement endpoints deactivation - Routes manually removed from FastAPI routes (seems there is no official method to do this) - Flush the docs cache --- core/cat/mad_hatter/decorators/endpoint.py | 43 ++++++++++++++++++++-- core/cat/mad_hatter/plugin.py | 7 +++- 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/core/cat/mad_hatter/decorators/endpoint.py b/core/cat/mad_hatter/decorators/endpoint.py index d101f1cd..d96dc095 100644 --- a/core/cat/mad_hatter/decorators/endpoint.py +++ b/core/cat/mad_hatter/decorators/endpoint.py @@ -1,13 +1,23 @@ from typing import Callable from fastapi import APIRouter + # class to represent a @endpoint class CustomEndpoint: - def __init__(self, prefix: str, path: str, function: Callable, tags, **kwargs): + def __init__( + self, + prefix: str, + path: str, + function: Callable, + tags, + cheshire_cat_api, + **kwargs, + ): self.prefix = prefix self.path = path self.function = function self.tags = tags + self.cheshire_cat_api = cheshire_cat_api self.name = self.prefix + self.path @@ -17,6 +27,17 @@ def __init__(self, prefix: str, path: str, function: Callable, tags, **kwargs): def __repr__(self) -> str: return f"CustomEndpoint(path={self.name})" + def set_api_route(self, api_route): + self.api_route = api_route + + def remove(self): + # Seems there is no official way to remove a route: + # https://github.com/fastapi/fastapi/discussions/8088 + self.cheshire_cat_api.routes.remove( + self.api_route + ) + self.cheshire_cat_api.openapi_schema = None # Flush the cached openapi schema + class Endpoint: @@ -51,7 +72,12 @@ def my_endpoint(): def _make_endpoint(endpoint): custom_endpoint = CustomEndpoint( - prefix=prefix, path=path, function=endpoint, tags=tags, **kwargs + prefix=prefix, + path=path, + function=endpoint, + tags=tags, + cheshire_cat_api=cls.cheshire_cat_api, + **kwargs, ) plugins_router = APIRouter() @@ -60,7 +86,17 @@ def _make_endpoint(endpoint): ) cls.cheshire_cat_api.include_router(plugins_router, prefix=prefix) - cls.cheshire_cat_api.openapi_schema = None # Flush the cache of openapi schema + cls.cheshire_cat_api.openapi_schema = ( + None # Flush the cache of openapi schema + ) + + # Set the fastapi api_route into the Custom Endpoint + # (The method add_api_route of FastAPI do append, so our new route is + # the last route) + for api_route in cls.cheshire_cat_api.routes: + if api_route.path == custom_endpoint.name: + custom_endpoint.set_api_route(api_route) + break return custom_endpoint @@ -129,6 +165,7 @@ def my_post_endpoint(item: Item): **kwargs, ) + endpoint = None if not endpoint: diff --git a/core/cat/mad_hatter/plugin.py b/core/cat/mad_hatter/plugin.py index 528a314b..0d5adcb3 100644 --- a/core/cat/mad_hatter/plugin.py +++ b/core/cat/mad_hatter/plugin.py @@ -100,7 +100,7 @@ def deactivate(self): self._hooks = [] self._tools = [] #TODO : clear of forms is missing here? - self._endpoints = [] + self._deactivate_endpoints() self._plugin_overrides = [] self._active = False @@ -338,6 +338,11 @@ def plugin_specific_error_message(self): url = self.manifest.get("plugin_url") return f"To resolve any problem related to {name} plugin, contact the creator using github issue at the link {url}" + def _deactivate_endpoints(self): + + for endpoint in self._endpoints: + endpoint.remove() + def _clean_hook(self, hook: CatHook): # getmembers returns a tuple h = hook[1] From 3cc6f5b486995932bd265e752b1230813a835284 Mon Sep 17 00:00:00 2001 From: Samuele Barzaghi Date: Sat, 2 Nov 2024 16:40:58 +0100 Subject: [PATCH 06/13] Test plugin deactivation (only endpoint, no tests on docs) --- core/tests/mad_hatter/test_endpoints.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/core/tests/mad_hatter/test_endpoints.py b/core/tests/mad_hatter/test_endpoints.py index 1bcdf825..7e4ac309 100644 --- a/core/tests/mad_hatter/test_endpoints.py +++ b/core/tests/mad_hatter/test_endpoints.py @@ -71,4 +71,14 @@ def test_post_endpoint(client): assert response.status_code == 200 assert response.json()["name"] == "the cat" - assert response.json()["description"] == "it's magic" \ No newline at end of file + assert response.json()["description"] == "it's magic" + +def test_plugin_deactivation(client, mad_hatter): + + response = client.get("/custom-endpoints/endpoint") + assert response.status_code == 200 + + mad_hatter.toggle_plugin("mock_plugin") + + response = client.get("/custom-endpoints/endpoint") + assert response.status_code == 404 \ No newline at end of file From 2bfb1a0cf3142d63d414843efc5f38d236191a21 Mon Sep 17 00:00:00 2001 From: Samuele Barzaghi Date: Sat, 2 Nov 2024 22:54:42 +0100 Subject: [PATCH 07/13] Refactor, remove FastAPI from MadHatter --- core/cat/looking_glass/cheshire_cat.py | 21 +++++-- core/cat/mad_hatter/decorators/endpoint.py | 68 ++++++++++++---------- core/cat/mad_hatter/mad_hatter.py | 4 +- core/cat/mad_hatter/plugin.py | 4 +- core/tests/mad_hatter/test_endpoints.py | 10 ++-- core/tests/mad_hatter/test_plugin.py | 1 + 6 files changed, 64 insertions(+), 44 deletions(-) diff --git a/core/cat/looking_glass/cheshire_cat.py b/core/cat/looking_glass/cheshire_cat.py index a25abaa6..851b940e 100644 --- a/core/cat/looking_glass/cheshire_cat.py +++ b/core/cat/looking_glass/cheshire_cat.py @@ -61,12 +61,13 @@ def __init__(self, fastapi_app): """ # bootstrap the Cat! ^._.^ + self.fastapi_app = fastapi_app # Start scheduling system self.white_rabbit = WhiteRabbit() # instantiate MadHatter (loads all plugins' hooks and tools) - self.mad_hatter = MadHatter(fastapi_app) + self.mad_hatter = MadHatter() # load AuthHandler self.load_auth() @@ -80,10 +81,11 @@ def __init__(self, fastapi_app): # Load memories (vector collections and working_memory) self.load_memory() - # After memory is loaded, we can get/create tools embeddings - # every time the mad_hatter finishes syncing hooks, tools and forms, it will notify the Cat (so it can embed tools in vector memory) - self.mad_hatter.on_finish_plugins_sync_callback = self.embed_procedures - self.embed_procedures() # first time launched manually + # After memory is loaded, we can get/create tools embeddings + self.mad_hatter.on_finish_plugins_sync_callback = self.on_finish_plugins_sync_callback + + # First time launched manually + self.on_finish_plugins_sync_callback() # Main agent instance (for reasoning) self.main_agent = MainAgent() @@ -332,6 +334,15 @@ def build_active_procedures_hashes(self, active_procedures): } return hashes + def on_finish_plugins_sync_callback(self): + self.activate_endpoints() + self.embed_procedures() + + def activate_endpoints(self): + for endpoint in self.mad_hatter.endpoints: + if endpoint.plugin_id in self.mad_hatter.active_plugins: + endpoint.activate(self.fastapi_app) + def embed_procedures(self): # Retrieve from vectorDB all procedural embeddings embedded_procedures, _ = self.memory.vectors.procedural.get_all_points() diff --git a/core/cat/mad_hatter/decorators/endpoint.py b/core/cat/mad_hatter/decorators/endpoint.py index d96dc095..0bbf3b7e 100644 --- a/core/cat/mad_hatter/decorators/endpoint.py +++ b/core/cat/mad_hatter/decorators/endpoint.py @@ -1,6 +1,7 @@ from typing import Callable from fastapi import APIRouter +from cat.log import log # class to represent a @endpoint class CustomEndpoint: @@ -9,36 +10,61 @@ def __init__( prefix: str, path: str, function: Callable, + methods, tags, - cheshire_cat_api, **kwargs, ): self.prefix = prefix self.path = path self.function = function self.tags = tags - self.cheshire_cat_api = cheshire_cat_api - + self.methods = methods + self.kwargs = kwargs self.name = self.prefix + self.path - for k in kwargs: - setattr(self, k, kwargs[k]) - def __repr__(self) -> str: return f"CustomEndpoint(path={self.name})" def set_api_route(self, api_route): self.api_route = api_route - def remove(self): + def activate(self, cheshire_cat_api): + + self.cheshire_cat_api = cheshire_cat_api + + # Set the fastapi api_route into the Custom Endpoint + for api_route in self.cheshire_cat_api.routes: + if api_route.path == self.name: + log.error(f"There is already an endpoint with path {self.name}") + return + + plugins_router = APIRouter() + plugins_router.add_api_route( + path=self.path, + endpoint=self.function, + methods=self.methods, + tags=self.tags, + **self.kwargs, + ) + + self.cheshire_cat_api.include_router(plugins_router, prefix=self.prefix) + self.cheshire_cat_api.openapi_schema = None # Flush the cache of openapi schema + + # Set the fastapi api_route into the Custom Endpoint + for api_route in self.cheshire_cat_api.routes: + if api_route.path == self.name: + self.api_route = api_route + break + + assert api_route.path == self.name + + def deactivate(self): + # Seems there is no official way to remove a route: # https://github.com/fastapi/fastapi/discussions/8088 - self.cheshire_cat_api.routes.remove( - self.api_route - ) + self.cheshire_cat_api.routes.remove(self.api_route) self.cheshire_cat_api.openapi_schema = None # Flush the cached openapi schema - class Endpoint: cheshire_cat_api = None @@ -75,29 +101,11 @@ def _make_endpoint(endpoint): prefix=prefix, path=path, function=endpoint, + methods=methods, tags=tags, - cheshire_cat_api=cls.cheshire_cat_api, **kwargs, ) - plugins_router = APIRouter() - plugins_router.add_api_route( - path=path, endpoint=endpoint, methods=methods, tags=tags, **kwargs - ) - - cls.cheshire_cat_api.include_router(plugins_router, prefix=prefix) - cls.cheshire_cat_api.openapi_schema = ( - None # Flush the cache of openapi schema - ) - - # Set the fastapi api_route into the Custom Endpoint - # (The method add_api_route of FastAPI do append, so our new route is - # the last route) - for api_route in cls.cheshire_cat_api.routes: - if api_route.path == custom_endpoint.name: - custom_endpoint.set_api_route(api_route) - break - return custom_endpoint return _make_endpoint diff --git a/core/cat/mad_hatter/mad_hatter.py b/core/cat/mad_hatter/mad_hatter.py index 84770352..55a038d3 100644 --- a/core/cat/mad_hatter/mad_hatter.py +++ b/core/cat/mad_hatter/mad_hatter.py @@ -35,7 +35,7 @@ class MadHatter: # - orders plugged in hooks by name and priority # - exposes functionality to the cat - def __init__(self, fastapi_app): + def __init__(self): self.plugins: Dict[str, Plugin] = {} # plugins dictionary self.hooks: Dict[ @@ -45,8 +45,6 @@ def __init__(self, fastapi_app): self.forms: List[CatForm] = [] # list of active plugins forms self.endpoints: List[CustomEndpoint] = [] # list of active plugins endpoints - endpoint._init_decorators(fastapi_app) # the endpoint decorators need the fastapi instance - self.active_plugins: List[str] = [] self.plugins_folder = utils.get_plugins_path() diff --git a/core/cat/mad_hatter/plugin.py b/core/cat/mad_hatter/plugin.py index 0d5adcb3..954e53e5 100644 --- a/core/cat/mad_hatter/plugin.py +++ b/core/cat/mad_hatter/plugin.py @@ -74,7 +74,7 @@ def activate(self): except Exception as e: raise e - # Load of hooks and tools + # Load of hook, tools, forms and endpoints self._load_decorated_functions() # by default, plugin settings are saved inside the plugin folder @@ -341,7 +341,7 @@ def plugin_specific_error_message(self): def _deactivate_endpoints(self): for endpoint in self._endpoints: - endpoint.remove() + endpoint.deactivate() def _clean_hook(self, hook: CatHook): # getmembers returns a tuple diff --git a/core/tests/mad_hatter/test_endpoints.py b/core/tests/mad_hatter/test_endpoints.py index 7e4ac309..d16c7a48 100644 --- a/core/tests/mad_hatter/test_endpoints.py +++ b/core/tests/mad_hatter/test_endpoints.py @@ -18,6 +18,8 @@ def mad_hatter(client): # client here injects the monkeypatched version of the yield mad_hatter + mad_hatter.uninstall_plugin("mock_plugin") + def test_endpoints_discovery(mad_hatter): mock_plugin_endpoints = mad_hatter.plugins["mock_plugin"].endpoints @@ -35,7 +37,7 @@ def test_endpoints_discovery(mad_hatter): assert h.tags == ["Tests"] -def test_endpoint_decorator(client): +def test_endpoint_decorator(client, mad_hatter): response = client.get("/custom-endpoints/endpoint") @@ -44,7 +46,7 @@ def test_endpoint_decorator(client): assert response.json()["result"] == "endpoint default prefix" -def test_endpoint_decorator_prefix(client): +def test_endpoint_decorator_prefix(client, mad_hatter): response = client.get("/tests/endpoint") @@ -53,7 +55,7 @@ def test_endpoint_decorator_prefix(client): assert response.json()["result"] == "endpoint prefix tests" -def test_get_endpoint(client): +def test_get_endpoint(client, mad_hatter): response = client.get("/tests/get") @@ -63,7 +65,7 @@ def test_get_endpoint(client): assert response.json()["stray_user_id"] == "user" -def test_post_endpoint(client): +def test_post_endpoint(client, mad_hatter): payload = {"name": "the cat", "description" : "it's magic"} response = client.post("/tests/post", json=payload) diff --git a/core/tests/mad_hatter/test_plugin.py b/core/tests/mad_hatter/test_plugin.py index 1d76f947..1a0fc301 100644 --- a/core/tests/mad_hatter/test_plugin.py +++ b/core/tests/mad_hatter/test_plugin.py @@ -54,6 +54,7 @@ def test_create_plugin(plugin): # hooks and tools assert plugin.hooks == [] assert plugin.tools == [] + assert plugin.endpoints == [] def test_activate_plugin(plugin): From 972cd7c75088264acddf47bb3cfa6e3bddefc023 Mon Sep 17 00:00:00 2001 From: Samuele Barzaghi Date: Sun, 3 Nov 2024 14:01:13 +0100 Subject: [PATCH 08/13] Log on endpoint activation/deactivation --- core/cat/mad_hatter/decorators/endpoint.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/core/cat/mad_hatter/decorators/endpoint.py b/core/cat/mad_hatter/decorators/endpoint.py index 0bbf3b7e..f28e74ce 100644 --- a/core/cat/mad_hatter/decorators/endpoint.py +++ b/core/cat/mad_hatter/decorators/endpoint.py @@ -30,12 +30,14 @@ def set_api_route(self, api_route): def activate(self, cheshire_cat_api): + log.info(f"Activating custom endpoint {self.name}...") + self.cheshire_cat_api = cheshire_cat_api # Set the fastapi api_route into the Custom Endpoint for api_route in self.cheshire_cat_api.routes: if api_route.path == self.name: - log.error(f"There is already an endpoint with path {self.name}") + log.info(f"There is already an active endpoint with path {self.name}") return plugins_router = APIRouter() @@ -60,6 +62,8 @@ def activate(self, cheshire_cat_api): def deactivate(self): + log.info(f"Deactivating custom endpoint {self.name}...") + # Seems there is no official way to remove a route: # https://github.com/fastapi/fastapi/discussions/8088 self.cheshire_cat_api.routes.remove(self.api_route) From 05d4107104003980a9b303e1e374f578f4b6ffd0 Mon Sep 17 00:00:00 2001 From: Samuele Barzaghi Date: Sun, 3 Nov 2024 14:05:55 +0100 Subject: [PATCH 09/13] Fix test_deactivate_plugin Why not using plugin.deactivate()? the "endpoint" decorator needs the reference to the FastAPI app instance for endpoint removing. Calling plugin.deactivate() the method CustomEndpoint.deactivate() raises the exception: "self.cheshire_cat_api is None" because the method on_finish_plugins_sync_callback is not called so the cheshire_cat_api is not injected --- core/tests/conftest.py | 19 ++++++++++++- core/tests/mad_hatter/test_endpoints.py | 36 ++++++------------------- core/tests/mad_hatter/test_plugin.py | 19 ++++++++----- 3 files changed, 39 insertions(+), 35 deletions(-) diff --git a/core/tests/conftest.py b/core/tests/conftest.py index 2930d85d..f504ca9e 100644 --- a/core/tests/conftest.py +++ b/core/tests/conftest.py @@ -21,6 +21,8 @@ from cat.startup import cheshire_cat_api from tests.utils import create_mock_plugin_zip +from cat.mad_hatter.mad_hatter import MadHatter + import time FAKE_TIMESTAMP = 1705855981 @@ -162,4 +164,19 @@ def patch_time_now(monkeypatch): def mytime(): return FAKE_TIMESTAMP - monkeypatch.setattr(time, 'time', mytime) \ No newline at end of file + monkeypatch.setattr(time, 'time', mytime) + +#fixture for mad hatter with mock plugin installed +@pytest.fixture +def mad_hatter_with_mock_plugin(client): # client here injects the monkeypatched version of the cat + + # each test is given the mad_hatter instance (it's a singleton) + mad_hatter = MadHatter() + + # install plugin + new_plugin_zip_path = create_mock_plugin_zip(flat=True) + mad_hatter.install_plugin(new_plugin_zip_path) + + yield mad_hatter + + mad_hatter.uninstall_plugin("mock_plugin") diff --git a/core/tests/mad_hatter/test_endpoints.py b/core/tests/mad_hatter/test_endpoints.py index d16c7a48..b65b5d79 100644 --- a/core/tests/mad_hatter/test_endpoints.py +++ b/core/tests/mad_hatter/test_endpoints.py @@ -1,28 +1,8 @@ -import pytest - -from cat.mad_hatter.mad_hatter import MadHatter from cat.mad_hatter.decorators import CustomEndpoint -from tests.utils import create_mock_plugin_zip - -# this function will be run before each test function -@pytest.fixture -def mad_hatter(client): # client here injects the monkeypatched version of the cat - - # each test is given the mad_hatter instance (it's a singleton) - mad_hatter = MadHatter() - - # install plugin - new_plugin_zip_path = create_mock_plugin_zip(flat=True) - mad_hatter.install_plugin(new_plugin_zip_path) - - yield mad_hatter - - mad_hatter.uninstall_plugin("mock_plugin") - +def test_endpoints_discovery(mad_hatter_with_mock_plugin): -def test_endpoints_discovery(mad_hatter): - mock_plugin_endpoints = mad_hatter.plugins["mock_plugin"].endpoints + mock_plugin_endpoints = mad_hatter_with_mock_plugin.plugins["mock_plugin"].endpoints assert len(mock_plugin_endpoints) == 4 @@ -37,7 +17,7 @@ def test_endpoints_discovery(mad_hatter): assert h.tags == ["Tests"] -def test_endpoint_decorator(client, mad_hatter): +def test_endpoint_decorator(client, mad_hatter_with_mock_plugin): response = client.get("/custom-endpoints/endpoint") @@ -46,7 +26,7 @@ def test_endpoint_decorator(client, mad_hatter): assert response.json()["result"] == "endpoint default prefix" -def test_endpoint_decorator_prefix(client, mad_hatter): +def test_endpoint_decorator_prefix(client, mad_hatter_with_mock_plugin): response = client.get("/tests/endpoint") @@ -55,7 +35,7 @@ def test_endpoint_decorator_prefix(client, mad_hatter): assert response.json()["result"] == "endpoint prefix tests" -def test_get_endpoint(client, mad_hatter): +def test_get_endpoint(client, mad_hatter_with_mock_plugin): response = client.get("/tests/get") @@ -65,7 +45,7 @@ def test_get_endpoint(client, mad_hatter): assert response.json()["stray_user_id"] == "user" -def test_post_endpoint(client, mad_hatter): +def test_post_endpoint(client, mad_hatter_with_mock_plugin): payload = {"name": "the cat", "description" : "it's magic"} response = client.post("/tests/post", json=payload) @@ -75,12 +55,12 @@ def test_post_endpoint(client, mad_hatter): assert response.json()["name"] == "the cat" assert response.json()["description"] == "it's magic" -def test_plugin_deactivation(client, mad_hatter): +def test_plugin_deactivation(client, mad_hatter_with_mock_plugin): response = client.get("/custom-endpoints/endpoint") assert response.status_code == 200 - mad_hatter.toggle_plugin("mock_plugin") + mad_hatter_with_mock_plugin.toggle_plugin("mock_plugin") response = client.get("/custom-endpoints/endpoint") assert response.status_code == 404 \ No newline at end of file diff --git a/core/tests/mad_hatter/test_plugin.py b/core/tests/mad_hatter/test_plugin.py index 1a0fc301..0551b265 100644 --- a/core/tests/mad_hatter/test_plugin.py +++ b/core/tests/mad_hatter/test_plugin.py @@ -94,12 +94,19 @@ def test_activate_plugin(plugin): assert "mock tool example 2" in tool.start_examples -def test_deactivate_plugin(plugin): - # The plugin is non active by default - plugin.activate() - - # deactivate it - plugin.deactivate() +def test_deactivate_plugin(mad_hatter_with_mock_plugin): + + # The plugin is installed and activated by the mad_hatter_with_mock_plugin fixture + + # Get the reference to the mock plugin + plugin = mad_hatter_with_mock_plugin.plugins["mock_plugin"] + + # Deactivate the mock plugin + # Why not using plugin.deactivate()? + # the "endpoint" decorator needs the reference to the FastAPI app instance for endpoint removing. + # Calling plugin.deactivate() the method CustomEndpoint.deactivate() raises the exception: + # "self.cheshire_cat_api is None" + mad_hatter_with_mock_plugin.toggle_plugin("mock_plugin") assert plugin.active is False From c5a42f7945cbd8ec7c2e23b842b8264c1d1fb5ea Mon Sep 17 00:00:00 2001 From: Samuele Barzaghi Date: Sun, 3 Nov 2024 14:13:20 +0100 Subject: [PATCH 10/13] Fix linter, endpoint not used --- core/cat/mad_hatter/mad_hatter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/cat/mad_hatter/mad_hatter.py b/core/cat/mad_hatter/mad_hatter.py index 55a038d3..5983983a 100644 --- a/core/cat/mad_hatter/mad_hatter.py +++ b/core/cat/mad_hatter/mad_hatter.py @@ -18,7 +18,7 @@ from cat.mad_hatter.plugin import Plugin from cat.mad_hatter.decorators.hook import CatHook from cat.mad_hatter.decorators.tool import CatTool -from cat.mad_hatter.decorators.endpoint import CustomEndpoint, endpoint +from cat.mad_hatter.decorators.endpoint import CustomEndpoint from cat.experimental.form import CatForm From 654171a7ea9f42e7e3ceca057a90dde0c183a521 Mon Sep 17 00:00:00 2001 From: Samuele Barzaghi Date: Sun, 3 Nov 2024 17:39:56 +0100 Subject: [PATCH 11/13] Catch exception while including the custom router --- core/cat/mad_hatter/decorators/endpoint.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/core/cat/mad_hatter/decorators/endpoint.py b/core/cat/mad_hatter/decorators/endpoint.py index f28e74ce..388ebf9f 100644 --- a/core/cat/mad_hatter/decorators/endpoint.py +++ b/core/cat/mad_hatter/decorators/endpoint.py @@ -49,7 +49,12 @@ def activate(self, cheshire_cat_api): **self.kwargs, ) - self.cheshire_cat_api.include_router(plugins_router, prefix=self.prefix) + try: + self.cheshire_cat_api.include_router(plugins_router, prefix=self.prefix) + except BaseException as e: + log.error(f"Error activating custom endpoint [{self.name}]: {e}") + return + self.cheshire_cat_api.openapi_schema = None # Flush the cache of openapi schema # Set the fastapi api_route into the Custom Endpoint From d418cf60b3b48f09ae5f18e91cfbff4da861db89 Mon Sep 17 00:00:00 2001 From: Samuele Barzaghi Date: Mon, 4 Nov 2024 08:07:50 +0100 Subject: [PATCH 12/13] Is not necessary to postpone AuthHandler load anymore --- core/cat/looking_glass/cheshire_cat.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/core/cat/looking_glass/cheshire_cat.py b/core/cat/looking_glass/cheshire_cat.py index 851b940e..a832871c 100644 --- a/core/cat/looking_glass/cheshire_cat.py +++ b/core/cat/looking_glass/cheshire_cat.py @@ -61,17 +61,19 @@ def __init__(self, fastapi_app): """ # bootstrap the Cat! ^._.^ + + # get reference to the FastAPI app self.fastapi_app = fastapi_app + # load AuthHandler + self.load_auth() + # Start scheduling system self.white_rabbit = WhiteRabbit() # instantiate MadHatter (loads all plugins' hooks and tools) self.mad_hatter = MadHatter() - # load AuthHandler - self.load_auth() - # allows plugins to do something before cat components are loaded self.mad_hatter.execute_hook("before_cat_bootstrap", cat=self) From dd485ee2644e93d2e0af723855ac23d05c1cccb8 Mon Sep 17 00:00:00 2001 From: Samuele Barzaghi Date: Tue, 5 Nov 2024 15:50:27 +0100 Subject: [PATCH 13/13] Default prefix "custom" instead of "custom-endpoints" --- core/cat/mad_hatter/decorators/endpoint.py | 2 +- core/tests/mad_hatter/test_endpoints.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/core/cat/mad_hatter/decorators/endpoint.py b/core/cat/mad_hatter/decorators/endpoint.py index 388ebf9f..c42ff754 100644 --- a/core/cat/mad_hatter/decorators/endpoint.py +++ b/core/cat/mad_hatter/decorators/endpoint.py @@ -78,7 +78,7 @@ class Endpoint: cheshire_cat_api = None - default_prefix = "/custom-endpoints" + default_prefix = "/custom" default_tags = ["Custom Endpoints"] # Called from madhatter to inject the fastapi app instance diff --git a/core/tests/mad_hatter/test_endpoints.py b/core/tests/mad_hatter/test_endpoints.py index b65b5d79..44fd36b2 100644 --- a/core/tests/mad_hatter/test_endpoints.py +++ b/core/tests/mad_hatter/test_endpoints.py @@ -10,16 +10,16 @@ def test_endpoints_discovery(mad_hatter_with_mock_plugin): assert isinstance(h, CustomEndpoint) assert h.plugin_id == "mock_plugin" - if h.name == "/custom-endpoints/endpoint": + if h.name == "/custom/endpoint": assert h.tags == ["Custom Endpoints"] - if h.name == "/custom-endpoints/tests": + if h.name == "/custom/tests": assert h.tags == ["Tests"] def test_endpoint_decorator(client, mad_hatter_with_mock_plugin): - response = client.get("/custom-endpoints/endpoint") + response = client.get("/custom/endpoint") assert response.status_code == 200 @@ -57,10 +57,10 @@ def test_post_endpoint(client, mad_hatter_with_mock_plugin): def test_plugin_deactivation(client, mad_hatter_with_mock_plugin): - response = client.get("/custom-endpoints/endpoint") + response = client.get("/custom/endpoint") assert response.status_code == 200 mad_hatter_with_mock_plugin.toggle_plugin("mock_plugin") - response = client.get("/custom-endpoints/endpoint") + response = client.get("/custom/endpoint") assert response.status_code == 404 \ No newline at end of file