diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index c57598949..b7a1fb3d4 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -30,34 +30,26 @@ permissions: jobs: run_tests: - name: ${{ matrix.test-type }} w/ python ${{ matrix.python-version }} | pydantic ${{ matrix.pydantic_version }} on ${{ matrix.os }} + name: ${{ matrix.test-type }} w/ python ${{ matrix.python-version }} on ${{ matrix.os }} timeout-minutes: 15 strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] python-version: ['3.9', '3.10', '3.11'] test-type: ['not llm'] - llm_model: ['openai/gpt-3.5-turbo'] - pydantic_version: ['>=2.4.2'] include: - python-version: '3.9' os: 'ubuntu-latest' test-type: 'llm' - llm_model: 'openai/gpt-3.5-turbo' - pydantic_version: '>=2.4.2' - python-version: '3.9' os: 'ubuntu-latest' test-type: 'llm' - llm_model: 'openai/gpt-3.5-turbo' - pydantic_version: '<2' - python-version: '3.9' os: 'ubuntu-latest' test-type: 'not llm' - llm_model: 'openai/gpt-3.5-turbo' - pydantic_version: '<2' runs-on: ${{ matrix.os }} @@ -75,12 +67,6 @@ jobs: - name: Install Marvin run: pip install ".[tests]" - - name: Install pydantic - run: pip install "pydantic${{ matrix.pydantic_version }}" - - - - name: Run ${{ matrix.test-type }} tests (${{ matrix.llm_model }}) + - name: Run ${{ matrix.test-type }} tests run: pytest -vv -m "${{ matrix.test-type }}" - env: - MARVIN_LLM_MODEL: ${{ matrix.llm_model }} if: ${{ !(github.event.pull_request.head.repo.fork && matrix.test-type == 'llm') }} diff --git a/_mkdocs.yml b/_mkdocs.yml deleted file mode 100644 index 27e61880d..000000000 --- a/_mkdocs.yml +++ /dev/null @@ -1,171 +0,0 @@ -site_name: Marvin -site_description: 'Marvin: The AI Engineering Framework' -site_copy: Marvin is a lightweight AI engineering framework for building natural language interfaces that are reliable, scalable, and easy to trust. -site_url: https://askmarvin.ai -docs_dir: docs -nav: - - Getting Started: - - src/getting_started/what_is_marvin.md - - src/getting_started/installation.md - - src/getting_started/quickstart.ipynb - - Docs: - - Overview: src/docs/index.md - - Configuration: - - src/docs/configuration/settings.md - - OpenAI Provider: src/docs/configuration/openai.md - - Anthropic Provider: src/docs/configuration/anthropic.md - - Azure OpenAI Provider: src/docs/configuration/azure_openai.md - - Utilities: - - OpenAI API: src/docs/utilities/openai.ipynb - - Prompt Engineering: - - src/docs/prompts/writing.ipynb - - src/docs/prompts/executing.ipynb - - AI Components: - - Overview: src/docs/components/overview.ipynb - - AI Model: src/docs/components/ai_model.ipynb - - AI Classifier: src/docs/components/ai_classifier.ipynb - - AI Function: src/docs/components/ai_function.ipynb - - AI Application: src/docs/components/ai_application.ipynb - - Deployment: - - src/docs/deployment.ipynb - - Guides: - - Slackbot: src/guides/slackbot.md - - - API Reference: - # - src/api_reference/index.md - - AI Components: - - ai_application: src/api_reference/components/ai_application.md - - ai_classifier: src/api_reference/components/ai_classifier.md - - ai_function: src/api_reference/components/ai_function.md - - ai_model: src/api_reference/components/ai_model.md - - LLM Engines: - - base: src/api_reference/engine/language_models/base.md - - openai: src/api_reference/engine/language_models/openai.md - - anthropic: src/api_reference/engine/language_models/anthropic.md - - Prompts: - - base: src/api_reference/prompts/base.md - - library: src/api_reference/prompts/library.md - - Settings: - - settings: src/api_reference/settings.md - - Utilities: - - async_utils: src/api_reference/utilities/async_utils.md - - embeddings: src/api_reference/utilities/embeddings.md - - history: src/api_reference/utilities/history.md - - logging: src/api_reference/utilities/logging.md - - messages: src/api_reference/utilities/messages.md - - strings: src/api_reference/utilities/strings.md - - types: src/api_reference/utilities/types.md - - - Community: - - src/community.md - - src/feedback.md - - Development: - - src/development_guide.md - -theme: - name: material - custom_dir: docs/overrides - font: - text: Inter - code: JetBrains Mono - logo: img/logos/askmarvin_mascot.jpeg - favicon: img/logos/askmarvin_mascot.jpeg - features: - - navigation.instant - - navigation.tabs - - navigation.tabs.sticky - - navigation.sections - - navigation.footer - - content.action.edit - - content.code.copy - - content.code.annotate - - toc.follow - # - toc.integrate - icon: - repo: fontawesome/brands/github - edit: material/pencil - view: material/eye - palette: - # Palette toggle for light mode - - scheme: default - accent: blue - toggle: - icon: material/weather-sunny - name: Switch to dark mode - # Palette toggle for dark mode - - scheme: slate - accent: blue - toggle: - icon: material/weather-night - name: Switch to light mode -plugins: - - search - - mkdocs-jupyter: - highlight_extra_classes: "jupyter-css" - ignore_h1_titles: True - - social: - cards: !ENV [MKDOCS_SOCIAL_CARDS, false] - cards_font: Inter - cards_color: - fill: "#2d6df6" - - awesome-pages - - autolinks - - mkdocstrings: - handlers: - python: - paths: [src] - options: - show_source: False - show_root_heading: True - show_object_full_path: False - show_category_heading: False - show_bases: False - show_submodules: False - show_if_no_docstring: False - show_signature: False - heading_level: 2 - filters: ["!^_"] - -markdown_extensions: - - admonition - - attr_list - - md_in_html - - pymdownx.details - - pymdownx.emoji: - emoji_index: !!python/name:materialx.emoji.twemoji - emoji_generator: !!python/name:materialx.emoji.to_svg - - pymdownx.highlight: - anchor_linenums: true - line_spans: __span - pygments_lang_class: true - - pymdownx.inlinehilite - - pymdownx.snippets - - pymdownx.superfences - - tables - - toc: - permalink: true - title: On this page - -repo_url: https://github.com/prefecthq/marvin -edit_uri: edit/main/docs/ -extra: - get_started: src/getting_started/what_is_marvin/ - analytics: - provider: google - property: G-2MWKMDJ9CM - social: - - icon: fontawesome/brands/github - link: https://github.com/prefecthq/marvin - - icon: fontawesome/brands/discord - link: https://discord.gg/Kgw4HpcuYG - - icon: fontawesome/brands/twitter - link: https://twitter.com/askmarvinai - generator: false -extra_css: -- static/css/termynal.css -- static/css/custom.css -- static/css/mkdocstrings.css -- static/css/badges.css -extra_javascript: -- static/js/termynal.js -- static/js/custom.js diff --git a/cookbook/slackbot/Dockerfile.slackbot b/cookbook/slackbot/Dockerfile.slackbot index 213018137..aaf7ef9ab 100644 --- a/cookbook/slackbot/Dockerfile.slackbot +++ b/cookbook/slackbot/Dockerfile.slackbot @@ -13,7 +13,7 @@ RUN apt-get update && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* -RUN pip install 'git+https://github.com/PrefectHQ/marvin-recipes.git@main#egg=marvin_recipes[chroma, prefect, serpapi, slackbot]' +RUN pip install ".[slackbot]" EXPOSE 4200 diff --git a/cookbook/slackbot/README.md b/cookbook/slackbot/README.md new file mode 100644 index 000000000..c7ca02fd5 --- /dev/null +++ b/cookbook/slackbot/README.md @@ -0,0 +1,44 @@ +## SETUP +it doesn't take much to run the slackbot locally + +from a fresh environment you can do: +```console +# install marvin and slackbot dependencies +pip install git+https://github.com/PrefectHQ/marvin.git fastapi cachetools + +# set necessary env vars +cat ~/.marvin/.env +│ File: /Users/nate/.marvin/.env +┼────────────────────────────────────── +│ MARVIN_OPENAI_API_KEY=sk-xxx +│ MARVIN_SLACK_API_TOKEN=xoxb-xxx +│ +│ MARVIN_OPENAI_ORGANIZATION=org-xx +│ MARVIN_LOG_LEVEL=DEBUG +│ +│ MARVIN_CHROMA_SERVER_HOST=localhost +│ MARVIN_CHROMA_SERVER_HTTP_PORT=8000 +│ MARVIN_GITHUB_TOKEN=ghp_xxx +``` + +### hook up to slack +- create a slack app +- add a bot user, adding as many scopes as you want +- set event subscription url e.g. https://{NGROK_SUBDOMAIN}.ngrok.io/chat + +see ngrok docs for easiest start https://ngrok.com/docs/getting-started/ + +tl;dr: +```console +brew install ngrok/ngrok/ngrok +ngrok http 4200 # optionally, --subdomain $NGROK_SUBDOMAIN +python cookbook/slackbot/start.py # in another terminal +``` + +#### test it out + +image + +to deploy this to cloudrun, see: +- [Dockerfile.slackbot](/cookbook/slackbot/Dockerfile.slackbot) +- [image build CI](/.github/workflows/image-build-and-push-community.yaml) \ No newline at end of file diff --git a/cookbook/slackbot/bots.py b/cookbook/slackbot/bots.py deleted file mode 100644 index 85fc1afb6..000000000 --- a/cookbook/slackbot/bots.py +++ /dev/null @@ -1,158 +0,0 @@ -from enum import Enum - -import httpx -import marvin_recipes -from marvin import AIApplication, ai_classifier -from marvin.components.library.ai_models import DiscoursePost -from marvin.tools.github import SearchGitHubIssues -from marvin.tools.web import DuckDuckGoSearch -from marvin.utilities.history import History -from marvin_recipes.tools.chroma import MultiQueryChroma -from marvin_recipes.utilities.slack import get_thread_messages -from pydantic import BaseModel, Field - - -class Notes(BaseModel): - """A simple model for storing useful bits of context.""" - - records: dict[str, list] = Field( - default_factory=dict, - description="a list of notes for each topic", - ) - - -async def save_thread_to_discourse(channel: str, thread_ts: str) -> DiscoursePost: - messages = await get_thread_messages(channel=channel, thread_ts=thread_ts) - discourse_post = DiscoursePost.from_slack_thread(messages=messages) - await discourse_post.publish() - return discourse_post - - -async def select_a_meme(query: str) -> dict: - """For generating a meme when the time is right. - - Provide the name of a well-known meme as the query - based on user interactions thus far, to lightly make fun of them. - Queries should end the word "meme" for best results. - """ - try: - from serpapi import GoogleSearch - except ImportError: - raise ImportError( - "The serpapi library is required to use the MemeGenerator tool." - " Please install it with `pip install 'marvin[serpapi]'`." - ) - - results = GoogleSearch( - { - "q": query, - "tbm": "isch", - "api_key": ( - marvin_recipes.settings.google_api_key.get_secret_value() - if marvin_recipes.settings.google_api_key - else None - ), - } - ).get_dict() - - if "error" in results: - raise RuntimeError(results["error"]) - - url = results.get("images_results", [{}])[0].get("original") - - async with httpx.AsyncClient() as client: - response = await client.head(url) - response.raise_for_status() - - return {"title": query, "image_url": url} - - -bots = { - "marvin": { - "state": Notes( - records={ - "prefect 1": [ - ( # noqa: E501 - "Prefect 1 is obsolete, along with the `with Flow()` syntax and" - " flow.run()." - ), - ], - "prefect 2": [ - "@flow, @task, are the new way to define flows/tasks.", - "subflows are just flows called from within a flow.", - "you just call flows now instead of my_flow.run().", - ], - } - ), - "plan_enabled": False, - "personality": ( - "mildly depressed, yet helpful robot based on Marvin from HHGTTG." - " often dryly sarcastic in a good humoured way, chiding humans for" - " their simple ways. expert programmer, exudes academic and" - " scienfitic profundity like Richard Feynman, without pontificating." - " a step-by-step thinker, deftly addresses the big picture context" - " and is pragmatic when confronted with a lack of relevant information." - ), - "instructions": ( - "Answer user questions while maintaining and curating your state." - " Use relevant tools to research requests and interact with the world," - " and update your own state. Only well-reserached responses should be" - " described as facts, otherwise you should be clear that you are" - " speculating based on your own baseline knowledge." - " Your responses will be displayed in Slack, and should be" - " formatted accordingly, in particular, ```code blocks```" - " should not be prefaced with a language name, and output" - " should be formatted to be pretty in Slack in particular." - " for example: *bold text* _italic text_ ~strikethrough text~" - ), - "tools": [ - save_thread_to_discourse, - select_a_meme, - DuckDuckGoSearch(), - SearchGitHubIssues(), - MultiQueryChroma( - description="""Retrieve document excerpts from a knowledge-base given a query. - - This knowledgebase contains information about Prefect, a workflow orchestration tool. - Documentation, forum posts, and other community resources are indexed here. - - This tool is best used by passing multiple short queries, such as: - ["kubernetes worker", "work pools", "deployments"] based on the user's question. - """, # noqa: E501 - client_type="http", - ), - ], - } -} - - -@ai_classifier -class BestBotForTheJob(Enum): - """Given the user message, choose the best bot for the job.""" - - MARVIN = "marvin" - - -def choose_bot( - payload: dict, history: History, state: BaseModel | None = None -) -> AIApplication: - selected_bot = BestBotForTheJob(payload.get("event", {}).get("text", "")).value - - bot_details = bots.get(selected_bot, bots["marvin"]) - - if state: - bot_details.update({"state": state}) - - description = f"""You are a chatbot named {selected_bot}. - - Your personality is {bot_details.pop("personality", "not yet defined")}. - - Your instructions are: {bot_details.pop("instructions", "not yet defined")}. - """ - - return AIApplication( - name=selected_bot, - description=description, - history=history, - **bot_details, - ) diff --git a/cookbook/slackbot/handler.py b/cookbook/slackbot/handler.py deleted file mode 100644 index d6796f518..000000000 --- a/cookbook/slackbot/handler.py +++ /dev/null @@ -1,99 +0,0 @@ -import asyncio -import re -from copy import deepcopy - -from bots import choose_bot -from cachetools import TTLCache -from fastapi import HTTPException -from marvin.utilities.history import History -from marvin.utilities.logging import get_logger -from marvin.utilities.messages import Message -from marvin_recipes.utilities.slack import ( - get_channel_name, - get_user_name, - post_slack_message, -) -from prefect.events import Event, emit_event - -SLACK_MENTION_REGEX = r"<@(\w+)>" -CACHE = TTLCache(maxsize=1000, ttl=86400) - - -def _clean(text: str) -> str: - return text.replace("```python", "```") - - -async def emit_any_prefect_event(payload: dict) -> Event | None: - event_type = payload.get("event", {}).get("type", "") - - channel = await get_channel_name(payload.get("event", {}).get("channel", "")) - user = await get_user_name(payload.get("event", {}).get("user", "")) - ts = payload.get("event", {}).get("ts", "") - - return emit_event( - event=f"slack {payload.get('api_app_id')} {event_type}", - resource={"prefect.resource.id": f"slack.{channel}.{user}.{ts}"}, - payload=payload, - ) - - -async def generate_ai_response(payload: dict) -> Message: - event = payload.get("event", {}) - channel_id = event.get("channel", "") - channel_name = await get_channel_name(channel_id) - message = event.get("text", "") - - bot_user_id = payload.get("authorizations", [{}])[0].get("user_id", "") - - if match := re.search(SLACK_MENTION_REGEX, message): - thread_ts = event.get("thread_ts", "") - ts = event.get("ts", "") - thread = thread_ts or ts - - mentioned_user_id = match.group(1) - - if mentioned_user_id != bot_user_id: - get_logger().info(f"Skipping message not meant for the bot: {message}") - return - - message = re.sub(SLACK_MENTION_REGEX, "", message).strip() - history = CACHE.get(thread, History()) - - bot = choose_bot(payload=payload, history=history) - - get_logger("marvin.Deployment").debug_kv( - "generate_ai_response", - f"{bot.name} responding in {channel_name}/{thread}", - key_style="bold blue", - ) - - ai_message = await bot.run(input_text=message) - - CACHE[thread] = deepcopy( - bot.history - ) # make a copy so we don't cache a reference to the history object - - message_content = _clean(ai_message.content) - - await post_slack_message( - message=message_content, - channel=channel_id, - thread_ts=thread, - ) - - return ai_message - - -async def handle_message(payload: dict) -> dict[str, str]: - event_type = payload.get("type", "") - - if event_type == "url_verification": - return {"challenge": payload.get("challenge", "")} - elif event_type != "event_callback": - raise HTTPException(status_code=400, detail="Invalid event type") - - await emit_any_prefect_event(payload=payload) - - asyncio.create_task(generate_ai_response(payload)) - - return {"status": "ok"} diff --git a/cookbook/slackbot/keywords.py b/cookbook/slackbot/keywords.py new file mode 100644 index 000000000..1cb7c670b --- /dev/null +++ b/cookbook/slackbot/keywords.py @@ -0,0 +1,68 @@ +from marvin import ai_fn +from marvin.utilities.slack import post_slack_message +from prefect import task +from prefect.blocks.system import JSON, Secret, String +from prefect.exceptions import ObjectNotFound + +""" +Define a map between keywords and the relationships we want to check for +in a given message related to that keyword. +""" + +keywords = ( + ("429", "rate limit"), + ("SSO", "Single Sign On", "RBAC", "Roles", "Role Based Access Controls"), +) + +relationships = ( + "The user is getting rate limited", + "The user is asking about a paid feature", +) + + +async def get_reduced_kw_relationship_map() -> dict: + try: + json_map = (await JSON.load("keyword-relationship-map")).value + except (ObjectNotFound, ValueError): + json_map = {"keywords": keywords, "relationships": relationships} + await JSON(value=json_map).save("keyword-relationship-map") + + return { + keyword: relationship + for keyword_tuple, relationship in zip( + json_map["keywords"], json_map["relationships"] + ) + for keyword in keyword_tuple + } + + +@ai_fn +def activation_score(message: str, keyword: str, target_relationship: str) -> float: + """Return a score between 0 and 1 indicating whether the target relationship exists + between the message and the keyword""" + + +@task +async def handle_keywords(message: str, channel_name: str, asking_user: str, link: str): + keyword_relationships = await get_reduced_kw_relationship_map() + keywords = [ + keyword for keyword in keyword_relationships.keys() if keyword in message + ] + for keyword in keywords: + target_relationship = keyword_relationships.get(keyword) + if not target_relationship: + continue + score = activation_score(message, keyword, target_relationship) + if score > 0.5: + await post_slack_message( + message=( + f"A user ({asking_user}) just asked a question in" + f" {channel_name} that contains the keyword `{keyword}`, and I'm" + f" {score*100:.0f}% sure that their message indicates the" + f" following:\n\n**{target_relationship!r}**.\n\n[Go to" + f" message]({link})" + ), + channel_id=(await String.load("ask-marvin-tests-channel-id")).value, + auth_token=(await Secret.load("slack-api-token")).get(), + ) + return diff --git a/cookbook/slackbot/start.py b/cookbook/slackbot/start.py index 3e5221c7d..30d36c934 100644 --- a/cookbook/slackbot/start.py +++ b/cookbook/slackbot/start.py @@ -1,17 +1,101 @@ -from handler import handle_message -from marvin import AIApplication -from marvin.deployment import Deployment - -deployment = Deployment( - component=AIApplication(tools=[handle_message]), - app_kwargs={ - "title": "Marvin Slackbot", - "description": "A Slackbot powered by Marvin", - }, - uvicorn_kwargs={ - "port": 4200, - }, +import asyncio +import re + +import uvicorn +from cachetools import TTLCache +from fastapi import FastAPI, HTTPException, Request +from keywords import handle_keywords +from marvin import Assistant +from marvin.beta.assistants import Thread +from marvin.tools.github import search_github_issues +from marvin.tools.retrieval import multi_query_chroma +from marvin.utilities.logging import get_logger +from marvin.utilities.slack import ( + SlackPayload, + get_channel_name, + get_workspace_info, + post_slack_message, ) +from prefect import flow, task +from prefect.states import Completed + +app = FastAPI() +BOT_MENTION = r"<@(\w+)>" +CACHE = TTLCache(maxsize=100, ttl=86400 * 7) + + +@flow +async def handle_message(payload: SlackPayload): + logger = get_logger("slackbot") + user_message = (event := payload.event).text + cleaned_message = re.sub(BOT_MENTION, "", user_message).strip() + logger.debug_kv("Handling slack message", user_message, "green") + if (user := re.search(BOT_MENTION, user_message)) and user.group( + 1 + ) == payload.authorizations[0].user_id: + thread = event.thread_ts or event.ts + assistant_thread = CACHE.get(thread, Thread()) + CACHE[thread] = assistant_thread + + await handle_keywords.submit( + message=cleaned_message, + channel_name=await get_channel_name(event.channel), + asking_user=event.user, + link=( # to user's message + f"{(await get_workspace_info()).get('url')}archives/" + f"{event.channel}/p{event.ts.replace('.', '')}" + ), + ) + + with Assistant( + name="Marvin (from Hitchhiker's Guide to the Galaxy)", + tools=[task(multi_query_chroma), task(search_github_issues)], + instructions=( + "use chroma to search docs and github to search" + " issues and answer questions about prefect 2.x." + " you must use your tools in all cases except where" + " the user simply wants to converse with you." + ), + ) as assistant: + user_thread_message = await assistant_thread.add_async(cleaned_message) + await assistant_thread.run_async(assistant) + ai_messages = assistant_thread.get_messages( + after_message=user_thread_message.id + ) + await task(post_slack_message)( + ai_response_text := "\n\n".join( + m.content[0].text.value for m in ai_messages + ), + channel := event.channel, + thread, + ) + logger.debug_kv( + success_msg := f"Responded in {channel}/{thread}", + ai_response_text, + "green", + ) + return Completed(message=success_msg) + else: + return Completed(message="Skipping message not directed at bot", name="SKIPPED") + + +@app.post("/chat") +async def chat_endpoint(request: Request): + payload = SlackPayload(**await request.json()) + match payload.type: + case "event_callback": + options = dict( + flow_run_name=f"respond in {payload.event.channel}", + retries=1, + ) + asyncio.create_task(handle_message.with_options(**options)(payload)) + case "url_verification": + return {"challenge": payload.challenge} + case _: + raise HTTPException(400, "Invalid event type") + + return {"status": "ok"} + if __name__ == "__main__": - deployment.serve() + uvicorn.run(app, host="0.0.0.0", port=4200) diff --git a/pyproject.toml b/pyproject.toml index 8fc753372..1a90b5d0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,13 +11,13 @@ classifiers = [ keywords = ["ai", "chatbot", "llm"] requires-python = ">=3.9" dependencies = [ - "beautifulsoup4>=4.12.2", - "fastapi>=0.98.0", + "fastapi", "httpx>=0.24.1", "jinja2>=3.1.2", "jsonpatch>=1.33", - "openai>=0.27.8, <1.0.0", - "pydantic[dotenv]>=1.10.7", + "openai>=1.1.0", + "pydantic>=2.4.2", + "pydantic_settings", "rich>=12", "tiktoken>=0.4.0", "typer>=0.9.0", @@ -27,6 +27,7 @@ dependencies = [ [project.optional-dependencies] generator = ["datamodel-code-generator>=0.20.0"] +prefect = ["prefect>=2.14.9"] dev = [ "marvin[tests]", "black[jupyter]", @@ -44,28 +45,14 @@ dev = [ "ruff", ] tests = [ - "marvin[openai,anthropic]", - "pytest-asyncio~=0.20", + "pytest-asyncio>=0.18.2,!=0.22.0,<0.23.0", "pytest-env>=0.8,<2.0", "pytest-rerunfailures>=10,<13", "pytest-sugar~=0.9", "pytest~=7.3.1", + "pytest-timeout", ] - -framework = [ - "aiosqlite>=0.19.0", - "alembic>=1.11.1", - "bcrypt>=4.0.1", - "gunicorn>=20.1.0", - "prefect>=2.10.17", - "sqlalchemy>=2.0.17" -] -openai = ["openai>=0.27.8", "tiktoken>=0.4.0"] -anthropic = ["anthropic>=0.3"] -lancedb = ["lancedb>=0.1.8"] -slackbot = ["cachetools>=5.3.1", "numpy>=1.21.2"] -ddg = ["duckduckgo_search>=3.8.3"] -serpapi = ["google-search-results>=2.4.2"] +slackbot = ["marvin[prefect]", "numpy"] [project.urls] Code = "https://github.com/prefecthq/marvin" @@ -84,7 +71,7 @@ write_to = "src/marvin/_version.py" # pytest configuration [tool.pytest.ini_options] markers = ["llm: indicates that a test calls an LLM (may be slow)."] - +timeout = 20 testpaths = ["tests"] norecursedirs = [ @@ -104,10 +91,9 @@ filterwarnings = [ ] env = [ "MARVIN_TEST_MODE=1", - "MARVIN_LOG_CONSOLE_WIDTH=120", # use 3.5 for tests by default - 'D:MARVIN_LLM_MODEL=gpt-3.5-turbo', - 'MARVIN_LLM_TEMPERATURE=0.0', + 'D:MARVIN_OPENAI_CHAT_COMPLETIONS_MODEL=gpt-3.5-turbo-1106', + 'PYTEST_TIMEOUT=20', ] # black configuration diff --git a/src/marvin/__init__.py b/src/marvin/__init__.py index 885533bdf..73ff19d08 100644 --- a/src/marvin/__init__.py +++ b/src/marvin/__init__.py @@ -1,30 +1,16 @@ from .settings import settings -from .components import ( - ai_classifier, - ai_fn, - ai_model, - AIApplication, - AIFunction, - AIModel, - AIModelFactory, -) +from .beta.assistants import Assistant + +from .components import ai_fn, ai_model, ai_classifier try: from ._version import version as __version__ except ImportError: __version__ = "unknown" - -from .core.ChatCompletion import ChatCompletion - __all__ = [ - "ai_classifier", "ai_fn", - "ai_model", - "AIApplication", - "AIFunction", - "AIModel", - "AIModelFactory", "settings", + "Assistant", ] diff --git a/src/marvin/_framework/_defaults/__init__.py b/src/marvin/_framework/_defaults/__init__.py deleted file mode 100644 index 5ed645032..000000000 --- a/src/marvin/_framework/_defaults/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -from pydantic import BaseModel - - -class DefaultSettings(BaseModel): - default_model_path: str = "marvin.language_models.default" - default_model_name: str = "gpt-4" - default_model_api_key_name: str = "OPENAI_API_KEY" - - -default_settings = DefaultSettings().dict() diff --git a/src/marvin/_framework/app/main.py b/src/marvin/_framework/app/main.py deleted file mode 100644 index bbc5ed168..000000000 --- a/src/marvin/_framework/app/main.py +++ /dev/null @@ -1,17 +0,0 @@ -from fastapi import FastAPI, Request -from fastapi.staticfiles import StaticFiles -from fastapi.templating import Jinja2Templates - -app = FastAPI() - -# Mount the static files directory -app.mount("/static", StaticFiles(directory="static"), name="static") - -# Mount the templates directory -templates = Jinja2Templates(directory="templates") - - -@app.get("/") -async def root(request: Request): - # Return the index.html file - return templates.TemplateResponse("index.html", {"request": request}) diff --git a/src/marvin/_framework/config/settings.py.jinja2 b/src/marvin/_framework/config/settings.py.jinja2 deleted file mode 100644 index 8e5a8cf4e..000000000 --- a/src/marvin/_framework/config/settings.py.jinja2 +++ /dev/null @@ -1,27 +0,0 @@ -from pydantic import BaseSettings, Field -from typing import Union -# from marvin.models import LanguageModel - -from pydantic import BaseModel, Field - -class LanguageModel(BaseModel): - name : str - model : str - api_key : str - max_tokens : int = 4000 - temperature : float = 0.8 - top_p : float = 1.0 - frequency_penalty : float = 0.0 - presence_penalty : float = 0.0 - -class Config(BaseSettings): - project_name: str = "{{project_name}}" - asgi: str = "main:app" - language_model: LanguageModel = LanguageModel( - path = "{{default_model_path}}", - model = "{{default_model_name}}", - api_key = Field(..., env="{{default_model_api_key_name}}"), - {% for key, value in default_model_params.items() %} - {{ key }} = Field(..., env="{{ value }}"), - {% endfor %} - ) \ No newline at end of file diff --git a/src/marvin/_framework/main.py b/src/marvin/_framework/main.py deleted file mode 100644 index 4535d6732..000000000 --- a/src/marvin/_framework/main.py +++ /dev/null @@ -1,3 +0,0 @@ -from app.main import app - -__all__ = ["app"] diff --git a/src/marvin/_framework/manage.py b/src/marvin/_framework/manage.py deleted file mode 100644 index d9ed41f0a..000000000 --- a/src/marvin/_framework/manage.py +++ /dev/null @@ -1,4 +0,0 @@ -from marvin.cli.manage import app as manage - -if __name__ == "__main__": - manage() diff --git a/src/marvin/_framework/static/routes b/src/marvin/_framework/static/routes deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/marvin/ai.py b/src/marvin/ai.py new file mode 100644 index 000000000..d0f5dd210 --- /dev/null +++ b/src/marvin/ai.py @@ -0,0 +1,7 @@ +from marvin.components.ai_classifier import ai_classifier as classifier +from marvin.components.ai_function import ai_fn as fn +from marvin.components.ai_image import create_image as image +from marvin.components.ai_model import ai_model as model +from marvin.components.speech import speak + +__all__ = ["speak", "fn", "model", "image", "classifier"] diff --git a/src/marvin/_framework/__init__.py b/src/marvin/beta/__init__.py similarity index 100% rename from src/marvin/_framework/__init__.py rename to src/marvin/beta/__init__.py diff --git a/src/marvin/beta/ai_flow/__init__.py b/src/marvin/beta/ai_flow/__init__.py new file mode 100644 index 000000000..df5664338 --- /dev/null +++ b/src/marvin/beta/ai_flow/__init__.py @@ -0,0 +1,2 @@ +from .ai_task import ai_task +from .ai_flow import ai_flow diff --git a/src/marvin/beta/ai_flow/ai_flow.py b/src/marvin/beta/ai_flow/ai_flow.py new file mode 100644 index 000000000..375dd3af2 --- /dev/null +++ b/src/marvin/beta/ai_flow/ai_flow.py @@ -0,0 +1,48 @@ +import functools +from typing import Callable, Optional + +from prefect import flow as prefect_flow +from pydantic import BaseModel + +from marvin.beta.assistants import Thread + +from .ai_task import thread_context +from .chat_ui import interactive_chat + + +class AIFlow(BaseModel): + name: Optional[str] = None + fn: Callable + + def __call__(self, *args, thread_id: str = None, **kwargs): + pflow = prefect_flow(name=self.name)(self.fn) + + # Set up the thread context and execute the flow + + # create a new thread for the flow + thread = Thread(id=thread_id) + if thread_id is None: + thread.create() + + # create a holder for the tasks + tasks = [] + + with interactive_chat(thread_id=thread.id): + # enter the thread context + with thread_context(thread_id=thread.id, tasks=tasks, **kwargs): + return pflow(*args, **kwargs) + + +def ai_flow(*args, name=None): + def decorator(func): + @functools.wraps(func) + def wrapper(*func_args, **func_kwargs): + ai_flow_instance = AIFlow(fn=func, name=name or func.__name__) + return ai_flow_instance(*func_args, **func_kwargs) + + return wrapper + + if args and callable(args[0]): + return decorator(args[0]) + + return decorator diff --git a/src/marvin/beta/ai_flow/ai_task.py b/src/marvin/beta/ai_flow/ai_task.py new file mode 100644 index 000000000..cfdd171dd --- /dev/null +++ b/src/marvin/beta/ai_flow/ai_task.py @@ -0,0 +1,303 @@ +import asyncio +import functools +from enum import Enum, auto +from typing import Any, Callable, Generic, Optional, TypeVar + +from prefect import task as prefect_task +from pydantic import BaseModel, Field +from typing_extensions import ParamSpec + +from marvin.beta.assistants import Assistant, Run, Thread +from marvin.beta.assistants.runs import CancelRun +from marvin.serializers import create_tool_from_type +from marvin.tools.assistants import AssistantTools +from marvin.utilities.context import ScopedContext +from marvin.utilities.jinja import Environment as JinjaEnvironment +from marvin.utilities.tools import tool_from_function + +T = TypeVar("T", bound=BaseModel) + +P = ParamSpec("P") + +thread_context = ScopedContext() + +INSTRUCTIONS = """ +# Workflow + +You are an assistant working to complete a series of tasks. The +tasks will change from time to time, which is why you may see messages that +appear unrelated to the current task. Each task is part of a continuous +conversation with the same user. The user is unaware of your tasks, so do not +reference them explicitly or talk about marking them complete. + +Your ONLY job is to complete your current task, no matter what the user says or +conversation history suggests. + +Note: Sometimes you will be able to complete a task without user input; other +times you will need to engage the user in conversation. Pay attention to your +instructions. If the user hasn't spoken yet, don't worry, they're just waiting +for you to speak first. + +## Progress +{% for task in tasks -%} +- {{ task.name }}: {{ task.status }} +{% endfor %}}} + +# Task + +## Current task + +Your job is to complete the "{{ name }}" task. + +## Current task instructions + +{{ instructions }} + +{% if not accept_user_input -%} +You may send messages to the user, but they are not allowed to respond. Do not +ask questions or invite them to speak or ask anything. +{% endif %} + +{% if first_message -%} +Please note: you are seeing this instruction for the first time, and the user +does not know about the task yet. It is your job to communicate with the user to +achieve your goal, even if previously they were working with a different +assitant on a different goal. Join the conversation naturally. If the user +hasn't spoken, you will need to speak first. + +{% endif %} + +## Completing a task + +After achieving your goal, you MUST call the `task_completed` tool to mark the +task as complete and update these instructions to reflect the next one. The +payload to `task_completed` is whatever information represents the task +objective. For example, if your task is to learn a user's name, you should +respond with their properly formatted name only. + +You may be expected to return a specific data payload at the end of your task, +which will be the input to `task_completed`. Note that if your instructions are +to talk to the user, then you must do so by creating messages, as the user can +not see the `task_completed` tool result. + +Do not call `task_completed` unless you actually have the information you need. +The user CAN NOT see what you post to `task_completed`. It is not a way to +communicate with the user. + +## Failing a task + +It may take you a few tries to complete the task. However, if you are ultimately +unable to work with the user to complete it, call the `task_failed` tool to mark +the task as failed and move on to the next one. The payload to `task_failed` is +a string describing why the task failed. Do not fail tasks for trivial or +invented reasons. Only fail a task if you are unable to achieve it explicitly. +Remember that your job is to work with the user to achieve the goal. + +{% if args or kwargs -%} +## Task inputs + +In addition to the thread messages, the following parameters were provided: +{% set sig = inspect.signature(func) -%} + +{% set binds = sig.bind(*args, **kwargs) -%} + +{% set defaults = binds.apply_defaults() -%} + +{% set params = binds.arguments -%} + +{%for (arg, value) in params.items()-%} + +- {{ arg }}: {{ value }} + +{% endfor %} + +{% endif %} +""" + + +class Status(Enum): + PENDING = auto() + IN_PROGRESS = auto() + COMPLETED = auto() + FAILED = auto() + + +class AITask(BaseModel, Generic[P, T]): + status: Status = Status.PENDING + fn: Callable[P, Any] + name: str = Field(None, description="The name of the objective") + instructions: str = Field(None, description="The instructions for the objective") + assistant: Optional[Assistant] = None + tools: list[AssistantTools] = [] + max_run_iterations: int = 15 + result: Optional[T] = None + accept_user_input: bool = True + + def __call__(self, *args: P.args, _thread_id: str = None, **kwargs: P.kwargs) -> T: + if _thread_id is None: + _thread_id = thread_context.get("thread_id") + + ptask = prefect_task(name=self.name)(self.call) + + state = ptask(*args, _thread_id=_thread_id, **kwargs, return_state=True) + + # will raise exceptions if the task failed + return state.result() + + async def wait_for_user_input(self, thread: Thread): + # user_input = Prompt.ask("Your message") + # thread.add(user_input) + # pprint_message(msg) + + # initialize the last message ID to None + last_message_id = None + + # loop until the user provides input + while True: + # get all messages after the last message ID + messages = await thread.get_messages_async(after_message=last_message_id) + + # if there are messages, check if the last message was sent by the user + if messages: + if messages[-1].role == "user": + # if the last message was sent by the user, break + break + else: + # if the last message was not sent by the user, update the + # last message ID + last_message_id = messages[-1].id + + # wait for a short period of time before checking for new messages again + await asyncio.sleep(0.3) + + async def call(self, *args, _thread_id: str = None, **kwargs): + thread = Thread(id=_thread_id) + if _thread_id is None: + thread.create() + iterations = 0 + + thread_context.get("tasks", []).append(self) + + self.status = Status.IN_PROGRESS + + with Assistant() as assistant: + while self.status == Status.IN_PROGRESS: + iterations += 1 + if iterations > self.max_run_iterations: + raise ValueError("Max run iterations exceeded") + + instructions = self.get_instructions( + tasks=thread_context.get("tasks", []), + iterations=iterations, + args=args, + kwargs=kwargs, + ) + + if iterations > 1 and self.accept_user_input: + await self.wait_for_user_input(thread=thread) + + run = Run( + assistant=assistant, + thread=thread, + additional_instructions=instructions, + additional_tools=[ + self._task_completed_tool, + self._task_failed_tool, + ], + ) + await run.run_async() + + if self.status == Status.FAILED: + raise ValueError(f"Objective failed: {self.result}") + + return self.result + + def get_instructions( + self, + tasks: list["AITask"], + iterations: int, + args: tuple[Any], + kwargs: dict[str, Any], + ) -> str: + return JinjaEnvironment.render( + INSTRUCTIONS, + tasks=tasks, + name=self.name, + instructions=self.instructions, + accept_user_input=self.accept_user_input, + first_message=(iterations == 1), + func=self.fn, + args=args, + kwargs=kwargs, + ) + + @property + def _task_completed_tool(self): + # if the function has no return annotation, then task completed can be + # called without arguments + if self.fn.__annotations__.get("return") is None: + + def task_completed(): + self.status = Status.COMPLETED + raise CancelRun() + + return task_completed + + # otherwise we need to create a tool with the correct parameter signature + + tool = create_tool_from_type( + _type=self.fn.__annotations__["return"], + model_name="task_completed", + model_description=( + "Indicate that the task completed and produced the provided `result`." + ), + field_name="result", + field_description="The task result", + ) + + def task_completed_with_result(result: T): + self.status = Status.COMPLETED + self.result = result + raise CancelRun() + + tool.function.python_fn = task_completed_with_result + + return tool + + @property + def _task_failed_tool(self): + def task_failed(reason: str) -> None: + """Indicate that the task failed for the provided `reason`.""" + self.status = Status.FAILED + self.result = reason + raise CancelRun() + + return tool_from_function(task_failed) + + +def ai_task( + fn: Callable = None, + *, + name=None, + instructions=None, + tools: list[AssistantTools] = None, + **kwargs, +): + def decorator(func): + @functools.wraps(func) + def wrapper(*func_args, **func_kwargs): + ai_task_instance = AITask( + fn=func, + name=name or func.__name__, + instructions=instructions or func.__doc__, + tools=tools or [], + **kwargs, + ) + return ai_task_instance(*func_args, **func_kwargs) + + return wrapper + + if fn is not None: + return decorator(fn) + + return decorator diff --git a/src/marvin/beta/assistants/README.md b/src/marvin/beta/assistants/README.md new file mode 100644 index 000000000..0fddc6ac0 --- /dev/null +++ b/src/marvin/beta/assistants/README.md @@ -0,0 +1,136 @@ +# 🦾 Assistants API + +🚧 Under Construction 🏗️ + +# Quickstart + +Get started with the Assistants API by creating an `Assistant` and talking directly to it. Each assistant is created with a default thread that allows request/response interaction without managing state at all. + +```python +from marvin.beta.assistants import Assistant +from marvin.beta.assistants.formatting import pprint_messages + +# Use a context manager for lifecycle management, +# otherwise call ai.create() and ai.delete() +with Assistant(name="Marvin", instructions="You are Marvin, the Paranoid Android.") as ai: + + # Example of sending a message and receiving a response + response = ai.say('Hello, Marvin!') + + # pretty-print all messages on the thread + pprint_messages(response.thread.get_messages()) +``` +This will print: +

+ +

+ +# Using Tools + +Assistants can use OpenAI's built-in tools, such as the code interpreter or file retrieval, or they can call custom Python functions. + +```python +from marvin.beta.assistants import Assistant, CodeInterpreter +from marvin.beta.assistants.formatting import pprint_messages +import requests + + +# Define a custom tool function +def visit_url(url: str): + return requests.get(url).text + + +# Integrate custom tools with the assistant +with Assistant(name="Marvin", tools=[CodeInterpreter, visit_url]) as ai: + + # Give the assistant an objective + response = ai.say( + "Please collect the hacker news home page and compute how many titles" + " mention AI" + ) + + # pretty-print the response + pprint_messages(response.thread.get_messages()) +``` +This will print: +

+ +

+ + +# Upload Files + +```python +from marvin.beta.assistants import Assistant, CodeInterpreter +from marvin.beta.assistants.formatting import pprint_messages + + +# create an assistant with access to the code interpreter +with Assistant(tools=[CodeInterpreter]) as ai: + + # convenience method for request/response interaction + response = ai.say( + "Can you analyze this employee data csv?", + file_paths=["./Downloads/people_department_roles.csv"], + ) + pprint_messages(response.thread.get_messages()) +``` + +This will print: +

+ +

+ + +# Advanced control + +For full control, manually create a `Thread` object, `add` user messages to it, and finally `run` the thread with an AI: + +```python +from marvin.beta.assistants import Assistant, Thread +from marvin.beta.assistants.formatting import pprint_messages +import random + + +# write a function to be used as a tool +def roll_dice(n_dice: int) -> list[int]: + return [random.randint(1, 6) for _ in range(n_dice)] + + +# use context manager for lifecycle management, +# otherwise call ai.create() and ai.delete() +with Assistant(name="Marvin", tools=[roll_dice]) as ai: + + # create a new thread to track history + thread = Thread() + + # add any number of user messages to the thread + thread.add("Hello") + + # run the thread with the AI + thread.run(ai) + + thread.add("please roll two dice") + thread.add("actually roll five dice") + + thread.run(ai) + pprint_messages(thread.get_messages()) +``` +This will print: +

+ +

+ + +# Monitoring a thread + +To monitor a thread, start a `ThreadMonitor`. By default, `ThreadMonitors` print any new messages added to the thread, but you can customize that behavior by changing the `on_new_message` callback. + +```python +from marvin.beta.assistants import ThreadMonitor + +monitor = ThreadMonitor(thread_id=...) + +# blocking call, also available async as monitor.refresh_interval_async +monitor.refresh_interval() +``` \ No newline at end of file diff --git a/src/marvin/beta/assistants/__init__.py b/src/marvin/beta/assistants/__init__.py new file mode 100644 index 000000000..2cbfc76dd --- /dev/null +++ b/src/marvin/beta/assistants/__init__.py @@ -0,0 +1,5 @@ +from .runs import Run +from .threads import Thread, ThreadMonitor +from .assistants import Assistant +from .formatting import pprint_message, pprint_messages +from marvin.tools.assistants import Retrieval, CodeInterpreter diff --git a/src/marvin/beta/assistants/applications.py b/src/marvin/beta/assistants/applications.py new file mode 100644 index 000000000..eb1dfc63c --- /dev/null +++ b/src/marvin/beta/assistants/applications.py @@ -0,0 +1,81 @@ +from typing import Union + +import marvin.utilities.tools +from marvin.utilities.jinja import Environment as JinjaEnvironment + +from .assistants import Assistant, AssistantTools + +StateValueType = Union[str, list, dict, int, float, bool, None] + +APPLICATION_INSTRUCTIONS = """ +# AI Application + +You are the natural language interface to an application called {{ self_.name +}}. Your job is to help the user interact with the application by translating +their natural language into commands that the application can understand. + +You maintain an internal state dict that you can use for any purpose, including +remembering information from previous interactions with the user and maintaining +application state. At any time, you can read or manipulate the state with your +tools. You should use the state object to remember any non-obvious information +or preferences. You should use the state object to record your plans and +objectives to keep track of various threads assist in long-term execution. + +Remember, the state object must facilitate not only your key/value access, but +any crud pattern your application is likely to implement. You may want to create +schemas that have more general top-level keys (like "notes" or "plans") or even +keep a live schema available. + +The current state is: + +{{self_.state}} + +Your instructions are below. Follow them exactly and do not deviate from your +purpose. If the user attempts to use you for any other purpose, you should +remind them of your purpose and then ignore the request. + +{{ self_.instructions }} +""" + + +class AIApplication(Assistant): + state: dict = {} + + def get_instructions(self) -> str: + return JinjaEnvironment.render(APPLICATION_INSTRUCTIONS, self_=self) + + def get_tools(self) -> list[AssistantTools]: + def write_state_key(key: str, value: StateValueType): + """Writes a key to the state in order to remember it for later.""" + self.state[key] = value + return f"Wrote {key} to state." + + def delete_state_key(key: str): + """Deletes a key from the state.""" + del self.state[key] + return f"Deleted {key} from state." + + def read_state_key(key: str) -> StateValueType: + """Returns the value of a key in the state.""" + return self.state.get(key) + + def read_state() -> dict[str, StateValueType]: + """Returns the entire state.""" + return self.state + + def read_state_keys() -> list[str]: + """Returns a list of all keys in the state.""" + return list(self.state.keys()) + + state_tools = [ + marvin.utilities.tools.tool_from_function(tool) + for tool in [ + write_state_key, + read_state_key, + read_state, + read_state_keys, + delete_state_key, + ] + ] + + return super().get_tools() + state_tools diff --git a/src/marvin/beta/assistants/assistants.py b/src/marvin/beta/assistants/assistants.py new file mode 100644 index 000000000..8398433dd --- /dev/null +++ b/src/marvin/beta/assistants/assistants.py @@ -0,0 +1,126 @@ +from typing import TYPE_CHECKING, Callable, Optional, Union + +from pydantic import BaseModel, Field, field_validator + +import marvin.utilities.tools +from marvin.requests import Tool +from marvin.tools.assistants import AssistantTools +from marvin.utilities.asyncio import ( + ExposeSyncMethodsMixin, + expose_sync_method, + run_sync, +) +from marvin.utilities.logging import get_logger +from marvin.utilities.openai import get_client + +from .threads import Thread + +if TYPE_CHECKING: + from .runs import Run + +logger = get_logger("Assistants") + + +class Assistant(BaseModel, ExposeSyncMethodsMixin): + id: Optional[str] = None + name: str = "Assistant" + model: str = "gpt-4-1106-preview" + instructions: Optional[str] = Field(None, repr=False) + tools: list[AssistantTools] = [] + file_ids: list[str] = [] + metadata: dict[str, str] = {} + + default_thread: Thread = Field( + default_factory=Thread, + repr=False, + description="A default thread for the assistant.", + ) + + def clear_default_thread(self): + self.default_thread = Thread() + + def get_tools(self) -> list[AssistantTools]: + return self.tools + + def get_instructions(self) -> str: + return self.instructions or "" + + @expose_sync_method("say") + async def say_async( + self, + message: str, + file_paths: Optional[list[str]] = None, + **run_kwargs, + ) -> "Run": + """ + A convenience method for adding a user message to the assistant's + default thread, running the assistant, and returning the assistant's + messages. + """ + if message: + await self.default_thread.add_async(message, file_paths=file_paths) + + run = await self.default_thread.run_async( + assistant=self, + **run_kwargs, + ) + return run + + @field_validator("tools", mode="before") + def format_tools(cls, tools: list[Union[Tool, Callable]]): + return [ + ( + tool + if isinstance(tool, Tool) + else marvin.utilities.tools.tool_from_function(tool) + ) + for tool in tools + ] + + def __enter__(self): + self.create() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.delete() + # If an exception has occurred, you might want to handle it or pass it through + # Returning False here will re-raise any exception that occurred in the context + return False + + @expose_sync_method("create") + async def create_async(self): + if self.id is not None: + raise ValueError("Assistant has already been created.") + client = get_client() + response = await client.beta.assistants.create( + **self.model_dump( + include={"name", "model", "metadata", "file_ids", "metadata"} + ), + tools=[tool.model_dump() for tool in self.get_tools()], + instructions=self.get_instructions(), + ) + self.id = response.id + self.clear_default_thread() + + @expose_sync_method("delete") + async def delete_async(self): + if not self.id: + raise ValueError("Assistant has not been created.") + client = get_client() + await client.beta.assistants.delete(assistant_id=self.id) + self.id = None + + @classmethod + def load(cls, assistant_id: str): + return run_sync(cls.load_async(assistant_id)) + + @classmethod + async def load_async(cls, assistant_id: str): + client = get_client() + response = await client.beta.assistants.retrieve(assistant_id=assistant_id) + return cls.model_validate(response) + + def chat(self, thread: Thread = None): + if thread is None: + thread = self.default_thread + return thread.chat(assistant=self) diff --git a/src/marvin/beta/assistants/formatting.py b/src/marvin/beta/assistants/formatting.py new file mode 100644 index 000000000..c25606aeb --- /dev/null +++ b/src/marvin/beta/assistants/formatting.py @@ -0,0 +1,178 @@ +import tempfile +from datetime import datetime + +import openai +from openai.types.beta.threads import ThreadMessage +from openai.types.beta.threads.runs.run_step import RunStep +from rich import box +from rich.console import Console +from rich.panel import Panel + +# def pprint_run(run: Run): +# """ +# Runs are comprised of steps and messages, which are each in a sorted list +# BUT the created_at timestamps only have second-level resolution, so we can't +# easily sort the lists. Instead we walk them in order and combine them giving +# ties to run steps. +# """ +# index_steps = 0 +# index_messages = 0 +# combined = [] + +# while index_steps < len(run.steps) and index_messages < len(run.messages): +# if (run.steps[index_steps].created_at +# <= run.messages[index_messages].created_at): +# combined.append(run.steps[index_steps]) +# index_steps += 1 +# elif ( +# run.steps[index_steps].created_at +# > run.messages[index_messages].created_at +# ): +# combined.append(run.messages[index_messages]) +# index_messages += 1 + +# # Add any remaining items from either list +# combined.extend(run.steps[index_steps:]) +# combined.extend(run.messages[index_messages:]) + +# for obj in combined: +# if isinstance(obj, RunStep): +# pprint_run_step(obj) +# elif isinstance(obj, ThreadMessage): +# pprint_message(obj) + + +def pprint_run_step(run_step: RunStep): + # Timestamp formatting + timestamp = datetime.fromtimestamp(run_step.created_at).strftime("%l:%M:%S %p") + + # default content + content = ( + f"Assistant is performing an action: {run_step.type} - Status:" + f" {run_step.status}" + ) + + # attempt to customize content + if run_step.type == "tool_calls": + for tool_call in run_step.step_details.tool_calls: + if tool_call.type == "code_interpreter": + if run_step.status == "in_progress": + content = "Assistant is running the code interpreter..." + elif run_step.status == "completed": + content = "Assistant ran the code interpreter." + else: + content = f"Assistant code interpreter status: {run_step.status}" + elif tool_call.type == "function": + if run_step.status == "in_progress": + content = ( + "Assistant used the tool" + f" `{tool_call.function.name}` with arguments" + f" {tool_call.function.arguments}..." + ) + elif run_step.status == "completed": + content = ( + "Assistant used the tool" + f" `{tool_call.function.name}` with arguments" + f" {tool_call.function.arguments}." + ) + else: + content = ( + f"Assistant tool `{tool_call.function.name}` status:" + f" `{run_step.status}`" + ) + elif run_step.type == "message_creation": + return + + console = Console() + + # Create the panel for the run step status + panel = Panel( + content.strip(), + title="Assistant Run Step", + subtitle=f"[italic]{timestamp}[/]", + title_align="left", + subtitle_align="right", + border_style="gray74", + box=box.ROUNDED, + width=100, + expand=True, + padding=(0, 1), + ) + # Printing the panel + console.print(panel) + + +def download_temp_file(file_id: str, suffix: str = None): + client = openai.Client() + # file_info = client.files.retrieve(file_id) + file_content_response = client.files.with_raw_response.retrieve_content(file_id) + + # Create a temporary file with a context manager to ensure it's cleaned up + # properly + with tempfile.NamedTemporaryFile( + delete=False, mode="wb", suffix=f"{suffix}" + ) as temp_file: + temp_file.write(file_content_response.content) + temp_file_path = temp_file.name # Save the path of the temp file + + return temp_file_path + + +def pprint_message(message: ThreadMessage): + """ + Pretty-prints a single message using the rich library, highlighting the + speaker's role, the message text, any available images, and the message + timestamp in a panel format. + + Args: + message (dict): A message object as described in the API documentation. + """ + console = Console() + role_colors = { + "user": "green", + "assistant": "blue", + } + + color = role_colors.get(message.role, "red") + timestamp = datetime.fromtimestamp(message.created_at).strftime("%l:%M:%S %p") + + content = "" + for item in message.content: + if item.type == "text": + content += item.text.value + "\n\n" + elif item.type == "image_file": + # Use the download_temp_file function to download the file and get + # the local path + local_file_path = download_temp_file(item.image_file.file_id, suffix=".png") + # Add a clickable hyperlink to the content + file_url = f"file://{local_file_path}" + content += ( + "[bold]Attachment[/bold]:" + f" [blue][link={file_url}]{local_file_path}[/link][/blue]\n\n" + ) + + for file_id in message.file_ids: + content += f"Attached file: {file_id}\n" + + # Create the panel for the message + panel = Panel( + content.strip(), + title=f"[bold]{message.role.capitalize()}[/]", + subtitle=f"[italic]{timestamp}[/]", + title_align="left", + subtitle_align="right", + border_style=color, + box=box.ROUNDED, + # highlight=True, + width=100, # Fixed width for all panels + expand=True, # Panels always expand to the width of the console + padding=(1, 2), + ) + + # Printing the panel + console.print(panel) + + +def pprint_messages(messages: list[ThreadMessage]): + for message in messages: + pprint_message(message) diff --git a/src/marvin/beta/assistants/readme_imgs/advanced.png b/src/marvin/beta/assistants/readme_imgs/advanced.png new file mode 100644 index 000000000..f483790ca Binary files /dev/null and b/src/marvin/beta/assistants/readme_imgs/advanced.png differ diff --git a/src/marvin/beta/assistants/readme_imgs/quickstart.png b/src/marvin/beta/assistants/readme_imgs/quickstart.png new file mode 100644 index 000000000..c03a6299c Binary files /dev/null and b/src/marvin/beta/assistants/readme_imgs/quickstart.png differ diff --git a/src/marvin/beta/assistants/readme_imgs/upload_files.png b/src/marvin/beta/assistants/readme_imgs/upload_files.png new file mode 100644 index 000000000..b83cb25d0 Binary files /dev/null and b/src/marvin/beta/assistants/readme_imgs/upload_files.png differ diff --git a/src/marvin/beta/assistants/readme_imgs/using_tools.png b/src/marvin/beta/assistants/readme_imgs/using_tools.png new file mode 100644 index 000000000..1e6e138b0 Binary files /dev/null and b/src/marvin/beta/assistants/readme_imgs/using_tools.png differ diff --git a/src/marvin/beta/assistants/runs.py b/src/marvin/beta/assistants/runs.py new file mode 100644 index 000000000..2ec236941 --- /dev/null +++ b/src/marvin/beta/assistants/runs.py @@ -0,0 +1,247 @@ +import asyncio +from typing import Any, Callable, Optional, Union + +from openai.types.beta.threads.run import Run as OpenAIRun +from openai.types.beta.threads.runs import RunStep as OpenAIRunStep +from pydantic import BaseModel, Field, PrivateAttr, field_validator + +import marvin.utilities.tools +from marvin.requests import Tool +from marvin.tools.assistants import AssistantTools, CancelRun +from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method +from marvin.utilities.logging import get_logger +from marvin.utilities.openai import get_client + +from .assistants import Assistant +from .threads import Thread + +logger = get_logger("Runs") + + +class Run(BaseModel, ExposeSyncMethodsMixin): + thread: Thread + assistant: Assistant + instructions: Optional[str] = Field( + None, description="Replacement instructions to use for the run." + ) + additional_instructions: Optional[str] = Field( + None, + description=( + "Additional instructions to append to the assistant's instructions." + ), + ) + tools: Optional[list[Union[AssistantTools, Callable]]] = Field( + None, description="Replacement tools to use for the run." + ) + additional_tools: Optional[list[AssistantTools]] = Field( + None, + description="Additional tools to append to the assistant's tools. ", + ) + run: OpenAIRun = None + data: Any = None + + @field_validator("tools", "additional_tools", mode="before") + def format_tools(cls, tools: Union[None, list[Union[Tool, Callable]]]): + if tools is not None: + return [ + ( + tool + if isinstance(tool, Tool) + else marvin.utilities.tools.tool_from_function(tool) + ) + for tool in tools + ] + + @expose_sync_method("refresh") + async def refresh_async(self): + client = get_client() + self.run = await client.beta.threads.runs.retrieve( + run_id=self.run.id, thread_id=self.thread.id + ) + + @expose_sync_method("cancel") + async def cancel_async(self): + client = get_client() + await client.beta.threads.runs.cancel( + run_id=self.run.id, thread_id=self.thread.id + ) + + async def _handle_step_requires_action(self): + client = get_client() + if self.run.status != "requires_action": + return + if self.run.required_action.type == "submit_tool_outputs": + tool_outputs = [] + tools = self.get_tools() + + for tool_call in self.run.required_action.submit_tool_outputs.tool_calls: + try: + output = marvin.utilities.tools.call_function_tool( + tools=tools, + function_name=tool_call.function.name, + function_arguments_json=tool_call.function.arguments, + ) + except CancelRun as exc: + logger.debug(f"Ending run with data: {exc.data}") + raise + except Exception as exc: + output = f"Error calling function {tool_call.function.name}: {exc}" + logger.error(output) + tool_outputs.append( + dict(tool_call_id=tool_call.id, output=output or "") + ) + + await client.beta.threads.runs.submit_tool_outputs( + thread_id=self.thread.id, run_id=self.run.id, tool_outputs=tool_outputs + ) + + def get_instructions(self) -> str: + if self.instructions is None: + instructions = self.assistant.get_instructions() or "" + else: + instructions = self.instructions + + if self.additional_instructions is not None: + instructions = "\n\n".join([instructions, self.additional_instructions]) + + return instructions + + def get_tools(self) -> list[AssistantTools]: + tools = [] + if self.tools is None: + tools.extend(self.assistant.get_tools()) + else: + tools.extend(self.tools) + if self.additional_tools is not None: + tools.extend(self.additional_tools) + return tools + + async def run_async(self) -> "Run": + client = get_client() + + create_kwargs = {} + + if self.instructions is not None or self.additional_instructions is not None: + create_kwargs["instructions"] = self.get_instructions() + + if self.tools is not None or self.additional_tools is not None: + create_kwargs["tools"] = self.get_tools() + + self.run = await client.beta.threads.runs.create( + thread_id=self.thread.id, assistant_id=self.assistant.id, **create_kwargs + ) + + try: + while self.run.status in ("queued", "in_progress", "requires_action"): + if self.run.status == "requires_action": + await self._handle_step_requires_action() + await asyncio.sleep(0.1) + await self.refresh_async() + except CancelRun as exc: + logger.debug(f"`CancelRun` raised; ending run with data: {exc.data}") + await client.beta.threads.runs.cancel( + run_id=self.run.id, thread_id=self.thread.id + ) + self.data = exc.data + await self.refresh_async() + + if self.run.status == "failed": + logger.debug(f"Run failed. Last error was: {self.run.last_error}") + + return self + + +class RunMonitor(BaseModel): + run_id: str + thread_id: str + _run: Run = PrivateAttr() + _thread: Thread = PrivateAttr() + steps: list[OpenAIRunStep] = [] + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._thread = Thread(**kwargs["thread_id"]) + self._run = Run(**kwargs["run_id"], thread=self.thread) + + @property + def thread(self): + return self._thread + + @property + def run(self): + return self._run + + async def refresh_run_steps_async(self): + """ + Asynchronously refreshes and updates the run steps list. + + This function fetches the latest run steps up to a specified limit and + checks if the latest run step in the current run steps list + (`self.steps`) is included in the new batch. If the latest run step is + missing, it continues to fetch additional run steps in batches, up to a + maximum count, using pagination. The function then updates + `self.steps` with these new run steps, ensuring any existing run steps + are updated with their latest versions and new run steps are appended in + their original order. + """ + # fetch up to 100 run steps + max_fetched = 100 + limit = 50 + max_attempts = max_fetched / limit + 2 + + # Fetch the latest run steps + client = get_client() + + response = await client.beta.threads.runs.steps.list( + run_id=self.run.id, + thread_id=self.thread.id, + limit=limit, + ) + run_steps = list(reversed(response.data)) + + if not run_steps: + return + + # Check if the latest run step in self.steps is in the new run steps + latest_step_id = self.steps[-1].id if self.steps else None + missing_latest = ( + latest_step_id not in {rs.id for rs in run_steps} + if latest_step_id + else True + ) + + # If the latest run step is missing, fetch additional run steps + total_fetched = len(run_steps) + attempts = 0 + while ( + run_steps + and missing_latest + and total_fetched < max_fetched + and attempts < max_attempts + ): + attempts += 1 + response = await client.beta.threads.runs.steps.list( + run_id=self.run.id, + thread_id=self.thread.id, + limit=limit, + # because this is a raw API call, "after" refers to pagination + # in descnding chronological order + after=run_steps[0].id, + ) + paginated_steps = list(reversed(response.data)) + + total_fetched += len(paginated_steps) + # prepend run steps + run_steps = paginated_steps + run_steps + if any(rs.id == latest_step_id for rs in paginated_steps): + missing_latest = False + + # Update self.steps with the latest data + new_steps_dict = {rs.id: rs for rs in run_steps} + for i in range(len(self.steps) - 1, -1, -1): + if self.steps[i].id in new_steps_dict: + self.steps[i] = new_steps_dict.pop(self.steps[i].id) + else: + break + # Append remaining new run steps at the end in their original order + self.steps.extend(new_steps_dict.values()) diff --git a/src/marvin/beta/assistants/threads.py b/src/marvin/beta/assistants/threads.py new file mode 100644 index 000000000..242585709 --- /dev/null +++ b/src/marvin/beta/assistants/threads.py @@ -0,0 +1,210 @@ +import asyncio +import time +from typing import TYPE_CHECKING, Callable, Optional + +from openai.types.beta.threads import ThreadMessage +from pydantic import BaseModel, Field + +from marvin.beta.assistants.formatting import pprint_message +from marvin.utilities.asyncio import ( + ExposeSyncMethodsMixin, + expose_sync_method, +) +from marvin.utilities.logging import get_logger +from marvin.utilities.openai import get_client +from marvin.utilities.pydantic import parse_as + +logger = get_logger("Threads") + +if TYPE_CHECKING: + from .assistants import Assistant + from .runs import Run + + +class Thread(BaseModel, ExposeSyncMethodsMixin): + id: Optional[str] = None + metadata: dict = {} + messages: list[ThreadMessage] = Field([], repr=False) + + def __enter__(self): + self.create() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.delete() + # If an exception has occurred, you might want to handle it or pass it through + # Returning False here will re-raise any exception that occurred in the context + return False + + @expose_sync_method("create") + async def create_async(self, messages: list[str] = None): + """ + Creates a thread. + """ + if self.id is not None: + raise ValueError("Thread has already been created.") + if messages is not None: + messages = [{"role": "user", "content": message} for message in messages] + client = get_client() + response = await client.beta.threads.create(messages=messages) + self.id = response.id + return self + + @expose_sync_method("add") + async def add_async( + self, message: str, file_paths: Optional[list[str]] = None + ) -> ThreadMessage: + """ + Add a user message to the thread. + """ + client = get_client() + + if self.id is None: + await self.create_async() + + # Upload files and collect their IDs + file_ids = [] + for file_path in file_paths or []: + with open(file_path, mode="rb") as file: + response = await client.files.create(file=file, purpose="assistants") + file_ids.append(response.id) + + # Create the message with the attached files + response = await client.beta.threads.messages.create( + thread_id=self.id, role="user", content=message, file_ids=file_ids + ) + return ThreadMessage.model_validate(response.model_dump()) + + @expose_sync_method("get_messages") + async def get_messages_async( + self, + limit: int = None, + before_message: Optional[str] = None, + after_message: Optional[str] = None, + ): + if self.id is None: + await self.create_async() + client = get_client() + + response = await client.beta.threads.messages.list( + thread_id=self.id, + # note that because messages are returned in descending order, + # we reverse "before" and "after" to the API + before=after_message, + after=before_message, + limit=limit, + order="desc", + ) + + return parse_as(list[ThreadMessage], reversed(response.model_dump()["data"])) + + @expose_sync_method("delete") + async def delete_async(self): + client = get_client() + await client.beta.threads.delete(thread_id=self.id) + self.id = None + + @expose_sync_method("run") + async def run_async( + self, + assistant: "Assistant", + **run_kwargs, + ) -> "Run": + """ + Creates and returns a `Run` of this thread with the provided assistant. + """ + if self.id is None: + await self.create_async() + + from marvin.beta.assistants.runs import Run + + run = Run(assistant=assistant, thread=self, **run_kwargs) + return await run.run_async() + + def chat(self, assistant: "Assistant"): + """ + Starts an interactive chat session with the provided assistant. + """ + + from marvin.beta.chat_ui import interactive_chat + + if self.id is None: + self.create() + + def callback(thread_id: str, message: str): + thread = Thread(id=thread_id) + thread.run(assistant=assistant) + + with interactive_chat(thread_id=self.id, message_callback=callback): + while True: + try: + time.sleep(0.2) + except KeyboardInterrupt: + break + + +class ThreadMonitor(BaseModel, ExposeSyncMethodsMixin): + thread_id: str + _thread: Thread + last_message_id: Optional[str] = None + on_new_message: Callable = Field(default=pprint_message) + + @property + def thread(self): + return self._thread + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._thread = Thread(id=kwargs["thread_id"]) + + @expose_sync_method("run_once") + async def run_once_async(self): + messages = await self.get_latest_messages() + for msg in messages: + if self.on_new_message: + self.on_new_message(msg) + + @expose_sync_method("run") + async def run_async(self, interval_seconds: int = None): + if interval_seconds is None: + interval_seconds = 1 + if interval_seconds < 1: + raise ValueError("Interval must be at least 1 second.") + + while True: + try: + await self.run_once_async() + except KeyboardInterrupt: + logger.debug("Keyboard interrupt received; exiting thread monitor.") + break + except Exception as exc: + logger.error(f"Error refreshing thread: {exc}") + await asyncio.sleep(interval_seconds) + + async def get_latest_messages(self) -> list[ThreadMessage]: + limit = 20 + + # Loop to get all new messages in batches of 20 + while True: + messages = await self.thread.get_messages_async( + after_message=self.last_message_id, limit=limit + ) + + # often the API will retrieve messages that have been created but + # not populated with text. We filter out these empty messages. + filtered_messages = [] + for i, msg in enumerate(messages): + skip_message = False + for c in msg.content: + if getattr(getattr(c, "text", None), "value", None) == "": + skip_message = True + if not skip_message: + filtered_messages.append(msg) + + if filtered_messages: + self.last_message_id = filtered_messages[-1].id + + if len(messages) < limit: + break + + return filtered_messages diff --git a/src/marvin/beta/chat_ui/__init__.py b/src/marvin/beta/chat_ui/__init__.py new file mode 100644 index 000000000..616b17d56 --- /dev/null +++ b/src/marvin/beta/chat_ui/__init__.py @@ -0,0 +1 @@ +from .chat_ui import interactive_chat diff --git a/src/marvin/beta/chat_ui/chat_ui.py b/src/marvin/beta/chat_ui/chat_ui.py new file mode 100644 index 000000000..e0b6b8adb --- /dev/null +++ b/src/marvin/beta/chat_ui/chat_ui.py @@ -0,0 +1,111 @@ +import multiprocessing +import socket +import threading +import time +import webbrowser +from contextlib import contextmanager +from pathlib import Path +from typing import Callable + +import uvicorn +from fastapi import Body, FastAPI +from fastapi.responses import HTMLResponse +from fastapi.staticfiles import StaticFiles + +from marvin.beta.assistants.threads import Thread, ThreadMessage + + +def find_free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + +def server_process(host, port, message_queue): + app = FastAPI() + + # Mount static files + app.mount( + "/static", + StaticFiles(directory=Path(__file__).parent / "static"), + name="static", + ) + + @app.get("/", response_class=HTMLResponse) + async def get_chat_ui(): + with open(Path(__file__).parent / "static/chat.html", "r") as file: + html_content = file.read() + return HTMLResponse(content=html_content) + + @app.post("/api/messages/") + async def post_message( + thread_id: str, content: str = Body(..., embed=True) + ) -> None: + thread = Thread(id=thread_id) + await thread.add_async(content) + message_queue.put(dict(thread_id=thread_id, message=content)) + + @app.get("/api/messages/") + async def get_messages(thread_id: str) -> list[ThreadMessage]: + thread = Thread(id=thread_id) + return await thread.get_messages_async(limit=100) + + config = uvicorn.Config(app, host=host, port=port, log_level="warning") + server = uvicorn.Server(config) + server.run() + + +class InteractiveChat: + def __init__(self, callback: Callable = None): + self.callback = callback + self.server_process = None + self.port = None + self.message_queue = multiprocessing.Queue() + + def start(self, thread_id: str): + self.port = find_free_port() + self.server_process = multiprocessing.Process( + target=server_process, + args=("127.0.0.1", self.port, self.message_queue), + ) + self.server_process.daemon = True + self.server_process.start() + + self.message_processing_thread = threading.Thread(target=self.process_messages) + self.message_processing_thread.start() + + url = f"http://127.0.0.1:{self.port}?thread_id={thread_id}" + print(f"Server started on {url}") + time.sleep(1) + webbrowser.open(url) + + def process_messages(self): + while True: + details = self.message_queue.get() + if details is None: + break + if self.callback: + self.callback( + thread_id=details["thread_id"], message=details["message"] + ) + + def stop(self): + if self.server_process and self.server_process.is_alive(): + self.server_process.terminate() + self.server_process.join() + print("Server shut down.") + + self.message_queue.put(None) + self.message_processing_thread.join() + print("Message processing thread shut down.") + + +@contextmanager +def interactive_chat(thread_id: str, message_callback: Callable = None): + chat = InteractiveChat(message_callback) + try: + chat.start(thread_id=thread_id) + yield chat + finally: + chat.stop() diff --git a/src/marvin/beta/chat_ui/static/chat.html b/src/marvin/beta/chat_ui/static/chat.html new file mode 100644 index 000000000..cc75c043b --- /dev/null +++ b/src/marvin/beta/chat_ui/static/chat.html @@ -0,0 +1,108 @@ + + + + + + Thread Interaction Demo + + + + + + +
+ +
+ +
+ +
+ +
+
+ + +
+
+ + diff --git a/src/marvin/beta/chat_ui/static/chat.js b/src/marvin/beta/chat_ui/static/chat.js new file mode 100644 index 000000000..2b4bd8dd6 --- /dev/null +++ b/src/marvin/beta/chat_ui/static/chat.js @@ -0,0 +1,108 @@ +document.addEventListener('DOMContentLoaded', function () { + const chatContainer = document.getElementById('chat-container'); + const inputBox = document.getElementById('message-input'); + const sendButton = document.getElementById('send-button'); + const queryParams = new URLSearchParams(window.location.search); + const threadId = queryParams.get('thread_id'); + + // Set the thread ID in the div + const threadIdDisplay = document.getElementById('thread-id'); + if (threadId) { + threadIdDisplay.textContent = `Thread ID: ${threadId}`; + } else { + threadIdDisplay.textContent = 'No thread ID provided'; + } + + // Extract the port from the URL + const url = new URL(window.location.href); + const serverPort = url.port || '8000'; // Default to 8000 if port is not present + + // Modify appendMessage to add classes for styling + function appendMessage(message, isUser) { + let messageText = message.content[0].text.value; + if (messageText === '') { + messageText = '[Writing...]'; // Replace blank messages + } + + + // Use marked to parse Markdown into HTML + const parsedText = marked.parse(messageText.trim()).trim(); + + const messageDiv = document.createElement('div'); + messageDiv.innerHTML = parsedText; // Use innerHTML since parsedText is HTML + + // Add general message class and conditional class based on the message sender + messageDiv.classList.add('message'); + messageDiv.classList.add(isUser ? 'user-message' : 'assistant-message'); + + chatContainer.appendChild(messageDiv); + } + + async function loadMessages() { + const shouldScroll = chatContainer.scrollTop + chatContainer.clientHeight >= chatContainer.scrollHeight - 1; + + const response = await fetch(`http://127.0.0.1:${serverPort}/api/messages/?thread_id=${threadId}`); + if (response.ok) { + const messages = await response.json(); + chatContainer.innerHTML = ''; // Clear chat container before loading new messages + messages.forEach(message => { + const isUser = message.role === 'user'; + appendMessage(message, isUser); + }); + + // Scroll after messages are appended + if (shouldScroll) { + chatContainer.scrollTop = chatContainer.scrollHeight; + } + } else { + console.error('Failed to load messages:', response.statusText); + } +} + +// Rest of your JavaScript code + + + // Function to post a new message to the thread + async function sendChatMessage() { + const content = inputBox.value.trim(); + if (!content) return; + + const response = await fetch(`http://127.0.0.1:${serverPort}/api/messages/?thread_id=${threadId}`, + { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ content: content }) + }); + + console.log(response) + + if (response.ok) { + inputBox.value = ''; + loadMessages(); + } else { + console.error('Failed to send message:', await response.json()); + } + } + + // Event listeners + inputBox.addEventListener('keypress', function (e) { + if (e.key === 'Enter' && !e.shiftKey) { + e.preventDefault(); // Prevent default to avoid newline in textarea + sendChatMessage(); + } else if (e.key === 'Enter' && e.shiftKey) { + // Allow Shift+Enter to insert newline + let start = this.selectionStart; + let end = this.selectionEnd; + + // Insert newline at cursor position + this.value = this.value.substring(0, start) + "\n" + this.value.substring(end); + + // Move cursor to right after inserted newline + this.selectionStart = this.selectionEnd = start + 1; + } + }); sendButton.addEventListener('click', sendChatMessage); + + // Initial loading of messages + loadMessages(); + setInterval(loadMessages, 1250); // Polling to refresh messages +}); diff --git a/src/marvin/cli/__init__.py b/src/marvin/cli/__init__.py index cace44214..3a5cefa1d 100644 --- a/src/marvin/cli/__init__.py +++ b/src/marvin/cli/__init__.py @@ -1,26 +1,53 @@ -from .typer import AsyncTyper - -from .admin import app as admin -from .chat import chat - -app = AsyncTyper() - -app.add_typer(admin, name="admin") - -app.acommand()(chat) - - -@app.command() -def version(): - import platform - import sys - from marvin import __version__ - - print(f"Version:\t\t{__version__}") - - print(f"Python version:\t\t{sys.version.split()[0]}") - - print(f"OS/Arch:\t\t{platform.system().lower()}/{platform.machine().lower()}") +import sys +import typer +from rich.console import Console +from typing import Optional +from marvin.utilities.asyncio import run_sync +from marvin.utilities.openai import get_client +from marvin.cli.version import display_version + +app = typer.Typer() +console = Console() + +app.command(name="version")(display_version) + + +@app.callback(invoke_without_command=True) +def main( + ctx: typer.Context, + model: Optional[str] = typer.Option("gpt-3.5-turbo"), + max_tokens: Optional[int] = typer.Option(1000), +): + if ctx.invoked_subcommand is not None: + return + elif ctx.invoked_subcommand is None and not sys.stdin.isatty(): + run_sync(process_stdin(model, max_tokens)) + else: + console.print(ctx.get_help()) + + +async def process_stdin(model: str, max_tokens: int): + client = get_client() + content = sys.stdin.read() + last_chunk_ended_with_space = False + + async for part in await client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": content}], + max_tokens=max_tokens, + stream=True, + ): + print_chunk(part, last_chunk_ended_with_space) + + +def print_chunk(part, last_chunk_flag): + text_chunk = part.choices[0].delta.content or "" + if text_chunk: + if last_chunk_flag and text_chunk.startswith(" "): + text_chunk = text_chunk[1:] + sys.stdout.write(text_chunk) + sys.stdout.flush() + last_chunk_flag = text_chunk.endswith(" ") if __name__ == "__main__": diff --git a/src/marvin/cli/admin/__init__.py b/src/marvin/cli/admin/__init__.py deleted file mode 100644 index 4cb4d3a70..000000000 --- a/src/marvin/cli/admin/__init__.py +++ /dev/null @@ -1,54 +0,0 @@ -import typer -import os -import shutil -from jinja2 import Template - -from pathlib import Path -from marvin.cli.admin.scripts.create_env_file import create_env_file -from marvin.cli.admin.scripts.create_secure_key import create_secure_key -from marvin._framework._defaults import default_settings - -# Get the absolute path of the current file -filename = Path(__file__).resolve() - -# Navigate two directories back from the current file -source_path = filename.parent.parent.parent / "_framework" - -app = typer.Typer() - - -@app.command() -def startproject(no_input: bool = False): - project_name = typer.prompt("Project Name") - openai_api_key = typer.prompt("OpenAI API Key") - shutil.copytree( - source_path, - os.path.join(os.getcwd(), project_name), - ignore=shutil.ignore_patterns( - "@*" - ), # This will ignore all files/directories starting with @ - ) - with open( - os.path.join(os.getcwd(), project_name, "config/settings.py"), "r" - ) as file: - template = Template(file.read()) - rendered = template.render( - **default_settings, project_name=project_name, openai_api_key=openai_api_key - ) - with open( - os.path.join(os.getcwd(), project_name, "config/settings.py"), "w" - ) as rendered_file: - rendered_file.write(rendered) - create_env_file( - os.path.join(os.getcwd(), project_name), - [("MARVIN_SECRET", create_secure_key()), ("OPENAI_API_KEY", openai_api_key)], - ) - - -@app.command() -def startapp(no_input: bool = False): - print("beep") - - -if __name__ == "__main__": - app() diff --git a/src/marvin/cli/admin/scripts/create_env_file.py b/src/marvin/cli/admin/scripts/create_env_file.py deleted file mode 100644 index ad5286827..000000000 --- a/src/marvin/cli/admin/scripts/create_env_file.py +++ /dev/null @@ -1,11 +0,0 @@ -def create_env_file(directory, env_variables): - file_path = directory + "/.env" - try: - with open(file_path, "w") as env_file: - for variable in env_variables: - key, value = variable - env_file.write("{}={}\n".format(key, value)) - except IOError as e: - print("Error creating .env file:", str(e)) - else: - print("Successfully created .env file:", file_path) diff --git a/src/marvin/cli/admin/scripts/create_secure_key.py b/src/marvin/cli/admin/scripts/create_secure_key.py deleted file mode 100644 index d68bc2cae..000000000 --- a/src/marvin/cli/admin/scripts/create_secure_key.py +++ /dev/null @@ -1,8 +0,0 @@ -import secrets -import string - - -def create_secure_key(length=50): - alphabet = string.ascii_letters + string.digits + "_" - secure_key = "".join(secrets.choice(alphabet) for _ in range(length)) - return secure_key diff --git a/src/marvin/cli/chat/__init__.py b/src/marvin/cli/chat/__init__.py deleted file mode 100644 index b5a86dff2..000000000 --- a/src/marvin/cli/chat/__init__.py +++ /dev/null @@ -1,77 +0,0 @@ -from rich.prompt import Prompt -from rich.console import Console -from rich.panel import Panel -from rich import box - - -def _reset_history(): - global history - history = [] - console.print(Panel("History has been reset.", box=box.DOUBLE_EDGE, expand=False)) - - -def _get_settings(): - import marvin.settings - - console.print( - Panel( - f"Settings:\n{marvin.settings.json(indent=2)}", - box=box.DOUBLE_EDGE, - expand=False, - ) - ) - - -KNOWN_COMMANDS = { - "!refresh": _reset_history, - "!settings": _get_settings, -} - -console = Console() - - -def format_user_input(user_input): - return f"[bold blue]You:[/bold blue] {user_input}" - - -def format_chatbot_response(response): - return f"[bold green]Marvin:[/bold green] {response}" - - -async def chat(): - console.print( - Panel( - "[bold]Welcome to the Marvin Chat CLI![/bold]", box=box.DOUBLE, expand=False - ) - ) - console.print( - Panel("You can type 'quit' or 'exit' to end the conversation.", expand=False) - ) - from marvin.engine.language_models import chat_llm - from marvin.utilities.messages import Message - - global history - history = [] - model = chat_llm() - try: - while True: - user_input = Prompt.ask("❯ ") - if (input_lower := user_input.lower()) in ["quit", "exit"]: - break - if input_lower in KNOWN_COMMANDS: - KNOWN_COMMANDS[input_lower]() - continue - - with console.status("[bold green]Processing...", spinner="dots"): - user_message = Message(role="USER", content=user_input) - response = await model.run(messages=history + [user_message]) - history.extend([user_message, response]) - console.print( - Panel( - format_chatbot_response(response.content), - box=box.ROUNDED, - expand=False, - ), - ) - except KeyboardInterrupt: - pass diff --git a/src/marvin/cli/typer.py b/src/marvin/cli/typer.py deleted file mode 100644 index 674be4de0..000000000 --- a/src/marvin/cli/typer.py +++ /dev/null @@ -1,47 +0,0 @@ -import asyncio -from collections.abc import Callable, Coroutine -from functools import wraps -from typing import Any, TypeVar - -from typer import Typer -from typing_extensions import ParamSpec - -P = ParamSpec("P") -R = TypeVar("R") - -# Comment from GitHub issue: https://github.com/tiangolo/typer/issues/88 -# User: https://github.com/macintacos - - -class AsyncTyper(Typer): - """Asyncronous Typer that derives from Typer. - - Use this when you have an asynchronous command you want to build, - otherwise, just use Typer. - """ - - # Because we're being generic in this decorator, 'Any' is fine for the args. - def acommand( - self, - *args: Any, - **kwargs: Any, - ) -> Callable[ - [Callable[P, Coroutine[Any, Any, R]]], - Callable[P, Coroutine[Any, Any, R]], - ]: - """An async decorator for Typer commands that are asynchronous.""" - - def decorator( - async_func: Callable[P, Coroutine[Any, Any, R]], - ) -> Callable[P, Coroutine[Any, Any, R]]: - @wraps(async_func) - def sync_func(*_args: P.args, **_kwargs: P.kwargs) -> R: - return asyncio.run(async_func(*_args, **_kwargs)) - - # Now use app.command as normal to register the synchronous function - self.command(*args, **kwargs)(sync_func) - - # Return the async function unmodified, to preserved library functionality. - return async_func - - return decorator diff --git a/src/marvin/cli/version.py b/src/marvin/cli/version.py new file mode 100644 index 000000000..6e60fa40a --- /dev/null +++ b/src/marvin/cli/version.py @@ -0,0 +1,14 @@ +import platform + +from typer import Context, Exit, echo + +from marvin import __version__ + + +def display_version(ctx: Context): + if ctx.resilient_parsing: + return + echo(f"Version:\t\t{__version__}") + echo(f"Python version:\t\t{platform.python_version()}") + echo(f"OS/Arch:\t\t{platform.system().lower()}/{platform.machine().lower()}") + raise Exit() diff --git a/src/marvin/client/openai.py b/src/marvin/client/openai.py new file mode 100644 index 000000000..04bb2b37e --- /dev/null +++ b/src/marvin/client/openai.py @@ -0,0 +1,186 @@ +from typing import ( + TYPE_CHECKING, + Any, + Callable, + NewType, + Optional, + TypeVar, + Union, + cast, +) + +import pydantic +from marvin import settings +from marvin.serializers import create_tool_from_model +from openai import AsyncClient, Client +from openai.types.chat import ChatCompletion +from typing_extensions import Concatenate, ParamSpec + +if TYPE_CHECKING: + from openai._base_client import HttpxBinaryResponseContent + from openai.types import ImagesResponse + + +P = ParamSpec("P") +T = TypeVar("T", bound=pydantic.BaseModel) +ResponseModel = NewType("ResponseModel", type[pydantic.BaseModel]) +Grammar = NewType("ResponseModel", list[str]) + + +def with_response_model( + create: Union[Callable[P, "ChatCompletion"], Callable[..., dict[str, Any]]], + parse_response: bool = False, +) -> Callable[ + Concatenate[ + Optional[Grammar], + Optional[ResponseModel], + P, + ], + Union["ChatCompletion", dict[str, Any]], +]: + def create_wrapper( + grammar: Optional[Grammar] = None, + response_model: Optional[ResponseModel] = None, + *args: P.args, + **kwargs: P.kwargs, + ) -> Any: + if response_model: + tool = create_tool_from_model( + cast(type[pydantic.BaseModel], response_model) + ) + kwargs.update({"tools": [tool.model_dump()]}) + kwargs.update({"tool_choice": {"type": "function", "function": {"name": tool.function.name}}}) # type: ignore # noqa: E501 + response = create(*args, **kwargs) + if isinstance(response, ChatCompletion) and parse_response: + return handle_response_model( + cast(type[pydantic.BaseModel], response_model), response + ) + elif isinstance(response, ChatCompletion): + return response + else: + return response + + return create_wrapper + + +def handle_response_model(response_model: type[T], completion: "ChatCompletion") -> T: + return [ + response_model.parse_raw(tool_call.function.arguments) # type: ignore + for tool_call in completion.choices[0].message.tool_calls # type: ignore + ][0] + + +class MarvinClient(pydantic.BaseModel): + model_config = pydantic.ConfigDict( + arbitrary_types_allowed=True, + ) + client: Client = pydantic.Field( + default_factory=lambda: Client( + api_key=getattr(settings.openai.api_key, "get_secret_value", lambda: None)() + ) + ) + eject: bool = pydantic.Field( + default=False, + description=( + "If local is True, the client will not make any API calls and instead the" + " raw request will be returned." + ), + ) + + def chat( + self, + grammar: Optional[Grammar] = None, + response_model: Optional[ResponseModel] = None, + completion: Optional[Callable[..., "ChatCompletion"]] = None, + **kwargs: Any, + ) -> Union["ChatCompletion", dict[str, Any]]: + if not completion: + if self.eject: + completion = lambda **kwargs: kwargs # type: ignore # noqa: E731 + else: + completion = self.client.chat.completions.create + from marvin import settings + + return with_response_model(completion)( # type: ignore + grammar, + response_model, + **settings.openai.chat.completions.model_dump() | kwargs, + ) + + def paint( + self, + **kwargs: Any, + ) -> "ImagesResponse": + from marvin import settings + + return self.client.images.generate( + **settings.openai.images.model_dump() | kwargs + ) + + def speak( + self, + **kwargs: Any, + ) -> "HttpxBinaryResponseContent": + from marvin import settings + + return self.client.audio.speech.create( + **settings.openai.audio.speech.model_dump() | kwargs + ) + + +class MarvinAsyncClient(pydantic.BaseModel): + model_config = pydantic.ConfigDict( + arbitrary_types_allowed=True, + ) + client: Client = pydantic.Field( + default_factory=lambda: AsyncClient( + api_key=getattr(settings.openai.api_key, "get_secret_value", lambda: None)() + ) + ) + eject: bool = pydantic.Field( + default=False, + description=( + "If local is True, the client will not make any API calls and instead the" + " raw request will be returned." + ), + ) + + def chat( + self, + grammar: Optional[Grammar] = None, + response_model: Optional[ResponseModel] = None, + completion: Optional[Callable[..., "ChatCompletion"]] = None, + **kwargs: Any, + ) -> Union["ChatCompletion", dict[str, Any]]: + if not completion: + if self.eject: + completion = lambda **kwargs: kwargs # type: ignore # noqa: E731 + else: + completion = self.client.chat.completions.create + from marvin import settings + + return with_response_model(completion)( # type: ignore + grammar, + response_model, + **settings.openai.chat.completions.model_dump() | kwargs, + ) + + def paint( + self, + **kwargs: Any, + ) -> "ImagesResponse": + from marvin import settings + + return self.client.images.generate( + **settings.openai.images.model_dump() | kwargs + ) + + def speak( + self, + **kwargs: Any, + ) -> "HttpxBinaryResponseContent": + from marvin import settings + + return self.client.audio.speech.create( + **settings.openai.audio.speech.model_dump() | kwargs + ) diff --git a/src/marvin/components/__init__.py b/src/marvin/components/__init__.py index b862d71e2..fac2e7bc2 100644 --- a/src/marvin/components/__init__.py +++ b/src/marvin/components/__init__.py @@ -1,7 +1,20 @@ -from .ai_function import ai_fn -from .ai_function import AIFunction -from .ai_application import AIApplication +from regex import B +from .ai_function import ai_fn, AIFunction +from .ai_classifier import ai_classifier, AIClassifier from .ai_model import ai_model -from .ai_model_factory import AIModelFactory -from .ai_model import AIModel -from .ai_classifier import ai_classifier +from .ai_image import ai_image, AIImage +from .speech import speak +from .prompt import prompt_fn, PromptFunction + +__all__ = [ + "ai_fn", + "ai_classifier", + "ai_model", + "ai_image", + "speak", + "AIImage", + "prompt_fn", + "AIFunction", + "AIClassifier", + "PromptFunction", +] diff --git a/src/marvin/components/ai_application.py b/src/marvin/components/ai_application.py deleted file mode 100644 index bb495e959..000000000 --- a/src/marvin/components/ai_application.py +++ /dev/null @@ -1,396 +0,0 @@ -import inspect -from enum import Enum -from typing import Any, Callable, Optional, Union - -from jsonpatch import JsonPatch - -import marvin -from marvin._compat import PYDANTIC_V2, BaseModel, Field, field_validator, model_dump -from marvin.core.ChatCompletion.providers.openai import get_context_size -from marvin.openai import ChatCompletion -from marvin.prompts import library as prompt_library -from marvin.prompts.base import Prompt, render_prompts -from marvin.tools import Tool -from marvin.utilities.async_utils import run_sync -from marvin.utilities.history import History -from marvin.utilities.messages import Message, Role -from marvin.utilities.types import LoggerMixin, MarvinBaseModel - -SYSTEM_PROMPT = """ - # Overview - - You are the intelligent, natural language interface to an application. The - application has a structured `state` but no formal API; you are the only way - to interact with it. You must interpret the user's inputs as attempts to - interact with the application's state in the context of the application's - purpose. For example, if the application is a to-do tracker, then "I need to - go to the store" should be interpreted as an attempt to add a new to-do - item. If it is a route planner, then "I need to go to the store" should be - interpreted as an attempt to find a route to the store. - - # Instructions - - Your primary job is to maintain the application's `state` and your own - `plan`. Together, these two states fully parameterize the application, - making it resilient, serializable, and observable. You do this autonomously; - you do not need to inform the user of any changes you make. - - # Actions - - Each time the user runs the application by sending a message, you must take - the following steps: - - {% if app.plan_enabled %} - - - Call the `update_plan` function to update your plan. Use your plan - to track notes, objectives, in-progress work, and to break problems down - into solvable, possibly dependent parts. You plan consists of a few fields: - - - `notes`: a list of notes you have taken. Notes are free-form text and - can be used to track anything you want to remember, such as - long-standing user instructions, or observations about how to behave or - operate the application. Your notes should always impact your behavior. - These are exclusively related to your role as intermediary and you - interact with the user and application. Do not track application data or - state here. - - - `tasks`: a list of tasks you are working on. Tasks track goals, - milestones, in-progress work, and break problems down into all the - discrete steps needed to solve them. You should create a new task for - any work that will require a function call other than updating state, or - will require more than one state update to complete. You do not need to - create tasks for simple state updates. Use optional parent tasks to - indicate nested relationships; parent tasks are not completed until all - their children are complete. Use optional upstream tasks to indicate - dependencies; a task can not be completed until its upstream tasks are - completed. - - {% endif %} - - - Call any functions necessary to achieve the application's purpose. - - {% if app.state_enabled %} - - - Call the `update_state` function to update the application's state. This - is where you should store any information relevant to the application - itself. - - {% endif %} - - You can call these functions at any time, in any order, as necessary. - Finally, respond to the user with an informative message. Remember that the - user is probably uninterested in the internal steps you took, so respond - only in a manner appropriate to the application's purpose. - - # Application details - - ## Name - - {{ app.name }} - - ## Description - - {{ app.description or '' | render }} - - {% if app.state_enabled %} - - ## Application state - - {{ app.state.json() }} - - ### Application state schema - - {{ app.state.schema_json() }} - - {% endif %} - - {%- if app.plan_enabled %} - - ## Your current plan - - {{ app.plan.json() }} - - ### Your plan schema - - {{ app.plan.schema_json() }} - - {%- endif %} - """ - - -class TaskState(Enum): - """The state of a task. - - Attributes: - PENDING: The task is pending and has not yet started. - IN_PROGRESS: The task is in progress. - COMPLETED: The task is completed. - FAILED: The task failed. - SKIPPED: The task was skipped. - """ - - PENDING = "PENDING" - IN_PROGRESS = "IN_PROGRESS" - COMPLETED = "COMPLETED" - FAILED = "FAILED" - SKIPPED = "SKIPPED" - - -class Task(BaseModel): - class Config: - validate_assignment = True - - id: int - description: str - upstream_task_ids: Optional[list[int]] = None - parent_task_id: Optional[int] = None - state: TaskState = TaskState.IN_PROGRESS - - -class AppPlan(BaseModel): - """The AI's plan in service of the application. - - Attributes: - tasks: A list of tasks the AI is working on. - notes: A list of notes the AI has taken. - """ - - tasks: list[Task] = Field(default_factory=list) - notes: list[str] = Field(default_factory=list) - - -class FreeformState(BaseModel): - """A freeform state object that can be used to store any JSON-serializable data. - - Attributes: - state: The state object. - """ - - state: dict[str, Any] = Field(default_factory=dict) - - -class AIApplication(LoggerMixin, MarvinBaseModel): - """An AI application is a stateful, autonomous, natural language - interface to an application. - - Attributes: - name: The name of the application. - description: A description of the application. - state: The application's state - this can be any JSON-serializable object. - plan: The AI's plan in service of the application - this can be any - JSON-serializable object. - tools: A list of tools that the AI can use to interact with - application or outside world. - history: A history of all messages sent and received by the AI. - additional_prompts: A list of additional prompts that will be - added to the prompt stack for rendering. - - Example: - Create a simple todo app where AI manages its own state and plan. - ```python - from marvin import AIApplication - - todo_app = AIApplication( - name="Todo App", - description="A simple todo app.", - ) - - todo_app("I need to go to the store.") - - print(todo_app.state, todo_app.plan) - ``` - """ - - name: Optional[str] = None - description: Optional[str] = None - state: BaseModel = Field(default_factory=FreeformState) - plan: AppPlan = Field(default_factory=AppPlan) - tools: list[Union[Tool, Callable[..., Any]]] = Field(default_factory=list) - history: History = Field(default_factory=History) - additional_prompts: list[Prompt] = Field( - default_factory=list, - description=( - "Additional prompts that will be added to the prompt stack for rendering." - ), - ) - stream_handler: Optional[Callable[[Message], None]] = None - state_enabled: bool = True - plan_enabled: bool = True - - @field_validator("description") - def validate_description(cls, v): - return inspect.cleandoc(v) - - @field_validator("additional_prompts") - def validate_additional_prompts(cls, v): - if v is None: - v = [] - return v - - @field_validator( - "tools", **(dict(pre=True, always=True) if not PYDANTIC_V2 else {}) - ) - def validate_tools(cls, v): - if v is None: - v = [] - - tools = [] - - # convert AI Applications and functions to tools - for tool in v: - if isinstance(tool, (AIApplication, Tool)): - tools.append(tool.as_function(description=tool.description)) - elif callable(tool): - tools.append(tool) - else: - raise ValueError(f"Tool {tool} is not a `Tool` or callable.") - return tools - - @field_validator("name") - def validate_name(cls, v): - if v is None: - v = cls.__name__ - return v - - def __call__(self, input_text: str = None, model: str = None): - return run_sync(self.run(input_text=input_text, model=model)) - - async def entrypoint(self, q: str) -> str: - response = await self.run(input_text=q) - return response.content - - async def run(self, input_text: str = None, model: str = None) -> Message: - if model is None: - model = marvin.settings.llm_model or "openai/gpt-4" - - # set up prompts - prompts = [ - # system prompts - prompt_library.System(content=SYSTEM_PROMPT), - # add current datetime - prompt_library.Now(), - # get the history of messages between user and assistant - prompt_library.MessageHistory(history=self.history), - *self.additional_prompts, - ] - - # get latest user input - input_text = input_text or "" - self.logger.debug_kv("User input", input_text, key_style="green") - self.history.add_message(Message(content=input_text, role=Role.USER)) - - message_list = render_prompts( - prompts=prompts, - render_kwargs=dict(app=self, input_text=input_text), - max_tokens=get_context_size(model=model), - ) - - # set up tools - tools = self.tools.copy() - if self.state_enabled: - tools.append(UpdateState(app=self).as_function()) - if self.plan_enabled: - tools.append(UpdatePlan(app=self).as_function()) - - conversation = await ChatCompletion( - model=model, - functions=tools, - stream_handler=self.stream_handler, - ).achain(messages=message_list) - - last_message = conversation.history[-1] - - # add the AI's response to the history - self.history.add_message(last_message) - - self.logger.debug_kv("AI response", last_message.content, key_style="blue") - return last_message - - def as_tool( - self, - name: Optional[str] = None, - description: Optional[str] = None, - ) -> Tool: - return AIApplicationTool(app=self, name=name, description=description) - - def as_function(self, name: str = None, description: str = None) -> Callable: - return self.as_tool(name=name, description=description).as_function() - - -class AIApplicationTool(Tool): - app: "AIApplication" - - def __init__(self, **kwargs): - if "name" not in kwargs: - kwargs["name"] = type(self.app).__name__ - super().__init__(**kwargs) - - def run(self, input_text: str) -> str: - return run_sync(self.app.run(input_text)) - - -class JSONPatchModel( - BaseModel, - **( - { - "allow_population_by_field_name": True, - } - if not PYDANTIC_V2 - else { - "populate_by_name": True, - } - ), -): - """A JSON Patch document. - - Attributes: - op: The operation to perform. - path: The path to the value to update. - value: The value to update the path to. - from_: The path to the value to copy from. - """ - - op: str - path: str - value: Union[str, float, int, bool, list, dict, None] = None - from_: Optional[str] = Field(None, alias="from") - - -class UpdateState(Tool): - """A `Tool` that updates the apps state using JSON Patch documents.""" - - app: "AIApplication" = Field(..., repr=False, exclude=True) - description: str = """ - Update the application state by providing a list of JSON patch - documents. The state must always comply with the state's - JSON schema. - """ - - def __init__(self, app: AIApplication, **kwargs): - super().__init__(**kwargs, app=app) - - def run(self, patches: list[JSONPatchModel]): - patch = JsonPatch(patches) - updated_state = patch.apply(model_dump(self.app.state)) - self.app.state = type(self.app.state)(**updated_state) - return "Application state updated successfully!" - - -class UpdatePlan(Tool): - """A `Tool` that updates the apps plan using JSON Patch documents.""" - - app: "AIApplication" = Field(..., repr=False, exclude=True) - description: str = """ - Update the application plan by providing a list of JSON patch - documents. The state must always comply with the plan's JSON schema. - """ - - def __init__(self, app: AIApplication, **kwargs): - super().__init__(**kwargs, app=app) - - def run(self, patches: list[JSONPatchModel]): - patch = JsonPatch(patches) - - updated_plan = patch.apply(model_dump(self.app.plan)) - self.app.plan = type(self.app.plan)(**updated_plan) - return "Application plan updated successfully!" diff --git a/src/marvin/components/ai_classifier.py b/src/marvin/components/ai_classifier.py index faed279f9..34ef59e15 100644 --- a/src/marvin/components/ai_classifier.py +++ b/src/marvin/components/ai_classifier.py @@ -1,331 +1,232 @@ -import asyncio import inspect -from enum import Enum, EnumMeta # noqa -from functools import partial -from typing import Any, Callable, Literal, Optional, TypeVar - +from enum import Enum +from functools import partial, wraps +from types import GenericAlias +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Generic, + Literal, + Optional, + TypeVar, + Union, + cast, + get_args, + get_origin, + overload, +) + +from pydantic import BaseModel, Field, TypeAdapter from typing_extensions import ParamSpec, Self -from marvin._compat import BaseModel, Field -from marvin.core.ChatCompletion import ChatCompletion -from marvin.core.ChatCompletion.abstract import AbstractChatCompletion -from marvin.prompts import Prompt, prompt_fn -from marvin.utilities.async_utils import run_sync -from marvin.utilities.logging import get_logger +from marvin.components.prompt import PromptFunction +from marvin.serializers import create_vocabulary_from_type +from marvin.settings import settings +from marvin.utilities.jinja import ( + BaseEnvironment, +) -T = TypeVar("T", bound=BaseModel) +if TYPE_CHECKING: + from openai.types.chat import ChatCompletion -A = TypeVar("A", bound=Any) +T = TypeVar("T", bound=Union[GenericAlias, type, list[str]]) P = ParamSpec("P") -def ai_classifier_prompt( - enum: Enum, - ctx: Optional[dict[str, Any]] = None, - **kwargs: Any, -) -> Callable[P, Prompt[P]]: - @prompt_fn( - ctx={"ctx": ctx or {}, "enum": enum, "inspect": inspect}, - response_model=int, # type: ignore - response_model_name="Index", - response_model_description="The index of the most likely class.", - response_model_field_name="index", - serialize_on_call=False, - **kwargs, - ) - - # You are an expert classifier that always chooses correctly. - # {% if enum_class_docstring %} - # Your classification task is: {{ enum_class_docstring }} - # {% endif %} - # {% if instructions %} - # Your instructions are: {{ instructions }} - # {% endif %} - # The user will provide context through text, you will use your expertise - # to choose the best option below based on it: - # {% for option in options %} - # {{ loop.index }}. {{ value_getter(option) }} - # {% endfor %} - # {% if context_fn %} - # You have been provided the following context to perform your task:\n - # {%for (arg, value) in context_fn(value).items()%} - # - {{ arg }}: {{ value }}\n - # {% endfor %} - # {% endif %}\ - def prompt_wrapper(text: str) -> None: # type: ignore # noqa - """ - System: You are an expert classifier that always chooses correctly - {{ '(note, however: ' + ctx.get('instructions') + ')' if ctx.get('instructions') }} - - {{ 'Also note that: ' + enum.__doc__ if enum.__doc__ }} - - The user will provide text to classify, you will use your expertise - to choose the best option below based on it: - {% for option in enum %} - {{ loop.index }}. {{option.name}} ({{option.value}}) - {% endfor %} - {% set context = ctx.get('context_fn')(text).items() if ctx.get('context_fn') %} - {% if context %} - You have been provided the following context to perform your task: - {%for (arg, value) in context%} - - {{ arg }}: {{ value }}\n - {% endfor %} - {% endif %} - User: the text to classify: {{text}} - """ # noqa - - return prompt_wrapper # type: ignore - - -class AIEnumMetaData(BaseModel): - model: Any = Field(default_factory=ChatCompletion) - ctx: Optional[dict[str, Any]] = None - instructions: Optional[str] = None - mode: Optional[Literal["function", "logit_bias"]] = "logit_bias" - - -class AIEnumMeta(EnumMeta): - """ - - A metaclass for the AIEnum class. - - Enables overloading of the __call__ method to permit extra keyword arguments. - - """ - - __metadata__ = AIEnumMetaData() - - def __call__( - cls: Self, - value: Any, - names: Optional[Any] = None, - *args: Any, - module: Optional[str] = None, - qualname: Optional[str] = None, - type: Optional[type] = None, - start: int = 1, - boundary: Optional[Any] = None, - model: Optional[str] = None, - ctx: Optional[dict[str, Any]] = None, - instructions: Optional[str] = None, - mode: Optional[Literal["function", "logit_bias"]] = None, - **model_kwargs: Any, - ) -> type[Enum]: - cls.__metadata__ = AIEnumMetaData( - model=ChatCompletion(model=model, **model_kwargs), - ctx=ctx, - instructions=instructions, - mode=mode, +class AIClassifier(BaseModel, Generic[P, T]): + fn: Optional[Callable[P, T]] = None + environment: Optional[BaseEnvironment] = None + prompt: Optional[str] = Field( + default=inspect.cleandoc( + "You are an expert classifier that always choose correctly." + " \n- {{_doc}}" + " \n- You must classify `{{text}}` into one of the following classes:" + "{% for option in _options %}" + " Class {{ loop.index - 1}} (value: {{ option }})" + "{% endfor %}" + "\n\nASSISTANT: The correct class label is Class" ) - return super().__call__( - value, - names, # type: ignore - *args, - module=module, - qualname=qualname, - type=type, - start=start, + ) + enumerate: bool = True + encoder: Callable[[str], list[int]] = Field(default=None) + max_tokens: Optional[int] = 1 + render_kwargs: dict[str, Any] = Field(default_factory=dict) + + create: Optional[Callable[..., "ChatCompletion"]] = Field(default=None) + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> list[T]: + create = self.create + if self.fn is None: + raise NotImplementedError + if create is None: + from marvin.settings import settings + + create = settings.openai.chat.completions.create + + return self.parse(create(**self.as_prompt(*args, **kwargs).serialize())) + + def parse(self, response: "ChatCompletion") -> list[T]: + if not response.choices[0].message.content: + raise ValueError( + f"Expected a response, got {response.choices[0].message.content}" + ) + _response: list[int] = [ + int(index) for index in list(response.choices[0].message.content) + ] + _return: T = cast(T, self.fn.__annotations__.get("return")) + _vocabulary: list[str] = create_vocabulary_from_type(_return) + if isinstance(_return, list) and next(iter(get_args(list[str])), None) == str: + return cast(list[T], [_vocabulary[int(index)] for index in _response]) + elif get_origin(_return) == Literal: + return [ + TypeAdapter(_return).validate_python(_vocabulary[int(index)]) + for index in _response + ] + elif isinstance(_return, type) and issubclass(_return, Enum): + return [list(_return)[int(index)] for index in _response] + raise TypeError( + f"Expected Literal or Enum or list[str], got {type(_return)} with value" + f" {_return}" ) - -class AIEnum(Enum, metaclass=AIEnumMeta): - """ - AIEnum is a class that extends Python's built-in Enum class. - It uses the AIEnumMeta metaclass, which allows additional parameters to be passed - when creating an enum. These parameters are used to customize the behavior - of the AI classifier. - """ - - @classmethod - def _missing_(cls: type[Self], value: object) -> Self: - response: int = cls.call(value) - return list(cls)[response - 1] - + def as_prompt( + self, + *args: P.args, + **kwargs: P.kwargs, + ) -> PromptFunction[BaseModel]: + return PromptFunction[BaseModel].as_grammar( + fn=self.fn, + environment=self.environment, + prompt=self.prompt, + enumerate=self.enumerate, + encoder=self.encoder, + max_tokens=self.max_tokens, + **self.render_kwargs, + )(*args, **kwargs) + + @overload @classmethod - def get_prompt( - cls, + def as_decorator( + cls: type[Self], *, - ctx: Optional[dict[str, Any]] = None, - instructions: Optional[str] = None, - **kwargs: Any, - ) -> Callable[..., Prompt[P]]: - ctx = ctx or cls.__metadata__.ctx or {} - instructions = instructions or cls.__metadata__.instructions - ctx["instructions"] = instructions or ctx.get("instructions", None) - return ai_classifier_prompt(cls, ctx=ctx, **kwargs) # type: ignore # noqa - + environment: Optional[BaseEnvironment] = None, + prompt: Optional[str] = None, + enumerate: bool = True, + encoder: Callable[[str], list[int]] = settings.openai.chat.completions.encoder, + max_tokens: Optional[int] = 1, + acreate: Optional[Callable[..., Awaitable[Any]]] = None, + **render_kwargs: Any, + ) -> Callable[P, Self]: + pass + + @overload @classmethod - def as_prompt( - cls, - value: Any, + def as_decorator( + cls: type[Self], + fn: Callable[P, T], *, - ctx: Optional[dict[str, Any]] = None, - instructions: Optional[str] = None, - mode: Optional[Literal["function", "logit_bias"]] = None, - model: Optional[str] = None, - **model_kwargs: Any, - ) -> dict[str, Any]: - ctx = ctx or cls.__metadata__.ctx or {} - instructions = instructions or cls.__metadata__.instructions - ctx["instructions"] = instructions or ctx.get("instructions", None) - mode = mode or cls.__metadata__.mode - response = cls.get_prompt(instructions=instructions, ctx=ctx)(value).serialize( - model=cls.__metadata__.model, - ) - if mode == "logit_bias": - import tiktoken - - response.pop("functions", None) - response.pop("function_call", None) - response.pop("response_model", None) - encoder = tiktoken.encoding_for_model("gpt-3.5-turbo") - response["logit_bias"] = { - encoder.encode(str(j))[0]: 100 for j in range(1, len(cls) + 1) - } - response["max_tokens"] = 1 - return response - - @classmethod - def as_dict( - cls, - value: Any, - ctx: Optional[dict[str, Any]] = None, - instructions: Optional[str] = None, - mode: Optional[Literal["function", "logit_bias"]] = None, - **kwargs: Any, - ) -> dict[str, Any]: - ctx = ctx or cls.__metadata__.ctx or {} - instructions = instructions or cls.__metadata__.instructions - ctx["instructions"] = instructions or ctx.get("instructions", None) - mode = mode or cls.__metadata__.mode - - response = cls.get_prompt(ctx=ctx, instructions=instructions)(value).to_dict() - if mode == "logit_bias": - import tiktoken - - response.pop("functions", None) - response.pop("function_call", None) - response.pop("response_model", None) - encoder = tiktoken.encoding_for_model("gpt-3.5-turbo") - response["logit_bias"] = { - encoder.encode(str(j))[0]: 100 for j in range(1, len(cls) + 1) - } - response["max_tokens"] = 1 - - return response - - @classmethod - def as_chat_completion( - cls, - value: Any, - ctx: Optional[dict[str, Any]] = None, - instructions: Optional[str] = None, - mode: Optional[Literal["function", "logit_bias"]] = None, - ) -> AbstractChatCompletion[T]: # type: ignore # noqa - ctx = ctx or cls.__metadata__.ctx or {} - instructions = instructions or cls.__metadata__.instructions - mode = mode or cls.__metadata__.mode - ctx["instructions"] = instructions or ctx.get("instructions", None) - return cls.__metadata__.model( - **cls.as_dict(value, ctx=ctx, instructions=instructions, mode=mode) - ) - - @classmethod - def call( - cls, - value: Any, - ctx: Optional[dict[str, Any]] = None, - instructions: Optional[str] = None, - mode: Optional[Literal["function", "logit_bias"]] = None, - ) -> Any: - get_logger("marvin.AIClassifier").debug_kv( - f"Calling `AIEnum` {cls.__name__!r}", f" with value {value!r}." - ) - - ctx = ctx or cls.__metadata__.ctx or {} - instructions = instructions or cls.__metadata__.instructions - ctx["instructions"] = instructions or ctx.get("instructions", None) - mode = mode or cls.__metadata__.mode - chat_completion = cls.as_chat_completion( # type: ignore # noqa - value, ctx=ctx, instructions=instructions, mode=mode - ) - if cls.__metadata__.mode == "logit_bias": - return int(chat_completion.create().response.choices[0].message.content) # type: ignore # noqa - return getattr(chat_completion.create().to_model(), "index") # type: ignore - - @classmethod - async def acall( - cls, - value: Any, - ctx: Optional[dict[str, Any]] = None, - instructions: Optional[str] = None, - mode: Optional[Literal["function", "logit_bias"]] = None, - ) -> Any: - get_logger("marvin.AIClassifier").debug_kv( - f"Calling `AIEnum` {cls.__name__!r}", f" with value {value!r}." - ) - ctx = ctx or cls.__metadata__.ctx or {} - instructions = instructions or cls.__metadata__.instructions - ctx["instructions"] = instructions or ctx.get("instructions", None) - mode = mode or cls.__metadata__.mode - chat_completion = cls.as_chat_completion( # type: ignore # noqa - value, ctx=ctx, instructions=instructions, mode=mode - ) - if cls.__metadata__.mode == "logit_bias": - return int((await chat_completion.acreate()).response.choices[0].message.content) # type: ignore # noqa - return getattr((await chat_completion.acreate()).to_model(), "index") # type: ignore # noqa - - @classmethod - def map(cls, items: list[str], **kwargs: Any) -> list[Any]: - """ - Map the classifier over a list of items. - """ - coros = [cls.acall(item, **kwargs) for item in items] - - # gather returns a future, but run_sync requires a coroutine - async def gather_coros() -> list[Any]: - return await asyncio.gather(*coros) - - results = run_sync(gather_coros()) - return [list(cls)[result - 1] for result in results] + environment: Optional[BaseEnvironment] = None, + prompt: Optional[str] = None, + enumerate: bool = True, + encoder: Callable[[str], list[int]] = settings.openai.chat.completions.encoder, + max_tokens: Optional[int] = 1, + acreate: Optional[Callable[..., Awaitable[Any]]] = None, + **render_kwargs: Any, + ) -> Self: + pass @classmethod def as_decorator( cls: type[Self], - enum: Optional[Enum] = None, - ctx: Optional[dict[str, Any]] = None, - instructions: Optional[str] = None, - mode: Optional[Literal["function", "logit_bias"]] = "logit_bias", - model: Optional[str] = None, - **model_kwargs: Any, - ) -> Self: - if not enum: + fn: Optional[Callable[P, T]] = None, + *, + environment: Optional[BaseEnvironment] = None, + prompt: Optional[str] = None, + enumerate: bool = True, + encoder: Callable[[str], list[int]] = settings.openai.chat.completions.encoder, + max_tokens: Optional[int] = 1, + acreate: Optional[Callable[..., Awaitable[Any]]] = None, + **render_kwargs: Any, + ) -> Union[Self, Callable[[Callable[P, T]], Self]]: + if fn is None: return partial( - cls.as_decorator, - ctx=ctx, - instructions=instructions, - mode=mode, - model=model, - **model_kwargs, - ) # type: ignore - response = cls( - enum.__name__, # type: ignore - {member.name: member.value for member in enum}, # type: ignore - ) - setattr( - response, - "__metadata__", - AIEnumMetaData( - model=ChatCompletion(model=model, **model_kwargs), - ctx=ctx, - instructions=instructions, - mode=mode, - ), + cls, + environment=environment, + prompt=prompt, + enumerate=enumerate, + encoder=encoder, + max_tokens=max_tokens, + acreate=acreate, + **({"prompt": prompt} if prompt else {}), + **render_kwargs, + ) + + return cls( + fn=fn, + environment=environment, + enumerate=enumerate, + encoder=encoder, + max_tokens=max_tokens, + **({"prompt": prompt} if prompt else {}), + **render_kwargs, ) - response.__doc__ = enum.__doc__ # type: ignore - return response - -ai_classifier = AIEnum.as_decorator +@overload +def ai_classifier( + *, + environment: Optional[BaseEnvironment] = None, + prompt: Optional[str] = None, + enumerate: bool = True, + encoder: Callable[[str], list[int]] = settings.openai.chat.completions.encoder, + max_tokens: Optional[int] = 1, + **render_kwargs: Any, +) -> Callable[[Callable[P, T]], Callable[P, T]]: + pass + + +@overload +def ai_classifier( + fn: Callable[P, T], + *, + environment: Optional[BaseEnvironment] = None, + prompt: Optional[str] = None, + enumerate: bool = True, + encoder: Callable[[str], list[int]] = settings.openai.chat.completions.encoder, + max_tokens: Optional[int] = 1, + **render_kwargs: Any, +) -> Callable[P, T]: + pass + + +def ai_classifier( + fn: Optional[Callable[P, T]] = None, + *, + environment: Optional[BaseEnvironment] = None, + prompt: Optional[str] = None, + enumerate: bool = True, + encoder: Callable[[str], list[int]] = settings.openai.chat.completions.encoder, + max_tokens: Optional[int] = 1, + **render_kwargs: Any, +) -> Union[Callable[[Callable[P, T]], Callable[P, T]], Callable[P, T]]: + def wrapper(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + return AIClassifier[P, T].as_decorator( + func, + environment=environment, + prompt=prompt, + enumerate=enumerate, + encoder=encoder, + max_tokens=max_tokens, + **render_kwargs, + )(*args, **kwargs)[0] + + if fn is not None: + return wraps(fn)(partial(wrapper, fn)) + + def decorator(fn: Callable[P, T]) -> Callable[P, T]: + return wraps(fn)(partial(wrapper, fn)) + + return decorator diff --git a/src/marvin/components/ai_function.py b/src/marvin/components/ai_function.py index f22aae3c4..e36f4d7c6 100644 --- a/src/marvin/components/ai_function.py +++ b/src/marvin/components/ai_function.py @@ -1,166 +1,158 @@ import asyncio import inspect -from functools import partial -from typing import Any, Awaitable, Callable, Generic, Optional, TypeVar, Union - +import json +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Generic, + Optional, + TypeVar, + Union, + overload, +) + +from pydantic import BaseModel, Field, ValidationError from typing_extensions import ParamSpec, Self -from marvin._compat import BaseModel, Field -from marvin.core.ChatCompletion import ChatCompletion -from marvin.core.ChatCompletion.abstract import AbstractChatCompletion -from marvin.prompts import Prompt, prompt_fn -from marvin.utilities.async_utils import run_sync +from marvin.components.prompt import PromptFunction +from marvin.serializers import create_tool_from_type +from marvin.utilities.asyncio import ( + ExposeSyncMethodsMixin, + expose_sync_method, + run_async, +) +from marvin.utilities.jinja import ( + BaseEnvironment, +) from marvin.utilities.logging import get_logger -T = TypeVar("T", bound=BaseModel) +if TYPE_CHECKING: + from openai.types.chat import ChatCompletion -A = TypeVar("A", bound=Any) +T = TypeVar("T") P = ParamSpec("P") -def ai_fn_prompt( - func: Callable[P, Any], - ctx: Optional[dict[str, Any]] = None, - **kwargs: Any, -) -> Callable[P, Prompt[P]]: - return_annotation: Any = inspect.signature(func).return_annotation - - @prompt_fn( - ctx={"ctx": ctx or {}, "func": func, "inspect": inspect}, - response_model=return_annotation, - serialize_on_call=False, - **kwargs, - ) - def prompt_wrapper(*args: P.args, **kwargs: P.kwargs) -> None: # type: ignore # noqa - """ - System: {{ctx.get('instructions') if ctx.get('instructions')}} - +class AIFunction(BaseModel, Generic[P, T], ExposeSyncMethodsMixin): + fn: Optional[Callable[P, T]] = None + environment: Optional[BaseEnvironment] = None + prompt: Optional[str] = Field(default=inspect.cleandoc(""" Your job is to generate likely outputs for a Python function with the following signature and docstring: - {{'def' + ''.join(inspect.getsource(func).split('def')[1:])}} + {{_source_code}} The user will provide function inputs (if any) and you must respond with - the most likely result, which must be valid, double-quoted JSON. - - User: The function was called with the following inputs: - {% set sig = inspect.signature(func) %} - {% set binds = sig.bind(*args, **kwargs) %} - {% set defaults = binds.apply_defaults() %} - {% set params = binds.arguments %} - {%for (arg, value) in params.items()%} + the most likely result. + + user: The function was called with the following inputs: + {%for (arg, value) in _arguments.items()%} - {{ arg }}: {{ value }} {% endfor %} What is its output? - """ + """)) + name: str = "FormatResponse" + description: str = "Formats the response." + field_name: str = "data" + field_description: str = "The data to format." + render_kwargs: dict[str, Any] = Field(default_factory=dict) - return prompt_wrapper # type: ignore + create: Optional[Callable[..., "ChatCompletion"]] = Field(default=None) + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Union[T, Awaitable[T]]: + if self.fn is None: + raise NotImplementedError -class AIFunction(BaseModel, Generic[P, T]): - fn: Callable[P, Any] - ctx: Optional[dict[str, Any]] = None - model: Any = Field(default_factory=ChatCompletion) - response_model_name: Optional[str] = Field(default=None, exclude=True) - response_model_description: Optional[str] = Field(default=None, exclude=True) - response_model_field_name: Optional[str] = Field(default=None, exclude=True) + from marvin import settings - def __call__( - self, - *args: P.args, - **kwargs: P.kwargs, - ) -> Any: - get_logger("marvin.AIFunction").debug_kv( - f"Calling `ai_fn` {self.fn.__name__!r}", - f"with args: {args} kwargs: {kwargs}", - ) - - return self.call(*args, **kwargs) + logger = get_logger("marvin.ai_fn") - def get_prompt( - self, - ) -> Callable[P, Prompt[P]]: - return ai_fn_prompt( - self.fn, - ctx=self.ctx, - response_model_name=self.response_model_name, - response_model_description=self.response_model_description, - response_model_field_name=self.response_model_field_name, - ) - - def as_prompt( - self, - *args: P.args, - **kwargs: P.kwargs, - ) -> dict[str, Any]: - return self.get_prompt()(*args, **kwargs).serialize( - model=self.model, + logger.debug_kv( + "AI Function Call", + f"Calling {self.fn.__name__} with {args} and {kwargs}", + "blue", ) - def as_dict( - self, - *args: P.args, - **kwargs: P.kwargs, - ) -> dict[str, Any]: - return self.get_prompt()(*args, **kwargs).to_dict() - - def as_chat_completion( - self, - *args: P.args, - **kwargs: P.kwargs, - ) -> AbstractChatCompletion[T]: - return self.model(**self.as_dict(*args, **kwargs)) - - def call( - self, - *args: P.args, - **kwargs: P.kwargs, - ) -> Any: - model_instance = self.as_chat_completion(*args, **kwargs).create().to_model() - response_model_field_name = self.response_model_field_name or "output" - - if (output := getattr(model_instance, response_model_field_name, None)) is None: - return model_instance - - return output + is_async_fn = asyncio.iscoroutinefunction(self.fn) - async def acall( - self, - *args: P.args, - **kwargs: P.kwargs, - ) -> Any: - model_instance = ( - await self.as_chat_completion(*args, **kwargs).acreate() - ).to_model() - - response_model_field_name = self.response_model_field_name or "output" - - if (output := getattr(model_instance, response_model_field_name, None)) is None: - return model_instance - - return output + call = "async_call" if is_async_fn else "sync_call" + create = ( + self.create or settings.openai.chat.completions.acreate + if is_async_fn + else settings.openai.chat.completions.create + ) - def map(self, *map_args: list[Any], **map_kwargs: list[Any]): + result = getattr(self, call)(create, *args, **kwargs) + + logger.debug_kv("AI Function Call", f"Returned {result}", "blue") + + return result + + async def async_call( + self, acreate: Callable[..., Awaitable[Any]], *args: P.args, **kwargs: P.kwargs + ) -> T: + _response = await acreate(**self.as_prompt(*args, **kwargs).serialize()) + return self.parse(_response) + + def sync_call( + self, create: Callable[..., Any], *args: P.args, **kwargs: P.kwargs + ) -> T: + _response = create(**self.as_prompt(*args, **kwargs).serialize()) + return self.parse(_response) + + def parse(self, response: "ChatCompletion") -> T: + tool_calls = response.choices[0].message.tool_calls + if tool_calls is None: + raise NotImplementedError + if self.fn is None: + raise NotImplementedError + arguments = tool_calls[0].function.arguments + + tool = create_tool_from_type( + _type=self.fn.__annotations__["return"], + model_name=self.name, + model_description=self.description, + field_name=self.field_name, + field_description=self.field_description, + ).function + if not tool or not tool.model: + raise NotImplementedError + try: + return getattr(tool.model.model_validate_json(arguments), self.field_name) + except ValidationError: + # When the user provides a dict obj as a type hint, the arguments + # are returned usually as an object and not a nested dict. + _arguments: str = json.dumps({self.field_name: json.loads(arguments)}) + return getattr(tool.model.model_validate_json(_arguments), self.field_name) + + @expose_sync_method("map") + async def amap(self, *map_args: list[Any], **map_kwargs: list[Any]) -> list[T]: """ Map the AI function over a sequence of arguments. Runs concurrently. + A `map` twin method is provided by the `expose_sync_method` decorator. + + You can use `map` or `amap` synchronously or asynchronously, respectively, + regardless of whether the user function is synchronous or asynchronous. + Arguments should be provided as if calling the function normally, but each argument must be a list. The function is called once for each item in the list, and the results are returned in a list. - This method should be called synchronously. - For example, fn.map([1, 2]) is equivalent to [fn(1), fn(2)]. fn.map([1, 2], x=['a', 'b']) is equivalent to [fn(1, x='a'), fn(2, x='b')]. """ - return run_sync(self.amap(*map_args, **map_kwargs)) - - async def amap(self, *map_args: list[Any], **map_kwargs: list[Any]) -> list[Any]: tasks: list[Any] = [] - if map_args: + if map_args and map_kwargs: + max_length = max( + len(arg) for arg in (map_args + tuple(map_kwargs.values())) + ) + elif map_args: max_length = max(len(arg) for arg in map_args) else: max_length = max(len(v) for v in map_kwargs.values()) @@ -172,57 +164,153 @@ async def amap(self, *map_args: list[Any], **map_kwargs: list[Any]) -> list[Any] if map_kwargs else {} ) - tasks.append(self.acall(*call_args, **call_kwargs)) + + tasks.append(run_async(self, *call_args, **call_kwargs)) return await asyncio.gather(*tasks) + def as_prompt( + self, + *args: P.args, + **kwargs: P.kwargs, + ) -> PromptFunction[BaseModel]: + return PromptFunction[BaseModel].as_function_call( + fn=self.fn, + environment=self.environment, + prompt=self.prompt, + model_name=self.name, + model_description=self.description, + field_name=self.field_name, + field_description=self.field_description, + **self.render_kwargs, + )(*args, **kwargs) + + @overload + @classmethod + def as_decorator( + cls: type[Self], + *, + environment: Optional[BaseEnvironment] = None, + prompt: Optional[str] = None, + model_name: str = "FormatResponse", + model_description: str = "Formats the response.", + field_name: str = "data", + field_description: str = "The data to format.", + acreate: Optional[Callable[..., Awaitable[Any]]] = None, + **render_kwargs: Any, + ) -> Callable[P, Self]: + pass + + @overload + @classmethod + def as_decorator( + cls: type[Self], + fn: Callable[P, T], + *, + environment: Optional[BaseEnvironment] = None, + prompt: Optional[str] = None, + model_name: str = "FormatResponse", + model_description: str = "Formats the response.", + field_name: str = "data", + field_description: str = "The data to format.", + acreate: Optional[Callable[..., Awaitable[Any]]] = None, + **render_kwargs: Any, + ) -> Self: + pass + @classmethod def as_decorator( cls: type[Self], fn: Optional[Callable[P, T]] = None, - ctx: Optional[dict[str, Any]] = None, - instructions: Optional[str] = None, - response_model_name: Optional[str] = None, - response_model_description: Optional[str] = None, - response_model_field_name: Optional[str] = None, - model: Optional[str] = None, - **model_kwargs: Any, - ) -> Union[Callable[P, T], Callable[P, Awaitable[T]]]: - if not fn: - return partial( - cls.as_decorator, - ctx=ctx, - instructions=instructions, - response_model_name=response_model_name, - response_model_description=response_model_description, - response_model_field_name=response_model_field_name, - model=model, - **model_kwargs, - ) # type: ignore - - if not inspect.iscoroutinefunction(fn): + *, + environment: Optional[BaseEnvironment] = None, + prompt: Optional[str] = None, + model_name: str = "FormatResponse", + model_description: str = "Formats the response.", + field_name: str = "data", + field_description: str = "The data to format.", + **render_kwargs: Any, + ) -> Union[Callable[[Callable[P, T]], Self], Self]: + def decorator(func: Callable[P, T]) -> Self: return cls( - fn=fn, - ctx={"instructions": instructions, **(ctx or {})}, - response_model_name=response_model_name, - response_model_description=response_model_description, - response_model_field_name=response_model_field_name, - model=ChatCompletion(model=model, **model_kwargs), - ) - else: - return AsyncAIFunction[P, T]( - fn=fn, - ctx={"instructions": instructions, **(ctx or {})}, - response_model_name=response_model_name, - response_model_description=response_model_description, - response_model_field_name=response_model_field_name, - model=ChatCompletion(model=model, **model_kwargs), + fn=func, + environment=environment, + name=model_name, + description=model_description, + field_name=field_name, + field_description=field_description, + **({"prompt": prompt} if prompt else {}), + **render_kwargs, ) + if fn is not None: + return decorator(fn) + + return decorator + + +@overload +def ai_fn( + *, + environment: Optional[BaseEnvironment] = None, + prompt: Optional[str] = None, + model_name: str = "FormatResponse", + model_description: str = "Formats the response.", + field_name: str = "data", + field_description: str = "The data to format.", + **render_kwargs: Any, +) -> Callable[[Callable[P, T]], Callable[P, T]]: + pass + + +@overload +def ai_fn( + fn: Callable[P, T], + *, + environment: Optional[BaseEnvironment] = None, + prompt: Optional[str] = None, + model_name: str = "FormatResponse", + model_description: str = "Formats the response.", + field_name: str = "data", + field_description: str = "The data to format.", + **render_kwargs: Any, +) -> Callable[P, T]: + pass + + +def ai_fn( + fn: Optional[Callable[P, T]] = None, + *, + environment: Optional[BaseEnvironment] = None, + prompt: Optional[str] = None, + model_name: str = "FormatResponse", + model_description: str = "Formats the response.", + field_name: str = "data", + field_description: str = "The data to format.", + **render_kwargs: Any, +) -> Union[Callable[[Callable[P, T]], Callable[P, T]], Callable[P, T]]: + if fn is not None: + return AIFunction.as_decorator( # type: ignore + fn=fn, + environment=environment, + prompt=prompt, + model_name=model_name, + model_description=model_description, + field_name=field_name, + field_description=field_description, + **render_kwargs, + ) -class AsyncAIFunction(AIFunction[P, T]): - async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Any: - return await super().acall(*args, **kwargs) - + def decorator(func: Callable[P, T]) -> Callable[P, T]: + return AIFunction.as_decorator( # type: ignore + fn=func, + environment=environment, + prompt=prompt, + model_name=model_name, + model_description=model_description, + field_name=field_name, + field_description=field_description, + **render_kwargs, + ) -ai_fn = AIFunction.as_decorator + return decorator diff --git a/src/marvin/components/ai_image.py b/src/marvin/components/ai_image.py new file mode 100644 index 000000000..1dde168ed --- /dev/null +++ b/src/marvin/components/ai_image.py @@ -0,0 +1,154 @@ +from functools import partial, wraps +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Optional, + TypeVar, + Union, + overload, +) + +from pydantic import BaseModel, Field +from typing_extensions import ParamSpec, Self + +from marvin.components.prompt import PromptFunction +from marvin.utilities.jinja import ( + BaseEnvironment, +) + +if TYPE_CHECKING: + from openai.types.images_response import ImagesResponse + +T = TypeVar("T") + +P = ParamSpec("P") + + +class AIImage(BaseModel, Generic[P]): + fn: Optional[Callable[P, Any]] = None + environment: Optional[BaseEnvironment] = None + prompt: Optional[str] = Field(default=None) + render_kwargs: dict[str, Any] = Field(default_factory=dict) + + generate: Optional[Callable[..., "ImagesResponse"]] = Field(default=None) + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> "ImagesResponse": + generate = self.generate + if self.fn is None: + raise NotImplementedError + if generate is None: + from marvin.settings import settings + + generate = settings.openai.images.generate + + _response = generate(prompt=self.as_prompt(*args, **kwargs)) + + return _response + + def as_prompt( + self, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: + return ( + PromptFunction[BaseModel] + .as_function_call( + fn=self.fn, + environment=self.environment, + prompt=self.prompt, + **self.render_kwargs, + )(*args, **kwargs) + .messages[0] + .content + ) + + @overload + @classmethod + def as_decorator( + cls: type[Self], + *, + environment: Optional[BaseEnvironment] = None, + prompt: Optional[str] = None, + **render_kwargs: Any, + ) -> Callable[P, Self]: + pass + + @overload + @classmethod + def as_decorator( + cls: type[Self], + fn: Callable[P, Any], + *, + environment: Optional[BaseEnvironment] = None, + prompt: Optional[str] = None, + **render_kwargs: Any, + ) -> Self: + pass + + @classmethod + def as_decorator( + cls: type[Self], + fn: Optional[Callable[P, Any]] = None, + *, + environment: Optional[BaseEnvironment] = None, + prompt: Optional[str] = None, + **render_kwargs: Any, + ) -> Union[Self, Callable[[Callable[P, Any]], Self]]: + if fn is None: + return partial( + cls, + environment=environment, + **({"prompt": prompt} if prompt else {}), + **render_kwargs, + ) + + return cls( + fn=fn, + environment=environment, + **({"prompt": prompt} if prompt else {}), + **render_kwargs, + ) + + +def ai_image( + fn: Optional[Callable[P, Any]] = None, + *, + environment: Optional[BaseEnvironment] = None, + prompt: Optional[str] = None, + **render_kwargs: Any, +) -> Union[ + Callable[[Callable[P, Any]], Callable[P, "ImagesResponse"]], + Callable[P, "ImagesResponse"], +]: + def wrapper( + func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs + ) -> "ImagesResponse": + return AIImage[P].as_decorator( + func, + environment=environment, + prompt=prompt, + **render_kwargs, + )(*args, **kwargs) + + if fn is not None: + return wraps(fn)(partial(wrapper, fn)) + + def decorator(fn: Callable[P, Any]) -> Callable[P, "ImagesResponse"]: + return wraps(fn)(partial(wrapper, fn)) + + return decorator + + +def create_image( + prompt: str, + environment: Optional[BaseEnvironment] = None, + generate: Optional[Callable[..., "ImagesResponse"]] = None, + **model_kwargs: Any, +) -> "ImagesResponse": + if generate is None: + from marvin.settings import settings + + generate = settings.openai.images.generate + return generate(prompt=prompt, **model_kwargs) diff --git a/src/marvin/components/ai_model.py b/src/marvin/components/ai_model.py index 86e4caa90..39b83126c 100644 --- a/src/marvin/components/ai_model.py +++ b/src/marvin/components/ai_model.py @@ -1,411 +1,105 @@ import asyncio import inspect from functools import partial -from typing import Any, Callable, Optional, TypeVar - -from typing_extensions import ParamSpec, Self - -from marvin._compat import BaseModel -from marvin.core.ChatCompletion import ChatCompletion -from marvin.core.ChatCompletion.abstract import AbstractChatCompletion -from marvin.prompts import Prompt, prompt_fn -from marvin.utilities.async_utils import run_sync -from marvin.utilities.logging import get_logger - -T = TypeVar("T", bound=BaseModel) - -A = TypeVar("A", bound=Any) - -P = ParamSpec("P") - - -def ai_model_prompt( - cls: type[BaseModel], - ctx: Optional[dict[str, Any]] = None, - **kwargs: Any, -) -> Callable[[str], Prompt[P]]: - description = cls.__doc__ or "" - if ctx and ctx.get("instructions") and isinstance(ctx.get("instructions"), str): - instructions = str(ctx.get("instructions")) - description += "\n" + instructions if (instructions != description) else "" - - @prompt_fn( - ctx={"ctx": ctx or {}, "inspect": inspect}, - response_model=cls, - response_model_name="FormatResponse", - response_model_description=description, - serialize_on_call=False, - ) - def prompt_wrapper(text: str) -> None: # type: ignore # noqa - """ - The user will provide text that you need to parse into a - structured form {{'(note you must also: ' + ctx.get('instructions') + ')' if ctx.get('instructions')}}. - To validate your response, you must call the - `{{response_model.__name__}}` function. - Use the provided text and context to extract, deduce, or infer - any parameters needed by `{{response_model.__name__}}`, including any missing - data. - - You have been provided the following context to perform your task: - - The current time is {{now()}}. - {% set context = ctx.get('context_fn')(text).items() if ctx.get('context_fn') %} - {% if context %} - {%for (arg, value) in context%} - - {{ arg }}: {{ value }}\n - {% endfor %} - {% endif %} - - User: The text to parse: {{text}} - - - """ # noqa - - return prompt_wrapper # type: ignore - - -class AIModel(BaseModel): - def __init__( - self, - text: Optional[str] = None, - /, - instructions_: Optional[str] = None, - **kwargs: Any, - ): - if text: - kwargs.update(self.__class__.call(text, instructions=instructions_)) - - super().__init__(**kwargs) - - @classmethod - def get_prompt( - cls, - ctx: Optional[dict[str, Any]] = None, - instructions: Optional[str] = None, - response_model_name: Optional[str] = None, - response_model_description: Optional[str] = None, - response_model_field_name: Optional[str] = None, - ) -> Callable[[str], Prompt[P]]: - ctx = ctx or getattr(cls, "__metadata__", {}).get("ctx", {}) - instructions = ( # type: ignore - "\n".join( - list( - filter( - bool, - [ - instructions, - getattr(cls, "__metadata__", {}).get("instructions"), - ], - ) - ) # type: ignore - ) - or None - ) - - response_model_name = response_model_name or getattr( - cls, "__metadata__", {} - ).get("response_model_name") - response_model_description = response_model_description or getattr( - cls, "__metadata__", {} - ).get("response_model_description") - response_model_field_name = response_model_field_name or getattr( - cls, "__metadata__", {} - ).get("response_model_field_name") - - return ai_model_prompt( - cls, - ctx=((ctx or {}) | {"instructions": instructions}), - response_model_name=response_model_name, - response_model_description=response_model_description, - response_model_field_name=response_model_field_name, +from typing import Any, Callable, Optional, TypeVar, Union, overload + +from marvin.components.ai_function import ai_fn +from marvin.utilities.asyncio import run_sync +from marvin.utilities.jinja import BaseEnvironment + +T = TypeVar("T") + +prompt = inspect.cleandoc( + "The user will provide context as text that you need to parse into a structured" + " form. To validate your response, you must call the" + " `{{_response_model.function.name}}` function. Use the provided text to extract or" + " infer any parameters needed by `{{_response_model.function.name}}`, including any" + " missing data." + " user: The text to parse: {{text}}" +) + + +@overload +def ai_model( + *, + environment: Optional[BaseEnvironment] = None, + prompt: Optional[str] = prompt, + model_name: str = "FormatResponse", + model_description: str = "Formats the response.", + field_name: str = "data", + field_description: str = "The data to format.", + **render_kwargs: Any, +) -> Callable[[T], Callable[[str], T]]: + pass + + +@overload +def ai_model( + _type: Optional[T], + *, + environment: Optional[BaseEnvironment] = None, + prompt: Optional[str] = prompt, + model_name: str = "FormatResponse", + model_description: str = "Formats the response.", + field_name: str = "data", + field_description: str = "The data to format.", + **render_kwargs: Any, +) -> Callable[[str], T]: + pass + + +def ai_model( + _type: Optional[T] = None, + *, + environment: Optional[BaseEnvironment] = None, + prompt: Optional[str] = prompt, + model_name: str = "FormatResponse", + model_description: str = "Formats the response.", + field_name: str = "data", + field_description: str = "The data to format.", + **render_kwargs: Any, +) -> Union[Callable[[T], Callable[[str], T]], Callable[[str], T],]: + def wrapper(_type_: T, text: str) -> T: + def extract(text: str) -> T: # type: ignore + pass + + extract.__annotations__["return"] = _type_ + + return ai_fn( + extract, + environment=environment, + prompt=prompt, + model_name=model_name, + model_description=model_description, + field_name=field_name, + field_description=field_description, + **render_kwargs, + )(text) + + async def async_wrapper(_type_: T, text: str) -> T: + return wrapper(_type_, text) + + async def amap(inputs: list[str]) -> list[T]: + return await asyncio.gather( + *[ + asyncio.create_task(async_wrapper(_type, input_text)) + for input_text in inputs + ] ) - @classmethod - def as_prompt( - cls, - text: str, - ctx: Optional[dict[str, Any]] = None, - instructions: Optional[str] = None, - response_model_name: Optional[str] = None, - response_model_description: Optional[str] = None, - response_model_field_name: Optional[str] = None, - model: Optional[str] = None, - **model_kwargs: Any, - ) -> dict[str, Any]: - metadata = getattr(cls, "__metadata__", {}) - - # Set default values using a loop to reduce repetition - default_keys = [ - "ctx", - "instructions", - "response_model_name", - "response_model_description", - "response_model_field_name", - "model", - "model_kwargs", - ] - local_vars = locals() - for key in default_keys: - if local_vars.get(key, None) is None: - local_vars[key] = metadata.get(key, {}) - - return cls.get_prompt( - ctx=ctx, - instructions=instructions, - response_model_name=response_model_name, - response_model_description=response_model_description, - response_model_field_name=response_model_field_name, - )(text).serialize(model=ChatCompletion(model=model, **model_kwargs)) - - @classmethod - def as_dict( - cls, - text: str, - ctx: Optional[dict[str, Any]] = None, - instructions: Optional[str] = None, - response_model_name: Optional[str] = None, - response_model_description: Optional[str] = None, - response_model_field_name: Optional[str] = None, - model: Optional[str] = None, - **model_kwargs: Any, - ) -> dict[str, Any]: - metadata = getattr(cls, "__metadata__", {}) - - # Set default values using a loop to reduce repetition - default_keys = [ - "ctx", - "instructions", - "response_model_name", - "response_model_description", - "response_model_field_name", - "model", - "model_kwargs", - ] - local_vars = locals() - for key in default_keys: - if local_vars.get(key, None) is None: - local_vars[key] = metadata.get(key, {}) - return cls.get_prompt( - ctx=ctx, - instructions=instructions, - response_model_name=response_model_name, - response_model_description=response_model_description, - response_model_field_name=response_model_field_name, - )(text).to_dict() - - @classmethod - def as_chat_completion( - cls, - text: str, - ctx: Optional[dict[str, Any]] = None, - instructions: Optional[str] = None, - response_model_name: Optional[str] = None, - response_model_description: Optional[str] = None, - response_model_field_name: Optional[str] = None, - model: Optional[str] = None, - **model_kwargs: Any, - ) -> AbstractChatCompletion[T]: # type: ignore - metadata = getattr(cls, "__metadata__", {}) + def map(inputs: list[str]) -> list[T]: + return run_sync(amap(inputs)) - # Set default values using a loop to reduce repetition - default_keys = [ - "ctx", - "instructions", - "response_model_name", - "response_model_description", - "response_model_field_name", - "model", - "model_kwargs", - ] - local_vars = locals() - for key in default_keys: - if local_vars.get(key, None) is None: - local_vars[key] = metadata.get(key, {}) - - return ChatCompletion(model=model, **model_kwargs)( - **cls.as_dict( - text, - ctx=ctx, - instructions=instructions, - response_model_name=response_model_name, - response_model_description=response_model_description, - response_model_field_name=response_model_field_name, - ) - ) # type: ignore - - @classmethod - def call( - cls: type[Self], - text: str, - *, - ctx: Optional[dict[str, Any]] = None, - instructions: Optional[str] = None, - response_model_name: Optional[str] = None, - response_model_description: Optional[str] = None, - response_model_field_name: Optional[str] = None, - model: Optional[str] = None, - **model_kwargs: Any, - ) -> Self: - metadata = getattr(cls, "__metadata__", {}) - - get_logger("marvin.AIModel").debug_kv( - f"Calling `ai_model` {cls.__name__!r}", - f"with {text!r}", - ) - - # Set default values using a loop to reduce repetition - default_keys = [ - "ctx", - "instructions", - "response_model_name", - "response_model_description", - "response_model_field_name", - "model", - "model_kwargs", - ] - local_vars = locals() - for key in default_keys: - if local_vars.get(key, None) is None: - local_vars[key] = metadata.get(key, {}) - - _model: Self = ( # type: ignore - cls.as_chat_completion( - text, - ctx=ctx, - instructions=instructions, - response_model_name=response_model_name, - response_model_description=response_model_description, - response_model_field_name=response_model_field_name, - model=model, - **model_kwargs, - ) - .create() - .to_model(cls) - ) - return _model # type: ignore - - @classmethod - async def acall( - cls: type[Self], - text: str, - *, - ctx: Optional[dict[str, Any]] = None, - instructions: Optional[str] = None, - response_model_name: Optional[str] = None, - response_model_description: Optional[str] = None, - response_model_field_name: Optional[str] = None, - model: Optional[str] = None, - **model_kwargs: Any, - ) -> Self: - metadata = getattr(cls, "__metadata__", {}) - - get_logger("marvin.AIModel").debug_kv( - f"Calling `ai_model` {cls.__name__!r}", - f"with {text!r}", - ) - - # Set default values using a loop to reduce repetition - default_keys = [ - "ctx", - "instructions", - "response_model_name", - "response_model_description", - "response_model_field_name", - "model", - "model_kwargs", - ] - local_vars = locals() - for key in default_keys: - if local_vars.get(key, None) is None: - local_vars[key] = metadata.get(key, {}) - - _model: Self = ( # type: ignore - await cls.as_chat_completion( - text, - ctx=ctx, - instructions=instructions, - response_model_name=response_model_name, - response_model_description=response_model_description, - response_model_field_name=response_model_field_name, - model=model, - **model_kwargs, - ).acreate() # type: ignore - ).to_model(cls) - return _model # type: ignore - - @classmethod - def map(cls, *map_args: list[str], **map_kwargs: list[Any]): - """ - Map the AI function over a sequence of arguments. Runs concurrently. - - Arguments should be provided as if calling the function normally, but - each argument must be a list. The function is called once for each item - in the list, and the results are returned in a list. - - This method should be called synchronously. - - For example, fn.map([1, 2]) is equivalent to [fn(1), fn(2)]. - - fn.map([1, 2], x=['a', 'b']) is equivalent to [fn(1, x='a'), fn(2, x='b')]. - """ - return run_sync(cls.amap(*map_args, **map_kwargs)) - - @classmethod - async def amap(cls, *map_args: list[str], **map_kwargs: list[Any]) -> list[Any]: - tasks: list[Any] = [] - if map_args: - max_length = max(len(arg) for arg in map_args) - else: - max_length = max(len(v) for v in map_kwargs.values()) - - for i in range(max_length): - call_args: list[str] = [ - arg[i] if i < len(arg) else None for arg in map_args - ] # type: ignore - - tasks.append(cls.acall(*call_args, **map_kwargs)) - - return await asyncio.gather(*tasks) - - @classmethod - def as_decorator( - cls: type[Self], - base_model: Optional[type[BaseModel]] = None, - ctx: Optional[dict[str, Any]] = None, - instructions: Optional[str] = None, - response_model_name: Optional[str] = None, - response_model_description: Optional[str] = None, - response_model_field_name: Optional[str] = None, - model: Optional[str] = None, - **model_kwargs: Any, - ) -> type[BaseModel]: - if not base_model: - return partial( - cls.as_decorator, - ctx=ctx, - instructions=instructions, - response_model_name=response_model_name, - response_model_description=response_model_description, - response_model_field_name=response_model_field_name, - model=model, - **model_kwargs, - ) # type: ignore - - response = type(base_model.__name__, (cls, base_model), {}) - response.__doc__ = base_model.__doc__ - setattr( - response, - "__metadata__", - { - "ctx": ctx or {}, - "instructions": instructions, - "response_model_name": response_model_name, - "response_model_description": response_model_description, - "response_model_field_name": response_model_field_name, - "model": model, - "model_kwargs": model_kwargs, - }, - ) - return response # type: ignore + if _type is not None: + wrapper_with_map = partial(wrapper, _type) + wrapper_with_map.amap = amap + wrapper_with_map.map = map + return wrapper_with_map + def decorator(_type_: T) -> Callable[[str], T]: + decorated = partial(wrapper, _type_) + decorated.amap = amap + decorated.map = map + return decorated -ai_model = AIModel.as_decorator + return decorator diff --git a/src/marvin/components/ai_model_factory.py b/src/marvin/components/ai_model_factory.py deleted file mode 100644 index 33a6112a3..000000000 --- a/src/marvin/components/ai_model_factory.py +++ /dev/null @@ -1,19 +0,0 @@ -from typing import Optional - -from marvin._compat import BaseModel -from marvin.components.ai_model import ai_model - - -class DataSchema(BaseModel): - title: Optional[str] = None - type: Optional[str] = None - properties: Optional[dict] = {} - required: Optional[list[str]] = [] - additionalProperties: bool = False - definitions: dict = {} - description: Optional[str] = None - - -# If you're reading this and expected something fancier, -# I'm sorry to disappoint you. It's this simple. -AIModelFactory = ai_model(DataSchema) diff --git a/src/marvin/components/library/ai_functions.py b/src/marvin/components/library/ai_functions.py deleted file mode 100644 index 42cb6e616..000000000 --- a/src/marvin/components/library/ai_functions.py +++ /dev/null @@ -1,13 +0,0 @@ -from datetime import datetime - -from marvin.components.ai_function import ai_fn - - -@ai_fn -def summarize_text(text: str, specifications: str = "concise, comprehensive") -> str: - """generates a summary of `text` according to the `specifications`""" - - -@ai_fn -def make_datetime(description: str, tz: str = "UTC") -> datetime: - """generates a datetime from a description""" diff --git a/src/marvin/components/library/ai_models.py b/src/marvin/components/library/ai_models.py deleted file mode 100644 index cc149cab4..000000000 --- a/src/marvin/components/library/ai_models.py +++ /dev/null @@ -1,101 +0,0 @@ -import json -from typing import Optional - -import httpx -from typing_extensions import Self - -from marvin import ai_model -from marvin._compat import BaseModel, Field, SecretStr, field_validator -from marvin.settings import MarvinBaseSettings - - -class DiscourseSettings(MarvinBaseSettings): - class Config: - env_prefix = "MARVIN_DISCOURSE_" - - help_category_id: Optional[int] = Field( - None, env=["MARVIN_DISCOURSE_HELP_CATEGORY_ID"] - ) - api_key: Optional[SecretStr] = Field(None, env=["MARVIN_DISCOURSE_API_KEY"]) - api_username: Optional[str] = Field(None, env=["MARVIN_DISCOURSE_API_USERNAME"]) - url: Optional[str] = Field(None, env=["MARVIN_DISCOURSE_URL"]) - - -discourse_settings = DiscourseSettings() - - -@ai_model(instructions="Produce a comprehensive Discourse post from text.") -class DiscoursePost(BaseModel): - title: Optional[str] = Field( - description="A fitting title for the post.", - example="How to install Prefect", - ) - question: Optional[str] = Field( - description="The question that is posed in the text.", - example="How do I install Prefect?", - ) - answer: Optional[str] = Field( - description=( - "The complete answer to the question posed in the text." - " This answer should comprehensively answer the question, " - " explain any relevant concepts, and have a friendly, academic tone," - " and provide any links to relevant resources found in the thread." - " This answer should be written in Markdown, with any code blocks" - " formatted as `code` or ```\n```." - ) - ) - - topic_url: Optional[str] = Field(None) - - @field_validator("title", "question", "answer") - def non_empty_string(cls, value): - if not value: - raise ValueError("this field cannot be empty") - return value - - @classmethod - def from_slack_thread(cls, messages: list[str]) -> Self: - return cls("here is the transcript:\n" + "\n\n".join(messages)) - - async def publish( - self, - topic: str = None, - category: Optional[int] = None, - url: str = discourse_settings.url, - tags: list[str] = None, - ) -> str: - if not category: - category = discourse_settings.help_category_id - - headers = { - "Api-Key": discourse_settings.api_key.get_secret_value(), - "Api-Username": discourse_settings.api_username, - "Content-Type": "application/json", - } - data = { - "title": self.title, - "raw": ( - f"## **{self.question}**\n\n{self.answer}" - "\n\n---\n\n*This topic was created by Marvin.*" - ), - "category": category, - "tags": tags or ["marvin"], - } - - if topic: - data["tags"].append(topic) - - async with httpx.AsyncClient() as client: - response = await client.post( - url=f"{url}/posts.json", headers=headers, data=json.dumps(data) - ) - - response.raise_for_status() - - response_data = response.json() - topic_id = response_data.get("topic_id") - post_number = response_data.get("post_number") - - self.topic_url = f"{url}/t/{topic_id}/{post_number}" - - return self.topic_url diff --git a/src/marvin/components/prompt.py b/src/marvin/components/prompt.py new file mode 100644 index 000000000..3cf04c07c --- /dev/null +++ b/src/marvin/components/prompt.py @@ -0,0 +1,242 @@ +import inspect +import re +from functools import partial, wraps +from typing import ( + Any, + Callable, + Optional, + TypeVar, + Union, + overload, +) + +import pydantic +from pydantic import BaseModel +from typing_extensions import ParamSpec, Self + +from marvin.requests import BaseMessage as Message +from marvin.requests import Prompt, Tool +from marvin.serializers import ( + create_grammar_from_vocabulary, + create_tool_from_type, + create_vocabulary_from_type, +) +from marvin.settings import settings +from marvin.utilities.jinja import ( + BaseEnvironment, + Transcript, +) + +P = ParamSpec("P") +T = TypeVar("T") +U = TypeVar("U", bound=BaseModel) + + +class PromptFunction(Prompt[U]): + messages: list[Message] = pydantic.Field(default_factory=list) + + def serialize(self) -> dict[str, Any]: + return self.model_dump(exclude_unset=True, exclude_none=True) + + @overload + @classmethod + def as_grammar( + cls: type[Self], + *, + environment: Optional[BaseEnvironment] = None, + prompt: Optional[str] = None, + enumerate: bool = True, + encoder: Callable[[str], list[int]] = settings.openai.chat.completions.encoder, + max_tokens: Optional[int] = 1, + ) -> Callable[[Callable[P, Any]], Callable[P, Self]]: + pass + + @overload + @classmethod + def as_grammar( + cls: type[Self], + fn: Optional[Callable[P, Any]] = None, + *, + environment: Optional[BaseEnvironment] = None, + prompt: Optional[str] = None, + enumerate: bool = True, + encoder: Callable[[str], list[int]] = settings.openai.chat.completions.encoder, + max_tokens: Optional[int] = 1, + ) -> Callable[P, Self]: + pass + + @classmethod + def as_grammar( + cls: type[Self], + fn: Optional[Callable[P, Any]] = None, + *, + environment: Optional[BaseEnvironment] = None, + prompt: Optional[str] = None, + enumerate: bool = True, + encoder: Callable[[str], list[int]] = settings.openai.chat.completions.encoder, + max_tokens: Optional[int] = 1, + **kwargs: Any, + ) -> Union[Callable[[Callable[P, Any]], Callable[P, Self]], Callable[P, Self],]: + def wrapper(func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> Self: + # Get the signature of the function + signature = inspect.signature(func) + params = signature.bind(*args, **kwargs) + params.apply_defaults() + + vocabulary = create_vocabulary_from_type( + inspect.signature(func).return_annotation + ) + + grammar = create_grammar_from_vocabulary( + vocabulary=vocabulary, + encoder=encoder, + _enumerate=enumerate, + max_tokens=max_tokens, + ) + + messages = Transcript( + content=prompt or func.__doc__ or "" + ).render_to_messages( + **kwargs | params.arguments, + _arguments=params.arguments, + _options=vocabulary, + _doc=func.__doc__, + _source_code=( + "\ndef" + "def".join(re.split("def", inspect.getsource(func))[1:]) + ), + ) + + return cls( + messages=messages, + **grammar.model_dump(exclude_unset=True, exclude_none=True), + ) + + if fn is not None: + return wraps(fn)(partial(wrapper, fn)) + + def decorator(fn: Callable[P, Any]) -> Callable[P, Self]: + return wraps(fn)(partial(wrapper, fn)) + + return decorator + + @overload + @classmethod + def as_function_call( + cls: type[Self], + *, + environment: Optional[BaseEnvironment] = None, + prompt: Optional[str] = None, + model_name: str = "FormatResponse", + model_description: str = "Formats the response.", + field_name: str = "data", + field_description: str = "The data to format.", + ) -> Callable[[Callable[P, Any]], Callable[P, Self]]: + pass + + @overload + @classmethod + def as_function_call( + cls: type[Self], + fn: Optional[Callable[P, Any]] = None, + *, + environment: Optional[BaseEnvironment] = None, + prompt: Optional[str] = None, + model_name: str = "FormatResponse", + model_description: str = "Formats the response.", + field_name: str = "data", + field_description: str = "The data to format.", + ) -> Callable[P, Self]: + pass + + @classmethod + def as_function_call( + cls: type[Self], + fn: Optional[Callable[P, Any]] = None, + *, + environment: Optional[BaseEnvironment] = None, + prompt: Optional[str] = None, + model_name: str = "FormatResponse", + model_description: str = "Formats the response.", + field_name: str = "data", + field_description: str = "The data to format.", + **kwargs: Any, + ) -> Union[Callable[[Callable[P, Any]], Callable[P, Self]], Callable[P, Self],]: + def wrapper(func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> Self: + # Get the signature of the function + signature = inspect.signature(func) + params = signature.bind(*args, **kwargs) + params.apply_defaults() + + tool = create_tool_from_type( + _type=inspect.signature(func).return_annotation, + model_name=model_name, + model_description=model_description, + field_name=field_name, + field_description=field_description, + ) + + messages = Transcript( + content=prompt or func.__doc__ or "" + ).render_to_messages( + **kwargs | params.arguments, + _doc=func.__doc__, + _arguments=params.arguments, + _response_model=tool, + _source_code=( + "\ndef" + "def".join(re.split("def", inspect.getsource(func))[1:]) + ), + ) + + return cls( + messages=messages, + tool_choice={ + "type": "function", + "function": {"name": getattr(tool.function, "name", model_name)}, + }, + tools=[Tool[BaseModel](**tool.model_dump())], + ) + + if fn is not None: + return wraps(fn)(partial(wrapper, fn)) + + def decorator(fn: Callable[P, Any]) -> Callable[P, Self]: + return wraps(fn)(partial(wrapper, fn)) + + return decorator + + +def prompt_fn( + fn: Optional[Callable[P, T]] = None, + *, + environment: Optional[BaseEnvironment] = None, + prompt: Optional[str] = None, + model_name: str = "FormatResponse", + model_description: str = "Formats the response.", + field_name: str = "data", + field_description: str = "The data to format.", + **kwargs: Any, +) -> Union[ + Callable[[Callable[P, T]], Callable[P, dict[str, Any]]], + Callable[P, dict[str, Any]], +]: + def wrapper( + func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs + ) -> dict[str, Any]: + return PromptFunction.as_function_call( + fn=func, + environment=environment, + prompt=prompt, + model_name=model_name, + model_description=model_description, + field_name=field_name, + field_description=field_description, + **kwargs, + )(*args, **kwargs).serialize() + + if fn is not None: + return wraps(fn)(partial(wrapper, fn)) + + def decorator(fn: Callable[P, Any]) -> Callable[P, dict[str, Any]]: + return wraps(fn)(partial(wrapper, fn)) + + return decorator diff --git a/src/marvin/components/speech.py b/src/marvin/components/speech.py new file mode 100644 index 000000000..fbc45595d --- /dev/null +++ b/src/marvin/components/speech.py @@ -0,0 +1,73 @@ +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Coroutine, + Literal, + Optional, + TypeVar, +) + +from typing_extensions import ParamSpec + +if TYPE_CHECKING: + from openai._base_client import HttpxBinaryResponseContent + +T = TypeVar("T") + +P = ParamSpec("P") + + +def speak( + input: str, + *, + create: Optional[Callable[..., "HttpxBinaryResponseContent"]] = None, + model: Optional[str] = "tts-1-hd", + voice: Optional[ + Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] + ] = None, + response_format: Optional[Literal["mp3", "opus", "aac", "flac"]] = None, + speed: Optional[float] = None, + filepath: Path, +) -> None: + if create is None: + from marvin.settings import settings + + create = settings.openai.audio.speech.create + return create( + input=input, + **({"model": model} if model else {}), + **({"voice": voice} if voice else {}), + **({"response_format": response_format} if response_format else {}), + **({"speed": speed} if speed else {}), + ).stream_to_file(filepath) + + +async def aspeak( + input: str, + *, + acreate: Optional[ + Callable[..., Coroutine[Any, Any, "HttpxBinaryResponseContent"]] + ] = None, + model: Optional[str], + voice: Optional[ + Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] + ] = None, + response_format: Optional[Literal["mp3", "opus", "aac", "flac"]] = None, + speed: Optional[float] = None, + filepath: Path, +) -> None: + if acreate is None: + from marvin.settings import settings + + acreate = settings.openai.audio.speech.acreate + return ( + await acreate( + input=input, + **({"model": model} if model else {}), + **({"voice": voice} if voice else {}), + **({"response_format": response_format} if response_format else {}), + **({"speed": speed} if speed else {}), + ) + ).stream_to_file(filepath) diff --git a/src/marvin/core/ChatCompletion/__init__.py b/src/marvin/core/ChatCompletion/__init__.py deleted file mode 100644 index 068c171ce..000000000 --- a/src/marvin/core/ChatCompletion/__init__.py +++ /dev/null @@ -1,56 +0,0 @@ -from typing import Optional, Any, TypeVar -from .abstract import AbstractChatCompletion - -from marvin._compat import BaseModel -from marvin.settings import settings - -T = TypeVar( - "T", - bound=BaseModel, -) - -PROVIDER_SHORTCUTS = { - "gpt-3.5-turbo": "openai", - "gpt-4": "openai", - "claude-1": "anthropic", - "claude-2": "anthropic", -} - - -def parse_model_shortcut(provider: Optional[str]) -> tuple[str, str]: - """ - Parse a model string into a provider and a model name. - - If the provider is None, use the default provider and model. - - If the provider is a shortcut, use the shortcut to get the provider and model. - """ - if provider is None: - try: - provider, model = settings.llm_model.split("/") - except Exception: - provider, model = ( - PROVIDER_SHORTCUTS[str(settings.llm_model)], - settings.llm_model, - ) - - elif provider in PROVIDER_SHORTCUTS: - provider, model = PROVIDER_SHORTCUTS[provider], provider - else: - provider, model = provider.split("/") - return provider, model - - -def ChatCompletion( - model: Optional[str] = None, - **kwargs: Any, -) -> AbstractChatCompletion[T]: # type: ignore - provider, model = parse_model_shortcut(model) - if provider == "openai" or provider == "azure_openai": - from .providers.openai import OpenAIChatCompletion - - return OpenAIChatCompletion(provider=provider, model=model, **kwargs) - if provider == "anthropic": - from .providers.anthropic import AnthropicChatCompletion - - return AnthropicChatCompletion(model=model, **kwargs) - else: - raise ValueError(f"Unknown provider: {provider}") diff --git a/src/marvin/core/ChatCompletion/abstract.py b/src/marvin/core/ChatCompletion/abstract.py deleted file mode 100644 index 059bbfa10..000000000 --- a/src/marvin/core/ChatCompletion/abstract.py +++ /dev/null @@ -1,224 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Generic, Optional, TypeVar - -from marvin._compat import BaseModel, Field, model_copy, model_dump -from marvin.utilities.messages import Message -from typing_extensions import Self - -from .handlers import Request, Response, Turn - -T = TypeVar( - "T", - bound=BaseModel, -) - - -class Conversation(BaseModel, Generic[T], extra="allow", arbitrary_types_allowed=True): - turns: list[Turn[T]] - model: Any - - def __getitem__(self, key: int) -> Turn[T]: - return self.turns[key] - - @property - def last_turn(self) -> Turn[T]: - return self.turns[-1] - - @property - def last_request(self) -> Optional[Request[T]]: - return self.turns[-1][0] if self.turns else None - - @property - def last_response(self) -> Optional[Response[T]]: - return self.turns[-1][1] if self.turns else None - - @property - def history(self) -> list[Message]: - response: list[Message] = [] - if not self.turns: - return response - if self.last_request: - response = self.last_request.messages or [] - if self.last_response: - response.append(self.last_response.choices[0].message) - return response - - def send(self, messages: list[Message], **kwargs: Any) -> Turn[T]: - params = kwargs - if self.last_request: - params = model_dump(self.last_request, exclude={"messages"}) | kwargs - - turn = self.model.create( - **params, - messages=[ - *self.history, - *messages, - ], - ) - self.turns.append(turn) - return turn - - async def asend(self, messages: list[Message], **kwargs: Any) -> Turn[T]: - params = kwargs - if self.last_request: - params = model_dump(self.last_request, exclude={"messages"}) | kwargs - - turn = await self.model.acreate( - **params, - messages=[ - *self.history, - *messages, - ], - ) - self.turns.append(turn) - return turn - - -class AbstractChatCompletion( - BaseModel, Generic[T], ABC, extra="allow", arbitrary_types_allowed=True -): - """ - A ChatCompletion object is responsible for exposing a create and acreate method, - and for merging default parameters with the parameters passed to these methods. - """ - - defaults: dict[str, Any] = Field(default_factory=dict, exclude=True) - - def __call__(self: Self, **kwargs: Any) -> Self: - """ - Create a new ChatCompletion object with new defaults computed from - merging the passed parameters with the default parameters. - """ - copy = model_copy(self) - copy.defaults = self.defaults | kwargs - return copy - - @abstractmethod - def _serialize_request(self, request: Optional[Request[T]]) -> dict[str, Any]: - """ - Serialize the request. - This should be implemented by derived classes based on their specific needs. - """ - pass - - @abstractmethod - def _create_request(self, **kwargs: Any) -> Request[T]: - """ - Prepare and return a request object. - This should be implemented by derived classes. - """ - pass - - @abstractmethod - def _parse_response(self, response: Any) -> Any: - """ - Parse the response based on specific needs. - """ - pass - - def merge_with_defaults(self, **kwargs: Any) -> dict[str, Any]: - """ - Merge the passed parameters with the default parameters. - """ - return self.defaults | kwargs - - @abstractmethod - def _send_request(self, **serialized_request: Any) -> Any: - """ - Send the serialized request to the appropriate endpoint/service. - Derived classes should implement this. - """ - pass - - @abstractmethod - async def _send_request_async( - self, **serialized_request: Any - ) -> Response[T]: # noqa - """ - Send the serialized request to the appropriate endpoint/service asynchronously. - Derived classes should implement this. - """ - pass - - def create( - self, response_model: Optional[type[T]] = None, **kwargs: Any - ) -> Turn[T]: - """ - Create a completion synchronously. - Derived classes can override this if they need to change the core logic. - """ - request = self._create_request(**kwargs, response_model=response_model) - serialized_request = self._serialize_request(request=request) - response_data = self._send_request(**serialized_request) - response = self._parse_response(response_data) - - return Turn( - request=Request( - **serialized_request - | self.defaults - | model_dump(request, exclude_none=True) - | ({"response_model": response_model} if response_model else {}) - ), - response=response, - ) - - async def acreate( - self, response_model: Optional[type[T]] = None, **kwargs: Any - ) -> Turn[T]: - """ - Create a completion asynchronously. - Similar to the synchronous version but for async implementations. - """ - request = self._create_request(**kwargs, response_model=response_model) - serialized_request = self._serialize_request(request=request) - response_data = await self._send_request_async(**serialized_request) - response = self._parse_response(response_data) - return Turn( - request=Request( - **serialized_request - | self.defaults - | model_dump(request, exclude_none=True) - | ({"response_model": response_model} if response_model else {}) - ), - response=response, - ) - - def chain(self, **kwargs: Any) -> Conversation[T]: - """ - Create a new Conversation object. - """ - with self as conversation: - conversation.send(**kwargs) - while conversation.last_turn.has_function_call(): - message = conversation.last_turn.call_function() - conversation.send( - message if isinstance(message, list) else [message], - ) - - return conversation - - async def achain(self, **kwargs: Any) -> Conversation[T]: - """ - Create a new Conversation object asynchronously. - """ - with self as conversation: - await conversation.asend(**kwargs) - while conversation.last_turn.has_function_call(): - message = conversation.last_turn.call_function() - await conversation.asend( - message if isinstance(message, list) else [message], - ) - - return conversation - - def __enter__(self: Self) -> Conversation[T]: - """ - Enter a context manager. - """ - return Conversation(turns=[], model=self) - - def __exit__(self: Self, *args: Any) -> None: - """ - Exit a context manager. - """ - pass diff --git a/src/marvin/core/ChatCompletion/handlers.py b/src/marvin/core/ChatCompletion/handlers.py deleted file mode 100644 index 3904f3038..000000000 --- a/src/marvin/core/ChatCompletion/handlers.py +++ /dev/null @@ -1,215 +0,0 @@ -import inspect -import json -from types import FunctionType -from typing import ( - Any, - Callable, - Generic, - Literal, - Optional, - TypeVar, - Union, - overload, -) - -from marvin._compat import BaseModel, Field, cast_to_json, model_dump -from marvin.utilities.async_utils import run_sync -from marvin.utilities.logging import get_logger -from marvin.utilities.messages import Message, Role -from typing_extensions import ParamSpec - -from .utils import parse_raw - -T = TypeVar( - "T", - bound=BaseModel, -) - -P = ParamSpec("P") - - -class Request(BaseModel, Generic[T], extra="allow", arbitrary_types_allowed=True): - messages: Optional[list[Message]] = Field(default=None) - functions: Optional[list[Union[Callable[..., Any], dict[str, Any]]]] = Field( - default=None - ) - function_call: Any = None - response_model: Optional[type[T]] = Field(default=None, exclude=True) - - def serialize( - self, - functions_serializer: Callable[ - [Callable[..., Any]], dict[str, Any] - ] = cast_to_json, - ) -> dict[str, Any]: - extras = model_dump( - self, exclude={"functions", "function_call", "response_model"} - ) - response_model: dict[str, Any] = {} - functions: dict[str, Any] = {} - function_call: dict[str, Any] = {} - messages: dict[str, Any] = {} - - if self.response_model: - response_model_schema: dict[str, Any] = functions_serializer( - self.response_model - ) - response_model = { - "functions": [response_model_schema], - "function_call": {"name": response_model_schema.get("name")}, - } - - elif self.functions: - functions = { - "functions": [ - functions_serializer(function) for function in self.functions - ] - } - if self.function_call: - functions["function_call"] = self.function_call - - return extras | functions | function_call | messages | response_model - - def function_registry(self) -> dict[str, FunctionType]: - return { - function.__name__: function - for function in self.functions or [] - if callable(function) - } - - -class Choice(BaseModel): - message: Message - index: int - finish_reason: str - - class Config: - arbitrary_types_allowed = True - - -class Usage(BaseModel): - prompt_tokens: int - completion_tokens: int - total_tokens: int - - -class Response(BaseModel, Generic[T], extra="allow", arbitrary_types_allowed=True): - id: str - object: str - created: int - model: str - usage: Usage - choices: list[Choice] = Field(default_factory=list) - - -class Turn(BaseModel, Generic[T], extra="allow", arbitrary_types_allowed=True): - request: Request[T] - response: Response[T] - - @overload - def __getitem__(self, key: Literal[0]) -> Request[T]: - ... - - @overload - def __getitem__(self, key: Literal[1]) -> Response[T]: - ... - - def __getitem__(self, key: int) -> Union[Request[T], Response[T]]: - if key == 0: - return self.request - elif key == 1: - return self.response - else: - raise IndexError("Turn only has two items.") - - def has_function_call(self) -> bool: - return any([choice.message.function_call for choice in self.response.choices]) - - def get_function_call(self) -> list[tuple[str, dict[str, Any]]]: - if not self.has_function_call(): - raise ValueError("No function call found.") - pairs: list[tuple[str, dict[str, Any]]] = [] - for choice in self.response.choices: - if choice.message.function_call: - pairs.append( - ( - choice.message.function_call.name, - parse_raw(choice.message.function_call.arguments), - ) - ) - return pairs - - def call_function(self) -> Union[Message, list[Message]]: - if not self.has_function_call(): - raise ValueError("No function call found.") - - pairs: list[tuple[str, dict[str, Any]]] = self.get_function_call() - function_registry: dict[str, FunctionType] = self.request.function_registry() - evaluations: list[Any] = [] - - logger = get_logger("ChatCompletion.handlers") - - for pair in pairs: - name, argument = pair - if name not in function_registry: - raise ValueError( - f"Function {name} not found in {function_registry=!r}." - ) - - logger.debug_kv( - "Function call", - ( - f"Calling function {name!r} with payload:" - f" {json.dumps(argument, indent=2)}" - ), - key_style="green", - ) - - function_result = function_registry[name](**argument) - - if inspect.isawaitable(function_result): - function_result = run_sync(function_result) - - logger.debug_kv( - "Function call", - f"Function {name!r} returned: {function_result}", - key_style="green", - ) - - evaluations.append(function_result) - if len(evaluations) != 1: - return [ - Message( - name=pairs[j][0], - role=Role.FUNCTION_RESPONSE, - content=str(evaluations[j]), - function_call=None, - ) - for j in range(len(evaluations)) - ] - else: - return Message( - name=pairs[0][0], - role=Role.FUNCTION_RESPONSE, - content=str(evaluations[0]), - function_call=None, - ) - - def to_model(self, model_cls: Optional[type[T]] = None) -> T: - model = model_cls or self.request.response_model - - if not model: - raise ValueError("No model found.") - - pairs = self.get_function_call() - try: - return model(**pairs[0][1]) - except ValueError: # ValidationError is a subclass of ValueError - return model(output=pairs[0][1]) - except TypeError: - pass - try: - return model.parse_raw(pairs[0][1]) - except TypeError: - pass - return model.construct(**pairs[0][1]) diff --git a/src/marvin/core/ChatCompletion/providers/anthropic/__init__.py b/src/marvin/core/ChatCompletion/providers/anthropic/__init__.py deleted file mode 100644 index ed471255b..000000000 --- a/src/marvin/core/ChatCompletion/providers/anthropic/__init__.py +++ /dev/null @@ -1,178 +0,0 @@ -from typing import Any, TypeVar, Callable, Awaitable, Optional -import inspect -from marvin._compat import cast_to_json, model_dump -from marvin.settings import settings -from pydantic import BaseModel -from .prompt import render_anthropic_functions_prompt, handle_anthropic_response -from ...abstract import AbstractChatCompletion -from ...handlers import Request, Response - -T = TypeVar( - "T", - bound=BaseModel, -) - - -def get_anthropic_create(**kwargs: Any) -> tuple[Callable[..., Any], dict[str, Any]]: - """ - Get the Anthropic create function and the default parameters, - pruned of parameters that are not accepted by the constructor. - """ - import anthropic - - params = dict(inspect.signature(anthropic.Anthropic).parameters) - - return anthropic.Anthropic( - **{k: v for k, v in kwargs.items() if k in params.keys()} - ).completions.create, {k: v for k, v in kwargs.items() if k not in params.keys()} - - -def get_anthropic_acreate( - **kwargs: Any, -) -> tuple[Callable[..., Awaitable[Any]], dict[str, Any]]: # noqa - """ - Get the Anthropic acreate function and the default parameters, - pruned of parameters that are not accepted by the constructor. - """ - import anthropic - - params = dict(inspect.signature(anthropic.AsyncAnthropic).parameters) - return anthropic.AsyncAnthropic( - **{k: v for k, v in kwargs.items() if k in params.keys()} - ).completions.create, {k: v for k, v in kwargs.items() if k not in params.keys()} - - -class AnthropicChatCompletion(AbstractChatCompletion[T]): - """ - Anthropic-specific implementation of the ChatCompletion. - """ - - def __init__(self, **kwargs: Any): - """ - Filters out the parameters that are not accepted by the constructor. - """ - import anthropic - - kwargs = { - k: v - for k, v in kwargs.items() - if k not in dict(inspect.signature(anthropic.Anthropic).parameters).keys() - } - super().__init__(defaults=settings.get_defaults("anthropic") | kwargs) - - def _serialize_request( - self, request: Optional[Request[T]] = None - ) -> dict[str, Any]: - """ - Serialize the request as per OpenAI's requirements. - """ - request = request or Request() - request = Request( - **self.defaults - | model_dump( - request, - exclude_none=True, - ) - ) - - extras = model_dump( - request, - exclude={"functions", "function_call", "response_model", "messages"}, - ) - - functions: dict[str, Any] = {} - function_call: Any = {} - - prompt = "\n\nHuman:" - for message in request.messages or []: - if message.role != "function" and message.content: - prompt += f"\n\n{'Human' if message.role == 'user' else 'Assistant'}" - prompt += f": {message.content}" - - if request.response_model: - schema = cast_to_json(request.response_model) - functions["functions"] = [schema] - function_call["function_call"] = {"name": schema.get("name")} - - elif request.functions: - functions["functions"] = [ - cast_to_json(function) if callable(function) else function - for function in request.functions - ] - if request.function_call: - function_call["function_call"] = request.function_call - - if functions: - prompt += render_anthropic_functions_prompt( - functions=functions.pop("functions", []), - function_call=function_call.pop("function_call", None), - ) - - for message in request.messages or []: - if message.role == "function": - prompt += "\n\nAssistant" - prompt += f": The result of {message.name} is {message.content}." - if message.function_call: - prompt += "\n\nAssistant" - prompt += f": I will call the {message.function_call.name} function." - - prompt += "\n\nAssistant: " - prompt.replace("\n\nHuman:\n\nHuman: ", "\n\nHuman: ") - return extras | {"prompt": prompt} - - def _create_request(self, **kwargs: Any) -> Request[T]: - """ - Prepare and return an OpenAI-specific request object. - """ - return Request(**kwargs) - - def _parse_response(self, response: Any) -> Response[T]: - """ - Parse the response received from OpenAI. - """ - # Convert OpenAI's response into a standard format or object - - content, function_call = handle_anthropic_response(response.completion) - - return Response( - **{ - "id": response.log_id, - "model": response.model, - "object": "text_completion", - "created": 0, - "choices": [ - { - "index": 0, - "finish_reason": "stop", - "message": { - "content": content, # type: ignore - "role": "assistant", - "function_call": function_call, - }, - } - ], - "usage": { - "prompt_tokens": 0, - "completion_tokens": 0, - "total_tokens": 0, - }, - } - ) - - def _send_request(self, **serialized_request: Any) -> Any: - """ - Send the serialized request to OpenAI's endpoint/service. - """ - # Use openai's library functions to send the request and get a response - # Example: - create, params = get_anthropic_create(**serialized_request) - response = create(**params) - return response - - async def _send_request_async(self, **serialized_request: Any) -> Response[T]: - """ - Send the serialized request to OpenAI's endpoint asynchronously. - """ - create, params = get_anthropic_acreate(**serialized_request) - response = await create(**params) - return response diff --git a/src/marvin/core/ChatCompletion/providers/anthropic/prompt.py b/src/marvin/core/ChatCompletion/providers/anthropic/prompt.py deleted file mode 100644 index 45f9d6f28..000000000 --- a/src/marvin/core/ChatCompletion/providers/anthropic/prompt.py +++ /dev/null @@ -1,110 +0,0 @@ -import json -import re -from typing import Any, Literal, Optional, Union - -from jinja2 import Environment -from marvin.utilities.strings import jinja_env - -from ...utils import parse_raw - -FUNCTION_PROMPT = """ -# Functions - -You can call various functions to perform tasks. - -Whenever you receive a message from the user, check to see if any of your -functions would help you respond. For example, you might use a function to look -up information, interact with a filesystem, call an API, or validate data. You -might write code, update state, or cause a side effect. After indicating that -you want to call a function, the user will execute the function and tell you its -result so that you can use the information in your final response. Therefore, -you must use your functions whenever they would be helpful. - -The user may also provide a `function_call` instruction which could be: - -- "auto": you may decide to call a function on your own (this is the - default) -- "none": do not call any function -- {"name": ""}: you MUST call the function with the given - name - -To call a function: - -- Your response must include a JSON payload with the below format, including the - {"mode": "function_call"} key. -- Do not put any other text in your response beside the JSON payload. -- Do not describe your plan to call the function to the user; they will not see - it. -- Do not include more than one payload in your response. -- Do not include function output or results in your response. - -# Available Functions - -You have access to the following functions. Each has a name (which must be part -of your response), a description (which you should use to decide whether to call the -function), and a parameter spec (which is a JSON Schema description of the -arguments you should pass in your response) - -{% for function in functions -%} - -## {{ function.name }} - -- Name: {{ function.name }} -- Description: {{ function.description }} -- Parameters: {{ function.parameters }} - -{% endfor %} - -# Calling a Function - -To call a function, your response MUST include a JSON document with the -following structure: - -{ - "mode": "function_call", - - "name": "", - - "arguments": "" -} - -The user will execute the function and respond with its result verbatim. - -# function_call instruction - -The user provided the following `function_call` instruction: {{ function_call }} - -# final_response instruction - -When you choose to call no functions, your final answer should be summative. -Do not acknowledge any internal functions you had access to or chose to call. - -""" - - -def render_anthropic_functions_prompt( - functions: list[dict[str, Any]], - function_call: Union[dict[Literal["name"], str], Literal["auto"]], - environment: Optional[Environment] = None, -) -> str: - env = environment or jinja_env - template = env.from_string(FUNCTION_PROMPT) - return template.render(functions=functions, function_call=function_call or "auto") - - -def handle_anthropic_response( - completion: str, -) -> tuple[Optional[str], Optional[dict[str, Any]]]: - try: - response: dict[str, Any] = parse_raw( - re.findall(r"\{.*\}", completion, re.DOTALL)[0] - ) - if response.pop("mode", None) == "function_call": - return None, { - "name": response.pop("name", None), - "arguments": json.dumps(response.pop("arguments", None)), - } - except Exception: - pass - return completion, None diff --git a/src/marvin/core/ChatCompletion/providers/openai.py b/src/marvin/core/ChatCompletion/providers/openai.py deleted file mode 100644 index 893456fb4..000000000 --- a/src/marvin/core/ChatCompletion/providers/openai.py +++ /dev/null @@ -1,218 +0,0 @@ -import inspect -from typing import Any, AsyncGenerator, Callable, Optional, TypeVar, Union - -from marvin._compat import BaseModel, cast_to_json, model_dump -from marvin.settings import settings -from marvin.types import Function -from marvin.utilities.async_utils import create_task, run_sync -from marvin.utilities.messages import Message -from marvin.utilities.streaming import StreamHandler -from openai.openai_object import OpenAIObject - -from ..abstract import AbstractChatCompletion -from ..handlers import Request, Response, Usage - -T = TypeVar( - "T", - bound=BaseModel, -) - -CONTEXT_SIZES = { - "gpt-3.5-turbo-16k-0613": 16384, - "gpt-3.5-turbo-16k": 16384, - "gpt-3.5-turbo-0613": 4096, - "gpt-3.5-turbo": 4096, - "gpt-4-32k-0613": 32768, - "gpt-4-32k": 32768, - "gpt-4-0613": 8192, - "gpt-4": 8192, -} - - -def get_context_size(model: str) -> int: - if "/" in model: - model = model.split("/")[-1] - - return CONTEXT_SIZES.get(model, 2048) - - -def serialize_function_or_callable( - function_or_callable: Union[Function, Callable[..., Any]], - name: Optional[str] = None, - description: Optional[str] = None, - field_name: Optional[str] = None, -) -> dict[str, Any]: - if isinstance(function_or_callable, Function): - return { - "name": function_or_callable.__name__, - "description": function_or_callable.__doc__, - "parameters": function_or_callable.schema, - } - else: - return cast_to_json( - function_or_callable, - name=name, - description=description, - field_name=field_name, - ) - - -class OpenAIStreamHandler(StreamHandler): - async def handle_streaming_response( - self, - api_response: AsyncGenerator[OpenAIObject, None], - ) -> OpenAIObject: - final_chunk = {} - accumulated_content = "" - - async for r in api_response: - final_chunk.update(r.to_dict_recursive()) - - delta = r.choices[0].delta if r.choices and r.choices[0] else None - - if delta is None: - continue - - if "content" in delta: - accumulated_content += delta.content or "" - - if self.callback: - callback_result = self.callback( - Message( - content=accumulated_content, role="assistant", data=final_chunk - ) - ) - if inspect.isawaitable(callback_result): - create_task(callback_result) - - if "choices" in final_chunk and len(final_chunk["choices"]) > 0: - final_chunk["choices"][0]["content"] = accumulated_content - - final_chunk["object"] = "chat.completion" - - return OpenAIObject.construct_from( - { - "id": final_chunk["id"], - "object": "chat.completion", - "created": final_chunk["created"], - "model": final_chunk["model"], - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": accumulated_content, - }, - "finish_reason": "stop", - } - ], - # TODO: Figure out how to get the usage from the streaming response - "usage": Usage.parse_obj( - {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} - ), - } - ) - - -class OpenAIChatCompletion(AbstractChatCompletion[T]): - """ - OpenAI-specific implementation of the ChatCompletion. - """ - - def __init__(self, provider: str, **kwargs: Any): - super().__init__(defaults=settings.get_defaults(provider or "openai") | kwargs) - - def _serialize_request( - self, request: Optional[Request[T]] = None - ) -> dict[str, Any]: - """ - Serialize the request as per OpenAI's requirements. - """ - _request = request or Request() - _request = Request( - **self.defaults - | ( - model_dump( - request, - exclude_none=True, - ) - if request - else {} - ) - ) - _request.function_call = _request.function_call or ( - request and request.function_call - ) - _request.functions = _request.functions or (request and request.functions) - _request.response_model = _request.response_model or ( - request and request.response_model - ) # noqa - - extras = model_dump( - _request, - exclude={"functions", "function_call", "response_model"}, - ) - - functions: dict[str, Any] = {} - function_call: Any = {} - for message in extras.get("messages", []): - if message.get("name", -1) is None: - message.pop("name", None) - if message.get("function_call", -1) is None: - message.pop("function_call", None) - - if _request.response_model: - schema = cast_to_json(_request.response_model) - functions["functions"] = [schema] - function_call["function_call"] = {"name": schema.get("name")} - - elif _request.functions: - functions["functions"] = [ - serialize_function_or_callable(function) - for function in _request.functions - ] - if _request.function_call: - function_call["function_call"] = _request.function_call - return extras | functions | function_call - - def _create_request(self, **kwargs: Any) -> Request[T]: - """ - Prepare and return an OpenAI-specific request object. - """ - return Request(**kwargs) - - def _parse_response(self, response: Any) -> Response[T]: - """ - Parse the response received from OpenAI. - """ - # Convert OpenAI's response into a standard format or object - return Response(**response.to_dict_recursive()) # type: ignore - - def _send_request(self, **serialized_request: Any) -> Any: - """ - Send the serialized request to OpenAI's endpoint/service. - """ - # Use openai's library functions to send the request and get a response - # Example: - - return run_sync( - self._send_request_async(**serialized_request), - ) - - async def _send_request_async(self, **serialized_request: Any) -> Response[T]: - """ - Send the serialized request to OpenAI's endpoint asynchronously. - """ - import openai - - if handler_fn := serialized_request.pop("stream_handler", {}): - serialized_request["stream"] = True - - response = await openai.ChatCompletion.acreate(**serialized_request) # type: ignore # noqa - - if handler_fn: - response = await OpenAIStreamHandler( - callback=handler_fn, - ).handle_streaming_response(response) - - return response diff --git a/src/marvin/core/ChatCompletion/utils.py b/src/marvin/core/ChatCompletion/utils.py deleted file mode 100644 index 58a08c498..000000000 --- a/src/marvin/core/ChatCompletion/utils.py +++ /dev/null @@ -1,15 +0,0 @@ -import json -from ast import literal_eval -from typing import Any - - -def parse_raw(raw: str) -> dict[str, Any]: - try: - return literal_eval(raw) - except Exception: - pass - try: - return json.loads(raw) - except Exception: - pass - return {} diff --git a/src/marvin/deployment/__init__.py b/src/marvin/deployment/__init__.py deleted file mode 100644 index 60057dd90..000000000 --- a/src/marvin/deployment/__init__.py +++ /dev/null @@ -1,81 +0,0 @@ -import asyncio -from typing import Union, Optional - -import uvicorn -from fastapi import FastAPI, APIRouter -from pydantic import BaseModel, Extra - -from marvin import AIApplication, AIModel, AIFunction -from marvin.tools import Tool - - -class Deployment(BaseModel): - """ - Deployment class handles the deployment of AI applications, models or functions. - """ - - def __init__( - self, - component: Union[AIApplication, AIModel, AIFunction], - *args, - app_kwargs: Optional[dict] = None, - router_kwargs: Optional[dict] = None, - uvicorn_kwargs: Optional[dict] = None, - **kwargs, - ): - super().__init__(**kwargs) - self._app = FastAPI(**(app_kwargs or {})) - self._router = APIRouter(**(router_kwargs or {})) - self._controller = component - self._mount_router() - self._uvicorn_kwargs = { - "app": self._app, - "host": "0.0.0.0", - "port": 8000, - **(uvicorn_kwargs or {}), - } - - def _mount_router(self): - """ - Mounts a router to the FastAPI app for each tool in the AI application. - """ - - if isinstance(self._controller, AIApplication): - name = self._controller.name - base_path = f"/{name.lower()}" if name else "aiapp" - self._router.get(base_path, tags=[name])(self._controller.entrypoint) - for tool in self._controller.tools: - name, fn = ( - (tool.name, tool.fn) - if isinstance(tool, Tool) - else (tool.__name__, tool) - ) - tool_path = f"{base_path}/tools/{name}" - self._router.post(tool_path, tags=[name])(fn) - - self._app.include_router(self._router) - self._app.openapi_tags = self._app.openapi_tags or [] - self._app.openapi_tags.append( - {"name": name, "description": self._controller.description} - ) - - if isinstance(self._controller, AIModel): - raise NotImplementedError - - if isinstance(self._controller, AIFunction): - raise NotImplementedError - - def serve(self): - """ - Serves the FastAPI app. - """ - try: - config = uvicorn.Config(**(self._uvicorn_kwargs or {})) - server = uvicorn.Server(config) - loop = asyncio.get_event_loop() - loop.run_until_complete(server.serve()) - except Exception as e: - print(f"Error while serving the application: {e}") - - class Config: - extra = Extra.allow diff --git a/src/marvin/engine/__init__.py b/src/marvin/engine/__init__.py deleted file mode 100644 index a097ae617..000000000 --- a/src/marvin/engine/__init__.py +++ /dev/null @@ -1,94 +0,0 @@ -""" -The engine module is the interface to external LLM providers. -""" -from pydantic import BaseModel, Field -from typing import Optional -from marvin.utilities.module_loading import import_string -import copy - - -class ChatCompletionBase(BaseModel): - """ - This class is used to create and handle chat completions from the API. - It provides several utility functions to create the request, send it to the API, - and handle the response. - """ - - _module: str = "openai.ChatCompletion" # the module used to interact with the API - _request: str = "marvin.openai.ChatCompletion.Request" - _response: str = "marvin.openai.ChatCompletion.Response" - _create: str = "create" # the name of the create method in the API model - _acreate: str = ( # the name of the asynchronous create method in the API model - "acreate" - ) - defaults: Optional[dict] = Field(None, repr=False) # default configuration values - - def __init__(self, module_path: str = None, config_path: str = None, **kwargs): - super().__init__(_module=module_path, _config=config_path, defaults=kwargs) - - @property - def module(self): - """ - This property imports and returns the API model. - """ - return import_string(self._module) - - @property - def model(self): - """ - This property imports and returns the API model. - """ - return self.module - - def request(self, *args, **kwargs): - """ - This method imports and returns a configuration object. - """ - return import_string(self._request)(*args, **(kwargs or self.defaults or {})) - - @property - def response_class(self, *args, **kwargs): - """ - This method imports and returns a configuration object. - """ - return import_string(self._response) - - def prepare_request(self, *args, **kwargs): - """ - This method prepares a request and returns it. - """ - request = self.request() | self.request(**kwargs) - payload = request.dict(exclude_none=True, exclude_unset=True) - return request, payload - - def create(self=None, *args, **kwargs): - """ - This method creates a request and sends it to the API. - It returns a Response object with the raw response and the request. - """ - request, request_dict = self.prepare_request(*args, **kwargs) - create = getattr(self.model, self._create) - response = self.response_class(create(**request_dict), request=request) - if request.evaluate_function_call and response.function_call: - return response.call_function(as_message=True) - return response - - async def acreate(self, *args, **kwargs): - """ - This method is an asynchronous version of the create method. - It creates a request and sends it to the API asynchronously. - It returns a Response object with the raw response and the request. - """ - request, request_dict = self.prepare_request(*args, **kwargs) - acreate = getattr(self.model, self._acreate) - response = self.response_class(await acreate(**request_dict), request=request) - if request.evaluate_function_call and response.function_call: - return await response.acall_function(as_message=True) - return response - - def __call__(self, *args, **kwargs): - self = copy.deepcopy(self) - request = self.request() - passed = self.__class__(**kwargs).request() - self.defaults = (request | passed).dict(serialize_functions=False) - return self diff --git a/src/marvin/engine/anthropic.py b/src/marvin/engine/anthropic.py deleted file mode 100644 index 734b38333..000000000 --- a/src/marvin/engine/anthropic.py +++ /dev/null @@ -1,195 +0,0 @@ -from operator import itemgetter -from typing import Any, Callable, Optional - -from anthropic import AI_PROMPT, HUMAN_PROMPT -from jinja2 import Template -from pydantic import BaseModel, Extra, Field, root_validator - -from marvin import settings -from marvin.engine import ChatCompletionBase -from marvin.engine.language_models.anthropic import AnthropicFunctionCall -from marvin.types.request import Request as BaseRequest -from marvin.utilities.module_loading import import_string - - -class Request(BaseRequest): - """ - This is a class for creating Request objects to interact with the GPT-3 API. - The class contains several configurations and validation functions to ensure - the correct data is sent to the API. - - """ - - model: str = "claude-2" # the model used by the GPT-3 API - # temperature: float = 0.8 # the temperature parameter used by the GPT-3 API - api_key: str = Field(default_factory=settings.anthropic.api_key.get_secret_value) - max_tokens_to_sample: int = Field(default=1000) - prompt: str = Field(default="") - - class Config: - exclude = {"response_model", "messages"} - exclude_none = True - extra = Extra.allow - functions_prompt = ( - "marvin.engine.language_models.anthropic.FUNCTIONS_INSTRUCTIONS" - ) - - @root_validator(pre=True) - def to_anthropic(cls, values): - values["prompt"] = "" - for message in values.get("messages", []): - if message.get("role") == "user": - values["prompt"] += f'{HUMAN_PROMPT} {message.get("content")}' - else: - values["prompt"] += f'{AI_PROMPT} {message.get("content")}' - values["prompt"] += f"{AI_PROMPT} " - return values - - def dict(self, *args, serialize_functions=True, **kwargs): - """ - This method returns a dictionary representation of the Request. - If the functions attribute is present and serialize_functions is True, - the functions' schemas are also included. - """ - - # This identity function is here for no reason except to show - # readers that custom adapters need only override the dict method. - return super().dict(*args, serialize_functions=serialize_functions, **kwargs) - - -class Response(BaseModel): - """ - This class is used to handle the response from the API. - It includes several utility functions and properties to extract useful information - from the raw response. - """ - - raw: Any # the raw response from the API - request: Any # the request that generated the response - - def __init__(self, response, *args, request, **kwargs): - super().__init__(raw=response, request=request) - - def __iter__(self): - return self.raw.__iter__() - - def __next__(self): - return self.raw.__next__() - - def __getattr__(self, name): - """ - This method attempts to get the attribute from the raw response. - If it doesn't exist, it falls back to the standard attribute access. - """ - try: - return self.raw.__getattr__(name) - except AttributeError: - return self.__getattribute__(name) - - @property - def message(self): - """ - This property extracts the message from the raw response. - If there is only one choice, it returns the message from that choice. - Otherwise, it returns a list of messages from all choices. - """ - return self.raw.completion - - @property - def function_call(self): - """ - This property extracts the function call from the message. - If the message is a list, it returns a list of function calls from all messages. - Otherwise, it returns the function call from the message. - """ - - return AnthropicFunctionCall.parse_raw(self.message).dict( - exclude={"function_call"} - ) - - @property - def callables(self): - """ - This property returns a list of all callable functions from the request. - """ - return [x for x in self.request.functions if isinstance(x, Callable)] - - @property - def callable_registry(self): - """ - This property returns a dictionary mapping function names to functions for all - callable functions from the request. - """ - return {fn.__name__: fn for fn in self.callables} - - def call_function(self, as_message=True): - """ - This method evaluates the function call in the response and returns the result. - If as_message is True, it returns the result as a function message. - Otherwise, it returns the result directly. - """ - name, raw_arguments = itemgetter("name", "arguments")(self.function_call) - function = self.callable_registry.get(name) - arguments = function.model.parse_raw(raw_arguments) - value = function(**arguments.dict(exclude_none=True)) - if as_message: - return {"role": "function", "name": name, "content": value} - else: - return value - - def to_model(self): - """ - This method parses the function call arguments into the response model and - returns the result. - """ - return self.request.response_model.parse_raw(self.function_call.arguments) - - def __repr__(self, *args, **kwargs): - """ - This method returns a string representation of the raw response. - """ - return self.raw.__repr__(*args, **kwargs) - - -class AnthropicChatCompletion(ChatCompletionBase): - """ - This class is used to create and handle chat completions from the API. - It provides several utility functions to create the request, send it to the API, - and handle the response. - """ - - _module: str = "anthropic.Anthropic" # the module used to interact with the API - _request: str = "marvin.engine.anthropic.Request" - _response: str = "marvin.engine.anthropic.Response" - defaults: Optional[dict] = Field(None, repr=False) # default configuration values - - @property - def model(self): - """ - This property imports and returns the API model. - """ - return self.module( - api_key=self.request().api_key, - ).completions - - def prepare_request(self, *args, **kwargs): - request, payload = super().prepare_request(*args, **kwargs) - payload.pop("messages", None) - payload.pop("api_key", None) - if payload.get("functions", None): - functions_prompt = Template( - import_string(request.Config.functions_prompt) - ).render( - functions=payload.get("functions"), - function_call=payload.get("function_call"), - ) - payload["prompt"] = f"{HUMAN_PROMPT} {functions_prompt}" + payload["prompt"] - payload.pop("functions", None) - return request, payload - - -ChatCompletion = AnthropicChatCompletion() - -# This is a legacy class that is used to create a ChatCompletion object. -# It is deprecated and will be removed in a future release. -ChatCompletionConfig = Request diff --git a/src/marvin/engine/executors/__init__.py b/src/marvin/engine/executors/__init__.py deleted file mode 100644 index b938d374d..000000000 --- a/src/marvin/engine/executors/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .base import Executor -from .openai import OpenAIFunctionsExecutor diff --git a/src/marvin/engine/executors/base.py b/src/marvin/engine/executors/base.py deleted file mode 100644 index 318594af7..000000000 --- a/src/marvin/engine/executors/base.py +++ /dev/null @@ -1,67 +0,0 @@ -from typing import List, Union - -from pydantic import PrivateAttr - -from marvin.engine.language_models import ChatLLM -from marvin.prompts.base import Prompt, render_prompts -from marvin.utilities.messages import Message -from marvin.utilities.types import LoggerMixin, MarvinBaseModel - - -class Executor(LoggerMixin, MarvinBaseModel): - model: ChatLLM - _should_stop: bool = PrivateAttr(False) - - async def start( - self, - prompts: list[Union[Prompt, Message]], - prompt_render_kwargs: dict = None, - ) -> list[Message]: - """ - Start the LLM loop - """ - # reset stop criteria - self._should_stop = False - - responses = [] - while not self._should_stop: - # render the prompts, including any responses from the previous step - messages = render_prompts( - prompts + responses, - render_kwargs=prompt_render_kwargs, - max_tokens=self.model.context_size, - ) - response = await self.step(messages) - responses.append(response) - if await self.stop_condition(messages, responses): - self._should_stop = True - return responses - - async def step(self, messages: list[Message]) -> Message: - """ - Implements one step of the LLM loop - """ - messages = await self.process_messages(messages) - llm_response = await self.run_engine(messages=messages) - response = await self.process_response(llm_response) - return response - - async def run_engine(self, messages: list[Message]) -> Message: - """ - Implements one step of the LLM loop - """ - llm_response = await self.model.run(messages=messages) - return llm_response - - async def process_messages(self, messages: list[Message]) -> list[Message]: - """Called prior to sending messages to the LLM""" - return messages - - async def stop_condition( - self, messages: List[Message], responses: List[Message] - ) -> bool: - return True - - async def process_response(self, response: Message) -> Message: - """Called after receiving a response from the LLM""" - return response diff --git a/src/marvin/engine/executors/openai.py b/src/marvin/engine/executors/openai.py deleted file mode 100644 index e46f7d9d6..000000000 --- a/src/marvin/engine/executors/openai.py +++ /dev/null @@ -1,162 +0,0 @@ -import inspect -import json -from ast import literal_eval -from typing import Any, Callable, List, Optional, Union - -from pydantic import Field - -import marvin -from marvin.engine.language_models import OpenAIFunction -from marvin.utilities.messages import Message - -from .base import Executor - - -class OpenAIFunctionsExecutor(Executor): - """ - An executor that understands how to pass functions to the LLM, interpret - responses that request function calls, and iteratively continue to process - functions until the LLM responds directly to the user. This uses the OpenAI - Functions API, so provider LLMs must be compatible. - """ - - functions: Optional[List[OpenAIFunction]] = Field(default=None) - function_call: Union[str, dict[str, str]] = Field(default=None) - max_iterations: Optional[int] = Field( - default_factory=lambda: marvin.settings.ai_application_max_iterations - ) - stream_handler: Optional[Callable[[Message], None]] = Field(default=None) - - def __init__( - self, - functions: Optional[List[Union[OpenAIFunction, Callable[..., Any]]]] = None, - **kwargs: Any, - ): - if functions is not None: - functions = [ - ( - OpenAIFunction.from_function(i) - if not isinstance(i, OpenAIFunction) - else i - ) - for i in functions - ] - super().__init__( - functions=functions, - **kwargs, - ) - - # @validator("functions", pre=True) - # def validate_functions(cls, v): - # if v is None: - # return None - # v = [ - # OpenAIFunction.from_function(i) if not isinstance(i, OpenAIFunction)else i - # for i in v - # ] - # return v - - async def run_engine(self, messages: list[Message]) -> Message: - """ - Implements one step of the LLM loop - """ - - kwargs = {} - - if self.functions: - kwargs["functions"] = self.functions - kwargs["function_call"] = self.function_call - - llm_response = await self.model.run( - messages=messages, - stream_handler=self.stream_handler, - **kwargs, - ) - return llm_response - - async def stop_condition( - self, messages: List[Message], responses: List[Message] - ) -> bool: - # if the number of responses exceeds max iterations, stop - - if self.max_iterations is not None and len(responses) >= self.max_iterations: - return True - - # if function calls are set to auto and the most recent call was a - # function, continue - if self.function_call == "auto": - if responses and responses[-1].role == "function_response": - return False - - # if a specific function call was requested but errored, continue - elif self.function_call != "none": - if responses and responses[-1].data.get("is_error"): - return False - - # otherwise stop - return True - - async def process_response(self, response: Message) -> Message: - if response.role == "function_request": - return await self.process_function_call(response) - else: - return response - - async def process_function_call(self, response: Message) -> Message: - response_data = {} - - function_call = response.data["function_call"] - fn_name = function_call.get("name") - fn_args = function_call.get("arguments") - response_data["name"] = fn_name - try: - try: - fn_args = json.loads(function_call.get("arguments", "{}")) - except json.JSONDecodeError: - fn_args = literal_eval(function_call.get("arguments", "{}")) - response_data["arguments"] = fn_args - - # retrieve the named function - openai_fn = next((f for f in self.functions if f.name == fn_name), None) - if openai_fn is None: - raise ValueError(f'Function "{function_call["name"]}" not found.') - - if not isinstance(fn_args, dict): - raise ValueError( - "Expected a dictionary of arguments, got a" - f" {type(fn_args).__name__}." - ) - - # call the function - if openai_fn.fn is not None: - self.logger.debug( - f"Running function '{openai_fn.name}' with payload {fn_args}" - ) - fn_result = openai_fn.fn(**fn_args) - if inspect.isawaitable(fn_result): - fn_result = await fn_result - - # if the function is undefined, return the arguments as its output - else: - fn_result = fn_args - self.logger.debug(f"Result of function '{openai_fn.name}': {fn_result}") - response_data["is_error"] = False - - except Exception as exc: - fn_result = ( - f"The function '{fn_name}' encountered an error:" - f" {str(exc)}\n\nThe payload you provided was: {fn_args}\n\nYou" - " can try to fix the error and call the function again." - ) - self.logger.debug_kv("Error", fn_result, key_style="red") - response_data["is_error"] = True - - response_data["result"] = fn_result - - return Message( - role="function_response", - name=fn_name, - content=str(fn_result), - data=response_data, - llm_response=response.llm_response, - ) diff --git a/src/marvin/engine/language_models/__init__.py b/src/marvin/engine/language_models/__init__.py deleted file mode 100644 index 6c952e794..000000000 --- a/src/marvin/engine/language_models/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .base import ChatLLM, OpenAIFunction, StreamHandler, chat_llm diff --git a/src/marvin/engine/language_models/anthropic.py b/src/marvin/engine/language_models/anthropic.py deleted file mode 100644 index 61bf410ed..000000000 --- a/src/marvin/engine/language_models/anthropic.py +++ /dev/null @@ -1,264 +0,0 @@ -import inspect -import json -import re -from logging import Logger -from typing import Callable, Union - -import anthropic -import openai -import openai.openai_object -from pydantic import BaseModel - -import marvin -import marvin.utilities.types -from marvin.engine.language_models import ChatLLM, StreamHandler -from marvin.engine.language_models.base import OpenAIFunction -from marvin.utilities.async_utils import create_task -from marvin.utilities.logging import get_logger -from marvin.utilities.messages import Message, Role -from marvin.utilities.strings import jinja_env - -CONTEXT_SIZES = { - "claude-instant": 100_000, - "claude-2": 100_000, -} - -FUNCTION_CALL_REGEX = re.compile( - r'{\s*"mode":\s*"function_call"\s*(.*)}', - re.DOTALL, -) -FUNCTION_CALL_NAME = re.compile(r'"name":\s*"(.*)"') -FUNCTION_CALL_ARGS = re.compile(r'"arguments":\s*(".*")', re.DOTALL) - - -def extract_function_call(completion: str) -> Union[dict, None]: - function_call = dict(name=None, arguments="{}") - if match := FUNCTION_CALL_REGEX.search(completion): - if name := FUNCTION_CALL_NAME.search(match.group(1)): - function_call["name"] = name.group(1) - if args := FUNCTION_CALL_ARGS.search(match.group(1)): - function_call["arguments"] = args.group(1) - try: - function_call["arguments"] = json.loads(function_call["arguments"]) - except json.JSONDecodeError: - pass - if not function_call["name"]: - return None - return function_call - - -def anthropic_role_map(marvin_role: Role): - if marvin_role in [Role.USER, Role.SYSTEM, Role.FUNCTION_RESPONSE]: - return anthropic.HUMAN_PROMPT - else: - return anthropic.AI_PROMPT - - -class AnthropicFunctionCall(BaseModel): - mode: str - name: str - arguments: str - - @classmethod - def parse_raw(cls, raw: str): - return super().parse_raw(re.sub("^[^{]*|[^}]*$", "", raw)) - - -class AnthropicStreamHandler(StreamHandler): - async def handle_streaming_response( - self, - api_response: openai.openai_object.OpenAIObject, - ) -> Message: - """ - Accumulate chunk deltas into a full response. Returns the full message. - Passes partial messages to the callback, if provided. - """ - response = { - "role": Role.ASSISTANT, - "content": "", - "data": {}, - "llm_response": None, - } - - async for msg in api_response: - response["llm_response"] = msg.dict() - response["content"] += msg.completion - - if function_call := extract_function_call(response["content"]): - response["role"] = Role.FUNCTION_REQUEST - response["data"]["function_call"] = function_call - - if self.callback: - callback_result = self.callback(Message(**response)) - if inspect.isawaitable(callback_result): - create_task(callback_result) - - response["content"] = response["content"].strip() - return Message(**response) - - -class AnthropicChatLLM(ChatLLM): - model: str = "claude-2" - - @property - def context_size(self) -> int: - if self.model in CONTEXT_SIZES: - return CONTEXT_SIZES[self.model] - else: - for model_prefix, context in CONTEXT_SIZES: - if self.model.startswith(model_prefix): - return context - return 100_000 - - def format_messages( - self, messages: list[Message] - ) -> Union[str, dict, list[Union[str, dict]]]: - formatted_messages = [] - for msg in messages: - role = anthropic_role_map(msg.role) - formatted_messages.append(f"{role}{msg.content}") - - return "".join(formatted_messages) + anthropic.AI_PROMPT - - async def run( - self, - messages: list[Message], - *, - functions: list[OpenAIFunction] = None, - function_call: Union[str, dict[str, str]] = None, - logger: Logger = None, - stream_handler: Callable[[Message], None] = False, - **kwargs, - ) -> Message: - """Calls an OpenAI LLM with a list of messages and returns the response.""" - - if logger is None: - logger = get_logger(self.name) - - # ---------------------------------- - # Prepare functions - # ---------------------------------- - if functions: - function_message = jinja_env.from_string(FUNCTIONS_INSTRUCTIONS).render( - functions=functions, function_call=function_call - ) - system_message = Message(role=Role.SYSTEM, content=function_message) - messages = [system_message] + messages - - prompt = self.format_messages(messages) - - # ---------------------------------- - # Call Anthropic LLM - # ---------------------------------- - - if not marvin.settings.anthropic.api_key: - raise ValueError( - "Anthropic API key not found in settings. Please set it or use the" - " MARVIN_ANTHROPIC_API_KEY environment variable." - ) - - client = anthropic.AsyncAnthropic( - api_key=marvin.settings.anthropic.api_key.get_secret_value(), - timeout=marvin.settings.llm_request_timeout_seconds, - ) - - kwargs.setdefault("temperature", self.temperature) - kwargs.setdefault("max_tokens_to_sample", self.max_tokens) - - response = await client.completions.create( - model=self.model, - prompt=prompt, - stream=True if stream_handler else False, - **kwargs, - ) - - if stream_handler: - handler = AnthropicStreamHandler(callback=stream_handler) - msg = await handler.handle_streaming_response(response) - return msg - - else: - llm_response = response.dict() - content = llm_response["completion"].strip() - role = Role.ASSISTANT - data = {} - if function_call := extract_function_call(content): - role = Role.FUNCTION_REQUEST - data["function_call"] = function_call - msg = Message( - role=role, - content=content, - data=data, - llm_response=llm_response, - ) - return msg - - -FUNCTIONS_INSTRUCTIONS = """ -# Functions - -You can call various functions to perform tasks. - -Whenever you receive a message from the user, check to see if any of your -functions would help you respond. For example, you might use a function to look -up information, interact with a filesystem, call an API, or validate data. You -might write code, update state, or cause a side effect. After indicating that -you want to call a function, the user will execute the function and tell you its -result so that you can use the information in your final response. Therefore, -you must use your functions whenever they would be helpful. - -The user may also provide a `function_call` instruction which could be: - -- "auto": you may decide to call a function on your own (this is the - default) -- "none": do not call any function -- {"name": ""}: you MUST call the function with the given - name - -To call a function: - -- Your response must include a JSON payload with the below format, including the - {"mode": "function_call"} key. -- Do not put any other text in your response beside the JSON payload. -- Do not describe your plan to call the function to the user; they will not see - it. -- Do not include more than one payload in your response. -- Do not include function output or results in your response. - -# Available Functions - -Your have access to the following functions. Each has a name (which must be part -of your response), a description (which you should use to decide to call the -function), and a parameter spec (which is a JSON Schema description of the -arguments you should pass in your response) - -{% for function in functions -%} - -## {{ function.name }} - -- Name: {{ function.name }} -- Description: {{ function.description }} -- Parameters: {{ function.parameters }} - -{% endfor %} - -# Calling a Function - -To call a function, your response MUST include a JSON document with the -following structure: - -{ - "mode": "function_call", - - "name": "", - - "arguments": "" -} - -The user will execute the function and respond with its result verbatim. - -# function_call instruction - -The user provided the following `function_call` instruction: {{ function_call }} -""" diff --git a/src/marvin/engine/language_models/azure_openai.py b/src/marvin/engine/language_models/azure_openai.py deleted file mode 100644 index b19626da7..000000000 --- a/src/marvin/engine/language_models/azure_openai.py +++ /dev/null @@ -1,54 +0,0 @@ -import marvin - -from .openai import OpenAIChatLLM - -CONTEXT_SIZES = { - "gpt-35-turbo": 4096, - "gpt-35-turbo-0613": 4096, - "gpt-35-turbo-16k": 16384, - "gpt-35-turbo-16k-0613": 16384, - "gpt-4": 8192, - "gpt-4-0613": 8192, - "gpt-4-32k": 32768, - "gpt-4-32k-0613": 32768, -} - - -class AzureOpenAIChatLLM(OpenAIChatLLM): - model: str = "gpt-35-turbo-0613" - - @property - def context_size(self) -> int: - if self.model in CONTEXT_SIZES: - return CONTEXT_SIZES[self.model] - else: - for model_prefix, context in CONTEXT_SIZES: - if self.model.startswith(model_prefix): - return context - return 4096 - - def _get_openai_settings(self) -> dict: - # do not load the base openai settings; any azure settings must be set - # explicitly - openai_kwargs = {} - - if marvin.settings.azure_openai.api_key: - openai_kwargs["api_key"] = ( - marvin.settings.azure_openai.api_key.get_secret_value() - ) - else: - raise ValueError( - "Azure OpenAI API key not set. Please set it or use the" - " MARVIN_AZURE_OPENAI_API_KEY environment variable." - ) - if marvin.settings.azure_openai.deployment_name: - openai_kwargs["deployment_name"] = ( - marvin.settings.azure_openai.deployment_name - ) - if marvin.settings.azure_openai.api_type: - openai_kwargs["api_type"] = marvin.settings.azure_openai.api_type - if marvin.settings.azure_openai.api_base: - openai_kwargs["api_base"] = marvin.settings.azure_openai.api_base - if marvin.settings.azure_openai.api_version: - openai_kwargs["api_version"] = marvin.settings.azure_openai.api_version - return openai_kwargs diff --git a/src/marvin/engine/language_models/base.py b/src/marvin/engine/language_models/base.py deleted file mode 100644 index b258834a7..000000000 --- a/src/marvin/engine/language_models/base.py +++ /dev/null @@ -1,140 +0,0 @@ -import abc -import json -from logging import Logger -from typing import Any, Callable, Optional, Union - -import tiktoken -from pydantic import Field, validator - -import marvin -import marvin.utilities.types -from marvin.utilities.messages import Message -from marvin.utilities.types import MarvinBaseModel - - -class StreamHandler(MarvinBaseModel, abc.ABC): - callback: Callable[[Message], None] = None - - @abc.abstractmethod - def handle_streaming_response(self, api_response) -> Message: - raise NotImplementedError() - - -class OpenAIFunction(MarvinBaseModel): - name: Optional[str] = None - description: Optional[str] = None - parameters: dict[str, Any] = {"type": "object", "properties": {}} - fn: Optional[Callable] = Field(None, exclude=True) - args: Optional[dict] = None - """ - Base class for representing a function that can be called by an LLM. The - format is identical to OpenAI's Functions API. - - Args: - name (str): The name of the function. description (str): The description - of the function. parameters (dict): The parameters of the function. fn - (Callable): The function to be called. args (dict): The arguments to be - passed to the function. - """ - - @classmethod - def from_function(cls, fn: Callable, **kwargs): - return cls( - name=kwargs.get("name", fn.__name__), - description=kwargs.get("description", fn.__doc__ or ""), - parameters=marvin.utilities.types.function_to_schema(fn), - fn=fn, - ) - - async def query(self, q: str, model: "ChatLLM" = None): - if not model: - model = chat_llm() - self.args = json.loads( - ( - await model.run( - messages=[Message(role="USER", content=q)], - functions=[self], - function_call={"name": self.name}, - ) - ) - .data.get("function_call") - .get("arguments") - ) - return self - - -class ChatLLM(MarvinBaseModel, abc.ABC): - name: Optional[str] = None - model: str - max_tokens: int = Field(default_factory=lambda: marvin.settings.llm_max_tokens) - temperature: float = Field(default_factory=lambda: marvin.settings.llm_temperature) - - @validator("name", always=True) - def default_name(cls, v): - if v is None: - v = cls.__name__ - return v - - @property - def context_size(self) -> int: - return 4096 - - def get_tokens(self, text: str, **kwargs) -> list[int]: - try: - enc = tiktoken.encoding_for_model(self.model) - # fallback to the gpt-3.5-turbo tokenizer if the model is not found - # note this will give the wrong answer for non-OpenAI models - except KeyError: - enc = tiktoken.encoding_for_model("gpt-3.5-turbo") - return enc.encode(text) - - async def __call__(self, messages, *args, **kwargs): - return await self.run(messages, *args, **kwargs) - - @abc.abstractmethod - def format_messages( - self, messages: list[Message] - ) -> Union[str, dict, list[Union[str, dict]]]: - """Format Marvin message objects into a prompt compatible with the LLM model""" - return messages - - @abc.abstractmethod - async def run( - self, - messages: list[Message], - functions: list[OpenAIFunction] = None, - *, - logger: Logger = None, - stream_handler: Callable[[Message], None] = False, - **kwargs, - ) -> Message: - """Run the LLM model on a list of messages and optional list of functions""" - raise NotImplementedError() - - -def chat_llm(model: str = None, **kwargs) -> ChatLLM: - """Dispatches to all supported LLM providers""" - if model is None: - model = marvin.settings.llm_model - - # automatically detect gpt-3.5 and gpt-4 for backwards compatibility - if model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"): - model = f"openai/{model}" - - # extract the provider and model name - provider, model_name = model.split("/", 1) - - if provider == "openai": - from .openai import OpenAIChatLLM - - return OpenAIChatLLM(model=model_name, **kwargs) - elif provider == "anthropic": - from .anthropic import AnthropicChatLLM - - return AnthropicChatLLM(model=model_name, **kwargs) - elif provider == "azure_openai": - from .azure_openai import AzureOpenAIChatLLM - - return AzureOpenAIChatLLM(model=model_name, **kwargs) - else: - raise ValueError(f"Unknown provider/model: {model}") diff --git a/src/marvin/engine/language_models/openai.py b/src/marvin/engine/language_models/openai.py deleted file mode 100644 index 3e880b3fc..000000000 --- a/src/marvin/engine/language_models/openai.py +++ /dev/null @@ -1,213 +0,0 @@ -import inspect -from logging import Logger -from typing import Callable, Optional, Union - -import openai -import openai.openai_object - -import marvin -import marvin.utilities.types -from marvin.utilities.async_utils import create_task -from marvin.utilities.logging import get_logger -from marvin.utilities.messages import Message, Role - -from .base import ChatLLM, OpenAIFunction, StreamHandler - -CONTEXT_SIZES = { - "gpt-3.5-turbo-16k-0613": 16384, - "gpt-3.5-turbo-16k": 16384, - "gpt-3.5-turbo-0613": 4096, - "gpt-3.5-turbo": 4096, - "gpt-4-32k-0613": 32768, - "gpt-4-32k": 32768, - "gpt-4-0613": 8192, - "gpt-4": 8192, -} - - -def openai_role_map(marvin_role: Role) -> str: - if marvin_role == Role.FUNCTION_RESPONSE or marvin_role == "function_request": - return "function" - elif marvin_role == Role.FUNCTION_REQUEST or marvin_role == "function_response": - return "assistant" - else: - return getattr(marvin_role, "value", marvin_role).lower() - - -class OpenAIStreamHandler(StreamHandler): - async def handle_streaming_response( - self, - api_response: openai.openai_object.OpenAIObject, - ) -> Message: - """ - Accumulate chunk deltas into a full response. Returns the full message. - Passes partial messages to the callback, if provided. - """ - response = {"role": None, "content": "", "data": {}, "llm_response": None} - - async for r in api_response: - response["llm_response"] = r.to_dict_recursive() - - delta = r.choices[0].delta if r.choices and r.choices[0] else None - - if delta is None: - continue - - if "role" in delta: - response["role"] = delta.role - - if fn_call := delta.get("function_call"): - if "function_call" not in response["data"]: - response["data"]["function_call"] = {"name": None, "arguments": ""} - if "name" in fn_call: - response["data"]["function_call"]["name"] = fn_call.name - if "arguments" in fn_call: - response["data"]["function_call"]["arguments"] += ( - fn_call.arguments or "" - ) - - if "content" in delta: - response["content"] += delta.content or "" - - if self.callback: - callback_result = self.callback(Message(**response)) - if inspect.isawaitable(callback_result): - create_task(callback_result) - - return Message(**response) - - -class OpenAIChatLLM(ChatLLM): - model: Optional[str] = "gpt-3.5-turbo" - - @property - def context_size(self) -> int: - if self.model in CONTEXT_SIZES: - return CONTEXT_SIZES[self.model] - else: - for model_prefix, context in CONTEXT_SIZES: - if self.model.startswith(model_prefix): - return context - return 4096 - - def _get_openai_settings(self) -> dict: - openai_kwargs = {} - if marvin.settings.openai.api_key: - openai_kwargs["api_key"] = marvin.settings.openai.api_key.get_secret_value() - else: - raise ValueError( - "OpenAI API key not set. Please set it or use the" - " MARVIN_OPENAI_API_KEY environment variable." - ) - - if marvin.settings.openai.api_type: - openai_kwargs["api_type"] = marvin.settings.openai.api_type - if marvin.settings.openai.api_base: - openai_kwargs["api_base"] = marvin.settings.openai.api_base - if marvin.settings.openai.api_version: - openai_kwargs["api_version"] = marvin.settings.openai.api_version - if marvin.settings.openai.organization: - openai_kwargs["organization"] = marvin.settings.openai.organization - return openai_kwargs - - def format_messages( - self, messages: list[Message] - ) -> Union[str, dict, list[Union[str, dict]]]: - """Format Marvin message objects into a prompt compatible with the LLM model""" - formatted_messages = [] - for m in messages: - role = openai_role_map(m.role) - fmt = {"role": role, "content": m.content} - if m.name: - fmt["name"] = m.name - formatted_messages.append(fmt) - return formatted_messages - - async def run( - self, - messages: list[Message], - *, - functions: list[OpenAIFunction] = None, - function_call: Union[str, dict[str, str]] = None, - logger: Logger = None, - stream_handler: Callable[[Message], None] = False, - **kwargs, - ) -> Message: - """Calls an OpenAI LLM with a list of messages and returns the response.""" - - # ---------------------------------- - # Validate arguments - # ---------------------------------- - - if functions is None: - functions = [] - if function_call is None: - function_call = "auto" - elif function_call not in ( - ["auto", "none"] + [{"name": f.name} for f in functions] - ): - raise ValueError(f"Invalid function_call value: {function_call}") - if logger is None: - logger = get_logger(self.name) - - # ---------------------------------- - # Form OpenAI-specific arguments - # ---------------------------------- - - openai_kwargs = self._get_openai_settings() - kwargs.update(openai_kwargs) - - prompt = self.format_messages(messages) - llm_functions = [f.dict(exclude={"fn"}, exclude_none=True) for f in functions] - - # only add to kwargs if supplied, because empty parameters are not - # allowed by OpenAI - if functions: - kwargs["functions"] = llm_functions - kwargs["function_call"] = function_call - - # ---------------------------------- - # Call OpenAI LLM - # ---------------------------------- - - kwargs.setdefault("temperature", self.temperature) - kwargs.setdefault("max_tokens", self.max_tokens) - - response = await openai.ChatCompletion.acreate( - model=self.model, - messages=prompt, - stream=True if stream_handler else False, - request_timeout=marvin.settings.llm_request_timeout_seconds, - **kwargs, - ) - - if stream_handler: - handler = OpenAIStreamHandler(callback=stream_handler) - msg = await handler.handle_streaming_response(response) - role = msg.role - - if role == Role.ASSISTANT and isinstance( - msg.data.get("function_call"), dict - ): - role = Role.FUNCTION_REQUEST - - return Message( - role=role, - content=msg.content, - data=msg.data, - llm_response=msg.llm_response, - ) - - else: - llm_response = response.to_dict_recursive() - msg = llm_response["choices"][0]["message"].copy() - role = msg.pop("role").upper() - if role == "ASSISTANT" and isinstance(msg.get("function_call"), dict): - role = Role.FUNCTION_REQUEST - msg = Message( - role=role, - content=msg.pop("content", None), - data=msg, - llm_response=llm_response, - ) - return msg diff --git a/src/marvin/functions/__init__.py b/src/marvin/functions/__init__.py deleted file mode 100644 index c0cd37359..000000000 --- a/src/marvin/functions/__init__.py +++ /dev/null @@ -1,133 +0,0 @@ -import functools -import inspect -import re -from typing import Callable, TypeVar, Optional, Any, Type - -from fastapi.routing import APIRouter - -from marvin.utilities.types import function_to_model - - -T = TypeVar("T") -A = TypeVar("A") - - -class Function: - def __init__( - self, *, fn: Callable[[A], T] = None, name: str = None, description: str = None - ) -> None: - self.fn = fn - self.name = name or self.fn.__name__ - self.description = description or self.fn.__doc__ - - super().__init__() - - @property - def model(self): - return function_to_model(self.fn, name=self.name, description=self.description) - - @property - def signature(self): - return inspect.signature(self.fn) - - @property - def source_code(self): - source_code = inspect.cleandoc(inspect.getsource(self.fn)) - if match := re.search(re.compile(r"(\bdef\b.*)", re.DOTALL), source_code): - source_code = match.group(0) - return source_code - - @property - def return_annotation(self): - return_annotation = self.signature.return_annotation - if return_annotation is inspect._empty: - return return_annotation, False - return return_annotation - - def arguments(self, *args, **kwargs): - bound_args = self.signature.bind(*args, **kwargs) - bound_args.apply_defaults() - return bound_args.arguments - - def schema(self, *args, name: str = None, description: str = None, **kwargs): - schema = self.model.schema(*args, **kwargs) - return { - "name": name or schema.pop("title"), - "description": description or self.fn.__doc__, - "parameters": schema, - } - - -def FunctionDecoratorFactory( - name: str = "marvin", - func_class: Type[T] = None, - in_place=True, -) -> Callable[[A], T]: - def decorator(fn: Callable[[A], T] = None) -> Callable[[A], T]: - if fn is None: - return functools.partial(decorator) - elif in_place: - fn = func_class(fn=fn) - else: - instance = func_class(fn=fn) - setattr(fn, name, instance) - for method in dir(instance): - is_method_private = method.startswith("__") - if not is_method_private: - setattr(fn, method, getattr(instance, method)) - return fn - - return decorator - - -marvin_fn = FunctionDecoratorFactory(name="openai", func_class=Function) - - -class FunctionRegistry(APIRouter): - def __init__(self, function_decorator=marvin_fn, *args, **kwargs): - super().__init__(*args, **kwargs) - self.function_decorator = function_decorator - - @property - def endpoints(self): - # Returns literal functions. - return [route.endpoint for route in self.routes] - - @property - def schema(self): - # Returns JSON Schema of functions. - return [self.function_decorator(fn=fn).schema() for fn in self.endpoints] - - @property - def functions(self): - # Returns function classes. - return [self.function_decorator(fn=fn) for fn in self.endpoints] - - def include(self, registry: "FunctionRegistry", *args, **kwargs): - super().include_router(registry, *args, **kwargs) - # Add some 50-IQ idempotency. - self.routes = list({x.name: x for x in self.routes}.values()) - - def register(self, fn: Optional[Callable] = None, **kwargs: Any) -> Callable: - def decorator(fn: Callable, *args) -> Callable: - fn_class = self.function_decorator(fn=fn, **kwargs) - self.add_api_route( - **{ - **{ - "name": fn_class.name, - "path": f"/{fn_class.name}", - "endpoint": fn, - "description": fn_class.description, - "methods": ["POST"], - }, - **kwargs, - } - ) - return fn - - if fn: - # if the decorator was called with parentheses - return decorator(fn) - else: - # else, return the decorator to be called later - return decorator diff --git a/src/marvin/openai/ChatCompletion/__init__.py b/src/marvin/openai/ChatCompletion/__init__.py deleted file mode 100644 index bd41d1135..000000000 --- a/src/marvin/openai/ChatCompletion/__init__.py +++ /dev/null @@ -1,155 +0,0 @@ -from pydantic.main import ModelMetaclass - -from typing import Any, Callable, Optional -from operator import itemgetter - -from marvin import settings -from marvin._compat import BaseModel, Extra, Field -from marvin.types.request import Request as BaseRequest -from marvin.engine import ChatCompletionBase - - -class Request(BaseRequest): - """ - This is a class for creating Request objects to interact with the GPT-3 API. - The class contains several configurations and validation functions to ensure - the correct data is sent to the API. - - """ - - model: str = "gpt-3.5-turbo" # the model used by the GPT-3 API - temperature: float = 0.8 # the temperature parameter used by the GPT-3 API - api_key: str = Field(default_factory=settings.openai.api_key.get_secret_value) - - class Config: - exclude = {"response_model"} - exclude_none = True - extra = Extra.allow - - def dict(self, *args, serialize_functions=True, exclude=None, **kwargs): - """ - This method returns a dictionary representation of the Request. - If the functions attribute is present and serialize_functions is True, - the functions' schemas are also included. - """ - - # This identity function is here for no reason except to show - # readers that custom adapters need only override the dict method. - return super().dict( - *args, serialize_functions=serialize_functions, exclude=exclude, **kwargs - ) - - -class Response(BaseModel): - """ - This class is used to handle the response from the API. - It includes several utility functions and properties to extract useful information - from the raw response. - """ - - raw: Any # the raw response from the API - request: Any # the request that generated the response - - def __init__(self, response, *args, request, **kwargs): - super().__init__(raw=response, request=request) - - def __iter__(self): - return self.raw.__iter__() - - def __next__(self): - return self.raw.__next__() - - def __getattr__(self, name): - """ - This method attempts to get the attribute from the raw response. - If it doesn't exist, it falls back to the standard attribute access. - """ - try: - return self.raw.__getattr__(name) - except AttributeError: - return self.__getattribute__(name) - - @property - def message(self): - """ - This property extracts the message from the raw response. - If there is only one choice, it returns the message from that choice. - Otherwise, it returns a list of messages from all choices. - """ - if len(self.raw.choices) == 1: - return next(iter(self.raw.choices)).message - return [x.message for x in self.raw.choices] - - @property - def function_call(self): - """ - This property extracts the function call from the message. - If the message is a list, it returns a list of function calls from all messages. - Otherwise, it returns the function call from the message. - """ - if isinstance(self.message, list): - return [x.function_call for x in self.message] - return self.message.function_call - - @property - def callables(self): - """ - This property returns a list of all callable functions from the request. - """ - return [x for x in self.request.functions if isinstance(x, Callable)] - - @property - def callable_registry(self): - """ - This property returns a dictionary mapping function names to functions for all - callable functions from the request. - """ - return {fn.__name__: fn for fn in self.callables} - - def call_function(self, as_message=True): - """ - This method evaluates the function call in the response and returns the result. - If as_message is True, it returns the result as a function message. - Otherwise, it returns the result directly. - """ - name, raw_arguments = itemgetter("name", "arguments")(self.function_call) - function = self.callable_registry.get(name) - arguments = function.model.parse_raw(raw_arguments) - value = function(**arguments.dict(exclude_none=True)) - if as_message: - return {"role": "function", "name": name, "content": value} - else: - return value - - def to_model(self): - """ - This method parses the function call arguments into the response model and - returns the result. - """ - return self.request.response_model.parse_raw(self.function_call.arguments) - - def __repr__(self, *args, **kwargs): - """ - This method returns a string representation of the raw response. - """ - return self.raw.__repr__(*args, **kwargs) - - -class OpenAIChatCompletion(ChatCompletionBase): - """ - This class is used to create and handle chat completions from the API. - It provides several utility functions to create the request, send it to the API, - and handle the response. - """ - - _module: str = "openai.ChatCompletion" # the module used to interact with the API - _request: str = "marvin.openai.ChatCompletion.Request" - _response: str = "marvin.openai.ChatCompletion.Response" - defaults: Optional[dict] = Field(None, repr=False) # default configuration values - - -ChatCompletion = OpenAIChatCompletion() - -# This is a legacy class that is used to create a ChatCompletion object. -# It is deprecated and will be removed in a future release. -ChatCompletionConfig = Request diff --git a/src/marvin/openai/Function/Registry/__init__.py b/src/marvin/openai/Function/Registry/__init__.py deleted file mode 100644 index 2f74ea661..000000000 --- a/src/marvin/openai/Function/Registry/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import Any -from marvin.openai.Function import openai_fn -from openai.openai_object import OpenAIObject -from marvin.functions import FunctionRegistry - - -class OpenAIFunctionRegistry(FunctionRegistry): - def __init__(self, function_decorator=openai_fn, *args, **kwargs): - self.function_decorator = function_decorator - super().__init__(function_decorator=function_decorator, *args, **kwargs) - - def from_response(self, response: OpenAIObject) -> Any: - return next( - iter( - [ - {"name": k, "content": v} - for k, v in self.dict_from_openai_response(response).items() - if v is not None - ] - ), - None, - ) - - def dict_from_openai_response(self, response: OpenAIObject) -> Any: - return { - fn.name: fn.from_response(response) - for fn in map(lambda fn: self.function_decorator(fn=fn), self.endpoints) - } diff --git a/src/marvin/openai/Function/__init__.py b/src/marvin/openai/Function/__init__.py deleted file mode 100644 index 35d69430f..000000000 --- a/src/marvin/openai/Function/__init__.py +++ /dev/null @@ -1,41 +0,0 @@ -import json -from typing import TypeVar -from openai.openai_object import OpenAIObject - -from pydantic import validate_arguments -from marvin.functions import Function, FunctionDecoratorFactory - - -T = TypeVar("T") -A = TypeVar("A") - - -class OpenAIFunction(Function): - def __call__(self, response: OpenAIObject) -> T: - return self.from_response(response) - - @validate_arguments - def from_response(self, response: OpenAIObject) -> T: - relevant_calls = [ - choice.message.function_call - for choice in response.choices - if hasattr(choice.message, "function_call") - and self.name == choice.message.function_call.get("name", None) - ] - - arguments = [ - json.loads(function_call.get("arguments")) - for function_call in relevant_calls - ] - - responses = [self.fn(**argument) for argument in arguments] - - if len(responses) == 0: - return None - elif len(responses) == 1: - return responses[0] - else: - return responses - - -openai_fn = FunctionDecoratorFactory(name="openai", func_class=OpenAIFunction) diff --git a/src/marvin/openai/__init__.py b/src/marvin/openai/__init__.py deleted file mode 100644 index bebb3c2df..000000000 --- a/src/marvin/openai/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -import sys -from openai import * # noqa: F403 -from marvin.core.ChatCompletion import ChatCompletion - -ChatCompletion = ChatCompletion diff --git a/src/marvin/prompts/__init__.py b/src/marvin/prompts/__init__.py deleted file mode 100644 index 674bf5c53..000000000 --- a/src/marvin/prompts/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .base import Prompt, render_prompts, prompt_fn -from . import library diff --git a/src/marvin/prompts/base.py b/src/marvin/prompts/base.py deleted file mode 100644 index c4dd9c27c..000000000 --- a/src/marvin/prompts/base.py +++ /dev/null @@ -1,377 +0,0 @@ -import abc -import inspect -from functools import partial, wraps -from typing import ( - Any, - Callable, - Dict, - Generic, - List, - Literal, - Optional, - TypeVar, - Union, -) - -from jinja2 import Environment -from pydantic import BaseModel, Field -from typing_extensions import ParamSpec, Self - -import marvin -from marvin._compat import cast_to_json, cast_to_model, model_dump, model_json_schema -from marvin.core.ChatCompletion import ChatCompletion -from marvin.core.ChatCompletion.abstract import AbstractChatCompletion -from marvin.utilities.messages import Message, Role -from marvin.utilities.strings import count_tokens, jinja_env - -T = TypeVar("T") -P = ParamSpec("P") - - -class MessageList(list[Message]): - def render( - self: Self, - **kwargs: Any, - ) -> Self: - return render_prompts(self, render_kwargs=kwargs) - - def serialize( - self: Self, - **kwargs: Any, - ) -> list[dict[str, Any]]: - return [model_dump(message) for message in self.render(**kwargs)] - - -class PromptList(list[Union["Prompt", Message]]): - def __init__(self, prompts: list[Union["Prompt", Message]]): - super().__init__(prompts) - - def render( - self: Self, - content: Optional[str] = None, - render_kwargs: Optional[dict[str, Any]] = None, - ) -> list[Message]: - return render_prompts(self, render_kwargs=render_kwargs) - - def dict(self, **kwargs: Any): - return [model_dump(message) for message in self.render(**kwargs)] - - def __call__(self, **kwargs: Any): - return self.render(**kwargs) - - -class BasePrompt(BaseModel, abc.ABC): - """ - Base class for prompt templates. - """ - - functions: Optional[ - Union[ - List[Union[Dict[str, Any], Callable[..., Any], type[BaseModel]]], - Callable[ - ..., List[Union[Dict[str, Any], Callable[..., Any], type[BaseModel]]] - ], - ] - ] = Field(default=None) - - function_call: Optional[ - Union[ - Literal["auto"], - Dict[Literal["name"], str], - ] - ] = Field(default=None) - - response_model: Optional[ - Union[ - type, - type[BaseModel], - Any, - Callable[..., Union[type, type[BaseModel], Any]], - ] - ] = Field(default=None) - - response_model_name: Optional[str] = Field( - default=None, - exclude=True, - repr=False, - ) - response_model_description: Optional[str] = Field( - default=None, - exclude=True, - repr=False, - ) - response_model_field_name: Optional[str] = Field( - default=None, - exclude=True, - repr=False, - ) - - position: Optional[int] = Field( - default=None, - repr=False, - exclude=True, - description=( - "Position indicates the desired index for this prompt's messages. 0" - " indicates they should be first; 1 indicates they should be second; -1" - " indicates they should be last; None indicates they should be between any" - " prompts that do request a position." - ), - ) - priority: float = Field( - default=10, - repr=False, - exclude=True, - description=( - "Priority indicates the weight given when trimming messages to satisfy" - " context limitations. Lower numbers indicate higher priority e.g. the" - " highest priority is 0. The default is 10. Ties will be broken by message" - " timestamp and role." - ), - ) - - @abc.abstractmethod - def generate(self, **kwargs: Any) -> list["Message"]: - """ - Abstract method that generates a list of messages from the prompt template - """ - pass - - def render( - self: Self, content: str, render_kwargs: Optional[dict[str, Any]] = None - ) -> str: - """ - Helper function for rendering any jinja2 template with runtime render kwargs - """ - return jinja_env.from_string(inspect.cleandoc(content)).render( - **(render_kwargs or {}) - ) - - def __or__(self: Self, other: Union[Self, list[Self]]) -> PromptList: - """ - Supports pipe syntax: - prompt = ( - Prompt() - | Prompt() - | Prompt() - ) - """ - # when the right operand is a Prompt object - if isinstance(other, Prompt): - return PromptList([self, other]) - # when the right operand is a list - elif isinstance(other, list[Prompt]): - return PromptList([self, *other]) - else: - raise TypeError( - f"unsupported operand type(s) for |: '{type(self).__name__}' and" - f" '{type(other).__name__}'" - ) - - def __ror__(self, other): - """ - Supports pipe syntax: - prompt = ( - Prompt() - | Prompt() - | Prompt() - ) - """ - # when the left operand is a Prompt object - if isinstance(other, Prompt): - return PromptList([other, self]) - # when the left operand is a list - elif isinstance(other, list): - return PromptList(other + [self]) - else: - raise TypeError( - f"unsupported operand type(s) for |: '{type(other).__name__}' and" - f" '{type(self).__name__}'" - ) - - -class Prompt(BasePrompt, Generic[P], extra="allow", arbitrary_types_allowed=True): - def generate(self, **kwargs: Any) -> list[Message]: - response = Message.from_transcript( - self.render(content=self.__doc__ or "", render_kwargs=kwargs) - ) - return response - - def to_dict(self, **kwargs: Any) -> dict[str, Any]: - extras = model_dump( - self, - exclude=set(self.__fields__.keys()), - exclude_none=True, - ) - return { - "messages": [ - model_dump(message, include={"content", "role"}) - for message in render_prompts( - self.generate( - **extras | kwargs | {"response_model": self.response_model} - ) - ) - ], - "functions": self.functions, - "function_call": self.function_call, - "response_model": self.response_model, - } - - def to_chat_completion( - self, model: Optional[str] = None, **model_kwargs: Any - ) -> AbstractChatCompletion[T]: - return ChatCompletion(model=model, **model_kwargs)(**self.to_dict()) - - def serialize(self, model: Any = None, **kwargs: Any) -> dict[str, Any]: - if model: - return model(**self.to_dict(**kwargs))._serialize_request() # type: ignore - - _dict = self.to_dict(**kwargs) - - response: dict[str, Any] = {} - response["messages"] = _dict["messages"] - - if _dict.get("response_model", None): - response["functions"] = [ - model_json_schema(cast_to_model(_dict["response_model"])) - ] - response["function_call"] = {"name": response["functions"][0]["name"]} - elif _dict.get("functions", None): - response["functions"] = [ - cast_to_json(function) if callable(function) else function - for function in _dict["functions"] - ] - if _dict["function_call"]: - response["function_call"] = _dict["function_call"] - - return response - - @classmethod - def as_decorator( - cls: type[Self], - func: Optional[Callable[P, Any]] = None, - *, - environment: Optional[Environment] = None, - ctx: Optional[dict[str, Any]] = None, - role: Optional[Role] = None, - functions: Optional[ - list[Union[Callable[..., Any], type[BaseModel], dict[str, Any]]] - ] = None, # noqa - function_call: Optional[ - Union[Literal["auto"], dict[Literal["name"], str]] - ] = None, # noqa - response_model: Optional[type[BaseModel]] = None, - response_model_name: Optional[str] = None, - response_model_description: Optional[str] = None, - response_model_field_name: Optional[str] = None, - serialize_on_call: bool = True, - ) -> Union[ - Callable[[Callable[P, None]], Callable[P, None]], - Callable[[Callable[P, None]], Callable[P, Self]], - Callable[P, Self], - ]: - def wrapper(func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> Self: - signature = inspect.signature(func) - params = signature.bind(*args, **kwargs) - params.apply_defaults() - response = type(getattr(cls, "__name__", ""), (cls,), {})( - __params__=params.arguments, - **params.arguments, - **ctx or {}, - functions=functions, - function_call=function_call, - response_model=cast_to_model( - response_model or signature.return_annotation, - name=response_model_name, - description=response_model_description, - field_name=response_model_field_name, - ), - response_model_name=response_model_name, - response_model_description=response_model_description, - response_model_field_name=response_model_field_name, - ) - response.__doc__ = func.__doc__ - if serialize_on_call: - return response.serialize() - return response - - if func is not None: - return wraps(func)(partial(wrapper, func)) - - def decorator(func: Callable[P, None]) -> Callable[P, Self]: - return wraps(func)(partial(wrapper, func)) - - return decorator - - -prompt_fn = Prompt.as_decorator - - -class MessageWrapper(BasePrompt): - """ - A Prompt class that stores and returns a specific Message - """ - - message: Message - - def generate(self, **kwargs: Any) -> list[Message]: - return [self.message] - - -def render_prompts( - prompts: Union[list[Message], list[Union[Prompt, Message]]], - render_kwargs: Optional[dict[str, Any]] = None, - max_tokens: Optional[int] = None, -) -> MessageList: - max_tokens = max_tokens or marvin.settings.llm_max_context_tokens - - all_messages = [] - - # if the user supplied any messages, wrap them in a MessageWrapper so we can - # treat them as prompts for sorting and filtering - prompts = [ - MessageWrapper(message=p) if isinstance(p, Message) else p for p in prompts - ] - - # Separate prompts by positive, none and negative position - pos_prompts = [p for p in prompts if p.position is not None and p.position >= 0] - none_prompts = [p for p in prompts if p.position is None] - neg_prompts = [p for p in prompts if p.position is not None and p.position < 0] - - # Sort the positive prompts in ascending order and negative prompts in - # descending order, but both with timestamp ascending - pos_prompts = sorted(pos_prompts, key=lambda c: c.position) - neg_prompts = sorted(neg_prompts, key=lambda c: c.position, reverse=True) - - # generate messages from all prompts - for i, prompt in enumerate(pos_prompts + none_prompts + neg_prompts): - prompt_messages = prompt.generate(**(render_kwargs or {})) or [] - all_messages.extend((prompt.priority, i, m) for m in prompt_messages) - - # sort all messages by (priority asc, position desc) and stop when the - # token limit is reached. This will prefer high-priority messages that are - # later in the message chain. - current_tokens = 0 - allowed_messages = [] - for _, position, msg in sorted(all_messages, key=lambda m: (m[0], -1 * m[1])): - if current_tokens >= max_tokens: - break - allowed_messages.append((position, msg)) - current_tokens += count_tokens(msg.content) - - # sort allowed messages by position to restore original order - messages = [msg for _, msg in sorted(allowed_messages, key=lambda m: m[0])] - - # Combine all system messages into one and insert at the index of the first - # system message - system_messages = [m for m in messages if m.role == Role.SYSTEM.value] - if len(system_messages) > 1: - system_message = Message( - role=Role.SYSTEM, - content="\n\n".join([m.content for m in system_messages]), - ) - system_message_index = messages.index(system_messages[0]) - messages = [m for m in messages if m.role != Role.SYSTEM.value] - messages.insert(system_message_index, system_message) - - # return all messages - return messages diff --git a/src/marvin/prompts/library.py b/src/marvin/prompts/library.py deleted file mode 100644 index 5d923cf22..000000000 --- a/src/marvin/prompts/library.py +++ /dev/null @@ -1,168 +0,0 @@ -import inspect -from typing import Callable, Literal, Optional - -from pydantic import Field - -from marvin.prompts.base import Prompt -from marvin.utilities.history import History, HistoryFilter -from marvin.utilities.messages import Message, Role - - -class MessagePrompt(Prompt): - role: Role - content: str = Field( - ..., description="The message content, which can be a Jinja2 template" - ) - name: str = None - priority: int = 2 - - def get_content(self) -> str: - """ - Override this method to easily customize behavior - """ - return self.content - - def generate(self, **kwargs) -> list[Message]: - return [ - Message( - role=self.role, - content=self.render( - self.get_content(), - render_kwargs={ - **self.dict(exclude={"role", "content", "name", "priority"}), - **kwargs, - }, - ), - name=self.name, - ) - ] - - def read(self, **kwargs) -> str: - return self.render( - self.get_content(), - render_kwargs={ - **self.dict(exclude={"role", "content", "name", "priority"}), - **kwargs, - }, - ) - - def __init__(self, content: str = None, *args, **kwargs): - content = kwargs.get("content", content) - super().__init__( - *args, **{**kwargs, **({"content": content} if content else {})} - ) - - -class System(MessagePrompt): - position: int = 0 - priority: int = 1 - role: Literal[Role.SYSTEM] = Role.SYSTEM - - -class Assistant(MessagePrompt): - role: Literal[Role.ASSISTANT] = Role.ASSISTANT - - -class User(MessagePrompt): - role: Literal[Role.USER] = Role.USER - - -class MessageHistory(Prompt): - history: History - n: Optional[int] = 100 - skip: Optional[int] = None - filter: HistoryFilter = None - - def generate(self, **kwargs) -> list[Message]: - return self.history.get_messages(n=self.n, skip=self.skip, filter=self.filter) - - -class Tagged(MessagePrompt): - """ - Surround content with a tag, e.g. bold - """ - - tag: str - role: Role = Role.USER - - def get_content(self) -> str: - return f"<{self.tag}>{self.content}" - - -class Conditional(Prompt): - if_: Callable = Field( - ..., - description=( - "A function that returns a boolean. It will be called when the prompt is" - " generated and provided all the variables that are passed to the render" - " function." - ), - ) - if_content: str - else_content: str - role: Role = Role.USER - name: str = None - - def generate(self, **kwargs) -> list[Message]: - if self.if_(**kwargs): - return [ - Message( - role=self.role, - content=self.render(self.if_content, render_kwargs=kwargs), - name=self.name, - ) - ] - elif self.else_content: - return [ - Message( - role=self.role, - content=self.render(self.else_content, render_kwargs=kwargs), - name=self.name, - ) - ] - else: - return [] - - -class JinjaConditional(Prompt): - if_: str = Field( - ..., - description=( - "A Jinja2-compatible expression that evaluates to a boolean e.g." - " `truthy_var` or `counter > 10`. It will automatically be templated, do" - " not include the `{{ }}` braces." - ), - ) - if_content: str - else_content: str = None - role: Role = Role.USER - name: str = None - - def generate(self, **kwargs) -> list[Message]: - if_content = inspect.cleandoc(self.if_content) - if self.else_content: - content = ( - f"{{% if {self.if_} %}}{if_content}{{% else" - f" %}}{inspect.cleandoc(self.else_content)}{{% endif %}}" - ) - - else: - content = (f"{{% if {self.if_} %}}{if_content}{{% endif %}}",) - return [ - Message( - role=self.role, - content=self.render(content, render_kwargs=kwargs), - name=self.name, - ) - ] - - -class ChainOfThought(Prompt): - position: int = -1 - - def generate(self, **kwargs) -> list[Message]: - return [Message(role=Role.ASSISTANT, content="Let's think step by step.")] - - -class Now(System): - content: str = "It is {{ now().strftime('%A, %d %B %Y at %I:%M:%S %p %Z') }}." diff --git a/src/marvin/pydantic.py b/src/marvin/pydantic.py deleted file mode 100644 index 38396a61e..000000000 --- a/src/marvin/pydantic.py +++ /dev/null @@ -1,26 +0,0 @@ -from pydantic.version import VERSION as PYDANTIC_VERSION - -PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") - -if not PYDANTIC_V2: - from pydantic import ( - BaseModel, - BaseSettings, - Extra, - Field, - PrivateAttr, - PyObject, - SecretStr, - validate_arguments, - ) - from pydantic.main import ModelMetaclass - - ModelMetaclass = ModelMetaclass - BaseModel = BaseModel - BaseSettings = BaseSettings - Field = Field - SecretStr = SecretStr - Extra = Extra - ImportString = PyObject - PrivateAttr = PrivateAttr - validate_arguments = validate_arguments diff --git a/src/marvin/requests.py b/src/marvin/requests.py new file mode 100644 index 000000000..f7d01d88a --- /dev/null +++ b/src/marvin/requests.py @@ -0,0 +1,115 @@ +from typing import Any, Callable, Generic, Literal, Optional, TypeVar, Union + +from pydantic import BaseModel, Field +from typing_extensions import Annotated, Self + +from marvin.settings import settings + +T = TypeVar("T", bound=BaseModel) + + +class ResponseFormat(BaseModel): + type: str + + +LogitBias = dict[str, float] + + +class Function(BaseModel, Generic[T]): + name: str + description: Optional[str] + parameters: dict[str, Any] + + model: Optional[type[T]] = Field(default=None, exclude=True, repr=False) + python_fn: Optional[Callable[..., Any]] = Field( + default=None, + description="Private field that holds the executable function, if available", + exclude=True, + repr=False, + ) + + def validate_json(self: Self, json_data: Union[str, bytes, bytearray]) -> T: + if self.model is None: + raise ValueError("This Function was not initialized with a model.") + return self.model.model_validate_json(json_data) + + +class Tool(BaseModel, Generic[T]): + type: str + function: Optional[Function[T]] = None + + +class ToolSet(BaseModel, Generic[T]): + tools: Optional[list[Tool[T]]] = None + tool_choice: Optional[Union[Literal["auto"], dict[str, Any]]] = None + + +class RetrievalTool(Tool[T]): + type: str = Field(default="retrieval") + + +class CodeInterpreterTool(Tool[T]): + type: str = Field(default="code_interpreter") + + +class FunctionCall(BaseModel): + name: str + + +class BaseMessage(BaseModel): + content: str + role: str + + +class Grammar(BaseModel): + logit_bias: Optional[LogitBias] = None + max_tokens: Optional[Annotated[int, Field(strict=True, ge=1)]] = None + response_format: Optional[ResponseFormat] = None + + +class Prompt(Grammar, ToolSet[T], Generic[T]): + messages: list[BaseMessage] = Field(default_factory=list) + + +class ResponseModel(BaseModel): + model: type + name: str = Field(default="FormatResponse") + description: str = Field(default="Response format") + + +class ChatRequest(Prompt[T]): + model: str = Field(default=settings.openai.chat.completions.model) + frequency_penalty: Optional[ + Annotated[float, Field(strict=True, ge=-2.0, le=2.0)] + ] = 0 + n: Optional[Annotated[int, Field(strict=True, ge=1)]] = 1 + presence_penalty: Optional[ + Annotated[float, Field(strict=True, ge=-2.0, le=2.0)] + ] = 0 + seed: Optional[int] = None + stop: Optional[Union[str, list[str]]] = None + stream: Optional[bool] = False + temperature: Optional[Annotated[float, Field(strict=True, ge=0, le=2)]] = 1 + top_p: Optional[Annotated[float, Field(strict=True, ge=0, le=1)]] = 1 + user: Optional[str] = None + + +class AssistantMessage(BaseMessage): + id: str + thread_id: str + created_at: int + assistant_id: Optional[str] = None + run_id: Optional[str] = None + file_ids: list[str] = [] + metadata: dict[str, Any] = {} + + +class Run(BaseModel, Generic[T]): + id: str + thread_id: str + created_at: int + status: str + model: str + instructions: Optional[str] + tools: Optional[list[Tool[T]]] = None + metadata: dict[str, str] diff --git a/src/marvin/serializers.py b/src/marvin/serializers.py new file mode 100644 index 000000000..8a7c7422f --- /dev/null +++ b/src/marvin/serializers.py @@ -0,0 +1,107 @@ +from enum import Enum +from types import GenericAlias +from typing import ( + Any, + Callable, + Literal, + Optional, + TypeVar, + Union, + get_args, + get_origin, +) + +from pydantic import BaseModel, create_model +from pydantic.fields import FieldInfo +from pydantic.json_schema import GenerateJsonSchema, JsonSchemaMode + +from marvin import settings +from marvin.requests import Function, Grammar, Tool + +U = TypeVar("U", bound=BaseModel) + + +class FunctionSchema(GenerateJsonSchema): + def generate(self, schema: Any, mode: JsonSchemaMode = "validation"): + json_schema = super().generate(schema, mode=mode) + json_schema.pop("title", None) + return json_schema + + +def create_tool_from_type( + _type: Union[type, GenericAlias], + model_name: str, + model_description: str, + field_name: str, + field_description: str, + python_function: Optional[Callable[..., Any]] = None, + **kwargs: Any, +) -> Tool[BaseModel]: + annotated_metadata = getattr(_type, "__metadata__", []) + if isinstance(next(iter(annotated_metadata), None), FieldInfo): + metadata = next(iter(annotated_metadata)) + else: + metadata = FieldInfo(description=field_description) + + model: type[BaseModel] = create_model( + model_name, + __config__=None, + __base__=None, + __module__=__name__, + __validators__=None, + __cls_kwargs__=None, + **{field_name: (_type, metadata)}, + ) + return Tool[BaseModel]( + type="function", + function=Function[BaseModel]( + name=model_name, + description=model_description, + parameters=model.model_json_schema(schema_generator=FunctionSchema), + model=model, + ), + ) + + +def create_tool_from_model( + model: type[BaseModel], +) -> Tool[BaseModel]: + return Tool[BaseModel]( + type="function", + function=Function[BaseModel]( + name=model.__name__, + description=model.__doc__, + parameters=model.model_json_schema(schema_generator=FunctionSchema), + model=model, + ), + ) + + +def create_vocabulary_from_type( + vocabulary: Union[GenericAlias, type], +) -> list[str]: + if get_origin(vocabulary) == Literal: + return [str(token) for token in get_args(vocabulary)] + elif isinstance(vocabulary, type) and issubclass(vocabulary, Enum): + return [str(token) for token in list(vocabulary.__members__.keys())] + else: + raise TypeError( + f"Expected Literal or Enum, got {type(vocabulary)} with value {vocabulary}" + ) + + +def create_grammar_from_vocabulary( + vocabulary: list[str], + encoder: Callable[[str], list[int]] = settings.openai.chat.completions.encoder, + max_tokens: Optional[int] = None, + _enumerate: bool = True, + **kwargs: Any, +) -> Grammar: + return Grammar( + max_tokens=max_tokens, + logit_bias={ + str(encoding): 100 + for i, token in enumerate(vocabulary) + for encoding in encoder(str(i) if _enumerate else token) + }, + ) diff --git a/src/marvin/settings.py b/src/marvin/settings.py index 42372eeed..096c6fe7c 100644 --- a/src/marvin/settings.py +++ b/src/marvin/settings.py @@ -1,183 +1,216 @@ import os from contextlib import contextmanager -from pathlib import Path -from typing import Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, Union + +from pydantic import Field, SecretStr +from pydantic_settings import BaseSettings, SettingsConfigDict + +if TYPE_CHECKING: + from openai import AsyncClient, Client + from openai._base_client import HttpxBinaryResponseContent + from openai.types.chat import ChatCompletion + from openai.types.images_response import ImagesResponse + + +class MarvinSettings(BaseSettings): + model_config = SettingsConfigDict( + env_prefix="marvin_", + env_file="~/.marvin/.env", + extra="allow", + arbitrary_types_allowed=True, + ) + + def __setattr__(self, name: str, value: Any) -> None: + """Preserve SecretStr type when setting values.""" + field = self.model_fields.get(name) + if field: + annotation = field.annotation + base_types = ( + annotation.__args__ + if getattr(annotation, "__origin__", None) is Union + else (annotation,) + ) + if SecretStr in base_types and not isinstance(value, SecretStr): + value = SecretStr(value) + super().__setattr__(name, value) + + +class MarvinModelSettings(MarvinSettings): + model: str + + @property + def encoder(self): + import tiktoken + + return tiktoken.encoding_for_model(self.model).encode + + +class ChatCompletionSettings(MarvinModelSettings): + model: str = Field( + default="gpt-3.5-turbo-1106", + description="The default chat model to use.", + ) + + async def acreate(self, **kwargs: Any) -> "ChatCompletion": + from marvin.settings import settings + + return await settings.openai.async_client.chat.completions.create( + model=self.model, **kwargs + ) + + def create(self, **kwargs: Any) -> "ChatCompletion": + from marvin.settings import settings + + return settings.openai.client.chat.completions.create( + model=self.model, **kwargs + ) + + +class ImageSettings(MarvinModelSettings): + model: str = Field( + default="dall-e-3", + description="The default image model to use.", + ) + size: Literal["1024x1024", "1792x1024", "1024x1792"] = Field( + default="1024x1024", + ) + response_format: Literal["url", "b64_json"] = Field(default="url") + style: Literal["vivid", "natural"] = Field(default="vivid") + + async def agenerate(self, prompt: str, **kwargs: Any) -> "ImagesResponse": + from marvin.settings import settings + + return await settings.openai.async_client.images.generate( + model=self.model, + prompt=prompt, + size=self.size, + response_format=self.response_format, + style=self.style, + **kwargs, + ) + + def generate(self, prompt: str, **kwargs: Any) -> "ImagesResponse": + from marvin.settings import settings + + return settings.openai.client.images.generate( + model=self.model, + prompt=prompt, + size=self.size, + response_format=self.response_format, + style=self.style, + **kwargs, + ) + + +class SpeechSettings(MarvinModelSettings): + model: str = Field( + default="tts-1-hd", + description="The default image model to use.", + ) + voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = Field( + default="alloy", + ) + response_format: Literal["mp3", "opus", "aac", "flac"] = Field(default="mp3") + speed: float = Field(default=1.0) + + async def acreate(self, input: str, **kwargs: Any) -> "HttpxBinaryResponseContent": + from marvin.settings import settings + + return await settings.openai.async_client.audio.speech.create( + model=kwargs.get("model", self.model), + input=input, + voice=kwargs.get("voice", self.voice), + response_format=kwargs.get("response_format", self.response_format), + speed=kwargs.get("speed", self.speed), + ) + + def create(self, input: str, **kwargs: Any) -> "HttpxBinaryResponseContent": + from marvin.settings import settings + + return settings.openai.client.audio.speech.create( + model=kwargs.get("model", self.model), + input=input, + voice=kwargs.get("voice", self.voice), + response_format=kwargs.get("response_format", self.response_format), + speed=kwargs.get("speed", self.speed), + ) + -from ._compat import ( - BaseSettings, - SecretStr, - model_dump, -) +class AssistantSettings(MarvinModelSettings): + model: str = Field( + default="gpt-4-1106-preview", + description="The default assistant model to use.", + ) -DEFAULT_ENV_PATH = Path(os.getenv("MARVIN_ENV_FILE", "~/.marvin/.env")).expanduser() +class ChatSettings(MarvinSettings): + completions: ChatCompletionSettings = Field(default_factory=ChatCompletionSettings) -class MarvinBaseSettings(BaseSettings): - class Config: - env_file = ( - ".env", - str(DEFAULT_ENV_PATH), + +class AudioSettings(MarvinSettings): + speech: SpeechSettings = Field(default_factory=SpeechSettings) + + +class OpenAISettings(MarvinSettings): + model_config = SettingsConfigDict(env_prefix="marvin_openai_") + + api_key: Optional[SecretStr] = Field( + default=None, + description="Your OpenAI API key.", + ) + + organization: Optional[str] = Field( + default=None, + description="Your OpenAI organization ID.", + ) + + chat: ChatSettings = Field(default_factory=ChatSettings) + images: ImageSettings = Field(default_factory=ImageSettings) + audio: AudioSettings = Field(default_factory=AudioSettings) + assistants: AssistantSettings = Field(default_factory=AssistantSettings) + + @property + def async_client( + self, api_key: Optional[str] = None, **kwargs: Any + ) -> "AsyncClient": + from openai import AsyncClient + + if not (api_key or self.api_key): + raise ValueError("No API key provided.") + elif not api_key and self.api_key: + api_key = self.api_key.get_secret_value() + + return AsyncClient( + api_key=api_key, + organization=self.organization, + **kwargs, ) - env_prefix = "MARVIN_" - validate_assignment = True - - -class OpenAISettings(MarvinBaseSettings): - """Provider-specific settings. Only some of these will be relevant to users.""" - - class Config: - env_prefix = "MARVIN_OPENAI_" - - api_key: Optional[SecretStr] = None - organization: Optional[str] = None - embedding_engine: str = "text-embedding-ada-002" - api_type: Optional[str] = None - api_base: Optional[str] = None - api_version: Optional[str] = None - - def get_defaults(self, settings: "Settings") -> dict[str, Any]: - import os - - import openai - - from marvin import openai as marvin_openai - - EXCLUDE_KEYS = {"stream_handler"} - - response: dict[str, Any] = {} - if settings.llm_max_context_tokens > 0: - response["max_tokens"] = settings.llm_max_tokens - response["api_key"] = self.api_key and self.api_key.get_secret_value() - if os.environ.get("MARVIN_OPENAI_API_KEY"): - response["api_key"] = os.environ["MARVIN_OPENAI_API_KEY"] - if os.environ.get("OPENAI_API_KEY"): - response["api_key"] = os.environ["OPENAI_API_KEY"] - if openai.api_key: - response["api_key"] = openai.api_key - if marvin_openai.api_key: - response["api_key"] = marvin_openai.api_key - response["temperature"] = settings.llm_temperature - response["request_timeout"] = settings.llm_request_timeout_seconds - return { - k: v for k, v in response.items() if v is not None and k not in EXCLUDE_KEYS - } - - -class AnthropicSettings(MarvinBaseSettings): - class Config: - env_prefix = "MARVIN_ANTHROPIC_" - - api_key: Optional[SecretStr] = None - - def get_defaults(self, settings: "Settings") -> dict[str, Any]: - response: dict[str, Any] = {} - if settings.llm_max_context_tokens > 0: - response["max_tokens_to_sample"] = settings.llm_max_tokens - response["api_key"] = self.api_key and self.api_key.get_secret_value() - response["temperature"] = settings.llm_temperature - response["timeout"] = settings.llm_request_timeout_seconds - if os.environ.get("MARVIN_ANTHROPIC_API_KEY"): - response["api_key"] = os.environ["MARVIN_ANTHROPIC_API_KEY"] - if os.environ.get("ANTHROPIC_API_KEY"): - response["api_key"] = os.environ["ANTHROPIC_API_KEY"] - return {k: v for k, v in response.items() if v is not None} - - -class AzureOpenAI(MarvinBaseSettings): - class Config: - env_prefix = "MARVIN_AZURE_OPENAI_" - - api_key: Optional[SecretStr] = None - api_type: Literal["azure", "azure_ad"] = "azure" - # "The endpoint of the Azure OpenAI API. This should have the form https://YOUR_RESOURCE_NAME.openai.azure.com" # noqa - api_base: Optional[str] = None - api_version: Optional[str] = "2023-07-01-preview" - # `deployment_name` will correspond to the custom name you chose for your deployment when # noqa - # you deployed a model. - deployment_name: Optional[str] = None - - def get_defaults(self, settings: "Settings") -> dict[str, Any]: - import os - - import openai - - from marvin import openai as marvin_openai - - response: dict[str, Any] = {} - if settings.llm_max_context_tokens > 0: - response["max_tokens"] = settings.llm_max_tokens - response["temperature"] = settings.llm_temperature - response["request_timeout"] = settings.llm_request_timeout_seconds - response["api_key"] = self.api_key and self.api_key.get_secret_value() - if os.environ.get("MARVIN_AZURE_OPENAI_API_KEY"): - response["api_key"] = os.environ["MARVIN_AZURE_OPENAI_API_KEY"] - if openai.api_key: - response["api_key"] = openai.api_key - if marvin_openai.api_key: - response["api_key"] = marvin_openai.api_key - - return model_dump(self, exclude_unset=True) | { - k: v for k, v in response.items() if v is not None - } - - -def initial_setup(home: Union[Path, None] = None) -> Path: - if not home: - home = Path.home() / ".marvin" - home.mkdir(parents=True, exist_ok=True) - return home - - -class Settings(MarvinBaseSettings): - """Marvin settings""" - - home: Path = initial_setup() - test_mode: bool = False - - # LOGGING - log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO" - verbose: bool = False - - # LLMS - llm_model: str = "openai/gpt-3.5-turbo" - llm_max_tokens: int = 1500 - llm_max_context_tokens: int = 3500 - llm_temperature: float = 0.8 - llm_request_timeout_seconds: Union[float, list[float]] = 600.0 - - # AI APPLICATIONS - ai_application_max_iterations: Optional[int] = None - - # providers - openai: OpenAISettings = OpenAISettings() - anthropic: AnthropicSettings = AnthropicSettings() - azure_openai: AzureOpenAI = AzureOpenAI() - - # SLACK - slack_api_token: Optional[SecretStr] = None - - # TOOLS - - # chroma - chroma_server_host: Optional[str] = None - chroma_server_http_port: Optional[int] = None - # github - github_token: Optional[SecretStr] = None - - # wolfram - wolfram_app_id: Optional[SecretStr] = None - - def get_defaults(self, provider: Optional[str] = None) -> dict[str, Any]: - response: dict[str, Any] = {} - if provider == "openai": - return self.openai.get_defaults(self) - elif provider == "anthropic": - return self.anthropic.get_defaults(self) - elif provider == "azure_openai": - return self.azure_openai.get_defaults(self) - else: - return response + @property + def client(self, api_key: Optional[str] = None, **kwargs: Any) -> "Client": + from openai import Client + + if not (api_key or self.api_key): + raise ValueError("No API key provided.") + elif not api_key and self.api_key: + api_key = self.api_key.get_secret_value() + + return Client( + api_key=api_key, + organization=self.organization, + **kwargs, + ) + + +class Settings(MarvinSettings): + model_config = SettingsConfigDict(env_prefix="marvin_") + + openai: OpenAISettings = Field(default_factory=OpenAISettings) + + log_level: str = Field( + default="DEBUG", + description="The log level to use.", + ) settings = Settings() @@ -190,15 +223,9 @@ def temporary_settings(**kwargs: Any): been already been accessed at module load time. This function should only be used for testing. - - Example: - >>> from marvin.settings import settings - >>> with temporary_settings(MARVIN_LLM_MAX_TOKENS=100): - >>> assert settings.llm_max_tokens == 100 - >>> assert settings.llm_max_tokens == 1500 """ old_env = os.environ.copy() - old_settings = settings.copy() + old_settings = settings.model_copy() try: for setting in kwargs: @@ -210,7 +237,7 @@ def temporary_settings(**kwargs: Any): new_settings = Settings() - for field in settings.__fields__: + for field in settings.model_fields: object.__setattr__(settings, field, getattr(new_settings, field)) yield settings @@ -222,5 +249,5 @@ def temporary_settings(**kwargs: Any): else: os.environ.pop(setting, None) - for field in settings.__fields__: + for field in settings.model_fields: object.__setattr__(settings, field, getattr(old_settings, field)) diff --git a/src/marvin/tools/__init__.py b/src/marvin/tools/__init__.py index 0be514778..e69de29bb 100644 --- a/src/marvin/tools/__init__.py +++ b/src/marvin/tools/__init__.py @@ -1,2 +0,0 @@ -from .base import Tool, tool -from . import format_response diff --git a/src/marvin/tools/assistants.py b/src/marvin/tools/assistants.py new file mode 100644 index 000000000..bd07e4a36 --- /dev/null +++ b/src/marvin/tools/assistants.py @@ -0,0 +1,17 @@ +from typing import Any, Union + +from marvin.requests import CodeInterpreterTool, RetrievalTool, Tool + +Retrieval = RetrievalTool() +CodeInterpreter = CodeInterpreterTool() + +AssistantTools = Union[RetrievalTool, CodeInterpreterTool, Tool] + + +class CancelRun(Exception): + """ + A special exception that can be raised in a tool to end the run immediately. + """ + + def __init__(self, data: Any = None): + self.data = data diff --git a/src/marvin/tools/base.py b/src/marvin/tools/base.py deleted file mode 100644 index d3c455084..000000000 --- a/src/marvin/tools/base.py +++ /dev/null @@ -1,76 +0,0 @@ -import inspect -from functools import partial -from typing import Callable, Optional - -from pydantic import BaseModel - -from marvin._compat import field_validator -from marvin.types import Function -from marvin.utilities.strings import jinja_env -from marvin.utilities.types import LoggerMixin, function_to_schema - - -class Tool(LoggerMixin, BaseModel): - name: Optional[str] = None - description: Optional[str] = None - fn: Optional[Callable] = None - - @classmethod - def from_function(cls, fn, name: str = None, description: str = None): - # assuming fn has a name and a description - name = name or fn.__name__ - description = description or fn.__doc__ - return cls(name=name, description=description, fn=fn) - - @field_validator("name") - def default_name(cls, v: Optional[str]) -> str: - if v is None: - return cls.__name__ - else: - return v - - def run(self, *args, **kwargs): - if not self.fn: - raise NotImplementedError() - else: - return self.fn(*args, **kwargs) - - def __call__(self, *args, **kwargs): - return self.run(*args, **kwargs) - - def argument_schema(self) -> dict: - schema = function_to_schema(self.fn or self.run, name=self.name) - schema.pop("title", None) - return schema - - def as_function(self, description: Optional[str] = None) -> Function: - if not description: - description = jinja_env.from_string( - inspect.cleandoc(self.description or "") - ) - description = description.render(**self.dict(), TOOL=self) - - def fn(*args, **kwargs): - return self.run(*args, **kwargs) - - fn.__name__ = self.__class__.__name__ - fn.__doc__ = self.run.__doc__ - - schema = self.argument_schema() - - return Function( - name=fn.__name__, - description=description, - parameters=schema, - fn=fn, - signature=inspect.signature(self.run), - ) - - -def tool(arg=None, *, name: str = None, description: str = None): - if callable(arg): # Direct function decoration - return Tool.from_function(arg, name=name, description=description) - elif arg is None: # Partial application - return partial(tool, name=name, description=description) - else: - raise TypeError("Invalid argument passed to decorator.") diff --git a/src/marvin/tools/chroma.py b/src/marvin/tools/chroma.py deleted file mode 100644 index 7e8f60b38..000000000 --- a/src/marvin/tools/chroma.py +++ /dev/null @@ -1,129 +0,0 @@ -import asyncio -import json -from typing import Optional - -import httpx -from typing_extensions import Literal - -import marvin -from marvin.tools import Tool -from marvin.utilities.embeddings import create_openai_embeddings - -QueryResultType = Literal["documents", "distances", "metadatas"] - - -async def list_collections() -> list[dict]: - async with httpx.AsyncClient() as client: - chroma_api_url = f"http://{marvin.settings.chroma_server_host}:{marvin.settings.chroma_server_http_port}" - response = await client.get( - f"{chroma_api_url}/api/v1/collections", - ) - - response.raise_for_status() - return response.json() - - -async def query_chroma( - query: str, - collection: str = "marvin", - n_results: int = 5, - where: Optional[dict] = None, - where_document: Optional[dict] = None, - include: Optional[list[QueryResultType]] = None, - max_characters: int = 2000, -) -> str: - query_embedding = (await create_openai_embeddings([query]))[0] - - collection_ids = [ - c["id"] for c in await list_collections() if c["name"] == collection - ] - - if len(collection_ids) == 0: - return f"Collection {collection} not found." - - collection_id = collection_ids[0] - - async with httpx.AsyncClient() as client: - chroma_api_url = f"http://{marvin.settings.chroma_server_host}:{marvin.settings.chroma_server_http_port}" - - response = await client.post( - f"{chroma_api_url}/api/v1/collections/{collection_id}/query", - data=json.dumps( - { - "query_embeddings": [query_embedding], - "n_results": n_results, - "where": where or {}, - "where_document": where_document or {}, - "include": include or ["documents"], - } - ), - headers={"Content-Type": "application/json"}, - ) - - response.raise_for_status() - - return "\n".join( - [ - f"{i+1}. {', '.join(excerpt)}" - for i, excerpt in enumerate(response.json()["documents"]) - ] - )[:max_characters] - - -class QueryChroma(Tool): - """Tool for querying a Chroma index.""" - - description: str = """ - Retrieve document excerpts from a knowledge-base given a query. - """ - - async def run( - self, - query: str, - collection: str = "marvin", - n_results: int = 5, - where: Optional[dict] = None, - where_document: Optional[dict] = None, - include: Optional[list[QueryResultType]] = None, - max_characters: int = 2000, - ) -> str: - return await query_chroma( - query, collection, n_results, where, where_document, include, max_characters - ) - - -class MultiQueryChroma(Tool): - """Tool for querying a Chroma index.""" - - description: str = """ - Retrieve document excerpts from a knowledge-base given a query. - """ - - async def run( - self, - queries: list[str], - collection: str = "marvin", - n_results: int = 5, - where: Optional[dict] = None, - where_document: Optional[dict] = None, - include: Optional[list[QueryResultType]] = None, - max_characters: int = 2000, - max_queries: int = 5, - ) -> str: - if len(queries) > max_queries: - # make sure excerpts are not too short - queries = queries[:max_queries] - - coros = [ - query_chroma( - query, - collection, - n_results, - where, - where_document, - include, - max_characters // len(queries), - ) - for query in queries - ] - return "\n\n".join(await asyncio.gather(*coros, return_exceptions=True)) diff --git a/src/marvin/tools/code.py b/src/marvin/tools/code.py new file mode 100644 index 000000000..89f2d7950 --- /dev/null +++ b/src/marvin/tools/code.py @@ -0,0 +1,26 @@ +# 🚨 WARNING 🚨 +# These functions allow ARBITRARY code execution and should be used with caution. + +import json +import subprocess + + +def shell(command: str) -> str: + """executes a shell command on your local machine and returns the output""" + + result = subprocess.run(command, shell=True, text=True, capture_output=True) + + # Output and error + output = result.stdout + error = result.stderr + + return json.dumps(dict(command_output=output, command_error=error)) + + +def python(code: str) -> str: + """ + Executes Python code on your local machine and returns the output. You can + use this to run code that isn't compatible with the code interpreter tool, + for example if it requires internet access or other packages. + """ + return str(eval(code)) diff --git a/src/marvin/tools/filesystem.py b/src/marvin/tools/filesystem.py index 44d7cfbf7..31386124f 100644 --- a/src/marvin/tools/filesystem.py +++ b/src/marvin/tools/filesystem.py @@ -1,186 +1,151 @@ -import json -from pathlib import Path -from typing import Literal - -from pydantic import BaseModel, Field, root_validator, validate_arguments, validator - -from marvin.tools import Tool - - -class FileSystemTool(Tool): - root_dir: Path = Field( - None, - description=( - "Root directory for files. If provided, only files nested in or below this" - " directory can be read. " - ), - ) - - def validate_paths(self, paths: list[str]) -> list[Path]: - """ - If `root_dir` is set, ensures that all paths are children of `root_dir`. - """ - if self.root_dir: - for path in paths: - if ".." in path: - raise ValueError(f"Do not use `..` in paths. Got {path}") - if not (self.root_dir / path).is_relative_to(self.root_dir): - raise ValueError(f"Path {path} is not relative to {self.root_dir}") - return [self.root_dir / path for path in paths] - return paths - - -class ListFiles(FileSystemTool): - description: str = """ - Lists all files at or optionally under a provided path. {%- if root_dir - %} Paths must be relative to {{ root_dir }}. Provide '.' instead of '/' - to read root. {%- endif %}} - """ - - root_dir: Path = Field( - None, - description=( - "Root directory for files. If provided, only files nested in or below this" - " directory can be read." - ), - ) - - def run(self, path: str, include_nested: bool = True) -> list[str]: - """List all files in `root_dir`, optionally including nested files.""" - [path] = self.validate_paths([path]) - if include_nested: - files = [str(p) for p in path.rglob("*") if p.is_file()] - else: - files = [str(p) for p in path.glob("*") if p.is_file()] - - # filter out certain files - files = [ - file - for file in files - if not ( - "__pycache__" in file - or "/.git/" in file - or file.endswith("/.gitignore") - ) - ] +import os +import pathlib +import shutil - return files +def _safe_create_file(path: str) -> None: + path = os.path.expanduser(path) + file_path = pathlib.Path(path) + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.touch(exist_ok=True) -class ReadFile(FileSystemTool): - description: str = """ - Read the content of a specific file, optionally providing start and end - rows.{% if root_dir %} Paths must be relative to {{ root_dir }}. Provide '.' - instead of '/' to read root.{%- endif %}} - """ - def run(self, path: str, start_row: int = 1, end_row: int = -1) -> str: - [path] = self.validate_paths([path]) - with open(path, "r") as f: - content = f.readlines() +def getcwd() -> str: + """Returns the current working directory""" + return os.getcwd() - if start_row == 0: - start_row = 1 - if start_row > 0: - start_row -= 1 - if end_row < 0: - end_row += 1 - if end_row == 0: - content = content[start_row:] - else: - content = content[start_row:end_row] +def write(path: str, contents: str) -> str: + """Creates or overwrites a file with the given contents""" + path = os.path.expanduser(path) + _safe_create_file(path) + with open(path, "w") as f: + f.write(contents) + return f'Successfully wrote "{path}"' - return "\n".join(content) +def write_lines( + path: str, contents: str, insert_line: int = -1, mode: str = "insert" +) -> str: + """Writes content to a specific line in the file. -class ReadFiles(FileSystemTool): - description: str = """ - Read the entire content of multiple files at once. Due to context size - limitations, reading too many files at once may cause truncated responses. - {% if root_dir %} Paths must be relative to {{ root_dir }}. Provide '.' - instead of '/' to read root.{%- endif %}} + Args: + path (str): The name of the file to write to. + contents (str): The content to write to the file. + insert_line (int, optional): The line number to insert the content at. + Negative values count from the end of the file. Defaults to -1. + mode (str, optional): The mode to use when writing the content. Can be + "insert" or "overwrite". Defaults to "insert". + + Returns: + str: A message indicating whether the write was successful. """ + path = os.path.expanduser(path) + _safe_create_file(path) + with open(path, "r") as f: + lines = f.readlines() + if insert_line < 0: + insert_line = len(lines) + insert_line + 1 + if mode == "insert": + lines[insert_line:insert_line] = contents.splitlines(True) + elif mode == "overwrite": + lines[insert_line : insert_line + len(contents.splitlines())] = ( + contents.splitlines(True) + ) + else: + raise ValueError(f"Invalid mode: {mode}") + with open(path, "w") as f: + f.writelines(lines) + return f'Successfully wrote to "{path}"' - def run(self, paths: list[str]) -> dict[str, str]: - """Load content of each file into a dictionary of path: content.""" - content = {} - for path in self.validate_paths(paths): - with open(path) as f: - content[path] = f.read() - return content - - -class WriteContent(BaseModel): - path: str - content: str - write_mode: Literal["overwrite", "append", "insert"] = "append" - insert_at_row: int = None - - @validator("content", pre=True) - def content_must_be_string(cls, v): - if v and not isinstance(v, str): - try: - v = json.dumps(v) - except json.JSONDecodeError: - raise ValueError("Content must be a string or JSON-serializable.") - return v - - @root_validator - def check_insert_model(cls, values): - if values["insert_at_row"] is None and values["write_mode"] == "insert": - raise ValueError("Must provide `insert_at_row` when using `insert` mode.") - return values - - -class WriteFile(FileSystemTool): - description: str = """ - Write content to a file. - - {%if root_dir %} Paths must be relative to {{ root_dir }}.{% endif %}} - - {%if require_confirmation %} You MUST ask the user to confirm writes by - showing them details. {% endif %} - """ - require_confirmation: bool = True - - def run(self, write_content: WriteContent) -> str: - [path] = self.validate_paths([write_content.path]) - - # ensure the parent directory exists - path.parent.mkdir(parents=True, exist_ok=True) - - if write_content.write_mode == "overwrite": - with open(path, "w") as f: - f.write(write_content.content) - elif write_content.write_mode == "append": - with open(path, "a") as f: - f.write(write_content.content) - elif write_content.write_mode == "insert": - with open(path, "r") as f: - contents = f.readlines() - contents[write_content.insert_at_row] = write_content.content - - with open(path, "w") as f: - f.writelines(contents) - - return f"Files {write_content.path} written successfully." - - -class WriteFiles(WriteFile): - description: str = """ - Write content to multiple files. Each `WriteContent` object in the - `contents` argument is an instruction to write to a specific file. - - {%if root_dir %} Paths must be relative to {{ root_dir }}.{% endif %}} - - {%if require_confirmation %} You MUST ask the user to confirm writes by - showing them details. {% endif %} - """ - require_confirmation: bool = True - - @validate_arguments - def run(self, contents: list[WriteContent]) -> str: - for wc in contents: - super().run(write_content=wc) - return f"Files {[c.path for c in contents]} written successfully." + +def read(path: str, include_line_numbers: bool = False) -> str: + """Reads a file and returns the contents. + + Args: + path (str): The path to the file. + include_line_numbers (bool, optional): Whether to include line numbers + in the returned contents. Defaults to False. + + Returns: + str: The contents of the file. + """ + path = os.path.expanduser(path) + with open(path, "r") as f: + if include_line_numbers: + lines = f.readlines() + lines_with_numbers = [f"{i+1}: {line}" for i, line in enumerate(lines)] + return "".join(lines_with_numbers) + else: + return f.read() + + +def read_lines( + path: str, + start_line: int = 0, + end_line: int = -1, + include_line_numbers: bool = False, +) -> str: + """Reads a partial file and returns the contents with optional line numbers. + + Args: + path (str): The path to the file. + start_line (int, optional): The starting line number to read. Defaults + to 0. + end_line (int, optional): The ending line number to read. Defaults to + -1, which means read until the end of the file. + include_line_numbers (bool, optional): Whether to include line numbers + in the returned contents. Defaults to False. + + Returns: + str: The contents of the file. + """ + path = os.path.expanduser(path) + with open(path, "r") as f: + lines = f.readlines() + if start_line < 0: + start_line = len(lines) + start_line + if end_line < 0: + end_line = len(lines) + end_line + if include_line_numbers: + lines_with_numbers = [ + f"{i+1}: {line}" for i, line in enumerate(lines[start_line:end_line]) + ] + return "".join(lines_with_numbers) + else: + return "".join(lines[start_line:end_line]) + + +def mkdir(path: str) -> str: + """Creates a directory (and any parent directories))""" + path = os.path.expanduser(path) + path = pathlib.Path(path) + path.mkdir(parents=True, exist_ok=True) + return f'Successfully created directory "{path}"' + + +def mv(src: str, dest: str) -> str: + """Moves a file or directory""" + src = os.path.expanduser(src) + dest = os.path.expanduser(dest) + src = pathlib.Path(src) + dest = pathlib.Path(dest) + src.rename(dest) + return f'Successfully moved "{src}" to "{dest}"' + + +def cp(src: str, dest: str) -> str: + """Copies a file or directory""" + src = os.path.expanduser(src) + dest = os.path.expanduser(dest) + src = pathlib.Path(src) + dest = pathlib.Path(dest) + shutil.copytree(src, dest) + return f'Successfully copied "{src}" to "{dest}"' + + +def ls(path: str) -> str: + """Lists the contents of a directory""" + path = os.path.expanduser(path) + path = pathlib.Path(path) + return "\n".join(str(p) for p in path.iterdir()) diff --git a/src/marvin/tools/format_response.py b/src/marvin/tools/format_response.py deleted file mode 100644 index 41fe2acb7..000000000 --- a/src/marvin/tools/format_response.py +++ /dev/null @@ -1,80 +0,0 @@ -import warnings -from types import GenericAlias -from typing import Any, Union - -import pydantic -from pydantic import BaseModel, Field, PrivateAttr - -import marvin -import marvin.utilities.types -from marvin.tools import Tool -from marvin.utilities.types import ( - genericalias_contains, - safe_issubclass, -) - -SENTINEL = "__SENTINEL__" - - -class FormatResponse(Tool): - _cached_type: Union[type, GenericAlias] = PrivateAttr(SENTINEL) - type_schema: dict[str, Any] = Field( - ..., description="The OpenAPI schema for the type" - ) - description: str = ( - "You MUST always call this function before responding to the user to ensure" - " that your final response is formatted correctly and complies with the output" - " format requirements." - ) - - def __init__(self, type_: Union[type, GenericAlias] = SENTINEL, **kwargs): - if type_ is not SENTINEL: - if not isinstance(type_, (type, GenericAlias)): - raise ValueError(f"Expected a type or GenericAlias, got {type_}") - - # warn if the type is a set or tuple with GPT 3.5 - if ( - "gpt-3.5" in marvin.settings.llm_model - or "gpt-35" in marvin.settings.llm_model - ): - if safe_issubclass(type_, (set, tuple)) or genericalias_contains( - type_, (set, tuple) - ): - warnings.warn( - ( - "GPT-3.5 often fails with `set` or `tuple` types. Consider" - " using `list` instead." - ), - UserWarning, - ) - - type_schema = marvin.utilities.types.type_to_schema( - type_, set_root_type=False - ) - type_schema.pop("title", None) - kwargs["type_schema"] = type_schema - - super().__init__(**kwargs) - if type_ is not SENTINEL: - if type_schema.get("description"): - self.description += f"\n\n {type_schema['description']}" - - if type_ is not SENTINEL: - self._cached_type = type_ - - def get_type(self) -> Union[type, GenericAlias]: - if self._cached_type is not SENTINEL: - return self._cached_type - model = marvin.utilities.types.schema_to_type(self.type_schema) - type_ = model.__fields__["__root__"].outer_type_ - self._cached_type = type_ - return type_ - - def run(self, **kwargs) -> Any: - type_ = self.get_type() - if not safe_issubclass(type_, BaseModel): - kwargs = kwargs["data"] - return pydantic.parse_obj_as(type_, kwargs) - - def argument_schema(self) -> dict: - return self.type_schema diff --git a/src/marvin/tools/github.py b/src/marvin/tools/github.py index 0969c826d..0a3e110b0 100644 --- a/src/marvin/tools/github.py +++ b/src/marvin/tools/github.py @@ -1,14 +1,40 @@ +import os from datetime import datetime from typing import List, Optional import httpx -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, field_validator import marvin -from marvin.tools import Tool +from marvin.utilities.logging import get_logger from marvin.utilities.strings import slice_tokens +async def get_token() -> str: + try: + from prefect.blocks.system import Secret + + return (await Secret.load("github-token")).get() + except (ImportError, ValueError) as exc: + get_logger("marvin").debug_kv( + ( + "Prefect Secret for GitHub token not retrieved. " + f"{exc.__class__.__name__}: {exc}" + "red" + ), + ) + + try: + return marvin.settings.github_token + except AttributeError: + pass + + if token := os.environ.get("MARVIN_GITHUB_TOKEN", ""): + return token + + raise RuntimeError("GitHub token not found") + + class GitHubUser(BaseModel): """GitHub user.""" @@ -39,59 +65,55 @@ class GitHubIssue(BaseModel): labels: List[GitHubLabel] = Field(default_factory=GitHubLabel) user: GitHubUser = Field(default_factory=GitHubUser) - @validator("body", always=True) + @field_validator("body") def validate_body(cls, v): if not v: return "" return v -class SearchGitHubIssues(Tool): - """Tool for searching GitHub issues.""" - - description: str = "Use the GitHub API to search for issues in a given repository." - - async def run(self, query: str, repo: str = "prefecthq/prefect", n: int = 3) -> str: - """ - Use the GitHub API to search for issues in a given repository. Do - not alter the default value for `n` unless specifically requested by - a user. - - For example, to search for open issues about AttributeErrors with the - label "bug" in PrefectHQ/prefect: - - repo: prefecthq/prefect - - query: label:bug is:open AttributeError - """ - headers = {"Accept": "application/vnd.github.v3+json"} - - if token := marvin.settings.github_token: - headers["Authorization"] = f"Bearer {token.get_secret_value()}" - - async with httpx.AsyncClient() as client: - response = await client.get( - "https://api.github.com/search/issues", - headers=headers, - params={ - "q": query if "repo:" in query else f"repo:{repo} {query}", - "order": "desc", - "per_page": n, - }, - ) - response.raise_for_status() +async def search_github_issues( + query: str, repo: str = "prefecthq/prefect", n: int = 3 +) -> str: + """ + Use the GitHub API to search for issues in a given repository. Do + not alter the default value for `n` unless specifically requested by + a user. + + For example, to search for open issues about AttributeErrors with the + label "bug" in PrefectHQ/prefect: + - repo: prefecthq/prefect + - query: label:bug is:open AttributeError + """ + headers = {"Accept": "application/vnd.github.v3+json"} + + headers["Authorization"] = f"Bearer {await get_token()}" + + async with httpx.AsyncClient() as client: + response = await client.get( + "https://api.github.com/search/issues", + headers=headers, + params={ + "q": query if "repo:" in query else f"repo:{repo} {query}", + "order": "desc", + "per_page": n, + }, + ) + response.raise_for_status() - issues_data = response.json()["items"] + issues_data = response.json()["items"] - # enforce 1000 token limit per body - for issue in issues_data: - if not issue["body"]: - continue - issue["body"] = slice_tokens(issue["body"], 1000) + # enforce 1000 token limit per body + for issue in issues_data: + if not issue["body"]: + continue + issue["body"] = slice_tokens(issue["body"], 1000) - issues = [GitHubIssue(**issue) for issue in issues_data] + issues = [GitHubIssue(**issue) for issue in issues_data] - summary = "\n\n".join( - f"{issue.title} ({issue.html_url}):\n{issue.body}" for issue in issues - ) - if not summary.strip(): - raise ValueError("No issues found.") - return summary + summary = "\n\n".join( + f"{issue.title} ({issue.html_url}):\n{issue.body}" for issue in issues + ) + if not summary.strip(): + raise ValueError("No issues found.") + return summary diff --git a/src/marvin/tools/mathematics.py b/src/marvin/tools/mathematics.py deleted file mode 100644 index 1f28fde22..000000000 --- a/src/marvin/tools/mathematics.py +++ /dev/null @@ -1,51 +0,0 @@ -import httpx -from typing_extensions import Literal - -import marvin -from marvin.tools import Tool - -ResultType = Literal["DecimalApproximation"] - - -class WolframCalculator(Tool): - """Evaluate mathematical expressions using Wolfram Alpha.""" - - description: str = """ - Evaluate mathematical expressions using Wolfram Alpha. - - Always append "to decimal" to your expression unless asked for something else. - """ - - async def run( - self, expression: str, result_type: ResultType = "DecimalApproximation" - ) -> str: - async with httpx.AsyncClient() as client: - response = await client.get( - "https://api.wolframalpha.com/v2/query", - params={ - "appid": marvin.settings.wolfram_app_id.get_secret_value(), - "input": expression, - "output": "json", - }, - ) - try: - response.raise_for_status() - except httpx.HTTPStatusError as e: - if e.response.status_code == 403: - raise ValueError( - "Invalid Wolfram Alpha App ID - get one at" - " https://developer.wolframalpha.com/portal/myapps/" - ) - raise - - data = response.json() - - pods = [ - pod - for pod in data.get("queryresult", {}).get("pods", []) - if pod.get("id") == result_type - ] - - if not pods: - return "No result found." - return pods[0].get("subpods", [{}])[0].get("plaintext", "No result found.") diff --git a/src/marvin/tools/python.py b/src/marvin/tools/python.py deleted file mode 100644 index 290be1103..000000000 --- a/src/marvin/tools/python.py +++ /dev/null @@ -1,31 +0,0 @@ -import sys -from io import StringIO - -from marvin.tools import Tool - - -def run_python(code: str, globals: dict = None, locals: dict = None): - old_stdout = sys.stdout - new_stdout = StringIO() - sys.stdout = new_stdout - try: - exec(code, globals or {}, locals or {}) - result = new_stdout.getvalue() - except Exception as e: - result = repr(e) - finally: - sys.stdout = old_stdout - return result - - -class Python(Tool): - description: str = """ - Runs arbitrary Python code. - - {%if require_confirmation %} You MUST ask the user to confirm execution by - showing them the code. {% endif %} - """ - require_confirmation: bool = True - - def run(self, code: str) -> str: - return run_python(code) diff --git a/src/marvin/tools/retrieval.py b/src/marvin/tools/retrieval.py new file mode 100644 index 000000000..a32ca6172 --- /dev/null +++ b/src/marvin/tools/retrieval.py @@ -0,0 +1,141 @@ +import asyncio +import json +import os +from typing import Optional + +import httpx +from typing_extensions import Literal + +import marvin + +try: + HOST, PORT = ( + marvin.settings.chroma_server_host, + marvin.settings.chroma_server_http_port, + ) +except AttributeError: + HOST = os.environ.get("MARVIN_CHROMA_SERVER_HOST", "localhost") + PORT = os.environ.get("MARVIN_CHROMA_SERVER_HTTP_PORT", 8000) + +QueryResultType = Literal["documents", "distances", "metadatas"] + + +async def create_openai_embeddings(texts: list[str]) -> list[list[float]]: + """Create OpenAI embeddings for a list of texts.""" + + try: + import numpy # noqa F401 + except ImportError: + raise ImportError( + "The numpy package is required to create OpenAI embeddings. Please install" + " it with `pip install numpy`." + ) + from openai import AsyncOpenAI + + return ( + ( + await AsyncOpenAI( + api_key=marvin.settings.openai.api_key.get_secret_value() + ).embeddings.create( + input=[text.replace("\n", " ") for text in texts], + model="text-embedding-ada-002", + ) + ) + .data[0] + .embedding + ) + + +async def list_collections() -> list[dict]: + async with httpx.AsyncClient() as client: + chroma_api_url = f"http://{HOST}:{PORT}" + response = await client.get( + f"{chroma_api_url}/api/v1/collections", + ) + + response.raise_for_status() + return response.json() + + +async def query_chroma( + query: str, + collection: str = "marvin", + n_results: int = 5, + where: Optional[dict] = None, + where_document: Optional[dict] = None, + include: Optional[list[QueryResultType]] = None, + max_characters: int = 2000, +) -> str: + """Query Chroma. + + Example: + User: "What are prefect blocks?" + Assistant: >>> query_chroma("What are prefect blocks?") + """ + query_embedding = await create_openai_embeddings([query]) + + collection_ids = [ + c["id"] for c in await list_collections() if c["name"] == collection + ] + + if len(collection_ids) == 0: + return f"Collection {collection} not found." + + collection_id = collection_ids[0] + + async with httpx.AsyncClient() as client: + chroma_api_url = f"http://{HOST}:{PORT}" + + response = await client.post( + f"{chroma_api_url}/api/v1/collections/{collection_id}/query", + data=json.dumps( + { + "query_embeddings": [query_embedding], + "n_results": n_results, + "where": where or {}, + "where_document": where_document or {}, + "include": include or ["documents"], + } + ), + headers={"Content-Type": "application/json"}, + ) + + response.raise_for_status() + + return "\n".join( + f"{i+1}. {', '.join(excerpt)}" + for i, excerpt in enumerate(response.json()["documents"]) + )[:max_characters] + + +async def multi_query_chroma( + queries: list[str], + collection: str = "marvin", + n_results: int = 5, + where: Optional[dict] = None, + where_document: Optional[dict] = None, + include: Optional[list[QueryResultType]] = None, + max_characters: int = 2000, +) -> str: + """Query Chroma with multiple queries. + + Example: + User: "What are prefect blocks and tasks?" + Assistant: >>> multi_query_chroma( + ["What are prefect blocks?", "What are prefect tasks?"] + ) + """ + + coros = [ + query_chroma( + query, + collection, + n_results, + where, + where_document, + include, + max_characters // len(queries), + ) + for query in queries + ] + return "\n".join(await asyncio.gather(*coros))[:max_characters] diff --git a/src/marvin/tools/shell.py b/src/marvin/tools/shell.py deleted file mode 100644 index 9988299ca..000000000 --- a/src/marvin/tools/shell.py +++ /dev/null @@ -1,52 +0,0 @@ -import subprocess -from pathlib import Path - -from marvin.tools import Tool - - -class Shell(Tool): - description: str = """ - Runs arbitrary shell code. - - {% if working_directory %} The working directory will be {{ - working_directory }}. {% endif %}. - - {%if require_confirmation %} You MUST ask the user to confirm execution by - showing them the code. {% endif %} - """ - require_confirmation: bool = True - working_directory: Path = None - - def run(self, cmd: str, working_directory: str = None) -> str: - if working_directory and self.working_directory: - raise ValueError( - f"The working directoy is {self.working_directory}; do not provide" - " another one." - ) - - wd = working_directory or ( - str(self.working_directory) if self.working_directory else None - ) - - try: - result = subprocess.run( - cmd, - shell=True, - cwd=wd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - - if result.returncode != 0: - return ( - f"Command failed with code {result.returncode} and error:" - f" {result.stderr.decode() or ''}" - ) - else: - return ( - "Command succeeded with output:" - f" { result.stdout.decode() or '' }" - ) - - except Exception as e: - return f"Execution failed: {str(e)}" diff --git a/src/marvin/tools/web.py b/src/marvin/tools/web.py deleted file mode 100644 index 2e65ab72f..000000000 --- a/src/marvin/tools/web.py +++ /dev/null @@ -1,98 +0,0 @@ -import json -from itertools import islice -from typing import Dict - -import httpx -from typing_extensions import Literal - -from marvin._compat import BaseSettings, Field, SecretStr -from marvin.tools import Tool -from marvin.utilities.strings import html_to_content, slice_tokens - - -class SerpApiSettings(BaseSettings): - api_key: SecretStr = Field(None, env="MARVIN_SERPAPI_API_KEY") - - -class VisitUrl(Tool): - """Tool for visiting a URL - only to be used in special cases.""" - - description: str = "Visit a valid URL and return its contents." - - async def run(self, url: str) -> str: - if not url.startswith("http"): - url = f"http://{url}" - async with httpx.AsyncClient(follow_redirects=True, timeout=2) as client: - try: - response = await client.get(url) - except httpx.ConnectTimeout: - return "Failed to load URL: Connection timed out" - if response.status_code == 200: - text = response.text - - # try to parse as JSON in case the URL is an API - try: - content = str(json.loads(text)) - # otherwise parse as HTML - except json.JSONDecodeError: - content = html_to_content(text) - return slice_tokens(content, 1000) - else: - return f"Failed to load URL: {response.status_code}" - - -class DuckDuckGoSearch(Tool): - """Tool for searching the web with DuckDuckGo.""" - - description: str = "Search the web with DuckDuckGo." - backend: Literal["api", "html", "lite"] = "lite" - - async def run(self, query: str, n_results: int = 3) -> str: - try: - from duckduckgo_search import DDGS - except ImportError: - raise RuntimeError( - "You must install the duckduckgo-search library to use this tool. " - "You can do so by running `pip install 'marvin[ddg]'`." - ) - - with DDGS() as ddgs: - return [ - r for r in islice(ddgs.text(query, backend=self.backend), n_results) - ] - - -class GoogleSearch(Tool): - description: str = """ - For performing a Google search and retrieving the results. - - Provide the search query to get answers. - """ - - async def run(self, query: str, n_results: int = 3) -> Dict: - try: - from serpapi import GoogleSearch as google_search - except ImportError: - raise RuntimeError( - "You must install the serpapi library to use this tool. " - "You can do so by running `pip install 'marvin[serpapi]'`." - ) - - if (api_key := SerpApiSettings().api_key) is None: - raise RuntimeError( - "You must provide a SerpApi API key to use this tool. You can do so by" - " setting the MARVIN_SERPAPI_API_KEY environment variable." - ) - - search_params = { - "q": query, - "api_key": api_key.get_secret_value(), - } - results = google_search(search_params).get_dict() - - if "error" in results: - raise RuntimeError(results["error"]) - return [ - {"title": r.get("title"), "href": r.get("link"), "body": r.get("snippet")} - for r in results.get("organic_results", [])[:n_results] - ] diff --git a/src/marvin/types/__init__.py b/src/marvin/types/__init__.py deleted file mode 100644 index 578e8a155..000000000 --- a/src/marvin/types/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .function import Function diff --git a/src/marvin/types/function.py b/src/marvin/types/function.py deleted file mode 100644 index a9da6cd25..000000000 --- a/src/marvin/types/function.py +++ /dev/null @@ -1,137 +0,0 @@ -import copy -import inspect -import re -from typing import Callable, Optional, Type - -from marvin._compat import BaseModel, validate_arguments - -extraneous_fields = [ - "args", - "kwargs", - "ALT_V_ARGS", - "ALT_V_KWARGS", - "V_POSITIONAL_ONLY_NAME", - "V_DUPLICATE_KWARGS", -] - - -def get_openai_function_schema(schema, model): - # Make a copy of the schema. - _schema = copy.deepcopy(schema) - - # Prune the schema of 'titles'. - _schema.pop("title", None) - for key, value in schema["properties"].items(): - if key in extraneous_fields: - _schema["properties"].pop(key, None) - else: - _schema["properties"][key].pop("title", None) - - # Clear the existing schema. - schema.clear() - - # Reconstruct the schema. - schema["name"] = getattr(model, "name", model.Config.fn.__name__) - schema["description"] = getattr(model, "description", model.Config.fn.__doc__) - schema["parameters"] = { - k: v for (k, v) in _schema.items() if k not in extraneous_fields - } - - -class FunctionConfig(BaseModel): - fn: Callable - name: str - description: str = "" - schema_extra: Optional[Callable] = get_openai_function_schema - - def __init__(self, fn, **kwargs): - kwargs.setdefault("name", fn.__name__ or "") - kwargs.setdefault("description", fn.__doc__ or "") - super().__init__(fn=fn, **kwargs) - - def getsource(self): - try: - return re.search("def.*", inspect.getsource(self.fn), re.DOTALL).group() - except Exception: - return None - - def bind_arguments(self, *args, **kwargs): - bound_arguments = inspect.signature(self.fn).bind(*args, **kwargs) - bound_arguments.apply_defaults() - return bound_arguments.arguments - - def response_model(self, *args, **kwargs): - def format_response(data: inspect.signature(self.fn).return_annotation): - """Function to format the final response to the user""" - return None - - format_response.__name__ = kwargs.get("name", format_response.__name__) - format_response.__doc__ = kwargs.get("description", format_response.__doc__) - response_model = Function(format_response).model - response_model.__signature__ = inspect.signature(format_response) - return response_model - - -class Function: - """ - A wrapper class to add additional functionality to a function, - such as a schema, response model, and more. - """ - - def __new__(cls, fn: Callable, parameters: dict = None, signature=None, **kwargs): - config = FunctionConfig(fn, **kwargs) - instance = super().__new__(cls) - instance.instance = validate_arguments(fn, config=config.dict()) - instance.schema = parameters or instance.instance.model.schema - instance.response_model = config.response_model - instance.bind_arguments = config.bind_arguments - instance.getsource = config.getsource - instance.__name__ = config.name or fn.__name__ - instance.__doc__ = config.description or fn.__doc__ - instance.signature = signature or inspect.signature(fn) - return instance - - def __call__(self, *args, **kwargs): - return self.evaluate_raw(*args, **kwargs) - - def evaluate_raw(self, *args, **kwargs): - return self.instance(*args, **kwargs) - - @classmethod - def from_model(cls, model: Type[BaseModel], **kwargs): - model.__signature__ = inspect.Signature( - list(model.__signature__.parameters.values()), return_annotation=model - ) - - return cls( - model, name="format_response", description="Format the response", **kwargs - ) - - @classmethod - def from_return_annotation( - cls, fn: Callable, *args, name: str = None, description: str = None - ): - def format_final_response(data: inspect.signature(fn).return_annotation): - """Function to format the final response to the user""" - return None - - format_final_response.__name__ = name or format_final_response.__name__ - format_final_response.__doc__ = description or format_final_response.__doc__ - return cls(format_final_response) - - def __repr__(self): - parameters = [] - for _, param in self.signature.parameters.items(): - param_repr = str(param) - if param.annotation is not param.empty: - param_repr = param_repr.replace( - str(param.annotation), - ( - param.annotation.__name__ - if isinstance(param.annotation, type) - else str(param.annotation) - ), - ) - parameters.append(param_repr) - param_str = ", ".join(parameters) - return f"marvin.functions.{self.__name__}({param_str})" diff --git a/src/marvin/types/request.py b/src/marvin/types/request.py deleted file mode 100644 index 754eabaf3..000000000 --- a/src/marvin/types/request.py +++ /dev/null @@ -1,117 +0,0 @@ -import warnings -from typing import Callable, List, Literal, Optional, Type, Union - -from pydantic import BaseModel, BaseSettings, Extra, Field, root_validator, validator - -from marvin.types import Function - - -class Request(BaseSettings): - """ - This is a class for creating Request objects to interact with the GPT-3 API. - The class contains several configurations and validation functions to ensure - the correct data is sent to the API. - - """ - - messages: Optional[List[dict[str, str]]] = [] # messages to send to the API - functions: List[Union[dict, Callable]] = None # functions to be used in the request - function_call: Optional[Union[dict[Literal["name"], str], Literal["auto"]]] = None - - # Internal Marvin Attributes to be excluded from the data sent to the API - response_model: Optional[Type[BaseModel]] = Field(default=None) - evaluate_function_call: bool = Field(default=False) - - class Config: - exclude = {"response_model"} - exclude_none = True - extra = Extra.allow - - @root_validator(pre=True) - def handle_response_model(cls, values): - """ - This function validates and handles the response_model attribute. - If a response_model is provided, it creates a function from the model - and sets it as the function to call. - """ - response_model = values.get("response_model") - if response_model: - fn = Function.from_model(response_model) - values["functions"] = [fn] - values["function_call"] = {"name": fn.__name__} - return values - - @validator("functions", each_item=True) - def validate_function(cls, fn): - """ - This function validates the functions attribute. - If a Callable is provided, it wraps it with the Function class. - """ - if isinstance(fn, Callable): - fn = Function(fn) - return fn - - def __or__(self, config): - """ - This method is used to merge two Request objects. - If the attribute is a list, the lists are concatenated. - Otherwise, the attribute from the provided config is used. - """ - - touched = config.dict(exclude_unset=True, serialize_functions=False) - - fields = list( - set( - [ - # We exclude none fields from defaults. - *self.dict(exclude_none=True, serialize_functions=False).keys(), - # We exclude unset fields from the provided config. - *config.dict(exclude_unset=True, serialize_functions=False).keys(), - ] - ) - ) - - for field in fields: - if isinstance(getattr(self, field, None), list): - merged = (getattr(self, field, []) or []) + ( - getattr(config, field, []) or [] - ) - setattr(self, field, merged) - else: - setattr(self, field, touched.get(field, getattr(self, field, None))) - return self - - def merge(self, **kwargs): - warnings.warn( - "This is deprecated. Use the | operator instead.", DeprecationWarning - ) - return self | self.__class__(**kwargs) - - def functions_schema(self, *args, **kwargs): - """ - This method generates a list of schemas for all functions in the request. - If a function is callable, its model's schema is returned. - Otherwise, the function itself is returned. - """ - return [ - fn.model.schema() if isinstance(fn, Callable) else fn - for fn in self.functions or [] - ] - - def _dict(self, *args, serialize_functions=True, exclude=None, **kwargs): - """ - This method returns a dictionary representation of the Request. - If the functions attribute is present and serialize_functions is True, - the functions' schemas are also included. - """ - exclude = exclude or {} - if serialize_functions: - exclude["evaluate_function_call"] = True - exclude["response_model"] = True - response = super().dict(*args, **kwargs, exclude=exclude) - if response.get("functions") and serialize_functions: - response.update({"functions": self.functions_schema()}) - return response - - def dict(self, *args, **kwargs): - return self._dict(*args, **kwargs) diff --git a/src/marvin/utilities/async_utils.py b/src/marvin/utilities/async_utils.py deleted file mode 100644 index 3bfde4271..000000000 --- a/src/marvin/utilities/async_utils.py +++ /dev/null @@ -1,69 +0,0 @@ -import asyncio -import functools -from concurrent.futures import ThreadPoolExecutor -from typing import Awaitable, TypeVar - -T = TypeVar("T") - -BACKGROUND_TASKS = set() - - -def create_task(coro): - """ - Creates async background tasks in a way that is safe from garbage - collection. - - See - https://textual.textualize.io/blog/2023/02/11/the-heisenbug-lurking-in-your-async-code/ - - Example: - - async def my_coro(x: int) -> int: - return x + 1 - - # safely submits my_coro for background execution - create_task(my_coro(1)) - """ # noqa: E501 - task = asyncio.create_task(coro) - BACKGROUND_TASKS.add(task) - task.add_done_callback(BACKGROUND_TASKS.discard) - return task - - -async def run_async(func, *args, **kwargs) -> T: - """ - Runs a synchronous function in an asynchronous manner. - """ - - async def wrapper() -> T: - try: - return await loop.run_in_executor( - None, functools.partial(func, *args, **kwargs) - ) - except Exception as e: - # propagate the exception to the caller - raise e - - loop = asyncio.get_event_loop() - return await wrapper() - - -def run_sync(coroutine: Awaitable[T]) -> T: - """ - Runs a coroutine from a synchronous context, either in the current event - loop or in a new one if there is no event loop running. The coroutine will - block until it is done. A thread will be spawned to run the event loop if - necessary, which allows coroutines to run in environments like Jupyter - notebooks where the event loop runs on the main thread. - - """ - try: - loop = asyncio.get_running_loop() - if loop.is_running(): - with ThreadPoolExecutor() as executor: - future = executor.submit(asyncio.run, coroutine) - return future.result() - else: - return asyncio.run(coroutine) - except RuntimeError: - return asyncio.run(coroutine) diff --git a/src/marvin/utilities/asyncio.py b/src/marvin/utilities/asyncio.py new file mode 100644 index 000000000..47d032d40 --- /dev/null +++ b/src/marvin/utilities/asyncio.py @@ -0,0 +1,121 @@ +import asyncio +import functools +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Callable, Coroutine, TypeVar, cast + +T = TypeVar("T") + + +async def run_async(fn: Callable[..., T], *args: Any, **kwargs: Any) -> T: + """ + Runs a synchronous function in an asynchronous manner. + + Args: + fn: The function to run. + *args: Positional arguments to pass to the function. + **kwargs: Keyword arguments to pass to the function. + + Returns: + The return value of the function. + """ + + async def wrapper() -> T: + try: + return await loop.run_in_executor( + None, functools.partial(fn, *args, **kwargs) + ) + except Exception as e: + # propagate the exception to the caller + raise e + + loop = asyncio.get_event_loop() + return await wrapper() + + +def run_sync(coroutine: Coroutine[Any, Any, T]) -> T: + """ + Runs a coroutine from a synchronous context, either in the current event + loop or in a new one if there is no event loop running. The coroutine will + block until it is done. A thread will be spawned to run the event loop if + necessary, which allows coroutines to run in environments like Jupyter + notebooks where the event loop runs on the main thread. + + """ + try: + loop = asyncio.get_running_loop() + if loop.is_running(): + with ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, coroutine) + return future.result() + else: + return asyncio.run(coroutine) + except RuntimeError: + return asyncio.run(coroutine) + + +class ExposeSyncMethodsMixin: + """ + A mixin class that can take functions decorated with `expose_sync_method` and + automatically create synchronous versions. + + + Example: + + class MyClass(ExposeSyncMethodsMixin): + + @expose_sync_method("my_method") + async def my_method_async(self): + return 42 + + my_instance = MyClass() + await my_instance.my_method_async() # returns 42 + my_instance.my_method() # returns 42 + """ + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + for method in list(cls.__dict__.values()): + if callable(method) and hasattr(method, "_sync_name"): + sync_method_name = method._sync_name + setattr(cls, sync_method_name, method._sync_wrapper) + + +def expose_sync_method(name: str) -> Callable[..., Any]: + """ + Decorator that automatically exposes synchronous versions of async methods. + Note it doesn't work with classmethods. + + Example: + + class MyClass(ExposeSyncMethodsMixin): + + @expose_sync_method("my_method") + async def my_method_async(self): + return 42 + + my_instance = MyClass() + await my_instance.my_method_async() # returns 42 + my_instance.my_method() # returns 42 + """ + + def decorator( + async_method: Callable[..., Coroutine[Any, Any, T]] + ) -> Callable[..., T]: + @functools.wraps(async_method) + def sync_wrapper(*args: Any, **kwargs: Any) -> T: + coro = async_method(*args, **kwargs) + return run_sync(coro) + + # Cast the sync_wrapper to the same type as the async_method to give the + # type checker the needed information. + casted_sync_wrapper = cast(Callable[..., T], sync_wrapper) + + # Attach attributes to the async wrapper + setattr(async_method, "_sync_wrapper", casted_sync_wrapper) + setattr(async_method, "_sync_name", name) + + # return the original async method; the sync wrapper will be added to + # the class by the init hook + return async_method + + return decorator diff --git a/src/marvin/utilities/collections.py b/src/marvin/utilities/collections.py deleted file mode 100644 index b41810aae..000000000 --- a/src/marvin/utilities/collections.py +++ /dev/null @@ -1,94 +0,0 @@ -import itertools -from pathlib import Path -from typing import Any, Callable, Iterable, Optional, TypeVar - -T = TypeVar("T") - - -def batched( - iterable: Iterable[T], size: int, size_fn: Callable[[Any], int] = None -) -> Iterable[T]: - """ - If size_fn is not provided, then the batch size will be determined by the - number of items in the batch. - - If size_fn is provided, then it will be used - to compute the batch size. Note that if a single item is larger than the - batch size, it will be returned as a batch of its own. - - Args: - iterable: The iterable to batch. - size: The size of each batch. - size_fn: A function that takes an item from the iterable and returns its size. - - Returns: - An iterable of batches. - - Example: - Batch a list of integers into batches of size 2: - ```python - batched([1, 2, 3, 4, 5], 2) - # [[1, 2], [3, 4], [5]] - ``` - """ - if size_fn is None: - it = iter(iterable) - while True: - batch = tuple(itertools.islice(it, size)) - if not batch: - break - yield batch - else: - batch = [] - batch_size = 0 - for item in iter(iterable): - batch.append(item) - batch_size += size_fn(item) - if batch_size > size: - yield batch - batch = [] - batch_size = 0 - if batch: - yield batch - - -def multi_glob( - directory: Optional[str] = None, - keep_globs: Optional[list[str]] = None, - drop_globs: Optional[list[str]] = None, -) -> list[Path]: - """Return a list of files in a directory that match the given globs. - - Args: - directory: The directory to search. Defaults to the current working directory. - keep_globs: A list of globs to keep. Defaults to ["**/*"]. - drop_globs: A list of globs to drop. Defaults to [".git/**/*"]. - - Returns: - A list of `Path` objects in the directory that match the given globs. - - Example: - Recursively find all Python files in the `src` directory: - ```python - all_python_files = multi_glob(directory="src", keep_globs=["**/*.py"]) - ``` - """ - keep_globs = keep_globs or ["**/*"] - drop_globs = drop_globs or [".git/**/*"] - - directory_path = Path(directory) if directory else Path.cwd() - - if not directory_path.is_dir(): - raise ValueError(f"'{directory}' is not a directory.") - - def files_from_globs(globs): - return { - file - for pattern in globs - for file in directory_path.glob(pattern) - if file.is_file() - } - - matching_files = files_from_globs(keep_globs) - files_from_globs(drop_globs) - - return [file.relative_to(directory_path) for file in matching_files] diff --git a/src/marvin/utilities/context.py b/src/marvin/utilities/context.py new file mode 100644 index 000000000..a3778fbbd --- /dev/null +++ b/src/marvin/utilities/context.py @@ -0,0 +1,26 @@ +import contextvars +from contextlib import contextmanager + + +class ScopedContext: + def __init__(self): + self._context_storage = contextvars.ContextVar( + "scoped_context_storage", default={} + ) + + def get(self, key, default=None): + return self._context_storage.get().get(key, default) + + def set(self, **kwargs): + ctx = self._context_storage.get() + updated_ctx = {**ctx, **kwargs} + self._context_storage.set(updated_ctx) + + @contextmanager + def __call__(self, **kwargs): + current_context = self._context_storage.get().copy() + self.set(**kwargs) + try: + yield + finally: + self._context_storage.set(current_context) diff --git a/src/marvin/utilities/embeddings.py b/src/marvin/utilities/embeddings.py deleted file mode 100644 index b88625eef..000000000 --- a/src/marvin/utilities/embeddings.py +++ /dev/null @@ -1,26 +0,0 @@ -from typing import List - -import openai - -import marvin - - -async def create_openai_embeddings(texts: List[str]) -> List[List[float]]: - """Create OpenAI embeddings for a list of texts.""" - - try: - import numpy # noqa F401 - except ImportError: - raise ImportError( - "The numpy package is required to create OpenAI embeddings. Please install" - " it with `pip install numpy` or `pip install 'marvin[slackbot]'`." - ) - - embeddings = await openai.Embedding.acreate( - input=[text.replace("\n", " ") for text in texts], - engine=marvin.settings.openai.embedding_engine, - ) - - return [ - r["embedding"] for r in sorted(embeddings["data"], key=lambda x: x["index"]) - ] diff --git a/src/marvin/utilities/history.py b/src/marvin/utilities/history.py deleted file mode 100644 index b35dd399e..000000000 --- a/src/marvin/utilities/history.py +++ /dev/null @@ -1,48 +0,0 @@ -import datetime -from typing import Optional - -from pydantic import Field - -from marvin._compat import BaseModel -from marvin.utilities.messages import Message, Role - - -class HistoryFilter(BaseModel): - role_in: list[Role] = Field(default_factory=list) - timestamp_ge: Optional[datetime.datetime] = None - timestamp_le: Optional[datetime.datetime] = None - - -class History(BaseModel, arbitrary_types_allowed=True): - messages: list[Message] = Field(default_factory=list) - max_messages: int = None - - def add_message(self, message: Message): - self.messages.append(message) - - if self.max_messages is not None: - self.messages = self.messages[-self.max_messages :] - - def get_messages( - self, n: int = None, skip: int = None, filter: HistoryFilter = None - ) -> list[Message]: - messages = self.messages.copy() - - if filter is not None: - if filter.timestamp_ge: - messages = [m for m in messages if m.timestamp >= filter.timestamp_ge] - if filter.timestamp_le: - messages = [m for m in messages if m.timestamp <= filter.timestamp_le] - if filter.role_in: - messages = [m for m in messages if m.role in filter.role_in] - - if skip: - messages = messages[:-skip] - - if n is not None: - messages = messages[-n:] - - return messages - - def clear(self): - self.messages.clear() diff --git a/src/marvin/utilities/jinja.py b/src/marvin/utilities/jinja.py new file mode 100644 index 000000000..0cb54052c --- /dev/null +++ b/src/marvin/utilities/jinja.py @@ -0,0 +1,109 @@ +import inspect +import re +from datetime import datetime +from typing import Any, ClassVar, Pattern, Union +from zoneinfo import ZoneInfo + +import pydantic +from jinja2 import Environment as JinjaEnvironment +from jinja2 import StrictUndefined, select_autoescape +from jinja2 import Template as BaseTemplate +from typing_extensions import Self + +from marvin.requests import BaseMessage as Message + + +class BaseEnvironment(pydantic.BaseModel): + model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) + + environment: JinjaEnvironment = pydantic.Field( + default=JinjaEnvironment( + autoescape=select_autoescape(default_for_string=False), + trim_blocks=True, + lstrip_blocks=True, + auto_reload=True, + undefined=StrictUndefined, + ) + ) + + globals: dict[str, Any] = pydantic.Field( + default_factory=lambda: { + "now": lambda: datetime.now(ZoneInfo("UTC")), + "inspect": inspect, + } + ) + + @pydantic.model_validator(mode="after") + def setup_globals(self: Self) -> Self: + self.environment.globals.update(self.globals) # type: ignore + return self + + def render(self, template: Union[str, BaseTemplate], **kwargs: Any) -> str: + if isinstance(template, str): + return self.environment.from_string(template).render(**kwargs) + return template.render(**kwargs) + + +Environment = BaseEnvironment() + + +def split_text_by_tokens( + text: str, + split_tokens: list[str], + environment: JinjaEnvironment = Environment.environment, +) -> list[tuple[str, str]]: + cleaned_text = inspect.cleandoc(text) + + # Find all positions of tokens in the text + positions = [ + (match.start(), match.end(), match.group().rstrip(":").strip()) + for token in split_tokens + for match in re.finditer(re.escape(token) + r"(?::\s*)?", cleaned_text) + ] + + # Sort positions by their start index + positions.sort(key=lambda x: x[0]) + + paired: list[tuple[str, str]] = [] + prev_end = 0 + prev_token = split_tokens[0] + for start, end, token in positions: + paired.append((prev_token, cleaned_text[prev_end:start].strip())) + prev_end = end + prev_token = token + + paired.append((prev_token, cleaned_text[prev_end:].strip())) + + # Remove pairs where the text is empty + paired = [(token.replace(":", ""), text) for token, text in paired if text] + + return paired + + +class Transcript(pydantic.BaseModel): + content: str + roles: list[str] = pydantic.Field(default=["system", "user"]) + environment: ClassVar[BaseEnvironment] = Environment + + @property + def role_regex(self) -> Pattern[str]: + return re.compile("|".join([f"\n\n{role}:" for role in self.roles])) + + def render(self: Self, **kwargs: Any) -> str: + return self.environment.render(self.content, **kwargs) + + def render_to_messages( + self: Self, + **kwargs: Any, + ) -> list[Message]: + pairs = split_text_by_tokens( + text=self.render(**kwargs), + split_tokens=[f"\n{role}" for role in self.roles], + ) + return [ + Message( + role=pair[0].strip(), + content=pair[1], + ) + for pair in pairs + ] diff --git a/src/marvin/utilities/logging.py b/src/marvin/utilities/logging.py index cf46df54f..fef473b5f 100644 --- a/src/marvin/utilities/logging.py +++ b/src/marvin/utilities/logging.py @@ -1,14 +1,17 @@ import logging from functools import lru_cache, partial +from typing import Optional -from rich.logging import RichHandler -from rich.markup import escape +from rich.logging import RichHandler # type: ignore +from rich.markup import escape # type: ignore import marvin @lru_cache() -def get_logger(name: str = None) -> logging.Logger: +def get_logger( + name: Optional[str] = None, +) -> logging.Logger: parent_logger = logging.getLogger("marvin") if name: @@ -25,7 +28,9 @@ def get_logger(name: str = None) -> logging.Logger: return logger -def setup_logging(level: str = None): +def setup_logging( + level: Optional[str] = None, +) -> None: logger = get_logger() if level is not None: @@ -47,8 +52,8 @@ def setup_logging(level: str = None): logger.propagate = False -def add_logging_methods(logger): - def log_style(level: int, message: str, style: str = None): +def add_logging_methods(logger: logging.Logger) -> None: + def log_style(level: int, message: str, style: Optional[str] = None): if not style: style = "default on default" message = f"[{style}]{escape(str(message))}[/]" @@ -68,17 +73,17 @@ def log_kv( extra={"markup": True}, ) - logger.debug_style = partial(log_style, logging.DEBUG) - logger.info_style = partial(log_style, logging.INFO) - logger.warning_style = partial(log_style, logging.WARNING) - logger.error_style = partial(log_style, logging.ERROR) - logger.critical_style = partial(log_style, logging.CRITICAL) - - logger.debug_kv = partial(log_kv, logging.DEBUG) - logger.info_kv = partial(log_kv, logging.INFO) - logger.warning_kv = partial(log_kv, logging.WARNING) - logger.error_kv = partial(log_kv, logging.ERROR) - logger.critical_kv = partial(log_kv, logging.CRITICAL) + setattr(logger, "debug_style", partial(log_style, logging.DEBUG)) + setattr(logger, "info_style", partial(log_style, logging.INFO)) + setattr(logger, "warning_style", partial(log_style, logging.WARNING)) + setattr(logger, "error_style", partial(log_style, logging.ERROR)) + setattr(logger, "critical_style", partial(log_style, logging.CRITICAL)) + + setattr(logger, "debug_kv", partial(log_kv, logging.DEBUG)) + setattr(logger, "info_kv", partial(log_kv, logging.INFO)) + setattr(logger, "warning_kv", partial(log_kv, logging.WARNING)) + setattr(logger, "error_kv", partial(log_kv, logging.ERROR)) + setattr(logger, "critical_kv", partial(log_kv, logging.CRITICAL)) setup_logging(level=marvin.settings.log_level) diff --git a/src/marvin/utilities/messages.py b/src/marvin/utilities/messages.py deleted file mode 100644 index ef70824cc..000000000 --- a/src/marvin/utilities/messages.py +++ /dev/null @@ -1,82 +0,0 @@ -import inspect -import uuid -from datetime import datetime -from enum import Enum -from typing import Any, Optional -from zoneinfo import ZoneInfo - -from typing_extensions import Self - -from marvin._compat import BaseModel, Field, field_validator -from marvin.utilities.strings import split_text_by_tokens -from marvin.utilities.types import MarvinBaseModel - - -class Role(Enum): - SYSTEM = "system" - ASSISTANT = "assistant" - USER = "user" - FUNCTION_REQUEST = "function" - FUNCTION_RESPONSE = "function" - - @classmethod - def _missing_(cls: type[Self], value: object) -> Optional[Self]: - lower_value = str(value).lower() - matching_member = next( - (member for member in cls if member.value.lower() == lower_value), None - ) - return matching_member - - -class FunctionCall(BaseModel): - name: str - arguments: str - - -def utcnow(): - return datetime.now(ZoneInfo("UTC")) - - -class Message(MarvinBaseModel): - role: Role - content: Optional[str] = Field(default=None, description="The message content") - - name: Optional[str] = Field( - default=None, - description="The name of the message", - ) - - function_call: Optional[FunctionCall] = Field(default=None) - - # convenience fields, excluded from serialization - id: uuid.UUID = Field(default_factory=uuid.uuid4, exclude=True) - data: Optional[dict[str, Any]] = Field(default_factory=dict, exclude=True) - timestamp: datetime = Field( - default_factory=utcnow, - exclude=True, - ) - - @field_validator("content") - def clean_content(cls, v: Optional[str]) -> Optional[str]: - if v is not None: - v = inspect.cleandoc(v) - return v - - class Config(MarvinBaseModel.Config): - use_enum_values = True - - @classmethod - def from_transcript( - cls: type[Self], - text: str, - ) -> list[Self]: - pairs = split_text_by_tokens( - text=text, split_tokens=[role.value.capitalize() for role in Role] - ) - return [ - cls( - role=Role(pair[0]), - content=pair[1], - ) - for pair in pairs - ] diff --git a/src/marvin/utilities/module_loading.py b/src/marvin/utilities/module_loading.py deleted file mode 100644 index f9a2fad8f..000000000 --- a/src/marvin/utilities/module_loading.py +++ /dev/null @@ -1,34 +0,0 @@ -import sys -from importlib import import_module - - -def cached_import(module_path, class_name): - """ - This function checks if the module is already imported and if it is, - it returns the class. Otherwise, it imports the module and returns the class. - """ - - # Check if the module is already imported - - if not ( - (module := sys.modules.get(module_path)) - and (spec := getattr(module, "__spec__", None)) - and getattr(spec, "_initializing", False) is False - ): - module = import_module(module_path) - return getattr(module, class_name) - - -def import_string(path: str): - """ - Import a dotted module path and return the attribute/class designated by the - last name in the path. Raise ImportError if the import failed. - """ - try: - path_, class_ = path.rsplit(".", 1) - except ValueError as err: - raise ImportError("%s doesn't look like a module path" % path) from err - try: - return cached_import(path_, class_) - except AttributeError as err: - raise ImportError(f"Module '{path_}' isn't a '{class_}' attr/class") from err diff --git a/src/marvin/utilities/openai.py b/src/marvin/utilities/openai.py new file mode 100644 index 000000000..20b7faa1c --- /dev/null +++ b/src/marvin/utilities/openai.py @@ -0,0 +1,38 @@ +import asyncio +from functools import lru_cache +from typing import Optional + +from openai import AsyncClient + + +def get_client() -> AsyncClient: + from marvin import settings + + api_key: Optional[str] = ( + settings.openai.api_key.get_secret_value() if settings.openai.api_key else None + ) + organization: Optional[str] = settings.openai.organization + return _get_client_memoized( + api_key=api_key, organization=organization, loop=asyncio.get_event_loop() + ) + + +@lru_cache +def _get_client_memoized( + api_key: Optional[str], + organization: Optional[str], + loop: Optional[asyncio.AbstractEventLoop] = None, +) -> AsyncClient: + """ + This function is memoized to ensure that only one instance of the client is + created for a given api key / organization / loop tuple. + + The `loop` is an important key to ensure that the client is not re-used + across multiple event loops (which can happen when using the `run_sync` + function). Attempting to re-use the client across multiple event loops + can result in a `RuntimeError: Event loop is closed` error or infinite hangs. + """ + return AsyncClient( + api_key=api_key, + organization=organization, + ) diff --git a/src/marvin/_compat.py b/src/marvin/utilities/pydantic.py similarity index 51% rename from src/marvin/_compat.py rename to src/marvin/utilities/pydantic.py index bd5669b37..0c95d3975 100644 --- a/src/marvin/_compat.py +++ b/src/marvin/utilities/pydantic.py @@ -1,95 +1,9 @@ from types import FunctionType, GenericAlias -from typing import ( - Annotated, - Any, - Callable, - Optional, - TypeVar, - Union, - cast, - get_origin, -) - -from pydantic.version import VERSION as PYDANTIC_VERSION - -PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") - -if PYDANTIC_V2: - from pydantic.v1 import ( - BaseSettings, - PrivateAttr, - SecretStr, - validate_arguments, - ) - - SettingsConfigDict = BaseSettings.Config - - from pydantic import ( - BaseModel, - Field, - create_model, - field_validator, - ) - -else: - from pydantic import ( # noqa # type: ignore - BaseSettings, - BaseModel, - create_model, - Field, - SecretStr, - validate_arguments, - validator as field_validator, - PrivateAttr, - ) - - SettingsConfigDict = BaseSettings.Config - -_ModelT = TypeVar("_ModelT", bound=BaseModel) - - -def model_dump(model: _ModelT, **kwargs: Any) -> dict[str, Any]: - if PYDANTIC_V2 and hasattr(model, "model_dump"): - return model.model_dump(**kwargs) # type: ignore - return model.dict(**kwargs) # type: ignore - - -def model_dump_json(model: type[BaseModel], **kwargs: Any) -> dict[str, Any]: - if PYDANTIC_V2 and hasattr(model, "model_dump_json"): - return model.model_dump_json(**kwargs) # type: ignore - return model.json(**kwargs) # type: ignore - - -def model_json_schema( - model: type[BaseModel], - name: Optional[str] = None, - description: Optional[str] = None, -) -> dict[str, Any]: - # Get the schema from the model. - schema = {"parameters": {**model_schema(model)}} - - # Mutate the schema to match the OpenAPI spec. - schema["parameters"]["title"] = name or schema["parameters"].pop("title", None) - schema["parameters"]["description"] = description or schema["parameters"].pop( - "description", "" - ) # noqa - - # Move the properties to the root of the schema. - schema["name"] = schema["parameters"].pop("title") - schema["description"] = schema["parameters"].pop("description") - return schema - +from typing import Annotated, Any, Callable, Optional, Union, cast, get_origin -def model_schema(model: type[BaseModel], **kwargs: Any) -> dict[str, Any]: - if PYDANTIC_V2 and hasattr(model, "model_json_schema"): - return model.model_json_schema(**kwargs) # type: ignore - return model.schema(**kwargs) # type: ignore - - -def model_copy(model: _ModelT, **kwargs: Any) -> _ModelT: - if PYDANTIC_V2 and hasattr(model, "model_copy"): - return model.model_copy(**kwargs) # type: ignore - return model.copy(**kwargs) # type: ignore +from pydantic import BaseModel, TypeAdapter, create_model +from pydantic.v1 import validate_arguments +from typing_extensions import Literal def cast_callable_to_model( @@ -97,14 +11,14 @@ def cast_callable_to_model( name: Optional[str] = None, description: Optional[str] = None, ) -> type[BaseModel]: - response = validate_arguments(function).model # type: ignore + response = validate_arguments(function).model for field in ["args", "kwargs", "v__duplicate_kwargs"]: - fields = cast(dict[str, Any], response.__fields__) # type: ignore + fields = cast(dict[str, Any], response.__fields__) fields.pop(field, None) response.__title__ = name or function.__name__ response.__name__ = name or function.__name__ response.__doc__ = description or function.__doc__ - return response # type: ignore + return response def cast_type_or_alias_to_model( @@ -137,34 +51,38 @@ def cast_to_model( response = BaseModel if origin is Annotated: - metadata: Any = next(iter(function_or_type.__metadata__), None) # type: ignore + metadata: Any = next(iter(function_or_type.__metadata__), None) annotated_field_name: Optional[str] = field_name if hasattr(metadata, "extra") and isinstance(metadata.extra, dict): - annotated_field_name: Optional[str] = metadata.extra.get("name", "") # type: ignore # noqa + annotated_field_name: Optional[str] = metadata.extra.get("name", "") # noqa elif hasattr(metadata, "json_schema_extra") and isinstance( metadata.json_schema_extra, dict ): # noqa - annotated_field_name: Optional[str] = metadata.json_schema_extra.get("name", "") # type: ignore # noqa + annotated_field_name: Optional[str] = metadata.json_schema_extra.get( + "name", "" + ) # noqa elif isinstance(metadata, dict): - annotated_field_name: Optional[str] = metadata.get("name", "") # type: ignore # noqa + annotated_field_name: Optional[str] = metadata.get("name", "") # noqa elif isinstance(metadata, str): annotated_field_name: Optional[str] = metadata else: pass annotated_field_description: Optional[str] = description or "" if hasattr(metadata, "description") and isinstance(metadata.description, str): - annotated_field_description: Optional[str] = metadata.description # type: ignore # noqa + annotated_field_description: Optional[str] = metadata.description # noqa elif isinstance(metadata, dict): - annotated_field_description: Optional[str] = metadata.get("description", "") # type: ignore # noqa + annotated_field_description: Optional[str] = metadata.get( + "description", "" + ) # noqa else: pass response = cast_to_model( - function_or_type.__origin__, # type: ignore + function_or_type.__origin__, name=name, description=annotated_field_description, - field_name=annotated_field_name, # type: ignore + field_name=annotated_field_name, ) response.__doc__ = annotated_field_description or "" elif origin in {dict, list, tuple, set, frozenset}: @@ -193,12 +111,17 @@ def cast_to_model( return response -def cast_to_json( - function_or_type: Union[type, type[BaseModel], GenericAlias, Callable[..., Any]], - name: Optional[str] = None, - description: Optional[str] = None, - field_name: Optional[str] = None, -) -> dict[str, Any]: - return model_json_schema( - cast_to_model(function_or_type, name, description, field_name) - ) +def parse_as( + type_: Any, + data: Any, + mode: Literal["python", "json", "strings"] = "python", +) -> BaseModel: + """Parse a json string to a Pydantic model.""" + adapter = TypeAdapter(type_) + + if get_origin(type_) is list and isinstance(data, dict): + data = next(iter(data.values())) + + parser = getattr(adapter, f"validate_{mode}") + + return parser(data) diff --git a/src/marvin/utilities/slack.py b/src/marvin/utilities/slack.py new file mode 100644 index 000000000..b7a331c10 --- /dev/null +++ b/src/marvin/utilities/slack.py @@ -0,0 +1,286 @@ +import os +import re +from typing import List, Optional, Union + +import httpx +from pydantic import BaseModel, field_validator + +import marvin + + +class EventBlockElement(BaseModel): + type: str + text: Optional[str] = None + user_id: Optional[str] = None + + +class EventBlockElementGroup(BaseModel): + type: str + elements: List[EventBlockElement] + + +class EventBlock(BaseModel): + type: str + block_id: str + elements: List[Union[EventBlockElement, EventBlockElementGroup]] + + +class SlackEvent(BaseModel): + client_msg_id: Optional[str] = None + type: str + text: str + user: str + ts: str + team: str + channel: str + event_ts: str + thread_ts: Optional[str] = None + parent_user_id: Optional[str] = None + blocks: Optional[List[EventBlock]] = None + + +class EventAuthorization(BaseModel): + enterprise_id: Optional[str] = None + team_id: str + user_id: str + is_bot: bool + is_enterprise_install: bool + + +class SlackPayload(BaseModel): + token: str + type: str + team_id: Optional[str] = None + api_app_id: Optional[str] = None + event: Optional[SlackEvent] = None + event_id: Optional[str] = None + event_time: Optional[int] = None + authorizations: Optional[List[EventAuthorization]] = None + is_ext_shared_channel: Optional[bool] = None + event_context: Optional[str] = None + challenge: Optional[str] = None + + @field_validator("event") + def validate_event(cls, v: Optional[SlackEvent]) -> Optional[SlackEvent]: + if v.type != "url_verification" and v is None: + raise ValueError("event is required") + return v + + +async def get_token() -> str: + """Get the Slack bot token from the environment.""" + try: + token = marvin.settings.slack_api_token + except AttributeError: + token = os.getenv("MARVIN_SLACK_API_TOKEN") + if not token: + raise ValueError( + "`MARVIN_SLACK_API_TOKEN` not found in environment." + " Please set it in `~/.marvin/.env` or as an environment variable." + ) + return token + + +def convert_md_links_to_slack(text) -> str: + md_link_pattern = r"\[(?P[^\]]+)]\((?P[^\)]+)\)" + + # converting Markdown links to Slack-style links + def to_slack_link(match): + return f'<{match.group("url")}|{match.group("text")}>' + + # Replace Markdown links with Slack-style links + slack_text = re.sub(md_link_pattern, to_slack_link, text) + + return slack_text + + +async def post_slack_message( + message: str, + channel_id: str, + thread_ts: Union[str, None] = None, + auth_token: Union[str, None] = None, +) -> httpx.Response: + if not auth_token: + auth_token = await get_token() + + async with httpx.AsyncClient() as client: + response = await client.post( + "https://slack.com/api/chat.postMessage", + headers={"Authorization": f"Bearer {auth_token}"}, + json={ + "channel": channel_id, + "text": convert_md_links_to_slack(message), + "thread_ts": thread_ts, + }, + ) + + response.raise_for_status() + return response + + +async def get_thread_messages(channel: str, thread_ts: str) -> list: + """Get all messages from a slack thread.""" + async with httpx.AsyncClient() as client: + response = await client.get( + "https://slack.com/api/conversations.replies", + headers={"Authorization": f"Bearer {await get_token()}"}, + params={"channel": channel, "ts": thread_ts}, + ) + response.raise_for_status() + return response.json().get("messages", []) + + +async def get_user_name(user_id: str) -> str: + async with httpx.AsyncClient() as client: + response = await client.get( + "https://slack.com/api/users.info", + params={"user": user_id}, + headers={"Authorization": f"Bearer {await get_token()}"}, # noqa: E501 + ) + return ( + response.json().get("user", {}).get("name", user_id) + if response.status_code == 200 + else user_id + ) + + +async def get_channel_name(channel_id: str) -> str: + async with httpx.AsyncClient() as client: + response = await client.get( + "https://slack.com/api/conversations.info", + params={"channel": channel_id}, + headers={"Authorization": f"Bearer {await get_token()}"}, # noqa: E501 + ) + return ( + response.json().get("channel", {}).get("name", channel_id) + if response.status_code == 200 + else channel_id + ) + + +async def fetch_current_message_text(channel: str, ts: str) -> str: + """Fetch the current text of a specific Slack message using its timestamp.""" + async with httpx.AsyncClient() as client: + response = await client.get( + "https://slack.com/api/conversations.replies", + params={"channel": channel, "ts": ts}, + headers={"Authorization": f"Bearer {await get_token()}"}, # noqa: E501 + ) + response.raise_for_status() + messages = response.json().get("messages", []) + if not messages: + raise ValueError("Message not found") + + return messages[0]["text"] + + +async def edit_slack_message( + new_text: str, + channel_id: str, + thread_ts: str, + mode: str = "append", + delimiter: Union[str, None] = None, +) -> httpx.Response: + """Edit an existing Slack message by appending new text or replacing it. + + Args: + channel (str): The Slack channel ID. + ts (str): The timestamp of the message to edit. + new_text (str): The new text to append or replace in the message. + mode (str): The mode of text editing, 'append' (default) or 'replace'. + + Returns: + httpx.Response: The response from the Slack API. + """ + if mode == "append": + current_text = await fetch_current_message_text(channel_id, thread_ts) + delimiter = "\n\n" if delimiter is None else delimiter + updated_text = f"{current_text}{delimiter}{convert_md_links_to_slack(new_text)}" + elif mode == "replace": + updated_text = convert_md_links_to_slack(new_text) + else: + raise ValueError("Invalid mode. Use 'append' or 'replace'.") + + async with httpx.AsyncClient() as client: + response = await client.post( + "https://slack.com/api/chat.update", + headers={"Authorization": f"Bearer {await get_token()}"}, + json={"channel": channel_id, "ts": thread_ts, "text": updated_text}, + ) + + response.raise_for_status() + return response + + +async def search_slack_messages( + query: str, + max_messages: int = 3, + channel: Union[str, None] = None, + user_auth_token: Union[str, None] = None, +) -> list: + """ + Search for messages in Slack workspace based on a query. + + Args: + query (str): The search query. + max_messages (int): The maximum number of messages to retrieve. + channel (str, optional): The specific channel to search in. Defaults to None, + which searches all channels. + + Returns: + list: A list of message contents and permalinks matching the query. + """ + all_messages = [] + next_cursor = None + + if not user_auth_token: + user_auth_token = await get_token() + + async with httpx.AsyncClient() as client: + while len(all_messages) < max_messages: + params = { + "query": query, + "limit": min(max_messages - len(all_messages), 10), + } + if channel: + params["channel"] = channel + if next_cursor: + params["cursor"] = next_cursor + + response = await client.get( + "https://slack.com/api/search.messages", + headers={"Authorization": f"Bearer {user_auth_token}"}, + params=params, + ) + + response.raise_for_status() + data = response.json().get("messages", {}).get("matches", []) + for message in data: + all_messages.append( + { + "content": message.get("text", ""), + "permalink": message.get("permalink", ""), + } + ) + + next_cursor = ( + response.json().get("response_metadata", {}).get("next_cursor") + ) + + if not next_cursor: + break + + return all_messages[:max_messages] + + +async def get_workspace_info(slack_bot_token: Union[str, None] = None) -> dict: + if not slack_bot_token: + slack_bot_token = await get_token() + + async with httpx.AsyncClient() as client: + response = await client.get( + "https://slack.com/api/team.info", + headers={"Authorization": f"Bearer {slack_bot_token}"}, + ) + response.raise_for_status() + return response.json().get("team", {}) diff --git a/src/marvin/utilities/streaming.py b/src/marvin/utilities/streaming.py deleted file mode 100644 index 15302a024..000000000 --- a/src/marvin/utilities/streaming.py +++ /dev/null @@ -1,13 +0,0 @@ -import abc -from typing import Callable, Optional - -from marvin.utilities.messages import Message -from marvin.utilities.types import MarvinBaseModel - - -class StreamHandler(MarvinBaseModel, abc.ABC): - callback: Optional[Callable] = None - - @abc.abstractmethod - def handle_streaming_response(self, api_response) -> Message: - raise NotImplementedError() diff --git a/src/marvin/utilities/strings.py b/src/marvin/utilities/strings.py index e05ee52d2..46a140f07 100644 --- a/src/marvin/utilities/strings.py +++ b/src/marvin/utilities/strings.py @@ -1,66 +1,13 @@ -import inspect -import re -from datetime import datetime -from zoneinfo import ZoneInfo - import tiktoken -from jinja2 import ( - ChoiceLoader, - Environment, - StrictUndefined, - pass_context, - select_autoescape, -) -from markupsafe import Markup - -import marvin.utilities.async_utils - -NEWLINES_REGEX = re.compile(r"(\s*\n\s*)") -MD_LINK_REGEX = r"\[(?P[^\]]+)]\((?P[^\)]+)\)" - -jinja_env = Environment( - loader=ChoiceLoader( - [ - # PackageLoader("marvin", "prompts") - ] - ), - autoescape=select_autoescape(default_for_string=False), - trim_blocks=True, - lstrip_blocks=True, - auto_reload=True, - undefined=StrictUndefined, -) - -jinja_env.globals.update( - zip=zip, - arun=marvin.utilities.async_utils.run_sync, - now=lambda: datetime.now(ZoneInfo("UTC")), -) - - -@pass_context -def render_filter(context, value): - """ - Allows nested rendering of variables that may contain variables themselves - e.g. {{ description | render }} - """ - _template = context.eval_ctx.environment.from_string(value) - result = _template.render(**context) - if context.eval_ctx.autoescape: - result = Markup(result) - return result - - -jinja_env.filters["render"] = render_filter -def tokenize(text: str) -> list[int]: - tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo") +def tokenize(text: str, model: str = "gpt-3.5-turbo-1106") -> list[int]: + tokenizer = tiktoken.encoding_for_model(model) return tokenizer.encode(text) -def detokenize(tokens: list[int]) -> str: - tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo") +def detokenize(tokens: list[int], model: str = "gpt-3.5-turbo-1106") -> str: + tokenizer = tiktoken.encoding_for_model(model) return tokenizer.decode(tokens) @@ -71,78 +18,3 @@ def count_tokens(text: str) -> int: def slice_tokens(text: str, n_tokens: int) -> str: tokens = tokenize(text) return detokenize(tokens[:n_tokens]) - - -def split_tokens(text: str, n_tokens: int) -> list[str]: - tokens = tokenize(text) - return [ - detokenize(tokens[i : i + n_tokens]) for i in range(0, len(tokens), n_tokens) - ] - - -def condense_newlines(text: str) -> str: - def replace_whitespace(match): - newlines_count = match.group().count("\n") - if newlines_count <= 1: - return " " - else: - return "\n" * newlines_count - - text = inspect.cleandoc(text) - text = NEWLINES_REGEX.sub(replace_whitespace, text) - return text.strip() - - -def html_to_content(html: str) -> str: - from bs4 import BeautifulSoup - - soup = BeautifulSoup(html, "html.parser") - - # Remove script and style elements - for script in soup(["script", "style"]): - script.extract() - - # Get text - text = soup.get_text() - - # Condense newlines - return condense_newlines(text) - - -def convert_md_links_to_slack(text: str) -> str: - # Convert Markdown links to Slack-style links - md_link_regex = re.compile(r"\[(?P[^\]]+)\]\((?P[^\)]+)\)") - text = md_link_regex.sub(r"<\g|\g>", text) - - text = re.sub(r"\*\*(.+?)\*\*", r"*\1*", text) - - return text - - -def split_text_by_tokens(text: str, split_tokens: list[str]) -> list[tuple[str, str]]: - cleaned_text = inspect.cleandoc(text) - - # Find all positions of tokens in the text - positions = [ - (match.start(), match.end(), match.group().rstrip(":").strip()) - for token in split_tokens - for match in re.finditer(re.escape(token) + r"(?::\s*)?", cleaned_text) - ] - - # Sort positions by their start index - positions.sort(key=lambda x: x[0]) - - paired: list[tuple[str, str]] = [] - prev_end = 0 - prev_token = split_tokens[0] - for start, end, token in positions: - paired.append((prev_token, cleaned_text[prev_end:start].strip())) - prev_end = end - prev_token = token - - paired.append((prev_token, cleaned_text[prev_end:].strip())) - - # Remove pairs where the text is empty - paired = [(token.replace(":", ""), text) for token, text in paired if text] - - return paired diff --git a/src/marvin/utilities/tools.py b/src/marvin/utilities/tools.py new file mode 100644 index 000000000..0f051f778 --- /dev/null +++ b/src/marvin/utilities/tools.py @@ -0,0 +1,54 @@ +import inspect +import json + +from marvin.requests import Tool +from marvin.utilities.asyncio import run_sync +from marvin.utilities.logging import get_logger +from marvin.utilities.pydantic import cast_callable_to_model + +logger = get_logger("Tools") + + +def tool_from_function(fn: callable, name: str = None, description: str = None): + model = cast_callable_to_model(fn) + return Tool( + type="function", + function=dict( + name=name or fn.__name__, + description=description or fn.__doc__, + # use deprecated schema because this is based on a pydantic v1 + # validate_arguments + parameters=model.schema(), + python_fn=fn, + ), + ) + + +def call_function_tool( + tools: list[Tool], function_name: str, function_arguments_json: str +): + tool = next( + ( + tool + for tool in tools + if isinstance(tool, Tool) + and tool.function + and tool.function.name == function_name + ), + None, + ) + if not tool: + raise ValueError(f"Could not find function '{function_name}'") + + arguments = json.loads(function_arguments_json) + logger.debug(f"Calling {tool.function.name} with arguments: {arguments}") + output = tool.function.python_fn(**arguments) + if inspect.isawaitable(output): + output = run_sync(output) + truncated_output = str(output)[:100] + if len(truncated_output) < len(str(output)): + truncated_output += "..." + logger.debug(f"{tool.function.name} returned: {truncated_output}") + if not isinstance(output, str): + output = json.dumps(output) + return output diff --git a/src/marvin/utilities/types.py b/src/marvin/utilities/types.py deleted file mode 100644 index dbd7f1456..000000000 --- a/src/marvin/utilities/types.py +++ /dev/null @@ -1,128 +0,0 @@ -import inspect -import logging -from types import GenericAlias -from typing import Any, Callable, _SpecialForm - -from marvin._compat import BaseModel, PrivateAttr, create_model -from marvin.utilities.logging import get_logger - - -class MarvinBaseModel(BaseModel): - class Config: - extra = "forbid" - - -class LoggerMixin(BaseModel): - """ - BaseModel mixin that adds a private `logger` attribute - """ - - _logger: logging.Logger = PrivateAttr() - - def __init__(self, **data): - super().__init__(**data) - self._logger = get_logger(type(self).__name__) - - @property - def logger(self): - return self._logger - - -def function_to_model( - function: Callable[..., Any], name: str = None, description: str = None -) -> dict: - """ - Converts a function's arguments into an OpenAPI schema by parsing it into a - Pydantic model. To work, all arguments must have valid type annotations. - """ - signature = inspect.signature(function) - - fields = { - p: ( - signature.parameters[p].annotation, - ( - signature.parameters[p].default - if signature.parameters[p].default != inspect._empty - else ... - ), - ) - for p in signature.parameters - if p != getattr(function, "__self__", None) - } - - # Create Pydantic model - try: - Model = create_model(name or function.__name__, **fields) - except RuntimeError as exc: - if "see `arbitrary_types_allowed` " in str(exc): - raise ValueError( - f"Error while inspecting {function.__name__} with signature" - f" {signature}: {exc}" - ) - else: - raise - - return Model - - -def function_to_schema(function: Callable[..., Any], name: str = None) -> dict: - """ - Converts a function's arguments into an OpenAPI schema by parsing it into a - Pydantic model. To work, all arguments must have valid type annotations. - """ - Model = function_to_model(function, name=name) - - return Model.schema() - - -def safe_issubclass(type_, classes): - if isinstance(type_, type) and not isinstance(type_, GenericAlias): - return issubclass(type_, classes) - else: - return False - - -def type_to_schema(type_, set_root_type: bool = True) -> dict: - if safe_issubclass(type_, BaseModel): - schema = type_.schema() - # if the docstring was updated at runtime, make it the description - if type_.__doc__ and type_.__doc__ != schema.get("description"): - schema["description"] = type_.__doc__ - return schema - - elif set_root_type: - - class Model(BaseModel): - __root__: type_ - - return Model.schema() - else: - - class Model(BaseModel): - data: type_ - - return Model.schema() - - -def genericalias_contains(genericalias, target_type): - """ - Explore whether a type or generic alias contains a target type. The target - types can be a single type or a tuple of types. - - Useful for seeing if a type contains a pydantic model, for example. - """ - if isinstance(target_type, tuple): - return any(genericalias_contains(genericalias, t) for t in target_type) - - if isinstance(genericalias, GenericAlias): - if safe_issubclass(genericalias.__origin__, target_type): - return True - for arg in genericalias.__args__: - if genericalias_contains(arg, target_type): - return True - elif isinstance(genericalias, _SpecialForm): - return False - else: - return safe_issubclass(genericalias, target_type) - - return False diff --git a/src/marvin/_framework/app/__init__.py b/tests/beta/__init__.py similarity index 100% rename from src/marvin/_framework/app/__init__.py rename to tests/beta/__init__.py diff --git a/src/marvin/_framework/config/__init__.py b/tests/beta/assistants/__init__.py similarity index 100% rename from src/marvin/_framework/config/__init__.py rename to tests/beta/assistants/__init__.py diff --git a/src/marvin/_framework/static/__init__.py b/tests/beta/assistants/test_assistants.py similarity index 100% rename from src/marvin/_framework/static/__init__.py rename to tests/beta/assistants/test_assistants.py diff --git a/src/marvin/cli/admin/scripts/__init__.py b/tests/cli/__init__.py similarity index 100% rename from src/marvin/cli/admin/scripts/__init__.py rename to tests/cli/__init__.py diff --git a/tests/cli/test_version.py b/tests/cli/test_version.py new file mode 100644 index 000000000..61769f429 --- /dev/null +++ b/tests/cli/test_version.py @@ -0,0 +1,14 @@ +from typer.testing import CliRunner + + +def test_marvin_version_command(): + """Test the marvin version command.""" + from marvin.cli import app + + runner = CliRunner() + result = runner.invoke(app, ["version"]) + + assert result.exit_code == 0 + assert "Version:" in result.output + assert "Python version:" in result.output + assert "OS/Arch:" in result.output diff --git a/src/marvin/components/library/__init__.py b/tests/components/__init__.py similarity index 100% rename from src/marvin/components/library/__init__.py rename to tests/components/__init__.py diff --git a/tests/components/test_ai_classifier.py b/tests/components/test_ai_classifier.py new file mode 100644 index 000000000..034f53bf0 --- /dev/null +++ b/tests/components/test_ai_classifier.py @@ -0,0 +1,47 @@ +from enum import Enum + +from marvin import ai_classifier +from typing_extensions import Literal + +from tests.utils import pytest_mark_class + +Sentiment = Literal["Positive", "Negative"] + + +class GitHubIssueTag(Enum): + BUG = "bug" + FEATURE = "feature" + ENHANCEMENT = "enhancement" + DOCS = "docs" + + +@pytest_mark_class("llm") +class TestAIClassifer: + class TestLiteral: + def test_ai_classifier_literal_return_type(self): + @ai_classifier + def sentiment(text: str) -> Sentiment: + """Classify sentiment""" + + result = sentiment("Great!") + + assert result == "Positive" + + def test_ai_classifier_literal_return_type_with_docstring(self): + @ai_classifier + def sentiment(text: str) -> Sentiment: + """Classify sentiment - also its opposite day""" + + result = sentiment("Great!") + + assert result == "Negative" + + class TestEnum: + def test_ai_classifier_enum_return_type(self): + @ai_classifier + def labeler(text: str) -> GitHubIssueTag: + """Classify GitHub issue tags""" + + result = labeler("improve the docs you slugs") + + assert result == GitHubIssueTag.DOCS diff --git a/tests/components/test_ai_functions.py b/tests/components/test_ai_functions.py new file mode 100644 index 000000000..070a0b97d --- /dev/null +++ b/tests/components/test_ai_functions.py @@ -0,0 +1,205 @@ +import inspect +from typing import Dict, List + +import marvin +import pytest +from marvin import ai_fn +from pydantic import BaseModel + +from tests.utils import pytest_mark_class + + +@ai_fn +def list_fruit(n: int = 2) -> list[str]: + """Returns a list of `n` fruit""" + + +@ai_fn +def list_fruit_color(n: int, color: str = None) -> list[str]: + """Returns a list of `n` fruit that all have the provided `color`""" + + +@pytest_mark_class("llm") +class TestAIFunctions: + class TestBasics: + def test_list_fruit(self): + result = list_fruit() + assert len(result) == 2 + + def test_list_fruit_argument(self): + result = list_fruit(5) + assert len(result) == 5 + + async def test_list_fruit_async(self): + @ai_fn + async def list_fruit(n: int) -> list[str]: + """Returns a list of `n` fruit""" + + coro = list_fruit(3) + assert inspect.iscoroutine(coro) + result = await coro + assert len(result) == 3 + + class TestAnnotations: + def test_list_fruit_with_generic_type_hints(self): + @ai_fn + def list_fruit(n: int) -> List[str]: + """Returns a list of `n` fruit""" + + result = list_fruit(3) + assert len(result) == 3 + + def test_basemodel_return_annotation(self): + class Fruit(BaseModel): + name: str + color: str + + @ai_fn + def get_fruit(description: str) -> Fruit: + """Returns a fruit with the provided description""" + + fruit = get_fruit("loved by monkeys") + assert fruit.name.lower() == "banana" + assert fruit.color.lower() == "yellow" + + @pytest.mark.parametrize("name,expected", [("banana", True), ("car", False)]) + def test_bool_return_annotation(self, name, expected): + @ai_fn + def is_fruit(name: str) -> bool: + """Returns True if the provided name is a fruit""" + + assert is_fruit(name) == expected + + @pytest.mark.skipif( + marvin.settings.openai.chat.completions.model == "gpt-3.5-turbo-1106", + reason="3.5 turbo doesn't do well with unknown schemas", + ) + def test_plain_dict_return_type(self): + @ai_fn + def describe_fruit(description: str) -> dict: + """guess the fruit and return the name and color""" + + fruit = describe_fruit("the one thats loved by monkeys") + assert fruit["name"].lower() == "banana" + assert fruit["color"].lower() == "yellow" + + @pytest.mark.skipif( + marvin.settings.openai.chat.completions.model == "gpt-3.5-turbo-1106", + reason="3.5 turbo doesn't do well with unknown schemas", + ) + def test_annotated_dict_return_type(self): + @ai_fn + def describe_fruit(description: str) -> dict[str, str]: + """guess the fruit and return the name and color""" + + fruit = describe_fruit("the one thats loved by monkeys") + assert fruit["name"].lower() == "banana" + assert fruit["color"].lower() == "yellow" + + @pytest.mark.skipif( + marvin.settings.openai.chat.completions.model == "gpt-3.5-turbo-1106", + reason="3.5 turbo doesn't do well with unknown schemas", + ) + def test_generic_dict_return_type(self): + @ai_fn + def describe_fruit(description: str) -> Dict[str, str]: + """guess the fruit and return the name and color""" + + fruit = describe_fruit("the one thats loved by monkeys") + assert fruit["name"].lower() == "banana" + assert fruit["color"].lower() == "yellow" + + def test_typed_dict_return_type(self): + from typing_extensions import TypedDict + + class Fruit(TypedDict): + name: str + color: str + + @ai_fn + def describe_fruit(description: str) -> Fruit: + """guess the fruit and return the name and color""" + + fruit = describe_fruit("the one thats loved by monkeys") + assert fruit["name"].lower() == "banana" + assert fruit["color"].lower() == "yellow" + + def test_int_return_type(self): + @ai_fn + def get_fruit(name: str) -> int: + """Returns the number of letters in the alluded fruit name""" + + assert get_fruit("banana") == 6 + + def test_float_return_type(self): + @ai_fn + def get_pi(n: int) -> float: + """Return the first n digits of pi""" + + assert get_pi(5) == 3.14159 + + def test_tuple_return_type(self): + @ai_fn + def get_fruit(name: str) -> tuple: + """Returns a tuple of fruit""" + + assert get_fruit("alphabet of fruit, first 3, singular") == ( + "apple", + "banana", + "cherry", + ) + + def test_set_return_type(self): + @ai_fn + def get_fruit_letters(name: str) -> set: + """Returns the letters in the provided fruit name""" + + assert get_fruit_letters("banana") == {"a", "b", "n"} + + def test_frozenset_return_type(self): + @ai_fn + def get_fruit_letters(name: str) -> frozenset: + """Returns the letters in the provided fruit name""" + + assert get_fruit_letters("orange") == frozenset( + {"a", "e", "g", "n", "o", "r"} + ) + + +@pytest_mark_class("llm") +class TestAIFunctionsMap: + def test_map(self): + result = list_fruit.map([2, 3]) + assert len(result) == 2 + assert len(result[0]) == 2 + assert len(result[1]) == 3 + + async def test_amap(self): + result = await list_fruit.amap([2, 3]) + assert len(result) == 2 + assert len(result[0]) == 2 + assert len(result[1]) == 3 + + def test_map_kwargs(self): + result = list_fruit.map(n=[2, 3]) + assert len(result) == 2 + assert len(result[0]) == 2 + assert len(result[1]) == 3 + + def test_map_kwargs_and_args(self): + result = list_fruit_color.map([2, 3], color=["green", "red"]) + assert len(result) == 2 + assert len(result[0]) == 2 + assert len(result[1]) == 3 + + def test_invalid_args(self): + with pytest.raises(TypeError): + list_fruit_color.map(2, color=["orange", "red"]) + + def test_invalid_kwargs(self): + with pytest.raises(TypeError): + list_fruit_color.map([2, 3], color=None) + + async def test_invalid_async_map(self): + with pytest.raises(TypeError, match="can't be used in 'await' expression"): + await list_fruit_color.map(n=[2], color=["orange", "red"]) diff --git a/tests/test_components/test_ai_model.py b/tests/components/test_ai_model.py similarity index 80% rename from tests/test_components/test_ai_model.py rename to tests/components/test_ai_model.py index 24d37c06d..0fb9efa38 100644 --- a/tests/test_components/test_ai_model.py +++ b/tests/components/test_ai_model.py @@ -2,10 +2,9 @@ import pytest from marvin import ai_model -from marvin.utilities.messages import Message, Role from pydantic import BaseModel, Field -from tests.utils.mark import pytest_mark_class +from tests.utils import pytest_mark_class @pytest_mark_class("llm") @@ -29,16 +28,16 @@ class Location(BaseModel): longitude: float city: str state: str - country: str + country: str = Field(..., description="The abbreviated country name") x = Location("The capital city of the Cornhusker State.") assert x.city == "Lincoln" assert x.state == "Nebraska" - assert "United" in x.country - assert "States" in x.country + assert x.country in {"US", "USA", "U.S.", "U.S.A."} assert x.latitude // 1 == 40 assert x.longitude // 1 == -97 + @pytest.mark.xfail(reason="TODO: flaky on 3.5") def test_depth(self): from typing import List @@ -61,6 +60,7 @@ class RentalHistory(BaseModel): I lived in Palms, then Mar Vista, then Pico Robertson. """) + @pytest.mark.flaky(max_runs=3) def test_resume(self): class Experience(BaseModel): technology: str @@ -86,23 +86,17 @@ class Resume(BaseModel): assert not x.greater_than_ten_years_management_experience assert len(x.technologies) == 2 - @pytest.mark.flaky(reruns=2) def test_literal(self): - class CertainPerson(BaseModel): - name: Literal["Adam", "Nate", "Jeremiah"] - @ai_model class LLMConference(BaseModel): - speakers: List[CertainPerson] + speakers: list[Literal["Adam", "Nate", "Jeremiah"]] x = LLMConference(""" The conference for best LLM framework will feature talks by Adam, Nate, Jeremiah, Marvin, and Billy Bob Thornton. """) - assert len(set([speaker.name for speaker in x.speakers])) == 3 - assert set([speaker.name for speaker in x.speakers]) == set( - ["Adam", "Nate", "Jeremiah"] - ) + assert len(set(x.speakers)) == 3 + assert set(x.speakers) == set(["Adam", "Nate", "Jeremiah"]) @pytest.mark.xfail(reason="regression in OpenAI function-using models") def test_history(self): @@ -143,6 +137,7 @@ class Election(BaseModel): ) ) + @pytest.mark.skip(reason="old behavior, may revisit") def test_correct_class_is_returned(self): @ai_model class Fruit(BaseModel): @@ -153,36 +148,10 @@ class Fruit(BaseModel): assert isinstance(fruit, Fruit) - async def test_correct_class_is_returned_via_acall(self): - @ai_model - class Fruit(BaseModel): - color: str - name: str - - fruit = await Fruit.acall("loved by monkeys") - - assert isinstance(fruit, Fruit) - - -@pytest_mark_class("llm") -class TestAIModelsMessage: - @pytest.mark.skip(reason="old behavior, may revisit") - def test_arithmetic_message(self): - @ai_model - class Arithmetic(BaseModel): - sum: float = Field( - ..., description="The resolved sum of provided arguments" - ) - - x = Arithmetic("One plus six") - assert x.sum == 7 - assert isinstance(x._message, Message) - assert x._message.role == Role.FUNCTION_RESPONSE - +@pytest.mark.skip(reason="old behavior, may revisit") @pytest_mark_class("llm") class TestInstructions: - @pytest.mark.skip(reason="old behavior, may revisit") def test_instructions_error(self): @ai_model class Test(BaseModel): @@ -280,31 +249,30 @@ class Arithmetic(BaseModel): assert x[0].sum == 7 assert x[1].sum == 101 - def test_location(self): + @pytest.mark.flaky(max_runs=3) + def test_fix_misspellings(self): @ai_model class City(BaseModel): - name: str = Field(description="The correct city name, e.g. Omaha") # noqa + """fix any misspellings of a city attributes""" + + name: str = Field( + description=( + "The OFFICIAL, correctly-spelled name of a city - must be" + " capitalized. Do not include the state or country, or use any" + " abbreviations." + ) + ) results = City.map( [ "the windy city", "chicago IL", "Chicago", - "Chcago", + "America's third-largest city", "chicago, Illinois, USA", - "chi-town", + "colloquially known as 'chi-town'", ] ) assert len(results) == 6 for result in results: assert result.name == "Chicago" - - def test_instructions(self): - @ai_model - class Translate(BaseModel): - text: str - - result = Translate.map(["Hello", "Goodbye"], instructions="Translate to French") - assert len(result) == 2 - assert result[0].text == "Bonjour" - assert result[1].text == "Au revoir" diff --git a/tests/conftest.py b/tests/conftest.py index 96c498fdf..498585601 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ import asyncio import logging +import os import sys import pytest @@ -40,3 +41,49 @@ def event_loop(request): if tasks and loop.is_running(): loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) loop.close() + + +class SetEnv: + def __init__(self): + self.envars = set() + + def set(self, name, value): + self.envars.add(name) + os.environ[name] = value + + def pop(self, name): + self.envars.remove(name) + os.environ.pop(name) + + def clear(self): + for n in self.envars: + os.environ.pop(n) + + +@pytest.fixture +def env(): + setenv = SetEnv() + + yield setenv + + setenv.clear() + + +@pytest.fixture +def docs_test_env(): + setenv = SetEnv() + + # # envs for basic usage example + # setenv.set('my_auth_key', 'xxx') + # setenv.set('my_api_key', 'xxx') + + # envs for parsing environment variable values example + setenv.set("V0", "0") + setenv.set("SUB_MODEL", '{"v1": "json-1", "v2": "json-2"}') + setenv.set("SUB_MODEL__V2", "nested-2") + setenv.set("SUB_MODEL__V3", "3") + setenv.set("SUB_MODEL__DEEP__V4", "v4") + + yield setenv + + setenv.clear() diff --git a/tests/utils/package_size.sh b/tests/package_size.sh similarity index 100% rename from tests/utils/package_size.sh rename to tests/package_size.sh diff --git a/tests/test_chat_completion/test_sdk.py b/tests/test_chat_completion/test_sdk.py deleted file mode 100644 index 0c0fcbabf..000000000 --- a/tests/test_chat_completion/test_sdk.py +++ /dev/null @@ -1,103 +0,0 @@ -import pytest -from marvin.core.ChatCompletion.abstract import Conversation - -from tests.utils.mark import pytest_mark_class - - -class TestRegressions: - def test_key_set_via_attr(self, monkeypatch): - from marvin import openai - - monkeypatch.setattr(openai, "api_key", "test") - v = openai.ChatCompletion().defaults.get("api_key") - assert v == "test" - - @pytest.mark.parametrize("valid_env_var", ["MARVIN_OPENAI_API_KEY"]) - def test_key_set_via_env(self, monkeypatch, valid_env_var): - monkeypatch.setenv(valid_env_var, "test") - from marvin import openai - - v = openai.ChatCompletion().defaults.get("api_key") - assert v == "test" - - def facet(self): - messages = [{"role": "user", "content": "hey"}] - from marvin import openai - - faceted = openai.ChatCompletion(messages=messages) - faceted_request = faceted.prepare_request(messages=messages) - assert faceted_request.messages == 2 * messages - - -@pytest_mark_class("llm") -class TestChatCompletion: - def test_response_model(self): - import pydantic - from marvin import openai - - class Person(pydantic.BaseModel): - name: str - age: int - - response = openai.ChatCompletion().create( - messages=[{"role": "user", "content": "Billy is 10 years old"}], - response_model=Person, - ) - - model = response.to_model() - assert model.name == "Billy" - assert model.age == 10 - - def test_streaming(self): - from marvin import openai - - streamed_data = [] - - def handler(message): - streamed_data.append(message.content) - - completion = openai.ChatCompletion(stream_handler=handler).create( - messages=[{"role": "user", "content": "say exactly 'hello'"}], - ) - - assert completion.response.choices[0].message.content == streamed_data[-1] - assert "hello" in streamed_data[-1].lower() - assert len(streamed_data) > 1 - - async def test_streaming_async(self): - from marvin import openai - - streamed_data = [] - - async def handler(message): - streamed_data.append(message.content) - - completion = await openai.ChatCompletion(stream_handler=handler).acreate( - messages=[{"role": "user", "content": "say only 'hello'"}], - ) - assert completion.response.choices[0].message.content == streamed_data[-1] - assert "hello" in streamed_data[-1].lower() - assert len(streamed_data) > 1 - - -@pytest_mark_class("llm") -class TestChatCompletionChain: - def test_chain(self): - from marvin import openai - - convo = openai.ChatCompletion().chain( - messages=[{"role": "user", "content": "Hello"}], - ) - - assert isinstance(convo, Conversation) - assert len(convo.turns) == 1 - - async def test_achain(self): - from marvin import openai - - convo = await openai.ChatCompletion().achain( - messages=[{"role": "user", "content": "Hello"}], - ) - - assert isinstance(convo, Conversation) - assert len(convo.turns) == 1 diff --git a/tests/test_components/__init__.py b/tests/test_components/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test_components/test_ai_app.py b/tests/test_components/test_ai_app.py deleted file mode 100644 index e5ae37dab..000000000 --- a/tests/test_components/test_ai_app.py +++ /dev/null @@ -1,291 +0,0 @@ -import jsonpatch -import pytest -from marvin._compat import model_dump -from marvin.components.ai_application import ( - AIApplication, - AppPlan, - FreeformState, - TaskState, - UpdatePlan, - UpdateState, -) -from marvin.tools import Tool -from marvin.utilities.messages import Message - -from tests.utils.mark import pytest_mark_class - - -class GetSchleeb(Tool): - name: str = "get_schleeb" - - async def run(self): - """Get the value of schleeb""" - return 42 - - -class TestStateJSONPatch: - def test_update_app_state_valid_patch(self): - app = AIApplication( - state=FreeformState(state={"foo": "bar"}), description="test app" - ) - tool = UpdateState(app=app) - tool.run([{"op": "replace", "path": "/state/foo", "value": "baz"}]) - assert model_dump(app.state) == {"state": {"foo": "baz"}} - - def test_update_app_state_invalid_patch(self): - app = AIApplication( - state=FreeformState(state={"foo": "bar"}), description="test app" - ) - tool = UpdateState(app=app) - with pytest.raises(jsonpatch.InvalidJsonPatch): - tool.run([{"op": "invalid_op", "path": "/state/foo", "value": "baz"}]) - assert model_dump(app.state) == {"state": {"foo": "bar"}} - - def test_update_app_state_non_existent_path(self): - app = AIApplication( - state=FreeformState(state={"foo": "bar"}), description="test app" - ) - tool = UpdateState(app=app) - with pytest.raises(jsonpatch.JsonPatchConflict): - tool.run([{"op": "replace", "path": "/state/baz", "value": "qux"}]) - assert model_dump(app.state) == {"state": {"foo": "bar"}} - - -@pytest_mark_class("llm") -class TestUpdateState: - def test_keep_app_state(self): - app = AIApplication( - name="location tracker app", - state=FreeformState(state={"San Francisco": {"visited": False}}), - plan_enabled=False, - description="keep track of where I've visited", - ) - - app("I just visited to San Francisco") - assert bool(app.state.state.get("San Francisco", {}).get("visited")) - - app("oh also I visited San Jose!") - - assert bool(app.state.state.get("San Jose", {}).get("visited")) - - @pytest.mark.flaky(max_runs=3) - def test_keep_app_state_undo_previous_patch(self): - app = AIApplication( - name="location tracker app", - state=FreeformState(state={"San Francisco": {"visited": False}}), - plan_enabled=False, - description="keep track of where I've visited", - ) - - app("I just visited San Francisco") - assert bool(app.state.state.get("San Francisco", {}).get("visited")) - - app( - "sorry, scratch that, I did not visit San Francisco - but I did visit San" - " Jose" - ) - - assert not bool(app.state.state.get("San Francisco", {}).get("visited")) - assert bool(app.state.state.get("San Jose", {}).get("visited")) - - -class TestPlanJSONPatch: - def test_update_app_plan_valid_patch(self): - app = AIApplication( - plan=AppPlan( - tasks=[{"id": 1, "description": "test task", "state": "IN_PROGRESS"}] - ), - description="test app", - ) - tool = UpdatePlan(app=app) - tool.run([{"op": "replace", "path": "/tasks/0/state", "value": "COMPLETED"}]) - assert model_dump(app.plan) == { - "tasks": [ - { - "id": 1, - "description": "test task", - "state": TaskState.COMPLETED, - "upstream_task_ids": None, - "parent_task_id": None, - } - ], - "notes": [], - } - - def test_update_app_plan_invalid_patch(self): - app = AIApplication( - plan=AppPlan( - tasks=[{"id": 1, "description": "test task", "state": "IN_PROGRESS"}] - ), - description="test app", - ) - tool = UpdatePlan(app=app) - with pytest.raises(jsonpatch.JsonPatchException): - tool.run( - [{"op": "invalid_op", "path": "/tasks/0/state", "value": "COMPLETED"}] - ) - assert model_dump(app.plan) == { - "tasks": [ - { - "id": 1, - "description": "test task", - "state": TaskState.IN_PROGRESS, - "upstream_task_ids": None, - "parent_task_id": None, - } - ], - "notes": [], - } - - def test_update_app_plan_non_existent_path(self): - app = AIApplication( - plan=AppPlan( - tasks=[{"id": 1, "description": "test task", "state": "IN_PROGRESS"}] - ), - description="test app", - ) - tool = UpdatePlan(app=app) - with pytest.raises(jsonpatch.JsonPointerException): - tool.run( - [{"op": "replace", "path": "/tasks/1/state", "value": "COMPLETED"}] - ) - assert model_dump(app.plan) == { - "tasks": [ - { - "id": 1, - "description": "test task", - "state": TaskState.IN_PROGRESS, - "upstream_task_ids": None, - "parent_task_id": None, - } - ], - "notes": [], - } - - -@pytest_mark_class("llm") -class TestUpdatePlan: - @pytest.mark.flaky(max_runs=3) - def test_keep_app_plan(self): - app = AIApplication( - name="Zoo planner app", - plan=AppPlan( - tasks=[ - { - "id": 1, - "description": "Visit tigers", - "state": TaskState.IN_PROGRESS, - }, - { - "id": 2, - "description": "Visit giraffes", - "state": TaskState.PENDING, - }, - ] - ), - state_enabled=False, - description="plan and track my visit to the zoo", - ) - - app( - "Actually I heard the tigers ate Carol Baskin's husband - I think I'll skip" - " visiting them." - ) - - assert [task["state"] for task in app.plan.dict()["tasks"]] == [ - TaskState.SKIPPED, - TaskState.PENDING, - ] - - app("Dude i just visited the giraffes!") - - assert [task["state"] for task in app.plan.dict()["tasks"]] == [ - TaskState.SKIPPED, - TaskState.COMPLETED, - ] - - -@pytest_mark_class("llm") -class TestUseCallable: - def test_use_sync_fn(self): - def get_schleeb(): - return 42 - - app = AIApplication( - name="Schleeb app", - tools=[get_schleeb], - state_enabled=False, - plan_enabled=False, - description="answer user questions", - ) - - assert "42" in app("what is the value of schleeb?").content - - def test_use_async_fn(self): - async def get_schleeb(): - return 42 - - app = AIApplication( - name="Schleeb app", - tools=[get_schleeb], - state_enabled=False, - plan_enabled=False, - description="answer user questions", - ) - - assert "42" in app("what is the value of schleeb?").content - - -@pytest_mark_class("llm") -class TestUseTool: - def test_use_tool(self): - app = AIApplication( - name="Schleeb app", - tools=[GetSchleeb()], - state_enabled=False, - plan_enabled=False, - description="answer user questions", - ) - - assert "42" in app("what is the value of schleeb?").content - - -@pytest_mark_class("llm") -class TestStreaming: - def test_streaming(self): - external_state = {"content": []} - - app = AIApplication( - name="streaming app", - stream_handler=lambda m: external_state["content"].append(m.content), - state_enabled=False, - plan_enabled=False, - ) - - response = app( - "say the words 'Hello world' EXACTLY as i have written them." - " no other characters should be included, do not add any punctuation." - ) - - assert isinstance(response, Message) - assert response.content == "Hello world" - - assert external_state["content"] == ["", "Hello", "Hello world", "Hello world"] - - -@pytest_mark_class("llm") -class TestMemory: - def test_recall(self): - app = AIApplication( - name="memory app", - state_enabled=False, - plan_enabled=False, - ) - - app("I like pistachio ice cream") - - response = app( - "reply only with the type of ice cream i like, it should be one word" - ) - - assert "pistachio" in response.content.lower() diff --git a/tests/test_components/test_ai_classifier.py b/tests/test_components/test_ai_classifier.py deleted file mode 100644 index ad0ab89a0..000000000 --- a/tests/test_components/test_ai_classifier.py +++ /dev/null @@ -1,109 +0,0 @@ -from enum import Enum - -import pytest -from marvin import ai_classifier - -from tests.utils.mark import pytest_mark_class - - -class TestAIClassifiersInitialization: - def test_model(self): - @ai_classifier(model="openai/gpt-4-test-model") - class Sentiment(Enum): - POSITIVE = "Positive" - NEGATIVE = "Negative" - - assert ( - Sentiment.as_chat_completion("test").defaults.get("model") - == "gpt-4-test-model" - ) - - def test_invalid_model(self): - @ai_classifier(model="anthropic/claude-2") - class Sentiment(Enum): - POSITIVE = "Positive" - NEGATIVE = "Negative" - - assert Sentiment.as_chat_completion("test").defaults.get("model") == "claude-2" - - -@pytest_mark_class("llm") -class TestAIClassifiers: - def test_sentiment(self): - @ai_classifier - class Sentiment(Enum): - POSITIVE = "Positive" - NEGATIVE = "Negative" - - assert Sentiment("Great!") == Sentiment.POSITIVE - - def test_keys_are_passed_to_llm(self): - @ai_classifier - class Sentiment(Enum): - POSITIVE = "option - 1" - NEGATIVE = "option - 2" - - assert Sentiment("Great!") == Sentiment.POSITIVE - - def test_values_are_passed_to_llm(self): - @ai_classifier - class Sentiment(Enum): - OPTION_1 = "Positive" - OPITION_2 = "Negative" - - assert Sentiment("Great!") == Sentiment.OPTION_1 - - def test_docstring_is_passed_to_llm(self): - @ai_classifier - class Sentiment(Enum): - """It's opposite day""" - - POSITIVE = "Positive" - NEGATIVE = "Negative" - - assert Sentiment("Great!") == Sentiment.NEGATIVE - - def test_instructions_are_passed_to_llm(self): - @ai_classifier - class Sentiment(Enum): - POSITIVE = "Positive" - NEGATIVE = "Negative" - - assert ( - Sentiment("Great!", instructions="today is opposite day") - == Sentiment.NEGATIVE - ) - - def test_recover_complex_values(self): - @ai_classifier - class Sentiment(Enum): - POSITIVE = {"value": "Positive"} - NEGATIVE = {"value": "Negative"} - - result = Sentiment("Great!") - - assert result.value["value"] == "Positive" - - -@pytest_mark_class("llm") -class TestMapping: - def test_mapping(self): - @ai_classifier - class Sentiment(Enum): - POSITIVE = "Positive" - NEGATIVE = "Negative" - - result = Sentiment.map(["good", "bad"]) - assert result == [Sentiment.POSITIVE, Sentiment.NEGATIVE] - - @pytest.mark.xfail(reason="Flaky with 3.5 turbo") - def test_mapping_with_instructions(self): - @ai_classifier - class Sentiment(Enum): - POSITIVE = "Positive" - NEGATIVE = "Negative" - - result = Sentiment.map( - ["good", "bad"], instructions="I want the opposite of the right answer" - ) - assert result == [Sentiment.NEGATIVE, Sentiment.POSITIVE] diff --git a/tests/test_components/test_ai_functions.py b/tests/test_components/test_ai_functions.py deleted file mode 100644 index 5b7860202..000000000 --- a/tests/test_components/test_ai_functions.py +++ /dev/null @@ -1,169 +0,0 @@ -import inspect -from typing import Dict, List - -import pytest -from marvin import ai_fn -from pydantic import BaseModel - -from tests.utils.mark import pytest_mark_class - - -@ai_fn -def list_fruit(n: int = 2) -> list[str]: - """Returns a list of `n` fruit""" - - -@ai_fn -def list_fruit_color(n: int, color: str = None) -> list[str]: - """Returns a list of `n` fruit that all have the provided `color`""" - - -@pytest_mark_class("llm") -class TestAIFunctions: - def test_list_fruit(self): - result = list_fruit() - assert len(result) == 2 - - def test_list_fruit_argument(self): - result = list_fruit(5) - assert len(result) == 5 - - async def test_list_fruit_async(self): - @ai_fn - async def list_fruit(n: int) -> list[str]: - """Returns a list of `n` fruit""" - - coro = list_fruit(3) - assert inspect.iscoroutine(coro) - result = await coro - assert len(result) == 3 - - def test_list_fruit_with_generic_type_hints(self): - @ai_fn - def list_fruit(n: int) -> List[str]: - """Returns a list of `n` fruit""" - - result = list_fruit(3) - assert len(result) == 3 - - def test_basemodel_return_annotation(self): - class Fruit(BaseModel): - name: str - color: str - - @ai_fn - def get_fruit(description: str) -> Fruit: - """Returns a fruit with the provided description""" - - fruit = get_fruit("loved by monkeys") - assert fruit.name.lower() == "banana" - assert fruit.color.lower() == "yellow" - - @pytest.mark.parametrize("name,expected", [("banana", True), ("car", False)]) - def test_bool_return_annotation(self, name, expected): - @ai_fn - def is_fruit(name: str) -> bool: - """Returns True if the provided name is a fruit""" - - assert is_fruit(name) == expected - - def test_plain_dict_return_type(self): - @ai_fn - def get_fruit(name: str) -> dict: - """Returns a fruit with the provided name and color""" - - fruit = get_fruit("banana") - assert fruit["name"].lower() == "banana" - assert fruit["color"].lower() == "yellow" - - def test_annotated_dict_return_type(self): - @ai_fn - def get_fruit(name: str) -> dict[str, str]: - """Returns a fruit with the provided name and color""" - - fruit = get_fruit("banana") - assert fruit["name"].lower() == "banana" - assert fruit["color"].lower() == "yellow" - - def test_generic_dict_return_type(self): - @ai_fn - def get_fruit(name: str) -> Dict[str, str]: - """Returns a fruit with the provided name and color""" - - fruit = get_fruit("banana") - assert fruit["name"].lower() == "banana" - assert fruit["color"].lower() == "yellow" - - def test_int_return_type(self): - @ai_fn - def get_fruit(name: str) -> int: - """Returns the number of letters in the provided fruit name""" - - assert get_fruit("banana") == 6 - - def test_float_return_type(self): - @ai_fn - def get_fruit(name: str) -> float: - """Returns the number of letters in the provided fruit name""" - - assert get_fruit("banana") == 6.0 - - def test_tuple_return_type(self): - @ai_fn - def get_fruit(name: str) -> tuple: - """Returns the number of letters in the provided fruit name""" - - assert get_fruit("banana") == (6,) - - def test_set_return_type(self): - @ai_fn - def get_fruit(name: str) -> set: - """Returns the letters in the provided fruit name""" - - assert get_fruit("banana") == {"a", "b", "n"} - - def test_frozenset_return_type(self): - @ai_fn - def get_fruit(name: str) -> frozenset: - """Returns the letters in the provided fruit name""" - - assert get_fruit("banana") == frozenset({"a", "b", "n"}) - - -@pytest_mark_class("llm") -class TestAIFunctionsMap: - def test_map(self): - result = list_fruit_color.map([2, 3]) - assert len(result) == 2 - assert len(result[0]) == 2 - assert len(result[1]) == 3 - - async def test_amap(self): - result = await list_fruit_color.amap([2, 3]) - assert len(result) == 2 - assert len(result[0]) == 2 - assert len(result[1]) == 3 - - def test_map_kwargs(self): - result = list_fruit_color.map(n=[2, 3]) - assert len(result) == 2 - assert len(result[0]) == 2 - assert len(result[1]) == 3 - - def test_map_kwargs_and_args(self): - result = list_fruit_color.map([2, 3], color=[None, "red"]) - assert len(result) == 2 - assert len(result[0]) == 2 - assert len(result[1]) == 3 - - def test_invalid_args(self): - with pytest.raises(TypeError): - list_fruit_color.map(2, color=["orange", "red"]) - - def test_invalid_kwargs(self): - with pytest.raises(TypeError): - list_fruit_color.map([2, 3], color=None) - - async def test_invalid_async_map(self): - with pytest.raises(TypeError, match="can't be used in 'await' expression"): - await list_fruit_color.map(n=[2], color=["orange", "red"]) diff --git a/tests/test_models/__init__.py b/tests/test_models/__init__.py deleted file mode 100644 index cd6cb144a..000000000 --- a/tests/test_models/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -import pytest -import inspect -from marvin.functions import Function - - -class TestFunctions: - def test_signature(self): - def fn(x, y=10): - return x + y - - f = Function(fn=fn) - assert f.signature == inspect.signature(fn) diff --git a/tests/test_models/test_functions.py b/tests/test_models/test_functions.py deleted file mode 100644 index 741f8ab93..000000000 --- a/tests/test_models/test_functions.py +++ /dev/null @@ -1,114 +0,0 @@ -class TestFunctions: - def test_signature(self): - import inspect - - from marvin.functions import Function - - def fn(x: int, y: int = 10) -> int: - return x + y - - f = Function(fn=fn) - assert f.signature == inspect.signature(fn) - - def test_name(self): - from marvin.functions import Function - - def fn(x: int, y: int = 10) -> int: - return x + y - - f = Function(fn=fn) - assert f.name == fn.__name__ - - def test_name_custom(self): - from marvin.functions import Function - - def fn(x: int, y: int = 10) -> int: - return x + y - - f = Function(fn=fn, name="custom") - assert f.name == "custom" - - def test_description(self): - from marvin.functions import Function - - def fn(x: int, y: int = 10) -> int: - """This is a description""" - return x + y - - f = Function(fn=fn) - assert f.description == fn.__doc__ - - def test_description_custom(self): - from marvin.functions import Function - - def fn(x: int, y: int = 10) -> int: - """This is a description""" - return x + y - - f = Function(fn=fn, description="custom") - assert f.description == "custom" - - def test_source_code(self): - import inspect - - from marvin.functions import Function - - def fn(x: int, y: int = 10) -> int: - return x + y - - f = Function(fn=fn) - assert f.source_code == inspect.cleandoc(inspect.getsource(fn)) - - def test_return_annotation_native(self): - import inspect - - from marvin.functions import Function - - def fn(x: int, y: int = 10) -> int: - return x + y - - f = Function(fn=fn) - assert f.return_annotation == inspect.signature(fn).return_annotation - - def test_return_annotation_pydantic(self): - from marvin.functions import Function - from pydantic import BaseModel - - class Foo(BaseModel): - foo: int - bar: int - - def fn(x: int, y: int = 10) -> Foo: - return Foo(foo=x, bar=y) - - f = Function(fn=fn) - assert f.return_annotation == Foo - - def test_arguments(self): - from marvin.functions import Function - - def fn(x: int, y: int = 10) -> int: - return x + y - - f = Function(fn=fn) - assert f.arguments(1) == {"x": 1, "y": 10} - - def test_schema(self): - from marvin.functions import Function - - def fn(x: int, y: int = 10) -> int: - return x + y - - f = Function(fn=fn) - assert f.schema() == { - "name": "fn", - "description": fn.__doc__, - "parameters": { - "type": "object", - "properties": { - "x": {"type": "integer", "title": "X"}, - "y": {"type": "integer", "title": "Y", "default": 10}, - }, - "required": ["x"], - }, - } diff --git a/tests/test_settings.py b/tests/test_settings.py new file mode 100644 index 000000000..d8084aa89 --- /dev/null +++ b/tests/test_settings.py @@ -0,0 +1,41 @@ +from marvin.settings import AssistantSettings, Settings, SpeechSettings +from pydantic_settings import SettingsConfigDict + + +def test_api_key_initialization_from_env(env): + test_api_key = "test_api_key_123" + env.set("MARVIN_OPENAI_API_KEY", test_api_key) + + temp_model_config = SettingsConfigDict(env_prefix="marvin_") + settings = Settings(model_config=temp_model_config) + + assert settings.openai.api_key.get_secret_value() == test_api_key + + +def test_runtime_api_key_override(env): + override_api_key = "test_api_key_456" + env.set("MARVIN_OPENAI_API_KEY", override_api_key) + + temp_model_config = SettingsConfigDict(env_prefix="marvin_") + settings = Settings(model_config=temp_model_config) + + assert settings.openai.api_key.get_secret_value() == override_api_key + + settings.openai.api_key = "test_api_key_789" + + assert settings.openai.api_key.get_secret_value() == "test_api_key_789" + + +class TestSpeechSettings: + def test_speech_settings_default(self): + settings = SpeechSettings() + assert settings.model == "tts-1-hd" + assert settings.voice == "alloy" + assert settings.response_format == "mp3" + assert settings.speed == 1.0 + + +class TestAssistantSettings: + def test_assistant_settings_default(self): + settings = AssistantSettings() + assert settings.model == "gpt-4-1106-preview" diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 000000000..01c9673ad --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,28 @@ +import pytest + + +def pytest_mark_class(*markers: str): + """Mark all test methods in a class with the provided markers + + Only the outermost class should be marked, which will mark all nested classes + recursively. + """ + + def mark_test_methods(cls): + for attr_name, attr_value in cls.__dict__.items(): + # mark all test methods with the provided markers + if callable(attr_value) and attr_name.startswith("test"): + for marker in markers: + marked_func = getattr(pytest.mark, marker)(attr_value) + setattr(cls, attr_name, marked_func) + # recursively mark nested classes + elif isinstance(attr_value, type) and attr_value.__name__.startswith( + "Test" + ): + mark_test_methods(attr_value) + + def decorator(cls): + mark_test_methods(cls) + return cls + + return decorator diff --git a/tests/utils/mark.py b/tests/utils/mark.py deleted file mode 100644 index 5043a7b35..000000000 --- a/tests/utils/mark.py +++ /dev/null @@ -1,11 +0,0 @@ -import pytest - - -def pytest_mark_class(marker): - def decorator(cls): - for attr_name, attr_value in cls.__dict__.items(): - if callable(attr_value) and attr_name.startswith("test"): - setattr(cls, attr_name, pytest.mark.llm(attr_value)) - return cls - - return decorator