Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New "*endpoint" decorators for custom endpoint creations #964

Merged
merged 13 commits into from
Nov 13, 2024
23 changes: 18 additions & 5 deletions core/cat/looking_glass/cheshire_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,17 @@ class CheshireCat:

"""

def __init__(self):
def __init__(self, fastapi_app):
"""Cat initialization.

At init time the Cat executes the bootstrap.
"""

# bootstrap the Cat! ^._.^

# get reference to the FastAPI app
self.fastapi_app = fastapi_app

# load AuthHandler
self.load_auth()

Expand All @@ -80,10 +83,11 @@ def __init__(self):
# Load memories (vector collections and working_memory)
self.load_memory()

# After memory is loaded, we can get/create tools embeddings
# every time the mad_hatter finishes syncing hooks, tools and forms, it will notify the Cat (so it can embed tools in vector memory)
self.mad_hatter.on_finish_plugins_sync_callback = self.embed_procedures
self.embed_procedures() # first time launched manually
# After memory is loaded, we can get/create tools embeddings
self.mad_hatter.on_finish_plugins_sync_callback = self.on_finish_plugins_sync_callback

# First time launched manually
self.on_finish_plugins_sync_callback()

# Main agent instance (for reasoning)
self.main_agent = MainAgent()
Expand Down Expand Up @@ -332,6 +336,15 @@ def build_active_procedures_hashes(self, active_procedures):
}
return hashes

def on_finish_plugins_sync_callback(self):
self.activate_endpoints()
self.embed_procedures()

def activate_endpoints(self):
for endpoint in self.mad_hatter.endpoints:
if endpoint.plugin_id in self.mad_hatter.active_plugins:
endpoint.activate(self.fastapi_app)

def embed_procedures(self):
# Retrieve from vectorDB all procedural embeddings
embedded_procedures, _ = self.memory.vectors.procedural.get_all_points()
Expand Down
3 changes: 2 additions & 1 deletion core/cat/mad_hatter/decorators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from cat.mad_hatter.decorators.tool import CatTool, tool
from cat.mad_hatter.decorators.hook import CatHook, hook
from cat.mad_hatter.decorators.endpoint import CustomEndpoint, endpoint
from cat.mad_hatter.decorators.plugin_decorator import CatPluginDecorator, plugin

__all__ = ["CatTool", "tool", "CatHook", "hook", "CatPluginDecorator", "plugin"]
__all__ = ["CatTool", "tool", "CatHook", "hook", "CustomEndpoint", "endpoint", "CatPluginDecorator", "plugin"]
189 changes: 189 additions & 0 deletions core/cat/mad_hatter/decorators/endpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
from typing import Callable
from fastapi import APIRouter

from cat.log import log

# class to represent a @endpoint
class CustomEndpoint:
def __init__(
self,
prefix: str,
path: str,
function: Callable,
methods,
tags,
**kwargs,
):
self.prefix = prefix
self.path = path
self.function = function
self.tags = tags
self.methods = methods
self.kwargs = kwargs
self.name = self.prefix + self.path

def __repr__(self) -> str:
return f"CustomEndpoint(path={self.name})"

def set_api_route(self, api_route):
self.api_route = api_route

def activate(self, cheshire_cat_api):

log.info(f"Activating custom endpoint {self.name}...")

self.cheshire_cat_api = cheshire_cat_api

# Set the fastapi api_route into the Custom Endpoint
for api_route in self.cheshire_cat_api.routes:
if api_route.path == self.name:
log.info(f"There is already an active endpoint with path {self.name}")
return

plugins_router = APIRouter()
valentimarco marked this conversation as resolved.
Show resolved Hide resolved
plugins_router.add_api_route(
path=self.path,
endpoint=self.function,
methods=self.methods,
tags=self.tags,
**self.kwargs,
)

try:
self.cheshire_cat_api.include_router(plugins_router, prefix=self.prefix)
except BaseException as e:
log.error(f"Error activating custom endpoint [{self.name}]: {e}")
return

self.cheshire_cat_api.openapi_schema = None # Flush the cache of openapi schema

# Set the fastapi api_route into the Custom Endpoint
for api_route in self.cheshire_cat_api.routes:
if api_route.path == self.name:
self.api_route = api_route
break

assert api_route.path == self.name

def deactivate(self):

log.info(f"Deactivating custom endpoint {self.name}...")

# Seems there is no official way to remove a route:
# https://github.com/fastapi/fastapi/discussions/8088
self.cheshire_cat_api.routes.remove(self.api_route)
self.cheshire_cat_api.openapi_schema = None # Flush the cached openapi schema

class Endpoint:

cheshire_cat_api = None

default_prefix = "/custom"
default_tags = ["Custom Endpoints"]

# Called from madhatter to inject the fastapi app instance
def _init_decorators(cls, new_cheshire_cat_api):
cls.cheshire_cat_api = new_cheshire_cat_api

# @endpoint decorator. Any function in a plugin decorated by @endpoint.endpoint will be exposed as FastAPI operation
def endpoint(
cls,
path,
methods,
prefix=default_prefix,
tags=default_tags,
**kwargs,
) -> Callable:
"""
Define a custom API endpoint, parameters are the same as FastAPI path operation.
Examples:
.. code-block:: python
from cat.mad_hatter.decorators import endpoint

@endpoint.endpoint(path="/hello", methods=["GET"])
def my_endpoint():
return {"Hello":"Alice"}
"""

def _make_endpoint(endpoint):
custom_endpoint = CustomEndpoint(
prefix=prefix,
path=path,
function=endpoint,
methods=methods,
tags=tags,
**kwargs,
)

return custom_endpoint

return _make_endpoint

# @get_endpoint decorator. Any function in a plugin decorated by @endpoint.get will be exposed as FastAPI GET operation
def get(
cls,
path,
prefix=default_prefix,
response_model=None,
tags=default_tags,
**kwargs,
) -> Callable:
"""
Define a custom API endpoint for GET operation, parameters are the same as FastAPI path operation.
Examples:
.. code-block:: python
from cat.mad_hatter.decorators import endpoint

@endpoint.get(path="/hello")
def my_get_endpoint():
return {"Hello":"Alice"}
"""

return cls.endpoint(
path=path,
methods=["GET"],
prefix=prefix,
response_model=response_model,
tags=tags,
**kwargs,
)

# @post_endpoint decorator. Any function in a plugin decorated by @endpoint.post will be exposed as FastAPI POST operation
def post(
cls,
path,
prefix=default_prefix,
response_model=None,
tags=default_tags,
**kwargs,
) -> Callable:
"""
Define a custom API endpoint for POST operation, parameters are the same as FastAPI path operation.
Examples:
.. code-block:: python

from cat.mad_hatter.decorators import endpoint
from pydantic import BaseModel

class Item(BaseModel):
name: str
description: str

@endpoint.post(path="/hello")
def my_post_endpoint(item: Item):
return {"Hello": item.name, "Description": item.description}
"""
return cls.endpoint(
path=path,
methods=["POST"],
prefix=prefix,
response_model=response_model,
tags=tags,
**kwargs,
)


endpoint = None

if not endpoint:
endpoint = Endpoint()
9 changes: 7 additions & 2 deletions core/cat/mad_hatter/mad_hatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from cat.mad_hatter.plugin import Plugin
from cat.mad_hatter.decorators.hook import CatHook
from cat.mad_hatter.decorators.tool import CatTool
from cat.mad_hatter.decorators.endpoint import CustomEndpoint

from cat.experimental.form import CatForm

Expand All @@ -29,7 +30,7 @@
@singleton
class MadHatter:
# loads and execute plugins
# - enter into the plugin folder and loads everthing
# - enter into the plugin folder and loads everything
# that is decorated or named properly
# - orders plugged in hooks by name and priority
# - exposes functionality to the cat
Expand All @@ -42,6 +43,7 @@ def __init__(self):
] = {} # dict of active plugins hooks ( hook_name -> [CatHook, CatHook, ...])
self.tools: List[CatTool] = [] # list of active plugins tools
self.forms: List[CatForm] = [] # list of active plugins forms
self.endpoints: List[CustomEndpoint] = [] # list of active plugins endpoints

self.active_plugins: List[str] = []

Expand Down Expand Up @@ -138,15 +140,18 @@ def sync_hooks_tools_and_forms(self):
self.hooks = {}
self.tools = []
self.forms = []
self.endpoints = []

for _, plugin in self.plugins.items():
# load hooks, tools and forms from active plugins
# load hooks, tools, forms and endpoints from active plugins
if plugin.id in self.active_plugins:
# cache tools
self.tools += plugin.tools

self.forms += plugin.forms

self.endpoints += plugin.endpoints

# cache hooks (indexed by hook name)
for h in plugin.hooks:
if h.name not in self.hooks.keys():
Expand Down
31 changes: 29 additions & 2 deletions core/cat/mad_hatter/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pydantic import BaseModel, ValidationError
from packaging.requirements import Requirement

from cat.mad_hatter.decorators import CatTool, CatHook, CatPluginDecorator
from cat.mad_hatter.decorators import CatTool, CatHook, CatPluginDecorator, CustomEndpoint
from cat.experimental.form import CatForm
from cat.utils import to_camel_case
from cat.log import log
Expand Down Expand Up @@ -59,6 +59,7 @@ def __init__(self, plugin_path: str):
self._hooks: List[CatHook] = [] # list of plugin hooks
self._tools: List[CatTool] = [] # list of plugin tools
self._forms: List[CatForm] = [] # list of plugin forms
self._endpoints: List[CustomEndpoint] = [] # list of plugin endpoints

# list of @plugin decorated functions overriding default plugin behaviour
self._plugin_overrides = [] # TODO: make this a dictionary indexed by func name, for faster access
Expand All @@ -73,7 +74,7 @@ def activate(self):
except Exception as e:
raise e

# Load of hooks and tools
# Load of hook, tools, forms and endpoints
self._load_decorated_functions()

# by default, plugin settings are saved inside the plugin folder
Expand All @@ -98,6 +99,8 @@ def deactivate(self):

self._hooks = []
self._tools = []
#TODO : clear of forms is missing here?
self._deactivate_endpoints()
self._plugin_overrides = []
self._active = False

Expand Down Expand Up @@ -295,6 +298,7 @@ def _load_decorated_functions(self):
hooks = []
tools = []
forms = []
endpoints = []
plugin_overrides = []

for py_file in self.py_files:
Expand All @@ -309,6 +313,7 @@ def _load_decorated_functions(self):
hooks += getmembers(plugin_module, self._is_cat_hook)
tools += getmembers(plugin_module, self._is_cat_tool)
forms += getmembers(plugin_module, self._is_cat_form)
endpoints += getmembers(plugin_module, self._is_custom_endpoint)
plugin_overrides += getmembers(
plugin_module, self._is_cat_plugin_override
)
Expand All @@ -323,6 +328,7 @@ def _load_decorated_functions(self):
self._hooks = list(map(self._clean_hook, hooks))
self._tools = list(map(self._clean_tool, tools))
self._forms = list(map(self._clean_form, forms))
self._endpoints = list(map(self._clean_endpoint, endpoints))
self._plugin_overrides = list(
map(self._clean_plugin_override, plugin_overrides)
)
Expand All @@ -332,6 +338,11 @@ def plugin_specific_error_message(self):
url = self.manifest.get("plugin_url")
return f"To resolve any problem related to {name} plugin, contact the creator using github issue at the link {url}"

def _deactivate_endpoints(self):

for endpoint in self._endpoints:
endpoint.deactivate()

def _clean_hook(self, hook: CatHook):
# getmembers returns a tuple
h = hook[1]
Expand All @@ -349,6 +360,12 @@ def _clean_form(self, form: CatForm):
f = form[1]
f.plugin_id = self._id
return f

def _clean_endpoint(self, endpoint: CustomEndpoint):
# getmembers returns a tuple
f = endpoint[1]
f.plugin_id = self._id
return f

def _clean_plugin_override(self, plugin_override):
# getmembers returns a tuple
Expand Down Expand Up @@ -382,6 +399,12 @@ def _is_cat_tool(obj):
def _is_cat_plugin_override(obj):
return isinstance(obj, CatPluginDecorator)

# a plugin custom endpoint has to be decorated with @endpoint
# (which returns an instance of CustomEndpoint)
@staticmethod
def _is_custom_endpoint(obj):
return isinstance(obj, CustomEndpoint)

@property
def path(self):
return self._path
Expand Down Expand Up @@ -409,3 +432,7 @@ def tools(self):
@property
def forms(self):
return self._forms

@property
def endpoints(self):
return self._endpoints
Loading
Loading