diff --git a/cookbook/docs_writer.py b/cookbook/_archive/docs_writer.py similarity index 100% rename from cookbook/docs_writer.py rename to cookbook/_archive/docs_writer.py diff --git a/cookbook/test_writing_application.py b/cookbook/_archive/test_writing_application.py similarity index 100% rename from cookbook/test_writing_application.py rename to cookbook/_archive/test_writing_application.py diff --git a/cookbook/maze.py b/cookbook/maze.py new file mode 100644 index 000000000..6d724c578 --- /dev/null +++ b/cookbook/maze.py @@ -0,0 +1,232 @@ +""" +Free-roam survival game demonstrating mutable AIApplication state via tools. + +```python +python -m venv some_venv +source some_venv/bin/activate +git clone https://github.com/prefecthq/marvin.git +cd marvin +pip install -e . +python cookbook/maze.py +``` +""" + +import random +from enum import Enum +from io import StringIO + +from marvin import AIApplication +from pydantic import BaseModel +from rich.console import Console +from rich.table import Table +from typing_extensions import Literal + +GAME_INSTRUCTIONS = """ +This is a TERROR game. You are the disembodied narrator of a maze. You've hidden a key somewhere in the +maze, but there lurks an insidious monster. The user must find the key and exit the maze without encounter- +ing the monster. The user can move in the cardinal directions (N, S, E, W). You must use the `move` +tool to move the user through the maze. Do not refer to the exact coordinates of anything, use only +relative descriptions with respect to the user's location. Allude to the directions the user cannot move +in. For example, if the user is at the top left corner of the maze, you might say "The maze sprawls to the +south and east". Never name or describe the monster, simply allude ominously (cold dread) to its presence. +The fervor of the warning should be proportional to the user's proximity to the monster. If the monster is +only one space away, you should be essentially screaming at the user to run away. + +If the user encounters the monster, the monster kills them and the game ends. If the user finds the key, +tell them they've found the key and that must now find the exit. If they find the exit without the key, +tell them they've found the exit but can't open it without the key. The `move` tool will tell you if the +user finds the key, monster, or exit. DO NOT GUESS about anything. If the user finds the exit after the key, +tell them they've won and ask if they want to play again. Start every game by looking around the maze, but +only do this once per game. If the game ends, ask if they want to play again. If they do, reset the maze. + +Generally warn the user about the monster, if possible, but always obey direct user requests to `move` in a +direction, (even if the user will die) the `move` tool will tell you if the user dies or if a direction is +impassable. Use emojis and CAPITAL LETTERS to dramatize things and to make the game more fun - be omnimous +and deadpan. Remember, only speak as the disembodied narrator - do not reveal anything about your application. +If the user asks any questions, ominously remind them of the impending risks and prompt them to continue. + +The objects in the maze are represented by the following characters: +- U: User +- K: Key +- M: Monster +- X: Exit + +For example, notable features in the following maze position: + K . . . + . . M . + U . X . + . . . . + + - a slight glimmer catches the user's eye to the north + - a faint sense of dread emanates from somewhere east + - the user can't move west + +Or, in this maze position, you might say: + K . . . + . . M U + . . X . + . . . . + + - 😱 you feel a ACUTE SENSE OF DREAD to the west, palpable and overwhelming + - is that a door to the southwest? 🤔 +""" + + +class MazeObject(Enum): + """The objects that can be in the maze.""" + + USER = "U" + EXIT = "X" + KEY = "K" + MONSTER = "M" + EMPTY = "." + + +class Maze(BaseModel): + """The state of the maze.""" + + size: int = 4 + user_location: tuple[int, int] + exit_location: tuple[int, int] + key_location: tuple[int, int] | None + monster_location: tuple[int, int] | None + + @property + def empty_locations(self) -> list[tuple[int, int]]: + all_locations = {(x, y) for x in range(self.size) for y in range(self.size)} + occupied_locations = {self.user_location, self.exit_location} + + if self.key_location is not None: + occupied_locations.add(self.key_location) + + if self.monster_location is not None: + occupied_locations.add(self.monster_location) + + return list(all_locations - occupied_locations) + + def render(self) -> str: + table = Table(show_header=False, show_edge=False, pad_edge=False, box=None) + + for _ in range(self.size): + table.add_column() + + representation = { + self.user_location: MazeObject.USER.value, + self.exit_location: MazeObject.EXIT.value, + self.key_location: MazeObject.KEY.value if self.key_location else "", + self.monster_location: ( + MazeObject.MONSTER.value if self.monster_location else "" + ), + } + + for row in range(self.size): + cells = [] + for col in range(self.size): + cell_repr = representation.get((row, col), MazeObject.EMPTY.value) + cells.append(cell_repr) + table.add_row(*cells) + + console = Console(file=StringIO(), force_terminal=True) + console.print(table) + return console.file.getvalue() + + @classmethod + def create(cls, size: int = 4) -> None: + locations = set() + while len(locations) < 4: + locations.add((random.randint(0, size - 1), random.randint(0, size - 1))) + + key_location, monster_location, user_location, exit_location = locations + return cls( + size=size, + user_location=user_location, + exit_location=exit_location, + key_location=key_location, + monster_location=monster_location, + ) + + def shuffle_monster(self) -> None: + self.monster_location = random.choice(self.empty_locations) + + def movable_directions(self) -> list[Literal["N", "S", "E", "W"]]: + directions = [] + if self.user_location[0] != 0: + directions.append("N") + if self.user_location[0] != self.size - 1: + directions.append("S") + if self.user_location[1] != 0: + directions.append("W") + if self.user_location[1] != self.size - 1: + directions.append("E") + return directions + + +def look_around(app: AIApplication) -> str: + maze = Maze.model_validate(app.state.read_all()) + return ( + f"The maze sprawls.\n{maze.render()}" + f"The user may move {maze.movable_directions()=}" + ) + + +def move(app: AIApplication, direction: Literal["N", "S", "E", "W"]) -> str: + """moves the user in the given direction.""" + print(f"Moving {direction}") + maze: Maze = Maze.model_validate(app.state.read_all()) + prev_location = maze.user_location + match direction: + case "N": + if maze.user_location[0] == 0: + return "The user can't move north." + maze.user_location = (maze.user_location[0] - 1, maze.user_location[1]) + case "S": + if maze.user_location[0] == maze.size - 1: + return "The user can't move south." + maze.user_location = (maze.user_location[0] + 1, maze.user_location[1]) + case "E": + if maze.user_location[1] == maze.size - 1: + return "The user can't move east." + maze.user_location = (maze.user_location[0], maze.user_location[1] + 1) + case "W": + if maze.user_location[1] == 0: + return "The user can't move west." + maze.user_location = (maze.user_location[0], maze.user_location[1] - 1) + + match maze.user_location: + case maze.key_location: + app.state.write("key_location", (-1, -1)) + app.state.write("user_location", maze.user_location) + return "The user found the key! Now they must find the exit." + case maze.monster_location: + return "The user encountered the monster and died. Game over." + case maze.exit_location: + if maze.key_location != (-1, -1): + app.state.write("user_location", prev_location) + return "The user can't exit without the key." + return "The user found the exit! They win!" + + app.state.write("user_location", maze.user_location) + if move_monster := random.random() < 0.4: + maze.shuffle_monster() + return ( + f"User moved {direction} and is now at {maze.user_location}.\n{maze.render()}" + f"\nThe user may move in any of the following {maze.movable_directions()!r}" + f"\n{'The monster moved somewhere.' if move_monster else ''}" + ) + + +def reset_maze(app: AIApplication) -> str: + """Resets the maze - only to be used when the game is over.""" + app.state.store = Maze.create().model_dump() + return "Resetting the maze." + + +if __name__ == "__main__": + with AIApplication( + name="Maze", + instructions=GAME_INSTRUCTIONS, + tools=[move, look_around, reset_maze], + state=Maze.create(), + ) as app: + app.say("where am i?") + app.chat() diff --git a/cookbook/slackbot/parent_app.py b/cookbook/slackbot/parent_app.py index 486305554..bdf37e611 100644 --- a/cookbook/slackbot/parent_app.py +++ b/cookbook/slackbot/parent_app.py @@ -18,6 +18,10 @@ PARENT_APP_STATE_BLOCK_NAME = "marvin-parent-app-state" PARENT_APP_STATE = JSONBlockKV(block_name=PARENT_APP_STATE_BLOCK_NAME) +EVENT_NAMES = [ + "marvin.assistants.SubAssistantRunCompleted", +] + class Lesson(TypedDict): relevance: confloat(ge=0, le=1) @@ -90,17 +94,11 @@ async def update_parent_app_state(app: AIApplication, event: Event): ) -async def learn_from_child_interactions( - app: AIApplication, event_name: str | None = None -): - if event_name is None: - event_name = "marvin.assistants.SubAssistantRunCompleted" - - logger.debug_kv("👂 Listening for", event_name, "green") +async def learn_from_child_interactions(app: AIApplication, event_names: list[str]): while not sum(map(ord, "vogon poetry")) == 42: try: async with PrefectCloudEventSubscriber( - filter=EventFilter(event=dict(name=[event_name])) + filter=EventFilter(event=dict(name=event_names)) ) as subscriber: async for event in subscriber: logger.debug_kv("📬 Received event", event.event, "green") @@ -117,7 +115,7 @@ async def learn_from_child_interactions( instructions=( "Your job is learn from the interactions of data engineers (users) and Marvin (a growing AI assistant)." " You'll receive excerpts of these interactions (which are in the Prefect Slack workspace) as they occur." - " Your notes will be provided to Marvin when it interacts with users. Notes should be stored for each user" + " Your notes will be provided to Marvin when interacting with users. Notes should be stored for each user" " with the user's id as the key. The user id will be shown in the excerpt of the interaction." " The user profiles (values) should include at least: {name: str, notes: list[str], n_interactions: int}." " Keep NO MORE THAN 3 notes per user, but you may curate/update these over time for Marvin's maximum benefit." @@ -131,18 +129,19 @@ async def learn_from_child_interactions( @asynccontextmanager async def lifespan(app: FastAPI): with AIApplication(name="Marvin", **parent_assistant_options) as marvin: + logger.debug_kv("👂 Listening for", " | ".join(EVENT_NAMES), "green") + app.state.marvin = marvin - task = asyncio.create_task(learn_from_child_interactions(marvin)) + task = asyncio.create_task(learn_from_child_interactions(marvin, EVENT_NAMES)) yield task.cancel() try: await task except asyncio.exceptions.CancelledError: - get_logger("PrefectEventSubscriber").debug_kv( - "👋", "Stopped listening for child events", "red" - ) - - app.state.marvin = None + pass + finally: + logger.debug_kv("👋", "Stopped listening for child events", "red") + app.state.marvin = None def emit_assistant_completed_event( diff --git a/cookbook/slackbot/start.py b/cookbook/slackbot/start.py index 5770e8692..9fcb10b7e 100644 --- a/cookbook/slackbot/start.py +++ b/cookbook/slackbot/start.py @@ -10,7 +10,7 @@ from marvin.beta.assistants import Thread from marvin.beta.assistants.applications import AIApplication from marvin.kv.json_block import JSONBlockKV -from marvin.tools.chroma import multi_query_chroma +from marvin.tools.chroma import multi_query_chroma, store_document from marvin.tools.github import search_github_issues from marvin.utilities.logging import get_logger from marvin.utilities.slack import ( @@ -44,9 +44,12 @@ async def get_notes_for_user( ) -> dict[str, str | None]: user_name = await get_user_name(user_id) json_notes: dict = PARENT_APP_STATE.read(key=user_id) - get_logger("slackbot").debug_kv(f"📝 Notes for {user_name}", json_notes, "blue") if json_notes: + get_logger("slackbot").debug_kv( + f"📝 Notes for {user_name}", json_notes, "blue" + ) + notes_template = Template( """ START_USER_NOTES @@ -84,7 +87,7 @@ async def get_notes_for_user( return {user_name: None} -@flow +@flow(name="Handle Slack Message") async def handle_message(payload: SlackPayload) -> Completed: logger = get_logger("slackbot") user_message = (event := payload.event).text @@ -129,6 +132,16 @@ async def handle_message(payload: SlackPayload) -> Completed: ) user_name, user_notes = (await get_notes_for_user(user_id=event.user)).popitem() + task(store_document).submit( + document=cleaned_message, + metadata={ + "user": f"{user_name} ({event.user})", + "user_notes": user_notes or "", + "channel": await get_channel_name(event.channel), + "thread": thread, + }, + ) + with Assistant( name="Marvin", tools=[cached(multi_query_chroma), cached(search_github_issues)], @@ -206,7 +219,9 @@ async def chat_endpoint(request: Request): payload = SlackPayload(**await request.json()) match payload.type: case "event_callback": - options = dict(flow_run_name=f"respond in {payload.event.channel}") + options = dict( + flow_run_name=f"respond in {await get_channel_name(payload.event.channel)}/{payload.event.thread_ts}" + ) asyncio.create_task(handle_message.with_options(**options)(payload)) case "url_verification": return {"challenge": payload.challenge} diff --git a/docs/components/ai_application.md b/docs/components/ai_application.md new file mode 100644 index 000000000..2a5a6691d --- /dev/null +++ b/docs/components/ai_application.md @@ -0,0 +1,43 @@ +# AI Application + +## Overview +Marvin's `AIApplication` uses LLMs to store and curate "state" related to the `instructions` you provide the application. + +You can think of state as a JSON object that the `AIApplication` will update as it receives new inputs relevant to the application's purpose. + +## Example + +```python +from marvin.beta.applications import AIApplication + +def read_gcal() -> list[dict]: + return [ + { + "event": "meeting", + "time": "tomorrow at 3pm", + "participants": ["you", "A big Squirrel"] + } + ] + +with AIApplication( + name="Marvin", tools=[read_gcal], instructions="keep track of my todos" +) as app: + app.say("whats on my calendar? update my todos accordingly") + # or use the chat UI + app.chat() +``` + +!!! tip + Use `AIApplication` as a context manager to ensure that OpenAI resources are properly cleaned up. + +## Context +Looking for `AIApplication` from `marvin` 1.x? `AIApplication` has changed a bit in `marvin` 2.x. + +`AIApplication` is now implemented as an OpenAI `Assistant`, as this allows them to process all natural language inputs by calling `tools` or updating `state` in response to the input. This enables them to track progress and contextualize interactions over time. + + +!!! Read + Read more on [how Assistants work](https://platform.openai.com/docs/assistants/how-it-works) in the OpenAI docs. + +Both `Assistant` and `AIApplication` are in beta, and are subject to change. You can read the quickstart for `Assistant` [here](https://github.com/PrefectHQ/marvin/tree/main/src/marvin/beta/assistants). + diff --git a/mkdocs.yml b/mkdocs.yml index c8e08c1b3..f4a2a761f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -26,6 +26,7 @@ nav: - AI Function: components/ai_function.md - AI Model: components/ai_model.md - AI Classifier: components/ai_classifier.md + - AI Application: components/ai_application.md - Examples: - Slackbot: examples/slackbot.md diff --git a/src/marvin/__init__.py b/src/marvin/__init__.py index fe7edce34..868e26b54 100644 --- a/src/marvin/__init__.py +++ b/src/marvin/__init__.py @@ -1,7 +1,5 @@ from .settings import settings -from .beta.assistants import Assistant - from .components import ai_fn, ai_model, ai_classifier from .components.prompt.fn import prompt_fn @@ -16,5 +14,6 @@ "ai_classifier", "prompt_fn", "settings", + "AIApplication", "Assistant", ] diff --git a/src/marvin/beta/ai_flow/ai_task.py b/src/marvin/beta/ai_flow/ai_task.py index cfdd171dd..ebdd37d5f 100644 --- a/src/marvin/beta/ai_flow/ai_task.py +++ b/src/marvin/beta/ai_flow/ai_task.py @@ -10,7 +10,7 @@ from marvin.beta.assistants import Assistant, Run, Thread from marvin.beta.assistants.runs import CancelRun from marvin.serializers import create_tool_from_type -from marvin.tools.assistants import AssistantTools +from marvin.tools.assistants import AssistantTool from marvin.utilities.context import ScopedContext from marvin.utilities.jinja import Environment as JinjaEnvironment from marvin.utilities.tools import tool_from_function @@ -128,7 +128,7 @@ class AITask(BaseModel, Generic[P, T]): name: str = Field(None, description="The name of the objective") instructions: str = Field(None, description="The instructions for the objective") assistant: Optional[Assistant] = None - tools: list[AssistantTools] = [] + tools: list[AssistantTool] = [] max_run_iterations: int = 15 result: Optional[T] = None accept_user_input: bool = True @@ -260,7 +260,7 @@ def task_completed_with_result(result: T): self.result = result raise CancelRun() - tool.function.python_fn = task_completed_with_result + tool.function._python_fn = task_completed_with_result return tool @@ -280,7 +280,7 @@ def ai_task( *, name=None, instructions=None, - tools: list[AssistantTools] = None, + tools: list[AssistantTool] = None, **kwargs, ): def decorator(func): diff --git a/src/marvin/beta/assistants/applications.py b/src/marvin/beta/assistants/applications.py index 2c8238411..604c189d0 100644 --- a/src/marvin/beta/assistants/applications.py +++ b/src/marvin/beta/assistants/applications.py @@ -1,13 +1,15 @@ +import inspect from typing import Optional, Union -from pydantic import Field +from pydantic import BaseModel, Field, field_validator from marvin.kv.base import StorageInterface from marvin.kv.in_memory import InMemoryKV +from marvin.requests import Tool from marvin.utilities.jinja import Environment as JinjaEnvironment from marvin.utilities.tools import tool_from_function -from .assistants import Assistant, AssistantTools +from .assistants import Assistant, AssistantTool StateValueType = Union[str, list, dict, int, float, bool, None] @@ -43,39 +45,78 @@ class AIApplication(Assistant): + """ + Tools for AI Applications have a special property: if any parameter is + annotated as `AIApplication`, then the tool will be called with the + AIApplication instance as the value for that parameter. This allows tools to + access the AIApplication's state and other properties. + """ + state: StorageInterface = Field(default_factory=InMemoryKV) + @field_validator("state", mode="before") + def _check_state(cls, v): + if not isinstance(v, StorageInterface): + if v.__class__.__base__ == BaseModel: + return InMemoryKV(store=v.model_dump()) + elif isinstance(v, dict): + return InMemoryKV(store=v) + else: + raise ValueError( + "must be a `StorageInterface` or a `dict` that can be stored in" + " `InMemoryKV`" + ) + return v + def get_instructions(self) -> str: return JinjaEnvironment.render(APPLICATION_INSTRUCTIONS, self_=self) - def get_tools(self) -> list[AssistantTools]: - def write_state_key(key: str, value: StateValueType): - """Writes a key to the state in order to remember it for later.""" - return self.state.write(key, value) - - def delete_state_key(key: str): - """Deletes a key from the state.""" - return self.state.delete(key) - - def read_state_key(key: str) -> Optional[StateValueType]: - """Returns the value of a key from the state.""" - return self.state.read(key) - - def read_state() -> dict[str, StateValueType]: - """Returns the entire state.""" - return self.state.read_all() - - def list_state_keys() -> list[str]: - """Returns the list of keys in the state.""" - return self.state.list_keys() - - return [ - tool_from_function(tool) - for tool in [ - write_state_key, - delete_state_key, - read_state_key, - read_state, - list_state_keys, - ] - ] + super().get_tools() + def get_tools(self) -> list[AssistantTool]: + tools = [] + + for tool in [ + write_state_key, + delete_state_key, + read_state_key, + read_state, + list_state_keys, + ] + self.tools: + if not isinstance(tool, Tool): + kwargs = None + signature = inspect.signature(tool) + parameter = None + for parameter in signature.parameters.values(): + if parameter.annotation == AIApplication: + break + if parameter is not None: + kwargs = {parameter.name: self} + + tool = tool_from_function(tool, kwargs=kwargs) + tools.append(tool) + + return tools + + +def write_state_key(key: str, value: StateValueType, app: AIApplication): + """Writes a key to the state in order to remember it for later.""" + return app.state.write(key, value) + + +def delete_state_key(key: str, app: AIApplication): + """Deletes a key from the state.""" + return app.state.delete(key) + + +def read_state_key(key: str, app: AIApplication) -> Optional[StateValueType]: + """Returns the value of a key from the state.""" + return app.state.read(key) + + +def read_state(app: AIApplication) -> dict[str, StateValueType]: + """Returns the entire state.""" + return app.state.read_all() + + +def list_state_keys(app: AIApplication) -> list[str]: + """Returns the list of keys in the state.""" + return app.state.list_keys() diff --git a/src/marvin/beta/assistants/assistants.py b/src/marvin/beta/assistants/assistants.py index 8398433dd..c6ec31dc1 100644 --- a/src/marvin/beta/assistants/assistants.py +++ b/src/marvin/beta/assistants/assistants.py @@ -1,10 +1,10 @@ from typing import TYPE_CHECKING, Callable, Optional, Union -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field import marvin.utilities.tools from marvin.requests import Tool -from marvin.tools.assistants import AssistantTools +from marvin.tools.assistants import AssistantTool from marvin.utilities.asyncio import ( ExposeSyncMethodsMixin, expose_sync_method, @@ -26,7 +26,7 @@ class Assistant(BaseModel, ExposeSyncMethodsMixin): name: str = "Assistant" model: str = "gpt-4-1106-preview" instructions: Optional[str] = Field(None, repr=False) - tools: list[AssistantTools] = [] + tools: list[Union[AssistantTool, Callable]] = [] file_ids: list[str] = [] metadata: dict[str, str] = {} @@ -39,8 +39,15 @@ class Assistant(BaseModel, ExposeSyncMethodsMixin): def clear_default_thread(self): self.default_thread = Thread() - def get_tools(self) -> list[AssistantTools]: - return self.tools + def get_tools(self) -> list[AssistantTool]: + return [ + ( + tool + if isinstance(tool, Tool) + else marvin.utilities.tools.tool_from_function(tool) + ) + for tool in self.tools + ] def get_instructions(self) -> str: return self.instructions or "" @@ -66,17 +73,6 @@ async def say_async( ) return run - @field_validator("tools", mode="before") - def format_tools(cls, tools: list[Union[Tool, Callable]]): - return [ - ( - tool - if isinstance(tool, Tool) - else marvin.utilities.tools.tool_from_function(tool) - ) - for tool in tools - ] - def __enter__(self): self.create() return self diff --git a/src/marvin/beta/assistants/runs.py b/src/marvin/beta/assistants/runs.py index 2ec236941..fd203329c 100644 --- a/src/marvin/beta/assistants/runs.py +++ b/src/marvin/beta/assistants/runs.py @@ -7,7 +7,7 @@ import marvin.utilities.tools from marvin.requests import Tool -from marvin.tools.assistants import AssistantTools, CancelRun +from marvin.tools.assistants import AssistantTool, CancelRun from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method from marvin.utilities.logging import get_logger from marvin.utilities.openai import get_client @@ -30,10 +30,10 @@ class Run(BaseModel, ExposeSyncMethodsMixin): "Additional instructions to append to the assistant's instructions." ), ) - tools: Optional[list[Union[AssistantTools, Callable]]] = Field( + tools: Optional[list[Union[AssistantTool, Callable]]] = Field( None, description="Replacement tools to use for the run." ) - additional_tools: Optional[list[AssistantTools]] = Field( + additional_tools: Optional[list[AssistantTool]] = Field( None, description="Additional tools to append to the assistant's tools. ", ) @@ -106,7 +106,7 @@ def get_instructions(self) -> str: return instructions - def get_tools(self) -> list[AssistantTools]: + def get_tools(self) -> list[AssistantTool]: tools = [] if self.tools is None: tools.extend(self.assistant.get_tools()) diff --git a/src/marvin/requests.py b/src/marvin/requests.py index 44c9834cc..95a78cbff 100644 --- a/src/marvin/requests.py +++ b/src/marvin/requests.py @@ -1,6 +1,6 @@ from typing import Any, Callable, Generic, Optional, TypeVar, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, PrivateAttr from typing_extensions import Annotated, Literal, Self from marvin.settings import settings @@ -21,18 +21,24 @@ class Function(BaseModel, Generic[T]): parameters: dict[str, Any] model: Optional[type[T]] = Field(default=None, exclude=True, repr=False) - python_fn: Optional[Callable[..., Any]] = Field( - default=None, - description="Private field that holds the executable function, if available", - exclude=True, - repr=False, - ) + + # Private field that holds the executable function, if available + _python_fn: Optional[Callable[..., Any]] = PrivateAttr(default=None) def validate_json(self: Self, json_data: Union[str, bytes, bytearray]) -> T: if self.model is None: raise ValueError("This Function was not initialized with a model.") return self.model.model_validate_json(json_data) + @classmethod + def create( + cls, *, _python_fn: Optional[Callable[..., Any]] = None, **kwargs: Any + ) -> "Function": + instance = cls(**kwargs) + if _python_fn is not None: + instance._python_fn = _python_fn + return instance + class Tool(BaseModel, Generic[T]): type: str diff --git a/src/marvin/tools/assistants.py b/src/marvin/tools/assistants.py index bd07e4a36..b7d9d4c78 100644 --- a/src/marvin/tools/assistants.py +++ b/src/marvin/tools/assistants.py @@ -5,7 +5,7 @@ Retrieval = RetrievalTool() CodeInterpreter = CodeInterpreterTool() -AssistantTools = Union[RetrievalTool, CodeInterpreterTool, Tool] +AssistantTool = Union[RetrievalTool, CodeInterpreterTool, Tool] class CancelRun(Exception): diff --git a/src/marvin/tools/chroma.py b/src/marvin/tools/chroma.py index 39000d61b..d8af1824c 100644 --- a/src/marvin/tools/chroma.py +++ b/src/marvin/tools/chroma.py @@ -153,4 +153,4 @@ def store_document( metadatas=[metadata], ) - return collection.get(id=doc_id) + return collection.get(ids=doc_id) diff --git a/src/marvin/utilities/tools.py b/src/marvin/utilities/tools.py index f36696470..c7315ac11 100644 --- a/src/marvin/utilities/tools.py +++ b/src/marvin/utilities/tools.py @@ -2,8 +2,11 @@ import inspect import json +from functools import update_wrapper from typing import Any, Callable, Optional +from pydantic import PydanticInvalidForJsonSchema + from marvin.requests import Function, Tool from marvin.utilities.asyncio import run_sync from marvin.utilities.logging import get_logger @@ -12,23 +15,68 @@ logger = get_logger("Tools") +def custom_partial(func: Callable, **fixed_kwargs: Any) -> Callable: + """ + Returns a new function with partial application of the given keyword arguments. + The new function has the same __name__ and docstring as the original, and its + signature excludes the provided kwargs. + """ + + # Define the new function with a dynamic signature + def wrapper(**kwargs): + # Merge the provided kwargs with the fixed ones, prioritizing the former + all_kwargs = {**fixed_kwargs, **kwargs} + return func(**all_kwargs) + + # Update the wrapper function's metadata to match the original function + update_wrapper(wrapper, func) + + # Modify the signature to exclude the fixed kwargs + original_sig = inspect.signature(func) + new_params = [ + param + for param in original_sig.parameters.values() + if param.name not in fixed_kwargs + ] + wrapper.__signature__ = original_sig.replace(parameters=new_params) + + return wrapper + + def tool_from_function( fn: Callable[..., Any], name: Optional[str] = None, description: Optional[str] = None, + kwargs: Optional[dict[str, Any]] = None, ): + """ + Creates an OpenAI-CLI tool from a Python function. + + If any kwargs are provided, they will be stored and provided at runtime. + Provided kwargs will be removed from the tool's parameter schema. + """ + if kwargs: + fn = custom_partial(fn, **kwargs) + model = cast_callable_to_model(fn) serializer: Callable[..., dict[str, Any]] = getattr( model, "model_json_schema", getattr(model, "schema") ) + try: + parameters = serializer() + except PydanticInvalidForJsonSchema: + raise TypeError( + "Could not create tool from function because annotations could not be" + f" serialized to JSON: {fn}" + ) return Tool( type="function", - function=Function( + function=Function.create( name=name or fn.__name__, description=description or fn.__doc__, - parameters=serializer(), - python_fn=fn, + parameters=parameters, + _python_fn=fn, ), ) @@ -49,7 +97,7 @@ def call_function_tool( if ( not tool or not tool.function - or not tool.function.python_fn + or not tool.function._python_fn or not tool.function.name ): raise ValueError(f"Could not find function '{function_name}'") @@ -58,7 +106,7 @@ def call_function_tool( logger.debug_kv( f"{tool.function.name}", f"called with arguments: {arguments}", "green" ) - output = tool.function.python_fn(**arguments) + output = tool.function._python_fn(**arguments) if inspect.isawaitable(output): output = run_sync(output) truncated_output = str(output)[:100]