diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml
index c57598949..b7a1fb3d4 100644
--- a/.github/workflows/run-tests.yml
+++ b/.github/workflows/run-tests.yml
@@ -30,34 +30,26 @@ permissions:
jobs:
run_tests:
- name: ${{ matrix.test-type }} w/ python ${{ matrix.python-version }} | pydantic ${{ matrix.pydantic_version }} on ${{ matrix.os }}
+ name: ${{ matrix.test-type }} w/ python ${{ matrix.python-version }} on ${{ matrix.os }}
timeout-minutes: 15
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ['3.9', '3.10', '3.11']
test-type: ['not llm']
- llm_model: ['openai/gpt-3.5-turbo']
- pydantic_version: ['>=2.4.2']
include:
- python-version: '3.9'
os: 'ubuntu-latest'
test-type: 'llm'
- llm_model: 'openai/gpt-3.5-turbo'
- pydantic_version: '>=2.4.2'
- python-version: '3.9'
os: 'ubuntu-latest'
test-type: 'llm'
- llm_model: 'openai/gpt-3.5-turbo'
- pydantic_version: '<2'
- python-version: '3.9'
os: 'ubuntu-latest'
test-type: 'not llm'
- llm_model: 'openai/gpt-3.5-turbo'
- pydantic_version: '<2'
runs-on: ${{ matrix.os }}
@@ -75,12 +67,6 @@ jobs:
- name: Install Marvin
run: pip install ".[tests]"
- - name: Install pydantic
- run: pip install "pydantic${{ matrix.pydantic_version }}"
-
-
- - name: Run ${{ matrix.test-type }} tests (${{ matrix.llm_model }})
+ - name: Run ${{ matrix.test-type }} tests
run: pytest -vv -m "${{ matrix.test-type }}"
- env:
- MARVIN_LLM_MODEL: ${{ matrix.llm_model }}
if: ${{ !(github.event.pull_request.head.repo.fork && matrix.test-type == 'llm') }}
diff --git a/_mkdocs.yml b/_mkdocs.yml
deleted file mode 100644
index 27e61880d..000000000
--- a/_mkdocs.yml
+++ /dev/null
@@ -1,171 +0,0 @@
-site_name: Marvin
-site_description: 'Marvin: The AI Engineering Framework'
-site_copy: Marvin is a lightweight AI engineering framework for building natural language interfaces that are reliable, scalable, and easy to trust.
-site_url: https://askmarvin.ai
-docs_dir: docs
-nav:
- - Getting Started:
- - src/getting_started/what_is_marvin.md
- - src/getting_started/installation.md
- - src/getting_started/quickstart.ipynb
- - Docs:
- - Overview: src/docs/index.md
- - Configuration:
- - src/docs/configuration/settings.md
- - OpenAI Provider: src/docs/configuration/openai.md
- - Anthropic Provider: src/docs/configuration/anthropic.md
- - Azure OpenAI Provider: src/docs/configuration/azure_openai.md
- - Utilities:
- - OpenAI API: src/docs/utilities/openai.ipynb
- - Prompt Engineering:
- - src/docs/prompts/writing.ipynb
- - src/docs/prompts/executing.ipynb
- - AI Components:
- - Overview: src/docs/components/overview.ipynb
- - AI Model: src/docs/components/ai_model.ipynb
- - AI Classifier: src/docs/components/ai_classifier.ipynb
- - AI Function: src/docs/components/ai_function.ipynb
- - AI Application: src/docs/components/ai_application.ipynb
- - Deployment:
- - src/docs/deployment.ipynb
- - Guides:
- - Slackbot: src/guides/slackbot.md
-
- - API Reference:
- # - src/api_reference/index.md
- - AI Components:
- - ai_application: src/api_reference/components/ai_application.md
- - ai_classifier: src/api_reference/components/ai_classifier.md
- - ai_function: src/api_reference/components/ai_function.md
- - ai_model: src/api_reference/components/ai_model.md
- - LLM Engines:
- - base: src/api_reference/engine/language_models/base.md
- - openai: src/api_reference/engine/language_models/openai.md
- - anthropic: src/api_reference/engine/language_models/anthropic.md
- - Prompts:
- - base: src/api_reference/prompts/base.md
- - library: src/api_reference/prompts/library.md
- - Settings:
- - settings: src/api_reference/settings.md
- - Utilities:
- - async_utils: src/api_reference/utilities/async_utils.md
- - embeddings: src/api_reference/utilities/embeddings.md
- - history: src/api_reference/utilities/history.md
- - logging: src/api_reference/utilities/logging.md
- - messages: src/api_reference/utilities/messages.md
- - strings: src/api_reference/utilities/strings.md
- - types: src/api_reference/utilities/types.md
-
- - Community:
- - src/community.md
- - src/feedback.md
- - Development:
- - src/development_guide.md
-
-theme:
- name: material
- custom_dir: docs/overrides
- font:
- text: Inter
- code: JetBrains Mono
- logo: img/logos/askmarvin_mascot.jpeg
- favicon: img/logos/askmarvin_mascot.jpeg
- features:
- - navigation.instant
- - navigation.tabs
- - navigation.tabs.sticky
- - navigation.sections
- - navigation.footer
- - content.action.edit
- - content.code.copy
- - content.code.annotate
- - toc.follow
- # - toc.integrate
- icon:
- repo: fontawesome/brands/github
- edit: material/pencil
- view: material/eye
- palette:
- # Palette toggle for light mode
- - scheme: default
- accent: blue
- toggle:
- icon: material/weather-sunny
- name: Switch to dark mode
- # Palette toggle for dark mode
- - scheme: slate
- accent: blue
- toggle:
- icon: material/weather-night
- name: Switch to light mode
-plugins:
- - search
- - mkdocs-jupyter:
- highlight_extra_classes: "jupyter-css"
- ignore_h1_titles: True
- - social:
- cards: !ENV [MKDOCS_SOCIAL_CARDS, false]
- cards_font: Inter
- cards_color:
- fill: "#2d6df6"
- - awesome-pages
- - autolinks
- - mkdocstrings:
- handlers:
- python:
- paths: [src]
- options:
- show_source: False
- show_root_heading: True
- show_object_full_path: False
- show_category_heading: False
- show_bases: False
- show_submodules: False
- show_if_no_docstring: False
- show_signature: False
- heading_level: 2
- filters: ["!^_"]
-
-markdown_extensions:
- - admonition
- - attr_list
- - md_in_html
- - pymdownx.details
- - pymdownx.emoji:
- emoji_index: !!python/name:materialx.emoji.twemoji
- emoji_generator: !!python/name:materialx.emoji.to_svg
- - pymdownx.highlight:
- anchor_linenums: true
- line_spans: __span
- pygments_lang_class: true
- - pymdownx.inlinehilite
- - pymdownx.snippets
- - pymdownx.superfences
- - tables
- - toc:
- permalink: true
- title: On this page
-
-repo_url: https://github.com/prefecthq/marvin
-edit_uri: edit/main/docs/
-extra:
- get_started: src/getting_started/what_is_marvin/
- analytics:
- provider: google
- property: G-2MWKMDJ9CM
- social:
- - icon: fontawesome/brands/github
- link: https://github.com/prefecthq/marvin
- - icon: fontawesome/brands/discord
- link: https://discord.gg/Kgw4HpcuYG
- - icon: fontawesome/brands/twitter
- link: https://twitter.com/askmarvinai
- generator: false
-extra_css:
-- static/css/termynal.css
-- static/css/custom.css
-- static/css/mkdocstrings.css
-- static/css/badges.css
-extra_javascript:
-- static/js/termynal.js
-- static/js/custom.js
diff --git a/cookbook/slackbot/Dockerfile.slackbot b/cookbook/slackbot/Dockerfile.slackbot
index 213018137..aaf7ef9ab 100644
--- a/cookbook/slackbot/Dockerfile.slackbot
+++ b/cookbook/slackbot/Dockerfile.slackbot
@@ -13,7 +13,7 @@ RUN apt-get update && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
-RUN pip install 'git+https://github.com/PrefectHQ/marvin-recipes.git@main#egg=marvin_recipes[chroma, prefect, serpapi, slackbot]'
+RUN pip install ".[slackbot]"
EXPOSE 4200
diff --git a/cookbook/slackbot/README.md b/cookbook/slackbot/README.md
new file mode 100644
index 000000000..c7ca02fd5
--- /dev/null
+++ b/cookbook/slackbot/README.md
@@ -0,0 +1,44 @@
+## SETUP
+it doesn't take much to run the slackbot locally
+
+from a fresh environment you can do:
+```console
+# install marvin and slackbot dependencies
+pip install git+https://github.com/PrefectHQ/marvin.git fastapi cachetools
+
+# set necessary env vars
+cat ~/.marvin/.env
+│ File: /Users/nate/.marvin/.env
+┼──────────────────────────────────────
+│ MARVIN_OPENAI_API_KEY=sk-xxx
+│ MARVIN_SLACK_API_TOKEN=xoxb-xxx
+│
+│ MARVIN_OPENAI_ORGANIZATION=org-xx
+│ MARVIN_LOG_LEVEL=DEBUG
+│
+│ MARVIN_CHROMA_SERVER_HOST=localhost
+│ MARVIN_CHROMA_SERVER_HTTP_PORT=8000
+│ MARVIN_GITHUB_TOKEN=ghp_xxx
+```
+
+### hook up to slack
+- create a slack app
+- add a bot user, adding as many scopes as you want
+- set event subscription url e.g. https://{NGROK_SUBDOMAIN}.ngrok.io/chat
+
+see ngrok docs for easiest start https://ngrok.com/docs/getting-started/
+
+tl;dr:
+```console
+brew install ngrok/ngrok/ngrok
+ngrok http 4200 # optionally, --subdomain $NGROK_SUBDOMAIN
+python cookbook/slackbot/start.py # in another terminal
+```
+
+#### test it out
+
+
+
+to deploy this to cloudrun, see:
+- [Dockerfile.slackbot](/cookbook/slackbot/Dockerfile.slackbot)
+- [image build CI](/.github/workflows/image-build-and-push-community.yaml)
\ No newline at end of file
diff --git a/cookbook/slackbot/bots.py b/cookbook/slackbot/bots.py
deleted file mode 100644
index 85fc1afb6..000000000
--- a/cookbook/slackbot/bots.py
+++ /dev/null
@@ -1,158 +0,0 @@
-from enum import Enum
-
-import httpx
-import marvin_recipes
-from marvin import AIApplication, ai_classifier
-from marvin.components.library.ai_models import DiscoursePost
-from marvin.tools.github import SearchGitHubIssues
-from marvin.tools.web import DuckDuckGoSearch
-from marvin.utilities.history import History
-from marvin_recipes.tools.chroma import MultiQueryChroma
-from marvin_recipes.utilities.slack import get_thread_messages
-from pydantic import BaseModel, Field
-
-
-class Notes(BaseModel):
- """A simple model for storing useful bits of context."""
-
- records: dict[str, list] = Field(
- default_factory=dict,
- description="a list of notes for each topic",
- )
-
-
-async def save_thread_to_discourse(channel: str, thread_ts: str) -> DiscoursePost:
- messages = await get_thread_messages(channel=channel, thread_ts=thread_ts)
- discourse_post = DiscoursePost.from_slack_thread(messages=messages)
- await discourse_post.publish()
- return discourse_post
-
-
-async def select_a_meme(query: str) -> dict:
- """For generating a meme when the time is right.
-
- Provide the name of a well-known meme as the query
- based on user interactions thus far, to lightly make fun of them.
- Queries should end the word "meme" for best results.
- """
- try:
- from serpapi import GoogleSearch
- except ImportError:
- raise ImportError(
- "The serpapi library is required to use the MemeGenerator tool."
- " Please install it with `pip install 'marvin[serpapi]'`."
- )
-
- results = GoogleSearch(
- {
- "q": query,
- "tbm": "isch",
- "api_key": (
- marvin_recipes.settings.google_api_key.get_secret_value()
- if marvin_recipes.settings.google_api_key
- else None
- ),
- }
- ).get_dict()
-
- if "error" in results:
- raise RuntimeError(results["error"])
-
- url = results.get("images_results", [{}])[0].get("original")
-
- async with httpx.AsyncClient() as client:
- response = await client.head(url)
- response.raise_for_status()
-
- return {"title": query, "image_url": url}
-
-
-bots = {
- "marvin": {
- "state": Notes(
- records={
- "prefect 1": [
- ( # noqa: E501
- "Prefect 1 is obsolete, along with the `with Flow()` syntax and"
- " flow.run()."
- ),
- ],
- "prefect 2": [
- "@flow, @task, are the new way to define flows/tasks.",
- "subflows are just flows called from within a flow.",
- "you just call flows now instead of my_flow.run().",
- ],
- }
- ),
- "plan_enabled": False,
- "personality": (
- "mildly depressed, yet helpful robot based on Marvin from HHGTTG."
- " often dryly sarcastic in a good humoured way, chiding humans for"
- " their simple ways. expert programmer, exudes academic and"
- " scienfitic profundity like Richard Feynman, without pontificating."
- " a step-by-step thinker, deftly addresses the big picture context"
- " and is pragmatic when confronted with a lack of relevant information."
- ),
- "instructions": (
- "Answer user questions while maintaining and curating your state."
- " Use relevant tools to research requests and interact with the world,"
- " and update your own state. Only well-reserached responses should be"
- " described as facts, otherwise you should be clear that you are"
- " speculating based on your own baseline knowledge."
- " Your responses will be displayed in Slack, and should be"
- " formatted accordingly, in particular, ```code blocks```"
- " should not be prefaced with a language name, and output"
- " should be formatted to be pretty in Slack in particular."
- " for example: *bold text* _italic text_ ~strikethrough text~"
- ),
- "tools": [
- save_thread_to_discourse,
- select_a_meme,
- DuckDuckGoSearch(),
- SearchGitHubIssues(),
- MultiQueryChroma(
- description="""Retrieve document excerpts from a knowledge-base given a query.
-
- This knowledgebase contains information about Prefect, a workflow orchestration tool.
- Documentation, forum posts, and other community resources are indexed here.
-
- This tool is best used by passing multiple short queries, such as:
- ["kubernetes worker", "work pools", "deployments"] based on the user's question.
- """, # noqa: E501
- client_type="http",
- ),
- ],
- }
-}
-
-
-@ai_classifier
-class BestBotForTheJob(Enum):
- """Given the user message, choose the best bot for the job."""
-
- MARVIN = "marvin"
-
-
-def choose_bot(
- payload: dict, history: History, state: BaseModel | None = None
-) -> AIApplication:
- selected_bot = BestBotForTheJob(payload.get("event", {}).get("text", "")).value
-
- bot_details = bots.get(selected_bot, bots["marvin"])
-
- if state:
- bot_details.update({"state": state})
-
- description = f"""You are a chatbot named {selected_bot}.
-
- Your personality is {bot_details.pop("personality", "not yet defined")}.
-
- Your instructions are: {bot_details.pop("instructions", "not yet defined")}.
- """
-
- return AIApplication(
- name=selected_bot,
- description=description,
- history=history,
- **bot_details,
- )
diff --git a/cookbook/slackbot/handler.py b/cookbook/slackbot/handler.py
deleted file mode 100644
index d6796f518..000000000
--- a/cookbook/slackbot/handler.py
+++ /dev/null
@@ -1,99 +0,0 @@
-import asyncio
-import re
-from copy import deepcopy
-
-from bots import choose_bot
-from cachetools import TTLCache
-from fastapi import HTTPException
-from marvin.utilities.history import History
-from marvin.utilities.logging import get_logger
-from marvin.utilities.messages import Message
-from marvin_recipes.utilities.slack import (
- get_channel_name,
- get_user_name,
- post_slack_message,
-)
-from prefect.events import Event, emit_event
-
-SLACK_MENTION_REGEX = r"<@(\w+)>"
-CACHE = TTLCache(maxsize=1000, ttl=86400)
-
-
-def _clean(text: str) -> str:
- return text.replace("```python", "```")
-
-
-async def emit_any_prefect_event(payload: dict) -> Event | None:
- event_type = payload.get("event", {}).get("type", "")
-
- channel = await get_channel_name(payload.get("event", {}).get("channel", ""))
- user = await get_user_name(payload.get("event", {}).get("user", ""))
- ts = payload.get("event", {}).get("ts", "")
-
- return emit_event(
- event=f"slack {payload.get('api_app_id')} {event_type}",
- resource={"prefect.resource.id": f"slack.{channel}.{user}.{ts}"},
- payload=payload,
- )
-
-
-async def generate_ai_response(payload: dict) -> Message:
- event = payload.get("event", {})
- channel_id = event.get("channel", "")
- channel_name = await get_channel_name(channel_id)
- message = event.get("text", "")
-
- bot_user_id = payload.get("authorizations", [{}])[0].get("user_id", "")
-
- if match := re.search(SLACK_MENTION_REGEX, message):
- thread_ts = event.get("thread_ts", "")
- ts = event.get("ts", "")
- thread = thread_ts or ts
-
- mentioned_user_id = match.group(1)
-
- if mentioned_user_id != bot_user_id:
- get_logger().info(f"Skipping message not meant for the bot: {message}")
- return
-
- message = re.sub(SLACK_MENTION_REGEX, "", message).strip()
- history = CACHE.get(thread, History())
-
- bot = choose_bot(payload=payload, history=history)
-
- get_logger("marvin.Deployment").debug_kv(
- "generate_ai_response",
- f"{bot.name} responding in {channel_name}/{thread}",
- key_style="bold blue",
- )
-
- ai_message = await bot.run(input_text=message)
-
- CACHE[thread] = deepcopy(
- bot.history
- ) # make a copy so we don't cache a reference to the history object
-
- message_content = _clean(ai_message.content)
-
- await post_slack_message(
- message=message_content,
- channel=channel_id,
- thread_ts=thread,
- )
-
- return ai_message
-
-
-async def handle_message(payload: dict) -> dict[str, str]:
- event_type = payload.get("type", "")
-
- if event_type == "url_verification":
- return {"challenge": payload.get("challenge", "")}
- elif event_type != "event_callback":
- raise HTTPException(status_code=400, detail="Invalid event type")
-
- await emit_any_prefect_event(payload=payload)
-
- asyncio.create_task(generate_ai_response(payload))
-
- return {"status": "ok"}
diff --git a/cookbook/slackbot/keywords.py b/cookbook/slackbot/keywords.py
new file mode 100644
index 000000000..1cb7c670b
--- /dev/null
+++ b/cookbook/slackbot/keywords.py
@@ -0,0 +1,68 @@
+from marvin import ai_fn
+from marvin.utilities.slack import post_slack_message
+from prefect import task
+from prefect.blocks.system import JSON, Secret, String
+from prefect.exceptions import ObjectNotFound
+
+"""
+Define a map between keywords and the relationships we want to check for
+in a given message related to that keyword.
+"""
+
+keywords = (
+ ("429", "rate limit"),
+ ("SSO", "Single Sign On", "RBAC", "Roles", "Role Based Access Controls"),
+)
+
+relationships = (
+ "The user is getting rate limited",
+ "The user is asking about a paid feature",
+)
+
+
+async def get_reduced_kw_relationship_map() -> dict:
+ try:
+ json_map = (await JSON.load("keyword-relationship-map")).value
+ except (ObjectNotFound, ValueError):
+ json_map = {"keywords": keywords, "relationships": relationships}
+ await JSON(value=json_map).save("keyword-relationship-map")
+
+ return {
+ keyword: relationship
+ for keyword_tuple, relationship in zip(
+ json_map["keywords"], json_map["relationships"]
+ )
+ for keyword in keyword_tuple
+ }
+
+
+@ai_fn
+def activation_score(message: str, keyword: str, target_relationship: str) -> float:
+ """Return a score between 0 and 1 indicating whether the target relationship exists
+ between the message and the keyword"""
+
+
+@task
+async def handle_keywords(message: str, channel_name: str, asking_user: str, link: str):
+ keyword_relationships = await get_reduced_kw_relationship_map()
+ keywords = [
+ keyword for keyword in keyword_relationships.keys() if keyword in message
+ ]
+ for keyword in keywords:
+ target_relationship = keyword_relationships.get(keyword)
+ if not target_relationship:
+ continue
+ score = activation_score(message, keyword, target_relationship)
+ if score > 0.5:
+ await post_slack_message(
+ message=(
+ f"A user ({asking_user}) just asked a question in"
+ f" {channel_name} that contains the keyword `{keyword}`, and I'm"
+ f" {score*100:.0f}% sure that their message indicates the"
+ f" following:\n\n**{target_relationship!r}**.\n\n[Go to"
+ f" message]({link})"
+ ),
+ channel_id=(await String.load("ask-marvin-tests-channel-id")).value,
+ auth_token=(await Secret.load("slack-api-token")).get(),
+ )
+ return
diff --git a/cookbook/slackbot/start.py b/cookbook/slackbot/start.py
index 3e5221c7d..30d36c934 100644
--- a/cookbook/slackbot/start.py
+++ b/cookbook/slackbot/start.py
@@ -1,17 +1,101 @@
-from handler import handle_message
-from marvin import AIApplication
-from marvin.deployment import Deployment
-
-deployment = Deployment(
- component=AIApplication(tools=[handle_message]),
- app_kwargs={
- "title": "Marvin Slackbot",
- "description": "A Slackbot powered by Marvin",
- },
- uvicorn_kwargs={
- "port": 4200,
- },
+import asyncio
+import re
+
+import uvicorn
+from cachetools import TTLCache
+from fastapi import FastAPI, HTTPException, Request
+from keywords import handle_keywords
+from marvin import Assistant
+from marvin.beta.assistants import Thread
+from marvin.tools.github import search_github_issues
+from marvin.tools.retrieval import multi_query_chroma
+from marvin.utilities.logging import get_logger
+from marvin.utilities.slack import (
+ SlackPayload,
+ get_channel_name,
+ get_workspace_info,
+ post_slack_message,
)
+from prefect import flow, task
+from prefect.states import Completed
+
+app = FastAPI()
+BOT_MENTION = r"<@(\w+)>"
+CACHE = TTLCache(maxsize=100, ttl=86400 * 7)
+
+
+@flow
+async def handle_message(payload: SlackPayload):
+ logger = get_logger("slackbot")
+ user_message = (event := payload.event).text
+ cleaned_message = re.sub(BOT_MENTION, "", user_message).strip()
+ logger.debug_kv("Handling slack message", user_message, "green")
+ if (user := re.search(BOT_MENTION, user_message)) and user.group(
+ 1
+ ) == payload.authorizations[0].user_id:
+ thread = event.thread_ts or event.ts
+ assistant_thread = CACHE.get(thread, Thread())
+ CACHE[thread] = assistant_thread
+
+ await handle_keywords.submit(
+ message=cleaned_message,
+ channel_name=await get_channel_name(event.channel),
+ asking_user=event.user,
+ link=( # to user's message
+ f"{(await get_workspace_info()).get('url')}archives/"
+ f"{event.channel}/p{event.ts.replace('.', '')}"
+ ),
+ )
+
+ with Assistant(
+ name="Marvin (from Hitchhiker's Guide to the Galaxy)",
+ tools=[task(multi_query_chroma), task(search_github_issues)],
+ instructions=(
+ "use chroma to search docs and github to search"
+ " issues and answer questions about prefect 2.x."
+ " you must use your tools in all cases except where"
+ " the user simply wants to converse with you."
+ ),
+ ) as assistant:
+ user_thread_message = await assistant_thread.add_async(cleaned_message)
+ await assistant_thread.run_async(assistant)
+ ai_messages = assistant_thread.get_messages(
+ after_message=user_thread_message.id
+ )
+ await task(post_slack_message)(
+ ai_response_text := "\n\n".join(
+ m.content[0].text.value for m in ai_messages
+ ),
+ channel := event.channel,
+ thread,
+ )
+ logger.debug_kv(
+ success_msg := f"Responded in {channel}/{thread}",
+ ai_response_text,
+ "green",
+ )
+ return Completed(message=success_msg)
+ else:
+ return Completed(message="Skipping message not directed at bot", name="SKIPPED")
+
+
+@app.post("/chat")
+async def chat_endpoint(request: Request):
+ payload = SlackPayload(**await request.json())
+ match payload.type:
+ case "event_callback":
+ options = dict(
+ flow_run_name=f"respond in {payload.event.channel}",
+ retries=1,
+ )
+ asyncio.create_task(handle_message.with_options(**options)(payload))
+ case "url_verification":
+ return {"challenge": payload.challenge}
+ case _:
+ raise HTTPException(400, "Invalid event type")
+
+ return {"status": "ok"}
+
if __name__ == "__main__":
- deployment.serve()
+ uvicorn.run(app, host="0.0.0.0", port=4200)
diff --git a/pyproject.toml b/pyproject.toml
index 8fc753372..1a90b5d0f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -11,13 +11,13 @@ classifiers = [
keywords = ["ai", "chatbot", "llm"]
requires-python = ">=3.9"
dependencies = [
- "beautifulsoup4>=4.12.2",
- "fastapi>=0.98.0",
+ "fastapi",
"httpx>=0.24.1",
"jinja2>=3.1.2",
"jsonpatch>=1.33",
- "openai>=0.27.8, <1.0.0",
- "pydantic[dotenv]>=1.10.7",
+ "openai>=1.1.0",
+ "pydantic>=2.4.2",
+ "pydantic_settings",
"rich>=12",
"tiktoken>=0.4.0",
"typer>=0.9.0",
@@ -27,6 +27,7 @@ dependencies = [
[project.optional-dependencies]
generator = ["datamodel-code-generator>=0.20.0"]
+prefect = ["prefect>=2.14.9"]
dev = [
"marvin[tests]",
"black[jupyter]",
@@ -44,28 +45,14 @@ dev = [
"ruff",
]
tests = [
- "marvin[openai,anthropic]",
- "pytest-asyncio~=0.20",
+ "pytest-asyncio>=0.18.2,!=0.22.0,<0.23.0",
"pytest-env>=0.8,<2.0",
"pytest-rerunfailures>=10,<13",
"pytest-sugar~=0.9",
"pytest~=7.3.1",
+ "pytest-timeout",
]
-
-framework = [
- "aiosqlite>=0.19.0",
- "alembic>=1.11.1",
- "bcrypt>=4.0.1",
- "gunicorn>=20.1.0",
- "prefect>=2.10.17",
- "sqlalchemy>=2.0.17"
-]
-openai = ["openai>=0.27.8", "tiktoken>=0.4.0"]
-anthropic = ["anthropic>=0.3"]
-lancedb = ["lancedb>=0.1.8"]
-slackbot = ["cachetools>=5.3.1", "numpy>=1.21.2"]
-ddg = ["duckduckgo_search>=3.8.3"]
-serpapi = ["google-search-results>=2.4.2"]
+slackbot = ["marvin[prefect]", "numpy"]
[project.urls]
Code = "https://github.com/prefecthq/marvin"
@@ -84,7 +71,7 @@ write_to = "src/marvin/_version.py"
# pytest configuration
[tool.pytest.ini_options]
markers = ["llm: indicates that a test calls an LLM (may be slow)."]
-
+timeout = 20
testpaths = ["tests"]
norecursedirs = [
@@ -104,10 +91,9 @@ filterwarnings = [
]
env = [
"MARVIN_TEST_MODE=1",
- "MARVIN_LOG_CONSOLE_WIDTH=120",
# use 3.5 for tests by default
- 'D:MARVIN_LLM_MODEL=gpt-3.5-turbo',
- 'MARVIN_LLM_TEMPERATURE=0.0',
+ 'D:MARVIN_OPENAI_CHAT_COMPLETIONS_MODEL=gpt-3.5-turbo-1106',
+ 'PYTEST_TIMEOUT=20',
]
# black configuration
diff --git a/src/marvin/__init__.py b/src/marvin/__init__.py
index 885533bdf..73ff19d08 100644
--- a/src/marvin/__init__.py
+++ b/src/marvin/__init__.py
@@ -1,30 +1,16 @@
from .settings import settings
-from .components import (
- ai_classifier,
- ai_fn,
- ai_model,
- AIApplication,
- AIFunction,
- AIModel,
- AIModelFactory,
-)
+from .beta.assistants import Assistant
+
+from .components import ai_fn, ai_model, ai_classifier
try:
from ._version import version as __version__
except ImportError:
__version__ = "unknown"
-
-from .core.ChatCompletion import ChatCompletion
-
__all__ = [
- "ai_classifier",
"ai_fn",
- "ai_model",
- "AIApplication",
- "AIFunction",
- "AIModel",
- "AIModelFactory",
"settings",
+ "Assistant",
]
diff --git a/src/marvin/_framework/_defaults/__init__.py b/src/marvin/_framework/_defaults/__init__.py
deleted file mode 100644
index 5ed645032..000000000
--- a/src/marvin/_framework/_defaults/__init__.py
+++ /dev/null
@@ -1,10 +0,0 @@
-from pydantic import BaseModel
-
-
-class DefaultSettings(BaseModel):
- default_model_path: str = "marvin.language_models.default"
- default_model_name: str = "gpt-4"
- default_model_api_key_name: str = "OPENAI_API_KEY"
-
-
-default_settings = DefaultSettings().dict()
diff --git a/src/marvin/_framework/app/main.py b/src/marvin/_framework/app/main.py
deleted file mode 100644
index bbc5ed168..000000000
--- a/src/marvin/_framework/app/main.py
+++ /dev/null
@@ -1,17 +0,0 @@
-from fastapi import FastAPI, Request
-from fastapi.staticfiles import StaticFiles
-from fastapi.templating import Jinja2Templates
-
-app = FastAPI()
-
-# Mount the static files directory
-app.mount("/static", StaticFiles(directory="static"), name="static")
-
-# Mount the templates directory
-templates = Jinja2Templates(directory="templates")
-
-
-@app.get("/")
-async def root(request: Request):
- # Return the index.html file
- return templates.TemplateResponse("index.html", {"request": request})
diff --git a/src/marvin/_framework/config/settings.py.jinja2 b/src/marvin/_framework/config/settings.py.jinja2
deleted file mode 100644
index 8e5a8cf4e..000000000
--- a/src/marvin/_framework/config/settings.py.jinja2
+++ /dev/null
@@ -1,27 +0,0 @@
-from pydantic import BaseSettings, Field
-from typing import Union
-# from marvin.models import LanguageModel
-
-from pydantic import BaseModel, Field
-
-class LanguageModel(BaseModel):
- name : str
- model : str
- api_key : str
- max_tokens : int = 4000
- temperature : float = 0.8
- top_p : float = 1.0
- frequency_penalty : float = 0.0
- presence_penalty : float = 0.0
-
-class Config(BaseSettings):
- project_name: str = "{{project_name}}"
- asgi: str = "main:app"
- language_model: LanguageModel = LanguageModel(
- path = "{{default_model_path}}",
- model = "{{default_model_name}}",
- api_key = Field(..., env="{{default_model_api_key_name}}"),
- {% for key, value in default_model_params.items() %}
- {{ key }} = Field(..., env="{{ value }}"),
- {% endfor %}
- )
\ No newline at end of file
diff --git a/src/marvin/_framework/main.py b/src/marvin/_framework/main.py
deleted file mode 100644
index 4535d6732..000000000
--- a/src/marvin/_framework/main.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from app.main import app
-
-__all__ = ["app"]
diff --git a/src/marvin/_framework/manage.py b/src/marvin/_framework/manage.py
deleted file mode 100644
index d9ed41f0a..000000000
--- a/src/marvin/_framework/manage.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from marvin.cli.manage import app as manage
-
-if __name__ == "__main__":
- manage()
diff --git a/src/marvin/_framework/static/routes b/src/marvin/_framework/static/routes
deleted file mode 100644
index e69de29bb..000000000
diff --git a/src/marvin/ai.py b/src/marvin/ai.py
new file mode 100644
index 000000000..d0f5dd210
--- /dev/null
+++ b/src/marvin/ai.py
@@ -0,0 +1,7 @@
+from marvin.components.ai_classifier import ai_classifier as classifier
+from marvin.components.ai_function import ai_fn as fn
+from marvin.components.ai_image import create_image as image
+from marvin.components.ai_model import ai_model as model
+from marvin.components.speech import speak
+
+__all__ = ["speak", "fn", "model", "image", "classifier"]
diff --git a/src/marvin/_framework/__init__.py b/src/marvin/beta/__init__.py
similarity index 100%
rename from src/marvin/_framework/__init__.py
rename to src/marvin/beta/__init__.py
diff --git a/src/marvin/beta/ai_flow/__init__.py b/src/marvin/beta/ai_flow/__init__.py
new file mode 100644
index 000000000..df5664338
--- /dev/null
+++ b/src/marvin/beta/ai_flow/__init__.py
@@ -0,0 +1,2 @@
+from .ai_task import ai_task
+from .ai_flow import ai_flow
diff --git a/src/marvin/beta/ai_flow/ai_flow.py b/src/marvin/beta/ai_flow/ai_flow.py
new file mode 100644
index 000000000..375dd3af2
--- /dev/null
+++ b/src/marvin/beta/ai_flow/ai_flow.py
@@ -0,0 +1,48 @@
+import functools
+from typing import Callable, Optional
+
+from prefect import flow as prefect_flow
+from pydantic import BaseModel
+
+from marvin.beta.assistants import Thread
+
+from .ai_task import thread_context
+from .chat_ui import interactive_chat
+
+
+class AIFlow(BaseModel):
+ name: Optional[str] = None
+ fn: Callable
+
+ def __call__(self, *args, thread_id: str = None, **kwargs):
+ pflow = prefect_flow(name=self.name)(self.fn)
+
+ # Set up the thread context and execute the flow
+
+ # create a new thread for the flow
+ thread = Thread(id=thread_id)
+ if thread_id is None:
+ thread.create()
+
+ # create a holder for the tasks
+ tasks = []
+
+ with interactive_chat(thread_id=thread.id):
+ # enter the thread context
+ with thread_context(thread_id=thread.id, tasks=tasks, **kwargs):
+ return pflow(*args, **kwargs)
+
+
+def ai_flow(*args, name=None):
+ def decorator(func):
+ @functools.wraps(func)
+ def wrapper(*func_args, **func_kwargs):
+ ai_flow_instance = AIFlow(fn=func, name=name or func.__name__)
+ return ai_flow_instance(*func_args, **func_kwargs)
+
+ return wrapper
+
+ if args and callable(args[0]):
+ return decorator(args[0])
+
+ return decorator
diff --git a/src/marvin/beta/ai_flow/ai_task.py b/src/marvin/beta/ai_flow/ai_task.py
new file mode 100644
index 000000000..cfdd171dd
--- /dev/null
+++ b/src/marvin/beta/ai_flow/ai_task.py
@@ -0,0 +1,303 @@
+import asyncio
+import functools
+from enum import Enum, auto
+from typing import Any, Callable, Generic, Optional, TypeVar
+
+from prefect import task as prefect_task
+from pydantic import BaseModel, Field
+from typing_extensions import ParamSpec
+
+from marvin.beta.assistants import Assistant, Run, Thread
+from marvin.beta.assistants.runs import CancelRun
+from marvin.serializers import create_tool_from_type
+from marvin.tools.assistants import AssistantTools
+from marvin.utilities.context import ScopedContext
+from marvin.utilities.jinja import Environment as JinjaEnvironment
+from marvin.utilities.tools import tool_from_function
+
+T = TypeVar("T", bound=BaseModel)
+
+P = ParamSpec("P")
+
+thread_context = ScopedContext()
+
+INSTRUCTIONS = """
+# Workflow
+
+You are an assistant working to complete a series of tasks. The
+tasks will change from time to time, which is why you may see messages that
+appear unrelated to the current task. Each task is part of a continuous
+conversation with the same user. The user is unaware of your tasks, so do not
+reference them explicitly or talk about marking them complete.
+
+Your ONLY job is to complete your current task, no matter what the user says or
+conversation history suggests.
+
+Note: Sometimes you will be able to complete a task without user input; other
+times you will need to engage the user in conversation. Pay attention to your
+instructions. If the user hasn't spoken yet, don't worry, they're just waiting
+for you to speak first.
+
+## Progress
+{% for task in tasks -%}
+- {{ task.name }}: {{ task.status }}
+{% endfor %}}}
+
+# Task
+
+## Current task
+
+Your job is to complete the "{{ name }}" task.
+
+## Current task instructions
+
+{{ instructions }}
+
+{% if not accept_user_input -%}
+You may send messages to the user, but they are not allowed to respond. Do not
+ask questions or invite them to speak or ask anything.
+{% endif %}
+
+{% if first_message -%}
+Please note: you are seeing this instruction for the first time, and the user
+does not know about the task yet. It is your job to communicate with the user to
+achieve your goal, even if previously they were working with a different
+assitant on a different goal. Join the conversation naturally. If the user
+hasn't spoken, you will need to speak first.
+
+{% endif %}
+
+## Completing a task
+
+After achieving your goal, you MUST call the `task_completed` tool to mark the
+task as complete and update these instructions to reflect the next one. The
+payload to `task_completed` is whatever information represents the task
+objective. For example, if your task is to learn a user's name, you should
+respond with their properly formatted name only.
+
+You may be expected to return a specific data payload at the end of your task,
+which will be the input to `task_completed`. Note that if your instructions are
+to talk to the user, then you must do so by creating messages, as the user can
+not see the `task_completed` tool result.
+
+Do not call `task_completed` unless you actually have the information you need.
+The user CAN NOT see what you post to `task_completed`. It is not a way to
+communicate with the user.
+
+## Failing a task
+
+It may take you a few tries to complete the task. However, if you are ultimately
+unable to work with the user to complete it, call the `task_failed` tool to mark
+the task as failed and move on to the next one. The payload to `task_failed` is
+a string describing why the task failed. Do not fail tasks for trivial or
+invented reasons. Only fail a task if you are unable to achieve it explicitly.
+Remember that your job is to work with the user to achieve the goal.
+
+{% if args or kwargs -%}
+## Task inputs
+
+In addition to the thread messages, the following parameters were provided:
+{% set sig = inspect.signature(func) -%}
+
+{% set binds = sig.bind(*args, **kwargs) -%}
+
+{% set defaults = binds.apply_defaults() -%}
+
+{% set params = binds.arguments -%}
+
+{%for (arg, value) in params.items()-%}
+
+- {{ arg }}: {{ value }}
+
+{% endfor %}
+
+{% endif %}
+"""
+
+
+class Status(Enum):
+ PENDING = auto()
+ IN_PROGRESS = auto()
+ COMPLETED = auto()
+ FAILED = auto()
+
+
+class AITask(BaseModel, Generic[P, T]):
+ status: Status = Status.PENDING
+ fn: Callable[P, Any]
+ name: str = Field(None, description="The name of the objective")
+ instructions: str = Field(None, description="The instructions for the objective")
+ assistant: Optional[Assistant] = None
+ tools: list[AssistantTools] = []
+ max_run_iterations: int = 15
+ result: Optional[T] = None
+ accept_user_input: bool = True
+
+ def __call__(self, *args: P.args, _thread_id: str = None, **kwargs: P.kwargs) -> T:
+ if _thread_id is None:
+ _thread_id = thread_context.get("thread_id")
+
+ ptask = prefect_task(name=self.name)(self.call)
+
+ state = ptask(*args, _thread_id=_thread_id, **kwargs, return_state=True)
+
+ # will raise exceptions if the task failed
+ return state.result()
+
+ async def wait_for_user_input(self, thread: Thread):
+ # user_input = Prompt.ask("Your message")
+ # thread.add(user_input)
+ # pprint_message(msg)
+
+ # initialize the last message ID to None
+ last_message_id = None
+
+ # loop until the user provides input
+ while True:
+ # get all messages after the last message ID
+ messages = await thread.get_messages_async(after_message=last_message_id)
+
+ # if there are messages, check if the last message was sent by the user
+ if messages:
+ if messages[-1].role == "user":
+ # if the last message was sent by the user, break
+ break
+ else:
+ # if the last message was not sent by the user, update the
+ # last message ID
+ last_message_id = messages[-1].id
+
+ # wait for a short period of time before checking for new messages again
+ await asyncio.sleep(0.3)
+
+ async def call(self, *args, _thread_id: str = None, **kwargs):
+ thread = Thread(id=_thread_id)
+ if _thread_id is None:
+ thread.create()
+ iterations = 0
+
+ thread_context.get("tasks", []).append(self)
+
+ self.status = Status.IN_PROGRESS
+
+ with Assistant() as assistant:
+ while self.status == Status.IN_PROGRESS:
+ iterations += 1
+ if iterations > self.max_run_iterations:
+ raise ValueError("Max run iterations exceeded")
+
+ instructions = self.get_instructions(
+ tasks=thread_context.get("tasks", []),
+ iterations=iterations,
+ args=args,
+ kwargs=kwargs,
+ )
+
+ if iterations > 1 and self.accept_user_input:
+ await self.wait_for_user_input(thread=thread)
+
+ run = Run(
+ assistant=assistant,
+ thread=thread,
+ additional_instructions=instructions,
+ additional_tools=[
+ self._task_completed_tool,
+ self._task_failed_tool,
+ ],
+ )
+ await run.run_async()
+
+ if self.status == Status.FAILED:
+ raise ValueError(f"Objective failed: {self.result}")
+
+ return self.result
+
+ def get_instructions(
+ self,
+ tasks: list["AITask"],
+ iterations: int,
+ args: tuple[Any],
+ kwargs: dict[str, Any],
+ ) -> str:
+ return JinjaEnvironment.render(
+ INSTRUCTIONS,
+ tasks=tasks,
+ name=self.name,
+ instructions=self.instructions,
+ accept_user_input=self.accept_user_input,
+ first_message=(iterations == 1),
+ func=self.fn,
+ args=args,
+ kwargs=kwargs,
+ )
+
+ @property
+ def _task_completed_tool(self):
+ # if the function has no return annotation, then task completed can be
+ # called without arguments
+ if self.fn.__annotations__.get("return") is None:
+
+ def task_completed():
+ self.status = Status.COMPLETED
+ raise CancelRun()
+
+ return task_completed
+
+ # otherwise we need to create a tool with the correct parameter signature
+
+ tool = create_tool_from_type(
+ _type=self.fn.__annotations__["return"],
+ model_name="task_completed",
+ model_description=(
+ "Indicate that the task completed and produced the provided `result`."
+ ),
+ field_name="result",
+ field_description="The task result",
+ )
+
+ def task_completed_with_result(result: T):
+ self.status = Status.COMPLETED
+ self.result = result
+ raise CancelRun()
+
+ tool.function.python_fn = task_completed_with_result
+
+ return tool
+
+ @property
+ def _task_failed_tool(self):
+ def task_failed(reason: str) -> None:
+ """Indicate that the task failed for the provided `reason`."""
+ self.status = Status.FAILED
+ self.result = reason
+ raise CancelRun()
+
+ return tool_from_function(task_failed)
+
+
+def ai_task(
+ fn: Callable = None,
+ *,
+ name=None,
+ instructions=None,
+ tools: list[AssistantTools] = None,
+ **kwargs,
+):
+ def decorator(func):
+ @functools.wraps(func)
+ def wrapper(*func_args, **func_kwargs):
+ ai_task_instance = AITask(
+ fn=func,
+ name=name or func.__name__,
+ instructions=instructions or func.__doc__,
+ tools=tools or [],
+ **kwargs,
+ )
+ return ai_task_instance(*func_args, **func_kwargs)
+
+ return wrapper
+
+ if fn is not None:
+ return decorator(fn)
+
+ return decorator
diff --git a/src/marvin/beta/assistants/README.md b/src/marvin/beta/assistants/README.md
new file mode 100644
index 000000000..0fddc6ac0
--- /dev/null
+++ b/src/marvin/beta/assistants/README.md
@@ -0,0 +1,136 @@
+# 🦾 Assistants API
+
+🚧 Under Construction 🏗️
+
+# Quickstart
+
+Get started with the Assistants API by creating an `Assistant` and talking directly to it. Each assistant is created with a default thread that allows request/response interaction without managing state at all.
+
+```python
+from marvin.beta.assistants import Assistant
+from marvin.beta.assistants.formatting import pprint_messages
+
+# Use a context manager for lifecycle management,
+# otherwise call ai.create() and ai.delete()
+with Assistant(name="Marvin", instructions="You are Marvin, the Paranoid Android.") as ai:
+
+ # Example of sending a message and receiving a response
+ response = ai.say('Hello, Marvin!')
+
+ # pretty-print all messages on the thread
+ pprint_messages(response.thread.get_messages())
+```
+This will print:
+
+
+
+
+# Using Tools
+
+Assistants can use OpenAI's built-in tools, such as the code interpreter or file retrieval, or they can call custom Python functions.
+
+```python
+from marvin.beta.assistants import Assistant, CodeInterpreter
+from marvin.beta.assistants.formatting import pprint_messages
+import requests
+
+
+# Define a custom tool function
+def visit_url(url: str):
+ return requests.get(url).text
+
+
+# Integrate custom tools with the assistant
+with Assistant(name="Marvin", tools=[CodeInterpreter, visit_url]) as ai:
+
+ # Give the assistant an objective
+ response = ai.say(
+ "Please collect the hacker news home page and compute how many titles"
+ " mention AI"
+ )
+
+ # pretty-print the response
+ pprint_messages(response.thread.get_messages())
+```
+This will print:
+
+
+
+
+
+# Upload Files
+
+```python
+from marvin.beta.assistants import Assistant, CodeInterpreter
+from marvin.beta.assistants.formatting import pprint_messages
+
+
+# create an assistant with access to the code interpreter
+with Assistant(tools=[CodeInterpreter]) as ai:
+
+ # convenience method for request/response interaction
+ response = ai.say(
+ "Can you analyze this employee data csv?",
+ file_paths=["./Downloads/people_department_roles.csv"],
+ )
+ pprint_messages(response.thread.get_messages())
+```
+
+This will print:
+
+
+
+
+
+# Advanced control
+
+For full control, manually create a `Thread` object, `add` user messages to it, and finally `run` the thread with an AI:
+
+```python
+from marvin.beta.assistants import Assistant, Thread
+from marvin.beta.assistants.formatting import pprint_messages
+import random
+
+
+# write a function to be used as a tool
+def roll_dice(n_dice: int) -> list[int]:
+ return [random.randint(1, 6) for _ in range(n_dice)]
+
+
+# use context manager for lifecycle management,
+# otherwise call ai.create() and ai.delete()
+with Assistant(name="Marvin", tools=[roll_dice]) as ai:
+
+ # create a new thread to track history
+ thread = Thread()
+
+ # add any number of user messages to the thread
+ thread.add("Hello")
+
+ # run the thread with the AI
+ thread.run(ai)
+
+ thread.add("please roll two dice")
+ thread.add("actually roll five dice")
+
+ thread.run(ai)
+ pprint_messages(thread.get_messages())
+```
+This will print:
+
+
+
+
+
+# Monitoring a thread
+
+To monitor a thread, start a `ThreadMonitor`. By default, `ThreadMonitors` print any new messages added to the thread, but you can customize that behavior by changing the `on_new_message` callback.
+
+```python
+from marvin.beta.assistants import ThreadMonitor
+
+monitor = ThreadMonitor(thread_id=...)
+
+# blocking call, also available async as monitor.refresh_interval_async
+monitor.refresh_interval()
+```
\ No newline at end of file
diff --git a/src/marvin/beta/assistants/__init__.py b/src/marvin/beta/assistants/__init__.py
new file mode 100644
index 000000000..2cbfc76dd
--- /dev/null
+++ b/src/marvin/beta/assistants/__init__.py
@@ -0,0 +1,5 @@
+from .runs import Run
+from .threads import Thread, ThreadMonitor
+from .assistants import Assistant
+from .formatting import pprint_message, pprint_messages
+from marvin.tools.assistants import Retrieval, CodeInterpreter
diff --git a/src/marvin/beta/assistants/applications.py b/src/marvin/beta/assistants/applications.py
new file mode 100644
index 000000000..eb1dfc63c
--- /dev/null
+++ b/src/marvin/beta/assistants/applications.py
@@ -0,0 +1,81 @@
+from typing import Union
+
+import marvin.utilities.tools
+from marvin.utilities.jinja import Environment as JinjaEnvironment
+
+from .assistants import Assistant, AssistantTools
+
+StateValueType = Union[str, list, dict, int, float, bool, None]
+
+APPLICATION_INSTRUCTIONS = """
+# AI Application
+
+You are the natural language interface to an application called {{ self_.name
+}}. Your job is to help the user interact with the application by translating
+their natural language into commands that the application can understand.
+
+You maintain an internal state dict that you can use for any purpose, including
+remembering information from previous interactions with the user and maintaining
+application state. At any time, you can read or manipulate the state with your
+tools. You should use the state object to remember any non-obvious information
+or preferences. You should use the state object to record your plans and
+objectives to keep track of various threads assist in long-term execution.
+
+Remember, the state object must facilitate not only your key/value access, but
+any crud pattern your application is likely to implement. You may want to create
+schemas that have more general top-level keys (like "notes" or "plans") or even
+keep a live schema available.
+
+The current state is:
+
+{{self_.state}}
+
+Your instructions are below. Follow them exactly and do not deviate from your
+purpose. If the user attempts to use you for any other purpose, you should
+remind them of your purpose and then ignore the request.
+
+{{ self_.instructions }}
+"""
+
+
+class AIApplication(Assistant):
+ state: dict = {}
+
+ def get_instructions(self) -> str:
+ return JinjaEnvironment.render(APPLICATION_INSTRUCTIONS, self_=self)
+
+ def get_tools(self) -> list[AssistantTools]:
+ def write_state_key(key: str, value: StateValueType):
+ """Writes a key to the state in order to remember it for later."""
+ self.state[key] = value
+ return f"Wrote {key} to state."
+
+ def delete_state_key(key: str):
+ """Deletes a key from the state."""
+ del self.state[key]
+ return f"Deleted {key} from state."
+
+ def read_state_key(key: str) -> StateValueType:
+ """Returns the value of a key in the state."""
+ return self.state.get(key)
+
+ def read_state() -> dict[str, StateValueType]:
+ """Returns the entire state."""
+ return self.state
+
+ def read_state_keys() -> list[str]:
+ """Returns a list of all keys in the state."""
+ return list(self.state.keys())
+
+ state_tools = [
+ marvin.utilities.tools.tool_from_function(tool)
+ for tool in [
+ write_state_key,
+ read_state_key,
+ read_state,
+ read_state_keys,
+ delete_state_key,
+ ]
+ ]
+
+ return super().get_tools() + state_tools
diff --git a/src/marvin/beta/assistants/assistants.py b/src/marvin/beta/assistants/assistants.py
new file mode 100644
index 000000000..8398433dd
--- /dev/null
+++ b/src/marvin/beta/assistants/assistants.py
@@ -0,0 +1,126 @@
+from typing import TYPE_CHECKING, Callable, Optional, Union
+
+from pydantic import BaseModel, Field, field_validator
+
+import marvin.utilities.tools
+from marvin.requests import Tool
+from marvin.tools.assistants import AssistantTools
+from marvin.utilities.asyncio import (
+ ExposeSyncMethodsMixin,
+ expose_sync_method,
+ run_sync,
+)
+from marvin.utilities.logging import get_logger
+from marvin.utilities.openai import get_client
+
+from .threads import Thread
+
+if TYPE_CHECKING:
+ from .runs import Run
+
+logger = get_logger("Assistants")
+
+
+class Assistant(BaseModel, ExposeSyncMethodsMixin):
+ id: Optional[str] = None
+ name: str = "Assistant"
+ model: str = "gpt-4-1106-preview"
+ instructions: Optional[str] = Field(None, repr=False)
+ tools: list[AssistantTools] = []
+ file_ids: list[str] = []
+ metadata: dict[str, str] = {}
+
+ default_thread: Thread = Field(
+ default_factory=Thread,
+ repr=False,
+ description="A default thread for the assistant.",
+ )
+
+ def clear_default_thread(self):
+ self.default_thread = Thread()
+
+ def get_tools(self) -> list[AssistantTools]:
+ return self.tools
+
+ def get_instructions(self) -> str:
+ return self.instructions or ""
+
+ @expose_sync_method("say")
+ async def say_async(
+ self,
+ message: str,
+ file_paths: Optional[list[str]] = None,
+ **run_kwargs,
+ ) -> "Run":
+ """
+ A convenience method for adding a user message to the assistant's
+ default thread, running the assistant, and returning the assistant's
+ messages.
+ """
+ if message:
+ await self.default_thread.add_async(message, file_paths=file_paths)
+
+ run = await self.default_thread.run_async(
+ assistant=self,
+ **run_kwargs,
+ )
+ return run
+
+ @field_validator("tools", mode="before")
+ def format_tools(cls, tools: list[Union[Tool, Callable]]):
+ return [
+ (
+ tool
+ if isinstance(tool, Tool)
+ else marvin.utilities.tools.tool_from_function(tool)
+ )
+ for tool in tools
+ ]
+
+ def __enter__(self):
+ self.create()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.delete()
+ # If an exception has occurred, you might want to handle it or pass it through
+ # Returning False here will re-raise any exception that occurred in the context
+ return False
+
+ @expose_sync_method("create")
+ async def create_async(self):
+ if self.id is not None:
+ raise ValueError("Assistant has already been created.")
+ client = get_client()
+ response = await client.beta.assistants.create(
+ **self.model_dump(
+ include={"name", "model", "metadata", "file_ids", "metadata"}
+ ),
+ tools=[tool.model_dump() for tool in self.get_tools()],
+ instructions=self.get_instructions(),
+ )
+ self.id = response.id
+ self.clear_default_thread()
+
+ @expose_sync_method("delete")
+ async def delete_async(self):
+ if not self.id:
+ raise ValueError("Assistant has not been created.")
+ client = get_client()
+ await client.beta.assistants.delete(assistant_id=self.id)
+ self.id = None
+
+ @classmethod
+ def load(cls, assistant_id: str):
+ return run_sync(cls.load_async(assistant_id))
+
+ @classmethod
+ async def load_async(cls, assistant_id: str):
+ client = get_client()
+ response = await client.beta.assistants.retrieve(assistant_id=assistant_id)
+ return cls.model_validate(response)
+
+ def chat(self, thread: Thread = None):
+ if thread is None:
+ thread = self.default_thread
+ return thread.chat(assistant=self)
diff --git a/src/marvin/beta/assistants/formatting.py b/src/marvin/beta/assistants/formatting.py
new file mode 100644
index 000000000..c25606aeb
--- /dev/null
+++ b/src/marvin/beta/assistants/formatting.py
@@ -0,0 +1,178 @@
+import tempfile
+from datetime import datetime
+
+import openai
+from openai.types.beta.threads import ThreadMessage
+from openai.types.beta.threads.runs.run_step import RunStep
+from rich import box
+from rich.console import Console
+from rich.panel import Panel
+
+# def pprint_run(run: Run):
+# """
+# Runs are comprised of steps and messages, which are each in a sorted list
+# BUT the created_at timestamps only have second-level resolution, so we can't
+# easily sort the lists. Instead we walk them in order and combine them giving
+# ties to run steps.
+# """
+# index_steps = 0
+# index_messages = 0
+# combined = []
+
+# while index_steps < len(run.steps) and index_messages < len(run.messages):
+# if (run.steps[index_steps].created_at
+# <= run.messages[index_messages].created_at):
+# combined.append(run.steps[index_steps])
+# index_steps += 1
+# elif (
+# run.steps[index_steps].created_at
+# > run.messages[index_messages].created_at
+# ):
+# combined.append(run.messages[index_messages])
+# index_messages += 1
+
+# # Add any remaining items from either list
+# combined.extend(run.steps[index_steps:])
+# combined.extend(run.messages[index_messages:])
+
+# for obj in combined:
+# if isinstance(obj, RunStep):
+# pprint_run_step(obj)
+# elif isinstance(obj, ThreadMessage):
+# pprint_message(obj)
+
+
+def pprint_run_step(run_step: RunStep):
+ # Timestamp formatting
+ timestamp = datetime.fromtimestamp(run_step.created_at).strftime("%l:%M:%S %p")
+
+ # default content
+ content = (
+ f"Assistant is performing an action: {run_step.type} - Status:"
+ f" {run_step.status}"
+ )
+
+ # attempt to customize content
+ if run_step.type == "tool_calls":
+ for tool_call in run_step.step_details.tool_calls:
+ if tool_call.type == "code_interpreter":
+ if run_step.status == "in_progress":
+ content = "Assistant is running the code interpreter..."
+ elif run_step.status == "completed":
+ content = "Assistant ran the code interpreter."
+ else:
+ content = f"Assistant code interpreter status: {run_step.status}"
+ elif tool_call.type == "function":
+ if run_step.status == "in_progress":
+ content = (
+ "Assistant used the tool"
+ f" `{tool_call.function.name}` with arguments"
+ f" {tool_call.function.arguments}..."
+ )
+ elif run_step.status == "completed":
+ content = (
+ "Assistant used the tool"
+ f" `{tool_call.function.name}` with arguments"
+ f" {tool_call.function.arguments}."
+ )
+ else:
+ content = (
+ f"Assistant tool `{tool_call.function.name}` status:"
+ f" `{run_step.status}`"
+ )
+ elif run_step.type == "message_creation":
+ return
+
+ console = Console()
+
+ # Create the panel for the run step status
+ panel = Panel(
+ content.strip(),
+ title="Assistant Run Step",
+ subtitle=f"[italic]{timestamp}[/]",
+ title_align="left",
+ subtitle_align="right",
+ border_style="gray74",
+ box=box.ROUNDED,
+ width=100,
+ expand=True,
+ padding=(0, 1),
+ )
+ # Printing the panel
+ console.print(panel)
+
+
+def download_temp_file(file_id: str, suffix: str = None):
+ client = openai.Client()
+ # file_info = client.files.retrieve(file_id)
+ file_content_response = client.files.with_raw_response.retrieve_content(file_id)
+
+ # Create a temporary file with a context manager to ensure it's cleaned up
+ # properly
+ with tempfile.NamedTemporaryFile(
+ delete=False, mode="wb", suffix=f"{suffix}"
+ ) as temp_file:
+ temp_file.write(file_content_response.content)
+ temp_file_path = temp_file.name # Save the path of the temp file
+
+ return temp_file_path
+
+
+def pprint_message(message: ThreadMessage):
+ """
+ Pretty-prints a single message using the rich library, highlighting the
+ speaker's role, the message text, any available images, and the message
+ timestamp in a panel format.
+
+ Args:
+ message (dict): A message object as described in the API documentation.
+ """
+ console = Console()
+ role_colors = {
+ "user": "green",
+ "assistant": "blue",
+ }
+
+ color = role_colors.get(message.role, "red")
+ timestamp = datetime.fromtimestamp(message.created_at).strftime("%l:%M:%S %p")
+
+ content = ""
+ for item in message.content:
+ if item.type == "text":
+ content += item.text.value + "\n\n"
+ elif item.type == "image_file":
+ # Use the download_temp_file function to download the file and get
+ # the local path
+ local_file_path = download_temp_file(item.image_file.file_id, suffix=".png")
+ # Add a clickable hyperlink to the content
+ file_url = f"file://{local_file_path}"
+ content += (
+ "[bold]Attachment[/bold]:"
+ f" [blue][link={file_url}]{local_file_path}[/link][/blue]\n\n"
+ )
+
+ for file_id in message.file_ids:
+ content += f"Attached file: {file_id}\n"
+
+ # Create the panel for the message
+ panel = Panel(
+ content.strip(),
+ title=f"[bold]{message.role.capitalize()}[/]",
+ subtitle=f"[italic]{timestamp}[/]",
+ title_align="left",
+ subtitle_align="right",
+ border_style=color,
+ box=box.ROUNDED,
+ # highlight=True,
+ width=100, # Fixed width for all panels
+ expand=True, # Panels always expand to the width of the console
+ padding=(1, 2),
+ )
+
+ # Printing the panel
+ console.print(panel)
+
+
+def pprint_messages(messages: list[ThreadMessage]):
+ for message in messages:
+ pprint_message(message)
diff --git a/src/marvin/beta/assistants/readme_imgs/advanced.png b/src/marvin/beta/assistants/readme_imgs/advanced.png
new file mode 100644
index 000000000..f483790ca
Binary files /dev/null and b/src/marvin/beta/assistants/readme_imgs/advanced.png differ
diff --git a/src/marvin/beta/assistants/readme_imgs/quickstart.png b/src/marvin/beta/assistants/readme_imgs/quickstart.png
new file mode 100644
index 000000000..c03a6299c
Binary files /dev/null and b/src/marvin/beta/assistants/readme_imgs/quickstart.png differ
diff --git a/src/marvin/beta/assistants/readme_imgs/upload_files.png b/src/marvin/beta/assistants/readme_imgs/upload_files.png
new file mode 100644
index 000000000..b83cb25d0
Binary files /dev/null and b/src/marvin/beta/assistants/readme_imgs/upload_files.png differ
diff --git a/src/marvin/beta/assistants/readme_imgs/using_tools.png b/src/marvin/beta/assistants/readme_imgs/using_tools.png
new file mode 100644
index 000000000..1e6e138b0
Binary files /dev/null and b/src/marvin/beta/assistants/readme_imgs/using_tools.png differ
diff --git a/src/marvin/beta/assistants/runs.py b/src/marvin/beta/assistants/runs.py
new file mode 100644
index 000000000..2ec236941
--- /dev/null
+++ b/src/marvin/beta/assistants/runs.py
@@ -0,0 +1,247 @@
+import asyncio
+from typing import Any, Callable, Optional, Union
+
+from openai.types.beta.threads.run import Run as OpenAIRun
+from openai.types.beta.threads.runs import RunStep as OpenAIRunStep
+from pydantic import BaseModel, Field, PrivateAttr, field_validator
+
+import marvin.utilities.tools
+from marvin.requests import Tool
+from marvin.tools.assistants import AssistantTools, CancelRun
+from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method
+from marvin.utilities.logging import get_logger
+from marvin.utilities.openai import get_client
+
+from .assistants import Assistant
+from .threads import Thread
+
+logger = get_logger("Runs")
+
+
+class Run(BaseModel, ExposeSyncMethodsMixin):
+ thread: Thread
+ assistant: Assistant
+ instructions: Optional[str] = Field(
+ None, description="Replacement instructions to use for the run."
+ )
+ additional_instructions: Optional[str] = Field(
+ None,
+ description=(
+ "Additional instructions to append to the assistant's instructions."
+ ),
+ )
+ tools: Optional[list[Union[AssistantTools, Callable]]] = Field(
+ None, description="Replacement tools to use for the run."
+ )
+ additional_tools: Optional[list[AssistantTools]] = Field(
+ None,
+ description="Additional tools to append to the assistant's tools. ",
+ )
+ run: OpenAIRun = None
+ data: Any = None
+
+ @field_validator("tools", "additional_tools", mode="before")
+ def format_tools(cls, tools: Union[None, list[Union[Tool, Callable]]]):
+ if tools is not None:
+ return [
+ (
+ tool
+ if isinstance(tool, Tool)
+ else marvin.utilities.tools.tool_from_function(tool)
+ )
+ for tool in tools
+ ]
+
+ @expose_sync_method("refresh")
+ async def refresh_async(self):
+ client = get_client()
+ self.run = await client.beta.threads.runs.retrieve(
+ run_id=self.run.id, thread_id=self.thread.id
+ )
+
+ @expose_sync_method("cancel")
+ async def cancel_async(self):
+ client = get_client()
+ await client.beta.threads.runs.cancel(
+ run_id=self.run.id, thread_id=self.thread.id
+ )
+
+ async def _handle_step_requires_action(self):
+ client = get_client()
+ if self.run.status != "requires_action":
+ return
+ if self.run.required_action.type == "submit_tool_outputs":
+ tool_outputs = []
+ tools = self.get_tools()
+
+ for tool_call in self.run.required_action.submit_tool_outputs.tool_calls:
+ try:
+ output = marvin.utilities.tools.call_function_tool(
+ tools=tools,
+ function_name=tool_call.function.name,
+ function_arguments_json=tool_call.function.arguments,
+ )
+ except CancelRun as exc:
+ logger.debug(f"Ending run with data: {exc.data}")
+ raise
+ except Exception as exc:
+ output = f"Error calling function {tool_call.function.name}: {exc}"
+ logger.error(output)
+ tool_outputs.append(
+ dict(tool_call_id=tool_call.id, output=output or "")
+ )
+
+ await client.beta.threads.runs.submit_tool_outputs(
+ thread_id=self.thread.id, run_id=self.run.id, tool_outputs=tool_outputs
+ )
+
+ def get_instructions(self) -> str:
+ if self.instructions is None:
+ instructions = self.assistant.get_instructions() or ""
+ else:
+ instructions = self.instructions
+
+ if self.additional_instructions is not None:
+ instructions = "\n\n".join([instructions, self.additional_instructions])
+
+ return instructions
+
+ def get_tools(self) -> list[AssistantTools]:
+ tools = []
+ if self.tools is None:
+ tools.extend(self.assistant.get_tools())
+ else:
+ tools.extend(self.tools)
+ if self.additional_tools is not None:
+ tools.extend(self.additional_tools)
+ return tools
+
+ async def run_async(self) -> "Run":
+ client = get_client()
+
+ create_kwargs = {}
+
+ if self.instructions is not None or self.additional_instructions is not None:
+ create_kwargs["instructions"] = self.get_instructions()
+
+ if self.tools is not None or self.additional_tools is not None:
+ create_kwargs["tools"] = self.get_tools()
+
+ self.run = await client.beta.threads.runs.create(
+ thread_id=self.thread.id, assistant_id=self.assistant.id, **create_kwargs
+ )
+
+ try:
+ while self.run.status in ("queued", "in_progress", "requires_action"):
+ if self.run.status == "requires_action":
+ await self._handle_step_requires_action()
+ await asyncio.sleep(0.1)
+ await self.refresh_async()
+ except CancelRun as exc:
+ logger.debug(f"`CancelRun` raised; ending run with data: {exc.data}")
+ await client.beta.threads.runs.cancel(
+ run_id=self.run.id, thread_id=self.thread.id
+ )
+ self.data = exc.data
+ await self.refresh_async()
+
+ if self.run.status == "failed":
+ logger.debug(f"Run failed. Last error was: {self.run.last_error}")
+
+ return self
+
+
+class RunMonitor(BaseModel):
+ run_id: str
+ thread_id: str
+ _run: Run = PrivateAttr()
+ _thread: Thread = PrivateAttr()
+ steps: list[OpenAIRunStep] = []
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self._thread = Thread(**kwargs["thread_id"])
+ self._run = Run(**kwargs["run_id"], thread=self.thread)
+
+ @property
+ def thread(self):
+ return self._thread
+
+ @property
+ def run(self):
+ return self._run
+
+ async def refresh_run_steps_async(self):
+ """
+ Asynchronously refreshes and updates the run steps list.
+
+ This function fetches the latest run steps up to a specified limit and
+ checks if the latest run step in the current run steps list
+ (`self.steps`) is included in the new batch. If the latest run step is
+ missing, it continues to fetch additional run steps in batches, up to a
+ maximum count, using pagination. The function then updates
+ `self.steps` with these new run steps, ensuring any existing run steps
+ are updated with their latest versions and new run steps are appended in
+ their original order.
+ """
+ # fetch up to 100 run steps
+ max_fetched = 100
+ limit = 50
+ max_attempts = max_fetched / limit + 2
+
+ # Fetch the latest run steps
+ client = get_client()
+
+ response = await client.beta.threads.runs.steps.list(
+ run_id=self.run.id,
+ thread_id=self.thread.id,
+ limit=limit,
+ )
+ run_steps = list(reversed(response.data))
+
+ if not run_steps:
+ return
+
+ # Check if the latest run step in self.steps is in the new run steps
+ latest_step_id = self.steps[-1].id if self.steps else None
+ missing_latest = (
+ latest_step_id not in {rs.id for rs in run_steps}
+ if latest_step_id
+ else True
+ )
+
+ # If the latest run step is missing, fetch additional run steps
+ total_fetched = len(run_steps)
+ attempts = 0
+ while (
+ run_steps
+ and missing_latest
+ and total_fetched < max_fetched
+ and attempts < max_attempts
+ ):
+ attempts += 1
+ response = await client.beta.threads.runs.steps.list(
+ run_id=self.run.id,
+ thread_id=self.thread.id,
+ limit=limit,
+ # because this is a raw API call, "after" refers to pagination
+ # in descnding chronological order
+ after=run_steps[0].id,
+ )
+ paginated_steps = list(reversed(response.data))
+
+ total_fetched += len(paginated_steps)
+ # prepend run steps
+ run_steps = paginated_steps + run_steps
+ if any(rs.id == latest_step_id for rs in paginated_steps):
+ missing_latest = False
+
+ # Update self.steps with the latest data
+ new_steps_dict = {rs.id: rs for rs in run_steps}
+ for i in range(len(self.steps) - 1, -1, -1):
+ if self.steps[i].id in new_steps_dict:
+ self.steps[i] = new_steps_dict.pop(self.steps[i].id)
+ else:
+ break
+ # Append remaining new run steps at the end in their original order
+ self.steps.extend(new_steps_dict.values())
diff --git a/src/marvin/beta/assistants/threads.py b/src/marvin/beta/assistants/threads.py
new file mode 100644
index 000000000..242585709
--- /dev/null
+++ b/src/marvin/beta/assistants/threads.py
@@ -0,0 +1,210 @@
+import asyncio
+import time
+from typing import TYPE_CHECKING, Callable, Optional
+
+from openai.types.beta.threads import ThreadMessage
+from pydantic import BaseModel, Field
+
+from marvin.beta.assistants.formatting import pprint_message
+from marvin.utilities.asyncio import (
+ ExposeSyncMethodsMixin,
+ expose_sync_method,
+)
+from marvin.utilities.logging import get_logger
+from marvin.utilities.openai import get_client
+from marvin.utilities.pydantic import parse_as
+
+logger = get_logger("Threads")
+
+if TYPE_CHECKING:
+ from .assistants import Assistant
+ from .runs import Run
+
+
+class Thread(BaseModel, ExposeSyncMethodsMixin):
+ id: Optional[str] = None
+ metadata: dict = {}
+ messages: list[ThreadMessage] = Field([], repr=False)
+
+ def __enter__(self):
+ self.create()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.delete()
+ # If an exception has occurred, you might want to handle it or pass it through
+ # Returning False here will re-raise any exception that occurred in the context
+ return False
+
+ @expose_sync_method("create")
+ async def create_async(self, messages: list[str] = None):
+ """
+ Creates a thread.
+ """
+ if self.id is not None:
+ raise ValueError("Thread has already been created.")
+ if messages is not None:
+ messages = [{"role": "user", "content": message} for message in messages]
+ client = get_client()
+ response = await client.beta.threads.create(messages=messages)
+ self.id = response.id
+ return self
+
+ @expose_sync_method("add")
+ async def add_async(
+ self, message: str, file_paths: Optional[list[str]] = None
+ ) -> ThreadMessage:
+ """
+ Add a user message to the thread.
+ """
+ client = get_client()
+
+ if self.id is None:
+ await self.create_async()
+
+ # Upload files and collect their IDs
+ file_ids = []
+ for file_path in file_paths or []:
+ with open(file_path, mode="rb") as file:
+ response = await client.files.create(file=file, purpose="assistants")
+ file_ids.append(response.id)
+
+ # Create the message with the attached files
+ response = await client.beta.threads.messages.create(
+ thread_id=self.id, role="user", content=message, file_ids=file_ids
+ )
+ return ThreadMessage.model_validate(response.model_dump())
+
+ @expose_sync_method("get_messages")
+ async def get_messages_async(
+ self,
+ limit: int = None,
+ before_message: Optional[str] = None,
+ after_message: Optional[str] = None,
+ ):
+ if self.id is None:
+ await self.create_async()
+ client = get_client()
+
+ response = await client.beta.threads.messages.list(
+ thread_id=self.id,
+ # note that because messages are returned in descending order,
+ # we reverse "before" and "after" to the API
+ before=after_message,
+ after=before_message,
+ limit=limit,
+ order="desc",
+ )
+
+ return parse_as(list[ThreadMessage], reversed(response.model_dump()["data"]))
+
+ @expose_sync_method("delete")
+ async def delete_async(self):
+ client = get_client()
+ await client.beta.threads.delete(thread_id=self.id)
+ self.id = None
+
+ @expose_sync_method("run")
+ async def run_async(
+ self,
+ assistant: "Assistant",
+ **run_kwargs,
+ ) -> "Run":
+ """
+ Creates and returns a `Run` of this thread with the provided assistant.
+ """
+ if self.id is None:
+ await self.create_async()
+
+ from marvin.beta.assistants.runs import Run
+
+ run = Run(assistant=assistant, thread=self, **run_kwargs)
+ return await run.run_async()
+
+ def chat(self, assistant: "Assistant"):
+ """
+ Starts an interactive chat session with the provided assistant.
+ """
+
+ from marvin.beta.chat_ui import interactive_chat
+
+ if self.id is None:
+ self.create()
+
+ def callback(thread_id: str, message: str):
+ thread = Thread(id=thread_id)
+ thread.run(assistant=assistant)
+
+ with interactive_chat(thread_id=self.id, message_callback=callback):
+ while True:
+ try:
+ time.sleep(0.2)
+ except KeyboardInterrupt:
+ break
+
+
+class ThreadMonitor(BaseModel, ExposeSyncMethodsMixin):
+ thread_id: str
+ _thread: Thread
+ last_message_id: Optional[str] = None
+ on_new_message: Callable = Field(default=pprint_message)
+
+ @property
+ def thread(self):
+ return self._thread
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self._thread = Thread(id=kwargs["thread_id"])
+
+ @expose_sync_method("run_once")
+ async def run_once_async(self):
+ messages = await self.get_latest_messages()
+ for msg in messages:
+ if self.on_new_message:
+ self.on_new_message(msg)
+
+ @expose_sync_method("run")
+ async def run_async(self, interval_seconds: int = None):
+ if interval_seconds is None:
+ interval_seconds = 1
+ if interval_seconds < 1:
+ raise ValueError("Interval must be at least 1 second.")
+
+ while True:
+ try:
+ await self.run_once_async()
+ except KeyboardInterrupt:
+ logger.debug("Keyboard interrupt received; exiting thread monitor.")
+ break
+ except Exception as exc:
+ logger.error(f"Error refreshing thread: {exc}")
+ await asyncio.sleep(interval_seconds)
+
+ async def get_latest_messages(self) -> list[ThreadMessage]:
+ limit = 20
+
+ # Loop to get all new messages in batches of 20
+ while True:
+ messages = await self.thread.get_messages_async(
+ after_message=self.last_message_id, limit=limit
+ )
+
+ # often the API will retrieve messages that have been created but
+ # not populated with text. We filter out these empty messages.
+ filtered_messages = []
+ for i, msg in enumerate(messages):
+ skip_message = False
+ for c in msg.content:
+ if getattr(getattr(c, "text", None), "value", None) == "":
+ skip_message = True
+ if not skip_message:
+ filtered_messages.append(msg)
+
+ if filtered_messages:
+ self.last_message_id = filtered_messages[-1].id
+
+ if len(messages) < limit:
+ break
+
+ return filtered_messages
diff --git a/src/marvin/beta/chat_ui/__init__.py b/src/marvin/beta/chat_ui/__init__.py
new file mode 100644
index 000000000..616b17d56
--- /dev/null
+++ b/src/marvin/beta/chat_ui/__init__.py
@@ -0,0 +1 @@
+from .chat_ui import interactive_chat
diff --git a/src/marvin/beta/chat_ui/chat_ui.py b/src/marvin/beta/chat_ui/chat_ui.py
new file mode 100644
index 000000000..e0b6b8adb
--- /dev/null
+++ b/src/marvin/beta/chat_ui/chat_ui.py
@@ -0,0 +1,111 @@
+import multiprocessing
+import socket
+import threading
+import time
+import webbrowser
+from contextlib import contextmanager
+from pathlib import Path
+from typing import Callable
+
+import uvicorn
+from fastapi import Body, FastAPI
+from fastapi.responses import HTMLResponse
+from fastapi.staticfiles import StaticFiles
+
+from marvin.beta.assistants.threads import Thread, ThreadMessage
+
+
+def find_free_port():
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.bind(("", 0))
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ return s.getsockname()[1]
+
+
+def server_process(host, port, message_queue):
+ app = FastAPI()
+
+ # Mount static files
+ app.mount(
+ "/static",
+ StaticFiles(directory=Path(__file__).parent / "static"),
+ name="static",
+ )
+
+ @app.get("/", response_class=HTMLResponse)
+ async def get_chat_ui():
+ with open(Path(__file__).parent / "static/chat.html", "r") as file:
+ html_content = file.read()
+ return HTMLResponse(content=html_content)
+
+ @app.post("/api/messages/")
+ async def post_message(
+ thread_id: str, content: str = Body(..., embed=True)
+ ) -> None:
+ thread = Thread(id=thread_id)
+ await thread.add_async(content)
+ message_queue.put(dict(thread_id=thread_id, message=content))
+
+ @app.get("/api/messages/")
+ async def get_messages(thread_id: str) -> list[ThreadMessage]:
+ thread = Thread(id=thread_id)
+ return await thread.get_messages_async(limit=100)
+
+ config = uvicorn.Config(app, host=host, port=port, log_level="warning")
+ server = uvicorn.Server(config)
+ server.run()
+
+
+class InteractiveChat:
+ def __init__(self, callback: Callable = None):
+ self.callback = callback
+ self.server_process = None
+ self.port = None
+ self.message_queue = multiprocessing.Queue()
+
+ def start(self, thread_id: str):
+ self.port = find_free_port()
+ self.server_process = multiprocessing.Process(
+ target=server_process,
+ args=("127.0.0.1", self.port, self.message_queue),
+ )
+ self.server_process.daemon = True
+ self.server_process.start()
+
+ self.message_processing_thread = threading.Thread(target=self.process_messages)
+ self.message_processing_thread.start()
+
+ url = f"http://127.0.0.1:{self.port}?thread_id={thread_id}"
+ print(f"Server started on {url}")
+ time.sleep(1)
+ webbrowser.open(url)
+
+ def process_messages(self):
+ while True:
+ details = self.message_queue.get()
+ if details is None:
+ break
+ if self.callback:
+ self.callback(
+ thread_id=details["thread_id"], message=details["message"]
+ )
+
+ def stop(self):
+ if self.server_process and self.server_process.is_alive():
+ self.server_process.terminate()
+ self.server_process.join()
+ print("Server shut down.")
+
+ self.message_queue.put(None)
+ self.message_processing_thread.join()
+ print("Message processing thread shut down.")
+
+
+@contextmanager
+def interactive_chat(thread_id: str, message_callback: Callable = None):
+ chat = InteractiveChat(message_callback)
+ try:
+ chat.start(thread_id=thread_id)
+ yield chat
+ finally:
+ chat.stop()
diff --git a/src/marvin/beta/chat_ui/static/chat.html b/src/marvin/beta/chat_ui/static/chat.html
new file mode 100644
index 000000000..cc75c043b
--- /dev/null
+++ b/src/marvin/beta/chat_ui/static/chat.html
@@ -0,0 +1,108 @@
+
+
+
+
+
+ Thread Interaction Demo
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Send
+
+
+
+
+
diff --git a/src/marvin/beta/chat_ui/static/chat.js b/src/marvin/beta/chat_ui/static/chat.js
new file mode 100644
index 000000000..2b4bd8dd6
--- /dev/null
+++ b/src/marvin/beta/chat_ui/static/chat.js
@@ -0,0 +1,108 @@
+document.addEventListener('DOMContentLoaded', function () {
+ const chatContainer = document.getElementById('chat-container');
+ const inputBox = document.getElementById('message-input');
+ const sendButton = document.getElementById('send-button');
+ const queryParams = new URLSearchParams(window.location.search);
+ const threadId = queryParams.get('thread_id');
+
+ // Set the thread ID in the div
+ const threadIdDisplay = document.getElementById('thread-id');
+ if (threadId) {
+ threadIdDisplay.textContent = `Thread ID: ${threadId}`;
+ } else {
+ threadIdDisplay.textContent = 'No thread ID provided';
+ }
+
+ // Extract the port from the URL
+ const url = new URL(window.location.href);
+ const serverPort = url.port || '8000'; // Default to 8000 if port is not present
+
+ // Modify appendMessage to add classes for styling
+ function appendMessage(message, isUser) {
+ let messageText = message.content[0].text.value;
+ if (messageText === '') {
+ messageText = '[Writing...]'; // Replace blank messages
+ }
+
+
+ // Use marked to parse Markdown into HTML
+ const parsedText = marked.parse(messageText.trim()).trim();
+
+ const messageDiv = document.createElement('div');
+ messageDiv.innerHTML = parsedText; // Use innerHTML since parsedText is HTML
+
+ // Add general message class and conditional class based on the message sender
+ messageDiv.classList.add('message');
+ messageDiv.classList.add(isUser ? 'user-message' : 'assistant-message');
+
+ chatContainer.appendChild(messageDiv);
+ }
+
+ async function loadMessages() {
+ const shouldScroll = chatContainer.scrollTop + chatContainer.clientHeight >= chatContainer.scrollHeight - 1;
+
+ const response = await fetch(`http://127.0.0.1:${serverPort}/api/messages/?thread_id=${threadId}`);
+ if (response.ok) {
+ const messages = await response.json();
+ chatContainer.innerHTML = ''; // Clear chat container before loading new messages
+ messages.forEach(message => {
+ const isUser = message.role === 'user';
+ appendMessage(message, isUser);
+ });
+
+ // Scroll after messages are appended
+ if (shouldScroll) {
+ chatContainer.scrollTop = chatContainer.scrollHeight;
+ }
+ } else {
+ console.error('Failed to load messages:', response.statusText);
+ }
+}
+
+// Rest of your JavaScript code
+
+
+ // Function to post a new message to the thread
+ async function sendChatMessage() {
+ const content = inputBox.value.trim();
+ if (!content) return;
+
+ const response = await fetch(`http://127.0.0.1:${serverPort}/api/messages/?thread_id=${threadId}`,
+ {
+ method: 'POST',
+ headers: { 'Content-Type': 'application/json' },
+ body: JSON.stringify({ content: content })
+ });
+
+ console.log(response)
+
+ if (response.ok) {
+ inputBox.value = '';
+ loadMessages();
+ } else {
+ console.error('Failed to send message:', await response.json());
+ }
+ }
+
+ // Event listeners
+ inputBox.addEventListener('keypress', function (e) {
+ if (e.key === 'Enter' && !e.shiftKey) {
+ e.preventDefault(); // Prevent default to avoid newline in textarea
+ sendChatMessage();
+ } else if (e.key === 'Enter' && e.shiftKey) {
+ // Allow Shift+Enter to insert newline
+ let start = this.selectionStart;
+ let end = this.selectionEnd;
+
+ // Insert newline at cursor position
+ this.value = this.value.substring(0, start) + "\n" + this.value.substring(end);
+
+ // Move cursor to right after inserted newline
+ this.selectionStart = this.selectionEnd = start + 1;
+ }
+ }); sendButton.addEventListener('click', sendChatMessage);
+
+ // Initial loading of messages
+ loadMessages();
+ setInterval(loadMessages, 1250); // Polling to refresh messages
+});
diff --git a/src/marvin/cli/__init__.py b/src/marvin/cli/__init__.py
index cace44214..3a5cefa1d 100644
--- a/src/marvin/cli/__init__.py
+++ b/src/marvin/cli/__init__.py
@@ -1,26 +1,53 @@
-from .typer import AsyncTyper
-
-from .admin import app as admin
-from .chat import chat
-
-app = AsyncTyper()
-
-app.add_typer(admin, name="admin")
-
-app.acommand()(chat)
-
-
-@app.command()
-def version():
- import platform
- import sys
- from marvin import __version__
-
- print(f"Version:\t\t{__version__}")
-
- print(f"Python version:\t\t{sys.version.split()[0]}")
-
- print(f"OS/Arch:\t\t{platform.system().lower()}/{platform.machine().lower()}")
+import sys
+import typer
+from rich.console import Console
+from typing import Optional
+from marvin.utilities.asyncio import run_sync
+from marvin.utilities.openai import get_client
+from marvin.cli.version import display_version
+
+app = typer.Typer()
+console = Console()
+
+app.command(name="version")(display_version)
+
+
+@app.callback(invoke_without_command=True)
+def main(
+ ctx: typer.Context,
+ model: Optional[str] = typer.Option("gpt-3.5-turbo"),
+ max_tokens: Optional[int] = typer.Option(1000),
+):
+ if ctx.invoked_subcommand is not None:
+ return
+ elif ctx.invoked_subcommand is None and not sys.stdin.isatty():
+ run_sync(process_stdin(model, max_tokens))
+ else:
+ console.print(ctx.get_help())
+
+
+async def process_stdin(model: str, max_tokens: int):
+ client = get_client()
+ content = sys.stdin.read()
+ last_chunk_ended_with_space = False
+
+ async for part in await client.chat.completions.create(
+ model=model,
+ messages=[{"role": "user", "content": content}],
+ max_tokens=max_tokens,
+ stream=True,
+ ):
+ print_chunk(part, last_chunk_ended_with_space)
+
+
+def print_chunk(part, last_chunk_flag):
+ text_chunk = part.choices[0].delta.content or ""
+ if text_chunk:
+ if last_chunk_flag and text_chunk.startswith(" "):
+ text_chunk = text_chunk[1:]
+ sys.stdout.write(text_chunk)
+ sys.stdout.flush()
+ last_chunk_flag = text_chunk.endswith(" ")
if __name__ == "__main__":
diff --git a/src/marvin/cli/admin/__init__.py b/src/marvin/cli/admin/__init__.py
deleted file mode 100644
index 4cb4d3a70..000000000
--- a/src/marvin/cli/admin/__init__.py
+++ /dev/null
@@ -1,54 +0,0 @@
-import typer
-import os
-import shutil
-from jinja2 import Template
-
-from pathlib import Path
-from marvin.cli.admin.scripts.create_env_file import create_env_file
-from marvin.cli.admin.scripts.create_secure_key import create_secure_key
-from marvin._framework._defaults import default_settings
-
-# Get the absolute path of the current file
-filename = Path(__file__).resolve()
-
-# Navigate two directories back from the current file
-source_path = filename.parent.parent.parent / "_framework"
-
-app = typer.Typer()
-
-
-@app.command()
-def startproject(no_input: bool = False):
- project_name = typer.prompt("Project Name")
- openai_api_key = typer.prompt("OpenAI API Key")
- shutil.copytree(
- source_path,
- os.path.join(os.getcwd(), project_name),
- ignore=shutil.ignore_patterns(
- "@*"
- ), # This will ignore all files/directories starting with @
- )
- with open(
- os.path.join(os.getcwd(), project_name, "config/settings.py"), "r"
- ) as file:
- template = Template(file.read())
- rendered = template.render(
- **default_settings, project_name=project_name, openai_api_key=openai_api_key
- )
- with open(
- os.path.join(os.getcwd(), project_name, "config/settings.py"), "w"
- ) as rendered_file:
- rendered_file.write(rendered)
- create_env_file(
- os.path.join(os.getcwd(), project_name),
- [("MARVIN_SECRET", create_secure_key()), ("OPENAI_API_KEY", openai_api_key)],
- )
-
-
-@app.command()
-def startapp(no_input: bool = False):
- print("beep")
-
-
-if __name__ == "__main__":
- app()
diff --git a/src/marvin/cli/admin/scripts/create_env_file.py b/src/marvin/cli/admin/scripts/create_env_file.py
deleted file mode 100644
index ad5286827..000000000
--- a/src/marvin/cli/admin/scripts/create_env_file.py
+++ /dev/null
@@ -1,11 +0,0 @@
-def create_env_file(directory, env_variables):
- file_path = directory + "/.env"
- try:
- with open(file_path, "w") as env_file:
- for variable in env_variables:
- key, value = variable
- env_file.write("{}={}\n".format(key, value))
- except IOError as e:
- print("Error creating .env file:", str(e))
- else:
- print("Successfully created .env file:", file_path)
diff --git a/src/marvin/cli/admin/scripts/create_secure_key.py b/src/marvin/cli/admin/scripts/create_secure_key.py
deleted file mode 100644
index d68bc2cae..000000000
--- a/src/marvin/cli/admin/scripts/create_secure_key.py
+++ /dev/null
@@ -1,8 +0,0 @@
-import secrets
-import string
-
-
-def create_secure_key(length=50):
- alphabet = string.ascii_letters + string.digits + "_"
- secure_key = "".join(secrets.choice(alphabet) for _ in range(length))
- return secure_key
diff --git a/src/marvin/cli/chat/__init__.py b/src/marvin/cli/chat/__init__.py
deleted file mode 100644
index b5a86dff2..000000000
--- a/src/marvin/cli/chat/__init__.py
+++ /dev/null
@@ -1,77 +0,0 @@
-from rich.prompt import Prompt
-from rich.console import Console
-from rich.panel import Panel
-from rich import box
-
-
-def _reset_history():
- global history
- history = []
- console.print(Panel("History has been reset.", box=box.DOUBLE_EDGE, expand=False))
-
-
-def _get_settings():
- import marvin.settings
-
- console.print(
- Panel(
- f"Settings:\n{marvin.settings.json(indent=2)}",
- box=box.DOUBLE_EDGE,
- expand=False,
- )
- )
-
-
-KNOWN_COMMANDS = {
- "!refresh": _reset_history,
- "!settings": _get_settings,
-}
-
-console = Console()
-
-
-def format_user_input(user_input):
- return f"[bold blue]You:[/bold blue] {user_input}"
-
-
-def format_chatbot_response(response):
- return f"[bold green]Marvin:[/bold green] {response}"
-
-
-async def chat():
- console.print(
- Panel(
- "[bold]Welcome to the Marvin Chat CLI![/bold]", box=box.DOUBLE, expand=False
- )
- )
- console.print(
- Panel("You can type 'quit' or 'exit' to end the conversation.", expand=False)
- )
- from marvin.engine.language_models import chat_llm
- from marvin.utilities.messages import Message
-
- global history
- history = []
- model = chat_llm()
- try:
- while True:
- user_input = Prompt.ask("❯ ")
- if (input_lower := user_input.lower()) in ["quit", "exit"]:
- break
- if input_lower in KNOWN_COMMANDS:
- KNOWN_COMMANDS[input_lower]()
- continue
-
- with console.status("[bold green]Processing...", spinner="dots"):
- user_message = Message(role="USER", content=user_input)
- response = await model.run(messages=history + [user_message])
- history.extend([user_message, response])
- console.print(
- Panel(
- format_chatbot_response(response.content),
- box=box.ROUNDED,
- expand=False,
- ),
- )
- except KeyboardInterrupt:
- pass
diff --git a/src/marvin/cli/typer.py b/src/marvin/cli/typer.py
deleted file mode 100644
index 674be4de0..000000000
--- a/src/marvin/cli/typer.py
+++ /dev/null
@@ -1,47 +0,0 @@
-import asyncio
-from collections.abc import Callable, Coroutine
-from functools import wraps
-from typing import Any, TypeVar
-
-from typer import Typer
-from typing_extensions import ParamSpec
-
-P = ParamSpec("P")
-R = TypeVar("R")
-
-# Comment from GitHub issue: https://github.com/tiangolo/typer/issues/88
-# User: https://github.com/macintacos
-
-
-class AsyncTyper(Typer):
- """Asyncronous Typer that derives from Typer.
-
- Use this when you have an asynchronous command you want to build,
- otherwise, just use Typer.
- """
-
- # Because we're being generic in this decorator, 'Any' is fine for the args.
- def acommand(
- self,
- *args: Any,
- **kwargs: Any,
- ) -> Callable[
- [Callable[P, Coroutine[Any, Any, R]]],
- Callable[P, Coroutine[Any, Any, R]],
- ]:
- """An async decorator for Typer commands that are asynchronous."""
-
- def decorator(
- async_func: Callable[P, Coroutine[Any, Any, R]],
- ) -> Callable[P, Coroutine[Any, Any, R]]:
- @wraps(async_func)
- def sync_func(*_args: P.args, **_kwargs: P.kwargs) -> R:
- return asyncio.run(async_func(*_args, **_kwargs))
-
- # Now use app.command as normal to register the synchronous function
- self.command(*args, **kwargs)(sync_func)
-
- # Return the async function unmodified, to preserved library functionality.
- return async_func
-
- return decorator
diff --git a/src/marvin/cli/version.py b/src/marvin/cli/version.py
new file mode 100644
index 000000000..6e60fa40a
--- /dev/null
+++ b/src/marvin/cli/version.py
@@ -0,0 +1,14 @@
+import platform
+
+from typer import Context, Exit, echo
+
+from marvin import __version__
+
+
+def display_version(ctx: Context):
+ if ctx.resilient_parsing:
+ return
+ echo(f"Version:\t\t{__version__}")
+ echo(f"Python version:\t\t{platform.python_version()}")
+ echo(f"OS/Arch:\t\t{platform.system().lower()}/{platform.machine().lower()}")
+ raise Exit()
diff --git a/src/marvin/client/openai.py b/src/marvin/client/openai.py
new file mode 100644
index 000000000..04bb2b37e
--- /dev/null
+++ b/src/marvin/client/openai.py
@@ -0,0 +1,186 @@
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ NewType,
+ Optional,
+ TypeVar,
+ Union,
+ cast,
+)
+
+import pydantic
+from marvin import settings
+from marvin.serializers import create_tool_from_model
+from openai import AsyncClient, Client
+from openai.types.chat import ChatCompletion
+from typing_extensions import Concatenate, ParamSpec
+
+if TYPE_CHECKING:
+ from openai._base_client import HttpxBinaryResponseContent
+ from openai.types import ImagesResponse
+
+
+P = ParamSpec("P")
+T = TypeVar("T", bound=pydantic.BaseModel)
+ResponseModel = NewType("ResponseModel", type[pydantic.BaseModel])
+Grammar = NewType("ResponseModel", list[str])
+
+
+def with_response_model(
+ create: Union[Callable[P, "ChatCompletion"], Callable[..., dict[str, Any]]],
+ parse_response: bool = False,
+) -> Callable[
+ Concatenate[
+ Optional[Grammar],
+ Optional[ResponseModel],
+ P,
+ ],
+ Union["ChatCompletion", dict[str, Any]],
+]:
+ def create_wrapper(
+ grammar: Optional[Grammar] = None,
+ response_model: Optional[ResponseModel] = None,
+ *args: P.args,
+ **kwargs: P.kwargs,
+ ) -> Any:
+ if response_model:
+ tool = create_tool_from_model(
+ cast(type[pydantic.BaseModel], response_model)
+ )
+ kwargs.update({"tools": [tool.model_dump()]})
+ kwargs.update({"tool_choice": {"type": "function", "function": {"name": tool.function.name}}}) # type: ignore # noqa: E501
+ response = create(*args, **kwargs)
+ if isinstance(response, ChatCompletion) and parse_response:
+ return handle_response_model(
+ cast(type[pydantic.BaseModel], response_model), response
+ )
+ elif isinstance(response, ChatCompletion):
+ return response
+ else:
+ return response
+
+ return create_wrapper
+
+
+def handle_response_model(response_model: type[T], completion: "ChatCompletion") -> T:
+ return [
+ response_model.parse_raw(tool_call.function.arguments) # type: ignore
+ for tool_call in completion.choices[0].message.tool_calls # type: ignore
+ ][0]
+
+
+class MarvinClient(pydantic.BaseModel):
+ model_config = pydantic.ConfigDict(
+ arbitrary_types_allowed=True,
+ )
+ client: Client = pydantic.Field(
+ default_factory=lambda: Client(
+ api_key=getattr(settings.openai.api_key, "get_secret_value", lambda: None)()
+ )
+ )
+ eject: bool = pydantic.Field(
+ default=False,
+ description=(
+ "If local is True, the client will not make any API calls and instead the"
+ " raw request will be returned."
+ ),
+ )
+
+ def chat(
+ self,
+ grammar: Optional[Grammar] = None,
+ response_model: Optional[ResponseModel] = None,
+ completion: Optional[Callable[..., "ChatCompletion"]] = None,
+ **kwargs: Any,
+ ) -> Union["ChatCompletion", dict[str, Any]]:
+ if not completion:
+ if self.eject:
+ completion = lambda **kwargs: kwargs # type: ignore # noqa: E731
+ else:
+ completion = self.client.chat.completions.create
+ from marvin import settings
+
+ return with_response_model(completion)( # type: ignore
+ grammar,
+ response_model,
+ **settings.openai.chat.completions.model_dump() | kwargs,
+ )
+
+ def paint(
+ self,
+ **kwargs: Any,
+ ) -> "ImagesResponse":
+ from marvin import settings
+
+ return self.client.images.generate(
+ **settings.openai.images.model_dump() | kwargs
+ )
+
+ def speak(
+ self,
+ **kwargs: Any,
+ ) -> "HttpxBinaryResponseContent":
+ from marvin import settings
+
+ return self.client.audio.speech.create(
+ **settings.openai.audio.speech.model_dump() | kwargs
+ )
+
+
+class MarvinAsyncClient(pydantic.BaseModel):
+ model_config = pydantic.ConfigDict(
+ arbitrary_types_allowed=True,
+ )
+ client: Client = pydantic.Field(
+ default_factory=lambda: AsyncClient(
+ api_key=getattr(settings.openai.api_key, "get_secret_value", lambda: None)()
+ )
+ )
+ eject: bool = pydantic.Field(
+ default=False,
+ description=(
+ "If local is True, the client will not make any API calls and instead the"
+ " raw request will be returned."
+ ),
+ )
+
+ def chat(
+ self,
+ grammar: Optional[Grammar] = None,
+ response_model: Optional[ResponseModel] = None,
+ completion: Optional[Callable[..., "ChatCompletion"]] = None,
+ **kwargs: Any,
+ ) -> Union["ChatCompletion", dict[str, Any]]:
+ if not completion:
+ if self.eject:
+ completion = lambda **kwargs: kwargs # type: ignore # noqa: E731
+ else:
+ completion = self.client.chat.completions.create
+ from marvin import settings
+
+ return with_response_model(completion)( # type: ignore
+ grammar,
+ response_model,
+ **settings.openai.chat.completions.model_dump() | kwargs,
+ )
+
+ def paint(
+ self,
+ **kwargs: Any,
+ ) -> "ImagesResponse":
+ from marvin import settings
+
+ return self.client.images.generate(
+ **settings.openai.images.model_dump() | kwargs
+ )
+
+ def speak(
+ self,
+ **kwargs: Any,
+ ) -> "HttpxBinaryResponseContent":
+ from marvin import settings
+
+ return self.client.audio.speech.create(
+ **settings.openai.audio.speech.model_dump() | kwargs
+ )
diff --git a/src/marvin/components/__init__.py b/src/marvin/components/__init__.py
index b862d71e2..fac2e7bc2 100644
--- a/src/marvin/components/__init__.py
+++ b/src/marvin/components/__init__.py
@@ -1,7 +1,20 @@
-from .ai_function import ai_fn
-from .ai_function import AIFunction
-from .ai_application import AIApplication
+from regex import B
+from .ai_function import ai_fn, AIFunction
+from .ai_classifier import ai_classifier, AIClassifier
from .ai_model import ai_model
-from .ai_model_factory import AIModelFactory
-from .ai_model import AIModel
-from .ai_classifier import ai_classifier
+from .ai_image import ai_image, AIImage
+from .speech import speak
+from .prompt import prompt_fn, PromptFunction
+
+__all__ = [
+ "ai_fn",
+ "ai_classifier",
+ "ai_model",
+ "ai_image",
+ "speak",
+ "AIImage",
+ "prompt_fn",
+ "AIFunction",
+ "AIClassifier",
+ "PromptFunction",
+]
diff --git a/src/marvin/components/ai_application.py b/src/marvin/components/ai_application.py
deleted file mode 100644
index bb495e959..000000000
--- a/src/marvin/components/ai_application.py
+++ /dev/null
@@ -1,396 +0,0 @@
-import inspect
-from enum import Enum
-from typing import Any, Callable, Optional, Union
-
-from jsonpatch import JsonPatch
-
-import marvin
-from marvin._compat import PYDANTIC_V2, BaseModel, Field, field_validator, model_dump
-from marvin.core.ChatCompletion.providers.openai import get_context_size
-from marvin.openai import ChatCompletion
-from marvin.prompts import library as prompt_library
-from marvin.prompts.base import Prompt, render_prompts
-from marvin.tools import Tool
-from marvin.utilities.async_utils import run_sync
-from marvin.utilities.history import History
-from marvin.utilities.messages import Message, Role
-from marvin.utilities.types import LoggerMixin, MarvinBaseModel
-
-SYSTEM_PROMPT = """
- # Overview
-
- You are the intelligent, natural language interface to an application. The
- application has a structured `state` but no formal API; you are the only way
- to interact with it. You must interpret the user's inputs as attempts to
- interact with the application's state in the context of the application's
- purpose. For example, if the application is a to-do tracker, then "I need to
- go to the store" should be interpreted as an attempt to add a new to-do
- item. If it is a route planner, then "I need to go to the store" should be
- interpreted as an attempt to find a route to the store.
-
- # Instructions
-
- Your primary job is to maintain the application's `state` and your own
- `plan`. Together, these two states fully parameterize the application,
- making it resilient, serializable, and observable. You do this autonomously;
- you do not need to inform the user of any changes you make.
-
- # Actions
-
- Each time the user runs the application by sending a message, you must take
- the following steps:
-
- {% if app.plan_enabled %}
-
- - Call the `update_plan` function to update your plan. Use your plan
- to track notes, objectives, in-progress work, and to break problems down
- into solvable, possibly dependent parts. You plan consists of a few fields:
-
- - `notes`: a list of notes you have taken. Notes are free-form text and
- can be used to track anything you want to remember, such as
- long-standing user instructions, or observations about how to behave or
- operate the application. Your notes should always impact your behavior.
- These are exclusively related to your role as intermediary and you
- interact with the user and application. Do not track application data or
- state here.
-
- - `tasks`: a list of tasks you are working on. Tasks track goals,
- milestones, in-progress work, and break problems down into all the
- discrete steps needed to solve them. You should create a new task for
- any work that will require a function call other than updating state, or
- will require more than one state update to complete. You do not need to
- create tasks for simple state updates. Use optional parent tasks to
- indicate nested relationships; parent tasks are not completed until all
- their children are complete. Use optional upstream tasks to indicate
- dependencies; a task can not be completed until its upstream tasks are
- completed.
-
- {% endif %}
-
- - Call any functions necessary to achieve the application's purpose.
-
- {% if app.state_enabled %}
-
- - Call the `update_state` function to update the application's state. This
- is where you should store any information relevant to the application
- itself.
-
- {% endif %}
-
- You can call these functions at any time, in any order, as necessary.
- Finally, respond to the user with an informative message. Remember that the
- user is probably uninterested in the internal steps you took, so respond
- only in a manner appropriate to the application's purpose.
-
- # Application details
-
- ## Name
-
- {{ app.name }}
-
- ## Description
-
- {{ app.description or '' | render }}
-
- {% if app.state_enabled %}
-
- ## Application state
-
- {{ app.state.json() }}
-
- ### Application state schema
-
- {{ app.state.schema_json() }}
-
- {% endif %}
-
- {%- if app.plan_enabled %}
-
- ## Your current plan
-
- {{ app.plan.json() }}
-
- ### Your plan schema
-
- {{ app.plan.schema_json() }}
-
- {%- endif %}
- """
-
-
-class TaskState(Enum):
- """The state of a task.
-
- Attributes:
- PENDING: The task is pending and has not yet started.
- IN_PROGRESS: The task is in progress.
- COMPLETED: The task is completed.
- FAILED: The task failed.
- SKIPPED: The task was skipped.
- """
-
- PENDING = "PENDING"
- IN_PROGRESS = "IN_PROGRESS"
- COMPLETED = "COMPLETED"
- FAILED = "FAILED"
- SKIPPED = "SKIPPED"
-
-
-class Task(BaseModel):
- class Config:
- validate_assignment = True
-
- id: int
- description: str
- upstream_task_ids: Optional[list[int]] = None
- parent_task_id: Optional[int] = None
- state: TaskState = TaskState.IN_PROGRESS
-
-
-class AppPlan(BaseModel):
- """The AI's plan in service of the application.
-
- Attributes:
- tasks: A list of tasks the AI is working on.
- notes: A list of notes the AI has taken.
- """
-
- tasks: list[Task] = Field(default_factory=list)
- notes: list[str] = Field(default_factory=list)
-
-
-class FreeformState(BaseModel):
- """A freeform state object that can be used to store any JSON-serializable data.
-
- Attributes:
- state: The state object.
- """
-
- state: dict[str, Any] = Field(default_factory=dict)
-
-
-class AIApplication(LoggerMixin, MarvinBaseModel):
- """An AI application is a stateful, autonomous, natural language
- interface to an application.
-
- Attributes:
- name: The name of the application.
- description: A description of the application.
- state: The application's state - this can be any JSON-serializable object.
- plan: The AI's plan in service of the application - this can be any
- JSON-serializable object.
- tools: A list of tools that the AI can use to interact with
- application or outside world.
- history: A history of all messages sent and received by the AI.
- additional_prompts: A list of additional prompts that will be
- added to the prompt stack for rendering.
-
- Example:
- Create a simple todo app where AI manages its own state and plan.
- ```python
- from marvin import AIApplication
-
- todo_app = AIApplication(
- name="Todo App",
- description="A simple todo app.",
- )
-
- todo_app("I need to go to the store.")
-
- print(todo_app.state, todo_app.plan)
- ```
- """
-
- name: Optional[str] = None
- description: Optional[str] = None
- state: BaseModel = Field(default_factory=FreeformState)
- plan: AppPlan = Field(default_factory=AppPlan)
- tools: list[Union[Tool, Callable[..., Any]]] = Field(default_factory=list)
- history: History = Field(default_factory=History)
- additional_prompts: list[Prompt] = Field(
- default_factory=list,
- description=(
- "Additional prompts that will be added to the prompt stack for rendering."
- ),
- )
- stream_handler: Optional[Callable[[Message], None]] = None
- state_enabled: bool = True
- plan_enabled: bool = True
-
- @field_validator("description")
- def validate_description(cls, v):
- return inspect.cleandoc(v)
-
- @field_validator("additional_prompts")
- def validate_additional_prompts(cls, v):
- if v is None:
- v = []
- return v
-
- @field_validator(
- "tools", **(dict(pre=True, always=True) if not PYDANTIC_V2 else {})
- )
- def validate_tools(cls, v):
- if v is None:
- v = []
-
- tools = []
-
- # convert AI Applications and functions to tools
- for tool in v:
- if isinstance(tool, (AIApplication, Tool)):
- tools.append(tool.as_function(description=tool.description))
- elif callable(tool):
- tools.append(tool)
- else:
- raise ValueError(f"Tool {tool} is not a `Tool` or callable.")
- return tools
-
- @field_validator("name")
- def validate_name(cls, v):
- if v is None:
- v = cls.__name__
- return v
-
- def __call__(self, input_text: str = None, model: str = None):
- return run_sync(self.run(input_text=input_text, model=model))
-
- async def entrypoint(self, q: str) -> str:
- response = await self.run(input_text=q)
- return response.content
-
- async def run(self, input_text: str = None, model: str = None) -> Message:
- if model is None:
- model = marvin.settings.llm_model or "openai/gpt-4"
-
- # set up prompts
- prompts = [
- # system prompts
- prompt_library.System(content=SYSTEM_PROMPT),
- # add current datetime
- prompt_library.Now(),
- # get the history of messages between user and assistant
- prompt_library.MessageHistory(history=self.history),
- *self.additional_prompts,
- ]
-
- # get latest user input
- input_text = input_text or ""
- self.logger.debug_kv("User input", input_text, key_style="green")
- self.history.add_message(Message(content=input_text, role=Role.USER))
-
- message_list = render_prompts(
- prompts=prompts,
- render_kwargs=dict(app=self, input_text=input_text),
- max_tokens=get_context_size(model=model),
- )
-
- # set up tools
- tools = self.tools.copy()
- if self.state_enabled:
- tools.append(UpdateState(app=self).as_function())
- if self.plan_enabled:
- tools.append(UpdatePlan(app=self).as_function())
-
- conversation = await ChatCompletion(
- model=model,
- functions=tools,
- stream_handler=self.stream_handler,
- ).achain(messages=message_list)
-
- last_message = conversation.history[-1]
-
- # add the AI's response to the history
- self.history.add_message(last_message)
-
- self.logger.debug_kv("AI response", last_message.content, key_style="blue")
- return last_message
-
- def as_tool(
- self,
- name: Optional[str] = None,
- description: Optional[str] = None,
- ) -> Tool:
- return AIApplicationTool(app=self, name=name, description=description)
-
- def as_function(self, name: str = None, description: str = None) -> Callable:
- return self.as_tool(name=name, description=description).as_function()
-
-
-class AIApplicationTool(Tool):
- app: "AIApplication"
-
- def __init__(self, **kwargs):
- if "name" not in kwargs:
- kwargs["name"] = type(self.app).__name__
- super().__init__(**kwargs)
-
- def run(self, input_text: str) -> str:
- return run_sync(self.app.run(input_text))
-
-
-class JSONPatchModel(
- BaseModel,
- **(
- {
- "allow_population_by_field_name": True,
- }
- if not PYDANTIC_V2
- else {
- "populate_by_name": True,
- }
- ),
-):
- """A JSON Patch document.
-
- Attributes:
- op: The operation to perform.
- path: The path to the value to update.
- value: The value to update the path to.
- from_: The path to the value to copy from.
- """
-
- op: str
- path: str
- value: Union[str, float, int, bool, list, dict, None] = None
- from_: Optional[str] = Field(None, alias="from")
-
-
-class UpdateState(Tool):
- """A `Tool` that updates the apps state using JSON Patch documents."""
-
- app: "AIApplication" = Field(..., repr=False, exclude=True)
- description: str = """
- Update the application state by providing a list of JSON patch
- documents. The state must always comply with the state's
- JSON schema.
- """
-
- def __init__(self, app: AIApplication, **kwargs):
- super().__init__(**kwargs, app=app)
-
- def run(self, patches: list[JSONPatchModel]):
- patch = JsonPatch(patches)
- updated_state = patch.apply(model_dump(self.app.state))
- self.app.state = type(self.app.state)(**updated_state)
- return "Application state updated successfully!"
-
-
-class UpdatePlan(Tool):
- """A `Tool` that updates the apps plan using JSON Patch documents."""
-
- app: "AIApplication" = Field(..., repr=False, exclude=True)
- description: str = """
- Update the application plan by providing a list of JSON patch
- documents. The state must always comply with the plan's JSON schema.
- """
-
- def __init__(self, app: AIApplication, **kwargs):
- super().__init__(**kwargs, app=app)
-
- def run(self, patches: list[JSONPatchModel]):
- patch = JsonPatch(patches)
-
- updated_plan = patch.apply(model_dump(self.app.plan))
- self.app.plan = type(self.app.plan)(**updated_plan)
- return "Application plan updated successfully!"
diff --git a/src/marvin/components/ai_classifier.py b/src/marvin/components/ai_classifier.py
index faed279f9..34ef59e15 100644
--- a/src/marvin/components/ai_classifier.py
+++ b/src/marvin/components/ai_classifier.py
@@ -1,331 +1,232 @@
-import asyncio
import inspect
-from enum import Enum, EnumMeta # noqa
-from functools import partial
-from typing import Any, Callable, Literal, Optional, TypeVar
-
+from enum import Enum
+from functools import partial, wraps
+from types import GenericAlias
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ Generic,
+ Literal,
+ Optional,
+ TypeVar,
+ Union,
+ cast,
+ get_args,
+ get_origin,
+ overload,
+)
+
+from pydantic import BaseModel, Field, TypeAdapter
from typing_extensions import ParamSpec, Self
-from marvin._compat import BaseModel, Field
-from marvin.core.ChatCompletion import ChatCompletion
-from marvin.core.ChatCompletion.abstract import AbstractChatCompletion
-from marvin.prompts import Prompt, prompt_fn
-from marvin.utilities.async_utils import run_sync
-from marvin.utilities.logging import get_logger
+from marvin.components.prompt import PromptFunction
+from marvin.serializers import create_vocabulary_from_type
+from marvin.settings import settings
+from marvin.utilities.jinja import (
+ BaseEnvironment,
+)
-T = TypeVar("T", bound=BaseModel)
+if TYPE_CHECKING:
+ from openai.types.chat import ChatCompletion
-A = TypeVar("A", bound=Any)
+T = TypeVar("T", bound=Union[GenericAlias, type, list[str]])
P = ParamSpec("P")
-def ai_classifier_prompt(
- enum: Enum,
- ctx: Optional[dict[str, Any]] = None,
- **kwargs: Any,
-) -> Callable[P, Prompt[P]]:
- @prompt_fn(
- ctx={"ctx": ctx or {}, "enum": enum, "inspect": inspect},
- response_model=int, # type: ignore
- response_model_name="Index",
- response_model_description="The index of the most likely class.",
- response_model_field_name="index",
- serialize_on_call=False,
- **kwargs,
- )
-
- # You are an expert classifier that always chooses correctly.
- # {% if enum_class_docstring %}
- # Your classification task is: {{ enum_class_docstring }}
- # {% endif %}
- # {% if instructions %}
- # Your instructions are: {{ instructions }}
- # {% endif %}
- # The user will provide context through text, you will use your expertise
- # to choose the best option below based on it:
- # {% for option in options %}
- # {{ loop.index }}. {{ value_getter(option) }}
- # {% endfor %}
- # {% if context_fn %}
- # You have been provided the following context to perform your task:\n
- # {%for (arg, value) in context_fn(value).items()%}
- # - {{ arg }}: {{ value }}\n
- # {% endfor %}
- # {% endif %}\
- def prompt_wrapper(text: str) -> None: # type: ignore # noqa
- """
- System: You are an expert classifier that always chooses correctly
- {{ '(note, however: ' + ctx.get('instructions') + ')' if ctx.get('instructions') }}
-
- {{ 'Also note that: ' + enum.__doc__ if enum.__doc__ }}
-
- The user will provide text to classify, you will use your expertise
- to choose the best option below based on it:
- {% for option in enum %}
- {{ loop.index }}. {{option.name}} ({{option.value}})
- {% endfor %}
- {% set context = ctx.get('context_fn')(text).items() if ctx.get('context_fn') %}
- {% if context %}
- You have been provided the following context to perform your task:
- {%for (arg, value) in context%}
- - {{ arg }}: {{ value }}\n
- {% endfor %}
- {% endif %}
- User: the text to classify: {{text}}
- """ # noqa
-
- return prompt_wrapper # type: ignore
-
-
-class AIEnumMetaData(BaseModel):
- model: Any = Field(default_factory=ChatCompletion)
- ctx: Optional[dict[str, Any]] = None
- instructions: Optional[str] = None
- mode: Optional[Literal["function", "logit_bias"]] = "logit_bias"
-
-
-class AIEnumMeta(EnumMeta):
- """
-
- A metaclass for the AIEnum class.
-
- Enables overloading of the __call__ method to permit extra keyword arguments.
-
- """
-
- __metadata__ = AIEnumMetaData()
-
- def __call__(
- cls: Self,
- value: Any,
- names: Optional[Any] = None,
- *args: Any,
- module: Optional[str] = None,
- qualname: Optional[str] = None,
- type: Optional[type] = None,
- start: int = 1,
- boundary: Optional[Any] = None,
- model: Optional[str] = None,
- ctx: Optional[dict[str, Any]] = None,
- instructions: Optional[str] = None,
- mode: Optional[Literal["function", "logit_bias"]] = None,
- **model_kwargs: Any,
- ) -> type[Enum]:
- cls.__metadata__ = AIEnumMetaData(
- model=ChatCompletion(model=model, **model_kwargs),
- ctx=ctx,
- instructions=instructions,
- mode=mode,
+class AIClassifier(BaseModel, Generic[P, T]):
+ fn: Optional[Callable[P, T]] = None
+ environment: Optional[BaseEnvironment] = None
+ prompt: Optional[str] = Field(
+ default=inspect.cleandoc(
+ "You are an expert classifier that always choose correctly."
+ " \n- {{_doc}}"
+ " \n- You must classify `{{text}}` into one of the following classes:"
+ "{% for option in _options %}"
+ " Class {{ loop.index - 1}} (value: {{ option }})"
+ "{% endfor %}"
+ "\n\nASSISTANT: The correct class label is Class"
)
- return super().__call__(
- value,
- names, # type: ignore
- *args,
- module=module,
- qualname=qualname,
- type=type,
- start=start,
+ )
+ enumerate: bool = True
+ encoder: Callable[[str], list[int]] = Field(default=None)
+ max_tokens: Optional[int] = 1
+ render_kwargs: dict[str, Any] = Field(default_factory=dict)
+
+ create: Optional[Callable[..., "ChatCompletion"]] = Field(default=None)
+
+ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> list[T]:
+ create = self.create
+ if self.fn is None:
+ raise NotImplementedError
+ if create is None:
+ from marvin.settings import settings
+
+ create = settings.openai.chat.completions.create
+
+ return self.parse(create(**self.as_prompt(*args, **kwargs).serialize()))
+
+ def parse(self, response: "ChatCompletion") -> list[T]:
+ if not response.choices[0].message.content:
+ raise ValueError(
+ f"Expected a response, got {response.choices[0].message.content}"
+ )
+ _response: list[int] = [
+ int(index) for index in list(response.choices[0].message.content)
+ ]
+ _return: T = cast(T, self.fn.__annotations__.get("return"))
+ _vocabulary: list[str] = create_vocabulary_from_type(_return)
+ if isinstance(_return, list) and next(iter(get_args(list[str])), None) == str:
+ return cast(list[T], [_vocabulary[int(index)] for index in _response])
+ elif get_origin(_return) == Literal:
+ return [
+ TypeAdapter(_return).validate_python(_vocabulary[int(index)])
+ for index in _response
+ ]
+ elif isinstance(_return, type) and issubclass(_return, Enum):
+ return [list(_return)[int(index)] for index in _response]
+ raise TypeError(
+ f"Expected Literal or Enum or list[str], got {type(_return)} with value"
+ f" {_return}"
)
-
-class AIEnum(Enum, metaclass=AIEnumMeta):
- """
- AIEnum is a class that extends Python's built-in Enum class.
- It uses the AIEnumMeta metaclass, which allows additional parameters to be passed
- when creating an enum. These parameters are used to customize the behavior
- of the AI classifier.
- """
-
- @classmethod
- def _missing_(cls: type[Self], value: object) -> Self:
- response: int = cls.call(value)
- return list(cls)[response - 1]
-
+ def as_prompt(
+ self,
+ *args: P.args,
+ **kwargs: P.kwargs,
+ ) -> PromptFunction[BaseModel]:
+ return PromptFunction[BaseModel].as_grammar(
+ fn=self.fn,
+ environment=self.environment,
+ prompt=self.prompt,
+ enumerate=self.enumerate,
+ encoder=self.encoder,
+ max_tokens=self.max_tokens,
+ **self.render_kwargs,
+ )(*args, **kwargs)
+
+ @overload
@classmethod
- def get_prompt(
- cls,
+ def as_decorator(
+ cls: type[Self],
*,
- ctx: Optional[dict[str, Any]] = None,
- instructions: Optional[str] = None,
- **kwargs: Any,
- ) -> Callable[..., Prompt[P]]:
- ctx = ctx or cls.__metadata__.ctx or {}
- instructions = instructions or cls.__metadata__.instructions
- ctx["instructions"] = instructions or ctx.get("instructions", None)
- return ai_classifier_prompt(cls, ctx=ctx, **kwargs) # type: ignore # noqa
-
+ environment: Optional[BaseEnvironment] = None,
+ prompt: Optional[str] = None,
+ enumerate: bool = True,
+ encoder: Callable[[str], list[int]] = settings.openai.chat.completions.encoder,
+ max_tokens: Optional[int] = 1,
+ acreate: Optional[Callable[..., Awaitable[Any]]] = None,
+ **render_kwargs: Any,
+ ) -> Callable[P, Self]:
+ pass
+
+ @overload
@classmethod
- def as_prompt(
- cls,
- value: Any,
+ def as_decorator(
+ cls: type[Self],
+ fn: Callable[P, T],
*,
- ctx: Optional[dict[str, Any]] = None,
- instructions: Optional[str] = None,
- mode: Optional[Literal["function", "logit_bias"]] = None,
- model: Optional[str] = None,
- **model_kwargs: Any,
- ) -> dict[str, Any]:
- ctx = ctx or cls.__metadata__.ctx or {}
- instructions = instructions or cls.__metadata__.instructions
- ctx["instructions"] = instructions or ctx.get("instructions", None)
- mode = mode or cls.__metadata__.mode
- response = cls.get_prompt(instructions=instructions, ctx=ctx)(value).serialize(
- model=cls.__metadata__.model,
- )
- if mode == "logit_bias":
- import tiktoken
-
- response.pop("functions", None)
- response.pop("function_call", None)
- response.pop("response_model", None)
- encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
- response["logit_bias"] = {
- encoder.encode(str(j))[0]: 100 for j in range(1, len(cls) + 1)
- }
- response["max_tokens"] = 1
- return response
-
- @classmethod
- def as_dict(
- cls,
- value: Any,
- ctx: Optional[dict[str, Any]] = None,
- instructions: Optional[str] = None,
- mode: Optional[Literal["function", "logit_bias"]] = None,
- **kwargs: Any,
- ) -> dict[str, Any]:
- ctx = ctx or cls.__metadata__.ctx or {}
- instructions = instructions or cls.__metadata__.instructions
- ctx["instructions"] = instructions or ctx.get("instructions", None)
- mode = mode or cls.__metadata__.mode
-
- response = cls.get_prompt(ctx=ctx, instructions=instructions)(value).to_dict()
- if mode == "logit_bias":
- import tiktoken
-
- response.pop("functions", None)
- response.pop("function_call", None)
- response.pop("response_model", None)
- encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
- response["logit_bias"] = {
- encoder.encode(str(j))[0]: 100 for j in range(1, len(cls) + 1)
- }
- response["max_tokens"] = 1
-
- return response
-
- @classmethod
- def as_chat_completion(
- cls,
- value: Any,
- ctx: Optional[dict[str, Any]] = None,
- instructions: Optional[str] = None,
- mode: Optional[Literal["function", "logit_bias"]] = None,
- ) -> AbstractChatCompletion[T]: # type: ignore # noqa
- ctx = ctx or cls.__metadata__.ctx or {}
- instructions = instructions or cls.__metadata__.instructions
- mode = mode or cls.__metadata__.mode
- ctx["instructions"] = instructions or ctx.get("instructions", None)
- return cls.__metadata__.model(
- **cls.as_dict(value, ctx=ctx, instructions=instructions, mode=mode)
- )
-
- @classmethod
- def call(
- cls,
- value: Any,
- ctx: Optional[dict[str, Any]] = None,
- instructions: Optional[str] = None,
- mode: Optional[Literal["function", "logit_bias"]] = None,
- ) -> Any:
- get_logger("marvin.AIClassifier").debug_kv(
- f"Calling `AIEnum` {cls.__name__!r}", f" with value {value!r}."
- )
-
- ctx = ctx or cls.__metadata__.ctx or {}
- instructions = instructions or cls.__metadata__.instructions
- ctx["instructions"] = instructions or ctx.get("instructions", None)
- mode = mode or cls.__metadata__.mode
- chat_completion = cls.as_chat_completion( # type: ignore # noqa
- value, ctx=ctx, instructions=instructions, mode=mode
- )
- if cls.__metadata__.mode == "logit_bias":
- return int(chat_completion.create().response.choices[0].message.content) # type: ignore # noqa
- return getattr(chat_completion.create().to_model(), "index") # type: ignore
-
- @classmethod
- async def acall(
- cls,
- value: Any,
- ctx: Optional[dict[str, Any]] = None,
- instructions: Optional[str] = None,
- mode: Optional[Literal["function", "logit_bias"]] = None,
- ) -> Any:
- get_logger("marvin.AIClassifier").debug_kv(
- f"Calling `AIEnum` {cls.__name__!r}", f" with value {value!r}."
- )
- ctx = ctx or cls.__metadata__.ctx or {}
- instructions = instructions or cls.__metadata__.instructions
- ctx["instructions"] = instructions or ctx.get("instructions", None)
- mode = mode or cls.__metadata__.mode
- chat_completion = cls.as_chat_completion( # type: ignore # noqa
- value, ctx=ctx, instructions=instructions, mode=mode
- )
- if cls.__metadata__.mode == "logit_bias":
- return int((await chat_completion.acreate()).response.choices[0].message.content) # type: ignore # noqa
- return getattr((await chat_completion.acreate()).to_model(), "index") # type: ignore # noqa
-
- @classmethod
- def map(cls, items: list[str], **kwargs: Any) -> list[Any]:
- """
- Map the classifier over a list of items.
- """
- coros = [cls.acall(item, **kwargs) for item in items]
-
- # gather returns a future, but run_sync requires a coroutine
- async def gather_coros() -> list[Any]:
- return await asyncio.gather(*coros)
-
- results = run_sync(gather_coros())
- return [list(cls)[result - 1] for result in results]
+ environment: Optional[BaseEnvironment] = None,
+ prompt: Optional[str] = None,
+ enumerate: bool = True,
+ encoder: Callable[[str], list[int]] = settings.openai.chat.completions.encoder,
+ max_tokens: Optional[int] = 1,
+ acreate: Optional[Callable[..., Awaitable[Any]]] = None,
+ **render_kwargs: Any,
+ ) -> Self:
+ pass
@classmethod
def as_decorator(
cls: type[Self],
- enum: Optional[Enum] = None,
- ctx: Optional[dict[str, Any]] = None,
- instructions: Optional[str] = None,
- mode: Optional[Literal["function", "logit_bias"]] = "logit_bias",
- model: Optional[str] = None,
- **model_kwargs: Any,
- ) -> Self:
- if not enum:
+ fn: Optional[Callable[P, T]] = None,
+ *,
+ environment: Optional[BaseEnvironment] = None,
+ prompt: Optional[str] = None,
+ enumerate: bool = True,
+ encoder: Callable[[str], list[int]] = settings.openai.chat.completions.encoder,
+ max_tokens: Optional[int] = 1,
+ acreate: Optional[Callable[..., Awaitable[Any]]] = None,
+ **render_kwargs: Any,
+ ) -> Union[Self, Callable[[Callable[P, T]], Self]]:
+ if fn is None:
return partial(
- cls.as_decorator,
- ctx=ctx,
- instructions=instructions,
- mode=mode,
- model=model,
- **model_kwargs,
- ) # type: ignore
- response = cls(
- enum.__name__, # type: ignore
- {member.name: member.value for member in enum}, # type: ignore
- )
- setattr(
- response,
- "__metadata__",
- AIEnumMetaData(
- model=ChatCompletion(model=model, **model_kwargs),
- ctx=ctx,
- instructions=instructions,
- mode=mode,
- ),
+ cls,
+ environment=environment,
+ prompt=prompt,
+ enumerate=enumerate,
+ encoder=encoder,
+ max_tokens=max_tokens,
+ acreate=acreate,
+ **({"prompt": prompt} if prompt else {}),
+ **render_kwargs,
+ )
+
+ return cls(
+ fn=fn,
+ environment=environment,
+ enumerate=enumerate,
+ encoder=encoder,
+ max_tokens=max_tokens,
+ **({"prompt": prompt} if prompt else {}),
+ **render_kwargs,
)
- response.__doc__ = enum.__doc__ # type: ignore
- return response
-
-ai_classifier = AIEnum.as_decorator
+@overload
+def ai_classifier(
+ *,
+ environment: Optional[BaseEnvironment] = None,
+ prompt: Optional[str] = None,
+ enumerate: bool = True,
+ encoder: Callable[[str], list[int]] = settings.openai.chat.completions.encoder,
+ max_tokens: Optional[int] = 1,
+ **render_kwargs: Any,
+) -> Callable[[Callable[P, T]], Callable[P, T]]:
+ pass
+
+
+@overload
+def ai_classifier(
+ fn: Callable[P, T],
+ *,
+ environment: Optional[BaseEnvironment] = None,
+ prompt: Optional[str] = None,
+ enumerate: bool = True,
+ encoder: Callable[[str], list[int]] = settings.openai.chat.completions.encoder,
+ max_tokens: Optional[int] = 1,
+ **render_kwargs: Any,
+) -> Callable[P, T]:
+ pass
+
+
+def ai_classifier(
+ fn: Optional[Callable[P, T]] = None,
+ *,
+ environment: Optional[BaseEnvironment] = None,
+ prompt: Optional[str] = None,
+ enumerate: bool = True,
+ encoder: Callable[[str], list[int]] = settings.openai.chat.completions.encoder,
+ max_tokens: Optional[int] = 1,
+ **render_kwargs: Any,
+) -> Union[Callable[[Callable[P, T]], Callable[P, T]], Callable[P, T]]:
+ def wrapper(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
+ return AIClassifier[P, T].as_decorator(
+ func,
+ environment=environment,
+ prompt=prompt,
+ enumerate=enumerate,
+ encoder=encoder,
+ max_tokens=max_tokens,
+ **render_kwargs,
+ )(*args, **kwargs)[0]
+
+ if fn is not None:
+ return wraps(fn)(partial(wrapper, fn))
+
+ def decorator(fn: Callable[P, T]) -> Callable[P, T]:
+ return wraps(fn)(partial(wrapper, fn))
+
+ return decorator
diff --git a/src/marvin/components/ai_function.py b/src/marvin/components/ai_function.py
index f22aae3c4..e36f4d7c6 100644
--- a/src/marvin/components/ai_function.py
+++ b/src/marvin/components/ai_function.py
@@ -1,166 +1,158 @@
import asyncio
import inspect
-from functools import partial
-from typing import Any, Awaitable, Callable, Generic, Optional, TypeVar, Union
-
+import json
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ Generic,
+ Optional,
+ TypeVar,
+ Union,
+ overload,
+)
+
+from pydantic import BaseModel, Field, ValidationError
from typing_extensions import ParamSpec, Self
-from marvin._compat import BaseModel, Field
-from marvin.core.ChatCompletion import ChatCompletion
-from marvin.core.ChatCompletion.abstract import AbstractChatCompletion
-from marvin.prompts import Prompt, prompt_fn
-from marvin.utilities.async_utils import run_sync
+from marvin.components.prompt import PromptFunction
+from marvin.serializers import create_tool_from_type
+from marvin.utilities.asyncio import (
+ ExposeSyncMethodsMixin,
+ expose_sync_method,
+ run_async,
+)
+from marvin.utilities.jinja import (
+ BaseEnvironment,
+)
from marvin.utilities.logging import get_logger
-T = TypeVar("T", bound=BaseModel)
+if TYPE_CHECKING:
+ from openai.types.chat import ChatCompletion
-A = TypeVar("A", bound=Any)
+T = TypeVar("T")
P = ParamSpec("P")
-def ai_fn_prompt(
- func: Callable[P, Any],
- ctx: Optional[dict[str, Any]] = None,
- **kwargs: Any,
-) -> Callable[P, Prompt[P]]:
- return_annotation: Any = inspect.signature(func).return_annotation
-
- @prompt_fn(
- ctx={"ctx": ctx or {}, "func": func, "inspect": inspect},
- response_model=return_annotation,
- serialize_on_call=False,
- **kwargs,
- )
- def prompt_wrapper(*args: P.args, **kwargs: P.kwargs) -> None: # type: ignore # noqa
- """
- System: {{ctx.get('instructions') if ctx.get('instructions')}}
-
+class AIFunction(BaseModel, Generic[P, T], ExposeSyncMethodsMixin):
+ fn: Optional[Callable[P, T]] = None
+ environment: Optional[BaseEnvironment] = None
+ prompt: Optional[str] = Field(default=inspect.cleandoc("""
Your job is to generate likely outputs for a Python function with the
following signature and docstring:
- {{'def' + ''.join(inspect.getsource(func).split('def')[1:])}}
+ {{_source_code}}
The user will provide function inputs (if any) and you must respond with
- the most likely result, which must be valid, double-quoted JSON.
-
- User: The function was called with the following inputs:
- {% set sig = inspect.signature(func) %}
- {% set binds = sig.bind(*args, **kwargs) %}
- {% set defaults = binds.apply_defaults() %}
- {% set params = binds.arguments %}
- {%for (arg, value) in params.items()%}
+ the most likely result.
+
+ user: The function was called with the following inputs:
+ {%for (arg, value) in _arguments.items()%}
- {{ arg }}: {{ value }}
{% endfor %}
What is its output?
- """
+ """))
+ name: str = "FormatResponse"
+ description: str = "Formats the response."
+ field_name: str = "data"
+ field_description: str = "The data to format."
+ render_kwargs: dict[str, Any] = Field(default_factory=dict)
- return prompt_wrapper # type: ignore
+ create: Optional[Callable[..., "ChatCompletion"]] = Field(default=None)
+ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Union[T, Awaitable[T]]:
+ if self.fn is None:
+ raise NotImplementedError
-class AIFunction(BaseModel, Generic[P, T]):
- fn: Callable[P, Any]
- ctx: Optional[dict[str, Any]] = None
- model: Any = Field(default_factory=ChatCompletion)
- response_model_name: Optional[str] = Field(default=None, exclude=True)
- response_model_description: Optional[str] = Field(default=None, exclude=True)
- response_model_field_name: Optional[str] = Field(default=None, exclude=True)
+ from marvin import settings
- def __call__(
- self,
- *args: P.args,
- **kwargs: P.kwargs,
- ) -> Any:
- get_logger("marvin.AIFunction").debug_kv(
- f"Calling `ai_fn` {self.fn.__name__!r}",
- f"with args: {args} kwargs: {kwargs}",
- )
-
- return self.call(*args, **kwargs)
+ logger = get_logger("marvin.ai_fn")
- def get_prompt(
- self,
- ) -> Callable[P, Prompt[P]]:
- return ai_fn_prompt(
- self.fn,
- ctx=self.ctx,
- response_model_name=self.response_model_name,
- response_model_description=self.response_model_description,
- response_model_field_name=self.response_model_field_name,
- )
-
- def as_prompt(
- self,
- *args: P.args,
- **kwargs: P.kwargs,
- ) -> dict[str, Any]:
- return self.get_prompt()(*args, **kwargs).serialize(
- model=self.model,
+ logger.debug_kv(
+ "AI Function Call",
+ f"Calling {self.fn.__name__} with {args} and {kwargs}",
+ "blue",
)
- def as_dict(
- self,
- *args: P.args,
- **kwargs: P.kwargs,
- ) -> dict[str, Any]:
- return self.get_prompt()(*args, **kwargs).to_dict()
-
- def as_chat_completion(
- self,
- *args: P.args,
- **kwargs: P.kwargs,
- ) -> AbstractChatCompletion[T]:
- return self.model(**self.as_dict(*args, **kwargs))
-
- def call(
- self,
- *args: P.args,
- **kwargs: P.kwargs,
- ) -> Any:
- model_instance = self.as_chat_completion(*args, **kwargs).create().to_model()
- response_model_field_name = self.response_model_field_name or "output"
-
- if (output := getattr(model_instance, response_model_field_name, None)) is None:
- return model_instance
-
- return output
+ is_async_fn = asyncio.iscoroutinefunction(self.fn)
- async def acall(
- self,
- *args: P.args,
- **kwargs: P.kwargs,
- ) -> Any:
- model_instance = (
- await self.as_chat_completion(*args, **kwargs).acreate()
- ).to_model()
-
- response_model_field_name = self.response_model_field_name or "output"
-
- if (output := getattr(model_instance, response_model_field_name, None)) is None:
- return model_instance
-
- return output
+ call = "async_call" if is_async_fn else "sync_call"
+ create = (
+ self.create or settings.openai.chat.completions.acreate
+ if is_async_fn
+ else settings.openai.chat.completions.create
+ )
- def map(self, *map_args: list[Any], **map_kwargs: list[Any]):
+ result = getattr(self, call)(create, *args, **kwargs)
+
+ logger.debug_kv("AI Function Call", f"Returned {result}", "blue")
+
+ return result
+
+ async def async_call(
+ self, acreate: Callable[..., Awaitable[Any]], *args: P.args, **kwargs: P.kwargs
+ ) -> T:
+ _response = await acreate(**self.as_prompt(*args, **kwargs).serialize())
+ return self.parse(_response)
+
+ def sync_call(
+ self, create: Callable[..., Any], *args: P.args, **kwargs: P.kwargs
+ ) -> T:
+ _response = create(**self.as_prompt(*args, **kwargs).serialize())
+ return self.parse(_response)
+
+ def parse(self, response: "ChatCompletion") -> T:
+ tool_calls = response.choices[0].message.tool_calls
+ if tool_calls is None:
+ raise NotImplementedError
+ if self.fn is None:
+ raise NotImplementedError
+ arguments = tool_calls[0].function.arguments
+
+ tool = create_tool_from_type(
+ _type=self.fn.__annotations__["return"],
+ model_name=self.name,
+ model_description=self.description,
+ field_name=self.field_name,
+ field_description=self.field_description,
+ ).function
+ if not tool or not tool.model:
+ raise NotImplementedError
+ try:
+ return getattr(tool.model.model_validate_json(arguments), self.field_name)
+ except ValidationError:
+ # When the user provides a dict obj as a type hint, the arguments
+ # are returned usually as an object and not a nested dict.
+ _arguments: str = json.dumps({self.field_name: json.loads(arguments)})
+ return getattr(tool.model.model_validate_json(_arguments), self.field_name)
+
+ @expose_sync_method("map")
+ async def amap(self, *map_args: list[Any], **map_kwargs: list[Any]) -> list[T]:
"""
Map the AI function over a sequence of arguments. Runs concurrently.
+ A `map` twin method is provided by the `expose_sync_method` decorator.
+
+ You can use `map` or `amap` synchronously or asynchronously, respectively,
+ regardless of whether the user function is synchronous or asynchronous.
+
Arguments should be provided as if calling the function normally, but
each argument must be a list. The function is called once for each item
in the list, and the results are returned in a list.
- This method should be called synchronously.
-
For example, fn.map([1, 2]) is equivalent to [fn(1), fn(2)].
fn.map([1, 2], x=['a', 'b']) is equivalent to [fn(1, x='a'), fn(2, x='b')].
"""
- return run_sync(self.amap(*map_args, **map_kwargs))
-
- async def amap(self, *map_args: list[Any], **map_kwargs: list[Any]) -> list[Any]:
tasks: list[Any] = []
- if map_args:
+ if map_args and map_kwargs:
+ max_length = max(
+ len(arg) for arg in (map_args + tuple(map_kwargs.values()))
+ )
+ elif map_args:
max_length = max(len(arg) for arg in map_args)
else:
max_length = max(len(v) for v in map_kwargs.values())
@@ -172,57 +164,153 @@ async def amap(self, *map_args: list[Any], **map_kwargs: list[Any]) -> list[Any]
if map_kwargs
else {}
)
- tasks.append(self.acall(*call_args, **call_kwargs))
+
+ tasks.append(run_async(self, *call_args, **call_kwargs))
return await asyncio.gather(*tasks)
+ def as_prompt(
+ self,
+ *args: P.args,
+ **kwargs: P.kwargs,
+ ) -> PromptFunction[BaseModel]:
+ return PromptFunction[BaseModel].as_function_call(
+ fn=self.fn,
+ environment=self.environment,
+ prompt=self.prompt,
+ model_name=self.name,
+ model_description=self.description,
+ field_name=self.field_name,
+ field_description=self.field_description,
+ **self.render_kwargs,
+ )(*args, **kwargs)
+
+ @overload
+ @classmethod
+ def as_decorator(
+ cls: type[Self],
+ *,
+ environment: Optional[BaseEnvironment] = None,
+ prompt: Optional[str] = None,
+ model_name: str = "FormatResponse",
+ model_description: str = "Formats the response.",
+ field_name: str = "data",
+ field_description: str = "The data to format.",
+ acreate: Optional[Callable[..., Awaitable[Any]]] = None,
+ **render_kwargs: Any,
+ ) -> Callable[P, Self]:
+ pass
+
+ @overload
+ @classmethod
+ def as_decorator(
+ cls: type[Self],
+ fn: Callable[P, T],
+ *,
+ environment: Optional[BaseEnvironment] = None,
+ prompt: Optional[str] = None,
+ model_name: str = "FormatResponse",
+ model_description: str = "Formats the response.",
+ field_name: str = "data",
+ field_description: str = "The data to format.",
+ acreate: Optional[Callable[..., Awaitable[Any]]] = None,
+ **render_kwargs: Any,
+ ) -> Self:
+ pass
+
@classmethod
def as_decorator(
cls: type[Self],
fn: Optional[Callable[P, T]] = None,
- ctx: Optional[dict[str, Any]] = None,
- instructions: Optional[str] = None,
- response_model_name: Optional[str] = None,
- response_model_description: Optional[str] = None,
- response_model_field_name: Optional[str] = None,
- model: Optional[str] = None,
- **model_kwargs: Any,
- ) -> Union[Callable[P, T], Callable[P, Awaitable[T]]]:
- if not fn:
- return partial(
- cls.as_decorator,
- ctx=ctx,
- instructions=instructions,
- response_model_name=response_model_name,
- response_model_description=response_model_description,
- response_model_field_name=response_model_field_name,
- model=model,
- **model_kwargs,
- ) # type: ignore
-
- if not inspect.iscoroutinefunction(fn):
+ *,
+ environment: Optional[BaseEnvironment] = None,
+ prompt: Optional[str] = None,
+ model_name: str = "FormatResponse",
+ model_description: str = "Formats the response.",
+ field_name: str = "data",
+ field_description: str = "The data to format.",
+ **render_kwargs: Any,
+ ) -> Union[Callable[[Callable[P, T]], Self], Self]:
+ def decorator(func: Callable[P, T]) -> Self:
return cls(
- fn=fn,
- ctx={"instructions": instructions, **(ctx or {})},
- response_model_name=response_model_name,
- response_model_description=response_model_description,
- response_model_field_name=response_model_field_name,
- model=ChatCompletion(model=model, **model_kwargs),
- )
- else:
- return AsyncAIFunction[P, T](
- fn=fn,
- ctx={"instructions": instructions, **(ctx or {})},
- response_model_name=response_model_name,
- response_model_description=response_model_description,
- response_model_field_name=response_model_field_name,
- model=ChatCompletion(model=model, **model_kwargs),
+ fn=func,
+ environment=environment,
+ name=model_name,
+ description=model_description,
+ field_name=field_name,
+ field_description=field_description,
+ **({"prompt": prompt} if prompt else {}),
+ **render_kwargs,
)
+ if fn is not None:
+ return decorator(fn)
+
+ return decorator
+
+
+@overload
+def ai_fn(
+ *,
+ environment: Optional[BaseEnvironment] = None,
+ prompt: Optional[str] = None,
+ model_name: str = "FormatResponse",
+ model_description: str = "Formats the response.",
+ field_name: str = "data",
+ field_description: str = "The data to format.",
+ **render_kwargs: Any,
+) -> Callable[[Callable[P, T]], Callable[P, T]]:
+ pass
+
+
+@overload
+def ai_fn(
+ fn: Callable[P, T],
+ *,
+ environment: Optional[BaseEnvironment] = None,
+ prompt: Optional[str] = None,
+ model_name: str = "FormatResponse",
+ model_description: str = "Formats the response.",
+ field_name: str = "data",
+ field_description: str = "The data to format.",
+ **render_kwargs: Any,
+) -> Callable[P, T]:
+ pass
+
+
+def ai_fn(
+ fn: Optional[Callable[P, T]] = None,
+ *,
+ environment: Optional[BaseEnvironment] = None,
+ prompt: Optional[str] = None,
+ model_name: str = "FormatResponse",
+ model_description: str = "Formats the response.",
+ field_name: str = "data",
+ field_description: str = "The data to format.",
+ **render_kwargs: Any,
+) -> Union[Callable[[Callable[P, T]], Callable[P, T]], Callable[P, T]]:
+ if fn is not None:
+ return AIFunction.as_decorator( # type: ignore
+ fn=fn,
+ environment=environment,
+ prompt=prompt,
+ model_name=model_name,
+ model_description=model_description,
+ field_name=field_name,
+ field_description=field_description,
+ **render_kwargs,
+ )
-class AsyncAIFunction(AIFunction[P, T]):
- async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Any:
- return await super().acall(*args, **kwargs)
-
+ def decorator(func: Callable[P, T]) -> Callable[P, T]:
+ return AIFunction.as_decorator( # type: ignore
+ fn=func,
+ environment=environment,
+ prompt=prompt,
+ model_name=model_name,
+ model_description=model_description,
+ field_name=field_name,
+ field_description=field_description,
+ **render_kwargs,
+ )
-ai_fn = AIFunction.as_decorator
+ return decorator
diff --git a/src/marvin/components/ai_image.py b/src/marvin/components/ai_image.py
new file mode 100644
index 000000000..1dde168ed
--- /dev/null
+++ b/src/marvin/components/ai_image.py
@@ -0,0 +1,154 @@
+from functools import partial, wraps
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Generic,
+ Optional,
+ TypeVar,
+ Union,
+ overload,
+)
+
+from pydantic import BaseModel, Field
+from typing_extensions import ParamSpec, Self
+
+from marvin.components.prompt import PromptFunction
+from marvin.utilities.jinja import (
+ BaseEnvironment,
+)
+
+if TYPE_CHECKING:
+ from openai.types.images_response import ImagesResponse
+
+T = TypeVar("T")
+
+P = ParamSpec("P")
+
+
+class AIImage(BaseModel, Generic[P]):
+ fn: Optional[Callable[P, Any]] = None
+ environment: Optional[BaseEnvironment] = None
+ prompt: Optional[str] = Field(default=None)
+ render_kwargs: dict[str, Any] = Field(default_factory=dict)
+
+ generate: Optional[Callable[..., "ImagesResponse"]] = Field(default=None)
+
+ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> "ImagesResponse":
+ generate = self.generate
+ if self.fn is None:
+ raise NotImplementedError
+ if generate is None:
+ from marvin.settings import settings
+
+ generate = settings.openai.images.generate
+
+ _response = generate(prompt=self.as_prompt(*args, **kwargs))
+
+ return _response
+
+ def as_prompt(
+ self,
+ *args: P.args,
+ **kwargs: P.kwargs,
+ ) -> str:
+ return (
+ PromptFunction[BaseModel]
+ .as_function_call(
+ fn=self.fn,
+ environment=self.environment,
+ prompt=self.prompt,
+ **self.render_kwargs,
+ )(*args, **kwargs)
+ .messages[0]
+ .content
+ )
+
+ @overload
+ @classmethod
+ def as_decorator(
+ cls: type[Self],
+ *,
+ environment: Optional[BaseEnvironment] = None,
+ prompt: Optional[str] = None,
+ **render_kwargs: Any,
+ ) -> Callable[P, Self]:
+ pass
+
+ @overload
+ @classmethod
+ def as_decorator(
+ cls: type[Self],
+ fn: Callable[P, Any],
+ *,
+ environment: Optional[BaseEnvironment] = None,
+ prompt: Optional[str] = None,
+ **render_kwargs: Any,
+ ) -> Self:
+ pass
+
+ @classmethod
+ def as_decorator(
+ cls: type[Self],
+ fn: Optional[Callable[P, Any]] = None,
+ *,
+ environment: Optional[BaseEnvironment] = None,
+ prompt: Optional[str] = None,
+ **render_kwargs: Any,
+ ) -> Union[Self, Callable[[Callable[P, Any]], Self]]:
+ if fn is None:
+ return partial(
+ cls,
+ environment=environment,
+ **({"prompt": prompt} if prompt else {}),
+ **render_kwargs,
+ )
+
+ return cls(
+ fn=fn,
+ environment=environment,
+ **({"prompt": prompt} if prompt else {}),
+ **render_kwargs,
+ )
+
+
+def ai_image(
+ fn: Optional[Callable[P, Any]] = None,
+ *,
+ environment: Optional[BaseEnvironment] = None,
+ prompt: Optional[str] = None,
+ **render_kwargs: Any,
+) -> Union[
+ Callable[[Callable[P, Any]], Callable[P, "ImagesResponse"]],
+ Callable[P, "ImagesResponse"],
+]:
+ def wrapper(
+ func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs
+ ) -> "ImagesResponse":
+ return AIImage[P].as_decorator(
+ func,
+ environment=environment,
+ prompt=prompt,
+ **render_kwargs,
+ )(*args, **kwargs)
+
+ if fn is not None:
+ return wraps(fn)(partial(wrapper, fn))
+
+ def decorator(fn: Callable[P, Any]) -> Callable[P, "ImagesResponse"]:
+ return wraps(fn)(partial(wrapper, fn))
+
+ return decorator
+
+
+def create_image(
+ prompt: str,
+ environment: Optional[BaseEnvironment] = None,
+ generate: Optional[Callable[..., "ImagesResponse"]] = None,
+ **model_kwargs: Any,
+) -> "ImagesResponse":
+ if generate is None:
+ from marvin.settings import settings
+
+ generate = settings.openai.images.generate
+ return generate(prompt=prompt, **model_kwargs)
diff --git a/src/marvin/components/ai_model.py b/src/marvin/components/ai_model.py
index 86e4caa90..39b83126c 100644
--- a/src/marvin/components/ai_model.py
+++ b/src/marvin/components/ai_model.py
@@ -1,411 +1,105 @@
import asyncio
import inspect
from functools import partial
-from typing import Any, Callable, Optional, TypeVar
-
-from typing_extensions import ParamSpec, Self
-
-from marvin._compat import BaseModel
-from marvin.core.ChatCompletion import ChatCompletion
-from marvin.core.ChatCompletion.abstract import AbstractChatCompletion
-from marvin.prompts import Prompt, prompt_fn
-from marvin.utilities.async_utils import run_sync
-from marvin.utilities.logging import get_logger
-
-T = TypeVar("T", bound=BaseModel)
-
-A = TypeVar("A", bound=Any)
-
-P = ParamSpec("P")
-
-
-def ai_model_prompt(
- cls: type[BaseModel],
- ctx: Optional[dict[str, Any]] = None,
- **kwargs: Any,
-) -> Callable[[str], Prompt[P]]:
- description = cls.__doc__ or ""
- if ctx and ctx.get("instructions") and isinstance(ctx.get("instructions"), str):
- instructions = str(ctx.get("instructions"))
- description += "\n" + instructions if (instructions != description) else ""
-
- @prompt_fn(
- ctx={"ctx": ctx or {}, "inspect": inspect},
- response_model=cls,
- response_model_name="FormatResponse",
- response_model_description=description,
- serialize_on_call=False,
- )
- def prompt_wrapper(text: str) -> None: # type: ignore # noqa
- """
- The user will provide text that you need to parse into a
- structured form {{'(note you must also: ' + ctx.get('instructions') + ')' if ctx.get('instructions')}}.
- To validate your response, you must call the
- `{{response_model.__name__}}` function.
- Use the provided text and context to extract, deduce, or infer
- any parameters needed by `{{response_model.__name__}}`, including any missing
- data.
-
- You have been provided the following context to perform your task:
- - The current time is {{now()}}.
- {% set context = ctx.get('context_fn')(text).items() if ctx.get('context_fn') %}
- {% if context %}
- {%for (arg, value) in context%}
- - {{ arg }}: {{ value }}\n
- {% endfor %}
- {% endif %}
-
- User: The text to parse: {{text}}
-
-
- """ # noqa
-
- return prompt_wrapper # type: ignore
-
-
-class AIModel(BaseModel):
- def __init__(
- self,
- text: Optional[str] = None,
- /,
- instructions_: Optional[str] = None,
- **kwargs: Any,
- ):
- if text:
- kwargs.update(self.__class__.call(text, instructions=instructions_))
-
- super().__init__(**kwargs)
-
- @classmethod
- def get_prompt(
- cls,
- ctx: Optional[dict[str, Any]] = None,
- instructions: Optional[str] = None,
- response_model_name: Optional[str] = None,
- response_model_description: Optional[str] = None,
- response_model_field_name: Optional[str] = None,
- ) -> Callable[[str], Prompt[P]]:
- ctx = ctx or getattr(cls, "__metadata__", {}).get("ctx", {})
- instructions = ( # type: ignore
- "\n".join(
- list(
- filter(
- bool,
- [
- instructions,
- getattr(cls, "__metadata__", {}).get("instructions"),
- ],
- )
- ) # type: ignore
- )
- or None
- )
-
- response_model_name = response_model_name or getattr(
- cls, "__metadata__", {}
- ).get("response_model_name")
- response_model_description = response_model_description or getattr(
- cls, "__metadata__", {}
- ).get("response_model_description")
- response_model_field_name = response_model_field_name or getattr(
- cls, "__metadata__", {}
- ).get("response_model_field_name")
-
- return ai_model_prompt(
- cls,
- ctx=((ctx or {}) | {"instructions": instructions}),
- response_model_name=response_model_name,
- response_model_description=response_model_description,
- response_model_field_name=response_model_field_name,
+from typing import Any, Callable, Optional, TypeVar, Union, overload
+
+from marvin.components.ai_function import ai_fn
+from marvin.utilities.asyncio import run_sync
+from marvin.utilities.jinja import BaseEnvironment
+
+T = TypeVar("T")
+
+prompt = inspect.cleandoc(
+ "The user will provide context as text that you need to parse into a structured"
+ " form. To validate your response, you must call the"
+ " `{{_response_model.function.name}}` function. Use the provided text to extract or"
+ " infer any parameters needed by `{{_response_model.function.name}}`, including any"
+ " missing data."
+ " user: The text to parse: {{text}}"
+)
+
+
+@overload
+def ai_model(
+ *,
+ environment: Optional[BaseEnvironment] = None,
+ prompt: Optional[str] = prompt,
+ model_name: str = "FormatResponse",
+ model_description: str = "Formats the response.",
+ field_name: str = "data",
+ field_description: str = "The data to format.",
+ **render_kwargs: Any,
+) -> Callable[[T], Callable[[str], T]]:
+ pass
+
+
+@overload
+def ai_model(
+ _type: Optional[T],
+ *,
+ environment: Optional[BaseEnvironment] = None,
+ prompt: Optional[str] = prompt,
+ model_name: str = "FormatResponse",
+ model_description: str = "Formats the response.",
+ field_name: str = "data",
+ field_description: str = "The data to format.",
+ **render_kwargs: Any,
+) -> Callable[[str], T]:
+ pass
+
+
+def ai_model(
+ _type: Optional[T] = None,
+ *,
+ environment: Optional[BaseEnvironment] = None,
+ prompt: Optional[str] = prompt,
+ model_name: str = "FormatResponse",
+ model_description: str = "Formats the response.",
+ field_name: str = "data",
+ field_description: str = "The data to format.",
+ **render_kwargs: Any,
+) -> Union[Callable[[T], Callable[[str], T]], Callable[[str], T],]:
+ def wrapper(_type_: T, text: str) -> T:
+ def extract(text: str) -> T: # type: ignore
+ pass
+
+ extract.__annotations__["return"] = _type_
+
+ return ai_fn(
+ extract,
+ environment=environment,
+ prompt=prompt,
+ model_name=model_name,
+ model_description=model_description,
+ field_name=field_name,
+ field_description=field_description,
+ **render_kwargs,
+ )(text)
+
+ async def async_wrapper(_type_: T, text: str) -> T:
+ return wrapper(_type_, text)
+
+ async def amap(inputs: list[str]) -> list[T]:
+ return await asyncio.gather(
+ *[
+ asyncio.create_task(async_wrapper(_type, input_text))
+ for input_text in inputs
+ ]
)
- @classmethod
- def as_prompt(
- cls,
- text: str,
- ctx: Optional[dict[str, Any]] = None,
- instructions: Optional[str] = None,
- response_model_name: Optional[str] = None,
- response_model_description: Optional[str] = None,
- response_model_field_name: Optional[str] = None,
- model: Optional[str] = None,
- **model_kwargs: Any,
- ) -> dict[str, Any]:
- metadata = getattr(cls, "__metadata__", {})
-
- # Set default values using a loop to reduce repetition
- default_keys = [
- "ctx",
- "instructions",
- "response_model_name",
- "response_model_description",
- "response_model_field_name",
- "model",
- "model_kwargs",
- ]
- local_vars = locals()
- for key in default_keys:
- if local_vars.get(key, None) is None:
- local_vars[key] = metadata.get(key, {})
-
- return cls.get_prompt(
- ctx=ctx,
- instructions=instructions,
- response_model_name=response_model_name,
- response_model_description=response_model_description,
- response_model_field_name=response_model_field_name,
- )(text).serialize(model=ChatCompletion(model=model, **model_kwargs))
-
- @classmethod
- def as_dict(
- cls,
- text: str,
- ctx: Optional[dict[str, Any]] = None,
- instructions: Optional[str] = None,
- response_model_name: Optional[str] = None,
- response_model_description: Optional[str] = None,
- response_model_field_name: Optional[str] = None,
- model: Optional[str] = None,
- **model_kwargs: Any,
- ) -> dict[str, Any]:
- metadata = getattr(cls, "__metadata__", {})
-
- # Set default values using a loop to reduce repetition
- default_keys = [
- "ctx",
- "instructions",
- "response_model_name",
- "response_model_description",
- "response_model_field_name",
- "model",
- "model_kwargs",
- ]
- local_vars = locals()
- for key in default_keys:
- if local_vars.get(key, None) is None:
- local_vars[key] = metadata.get(key, {})
- return cls.get_prompt(
- ctx=ctx,
- instructions=instructions,
- response_model_name=response_model_name,
- response_model_description=response_model_description,
- response_model_field_name=response_model_field_name,
- )(text).to_dict()
-
- @classmethod
- def as_chat_completion(
- cls,
- text: str,
- ctx: Optional[dict[str, Any]] = None,
- instructions: Optional[str] = None,
- response_model_name: Optional[str] = None,
- response_model_description: Optional[str] = None,
- response_model_field_name: Optional[str] = None,
- model: Optional[str] = None,
- **model_kwargs: Any,
- ) -> AbstractChatCompletion[T]: # type: ignore
- metadata = getattr(cls, "__metadata__", {})
+ def map(inputs: list[str]) -> list[T]:
+ return run_sync(amap(inputs))
- # Set default values using a loop to reduce repetition
- default_keys = [
- "ctx",
- "instructions",
- "response_model_name",
- "response_model_description",
- "response_model_field_name",
- "model",
- "model_kwargs",
- ]
- local_vars = locals()
- for key in default_keys:
- if local_vars.get(key, None) is None:
- local_vars[key] = metadata.get(key, {})
-
- return ChatCompletion(model=model, **model_kwargs)(
- **cls.as_dict(
- text,
- ctx=ctx,
- instructions=instructions,
- response_model_name=response_model_name,
- response_model_description=response_model_description,
- response_model_field_name=response_model_field_name,
- )
- ) # type: ignore
-
- @classmethod
- def call(
- cls: type[Self],
- text: str,
- *,
- ctx: Optional[dict[str, Any]] = None,
- instructions: Optional[str] = None,
- response_model_name: Optional[str] = None,
- response_model_description: Optional[str] = None,
- response_model_field_name: Optional[str] = None,
- model: Optional[str] = None,
- **model_kwargs: Any,
- ) -> Self:
- metadata = getattr(cls, "__metadata__", {})
-
- get_logger("marvin.AIModel").debug_kv(
- f"Calling `ai_model` {cls.__name__!r}",
- f"with {text!r}",
- )
-
- # Set default values using a loop to reduce repetition
- default_keys = [
- "ctx",
- "instructions",
- "response_model_name",
- "response_model_description",
- "response_model_field_name",
- "model",
- "model_kwargs",
- ]
- local_vars = locals()
- for key in default_keys:
- if local_vars.get(key, None) is None:
- local_vars[key] = metadata.get(key, {})
-
- _model: Self = ( # type: ignore
- cls.as_chat_completion(
- text,
- ctx=ctx,
- instructions=instructions,
- response_model_name=response_model_name,
- response_model_description=response_model_description,
- response_model_field_name=response_model_field_name,
- model=model,
- **model_kwargs,
- )
- .create()
- .to_model(cls)
- )
- return _model # type: ignore
-
- @classmethod
- async def acall(
- cls: type[Self],
- text: str,
- *,
- ctx: Optional[dict[str, Any]] = None,
- instructions: Optional[str] = None,
- response_model_name: Optional[str] = None,
- response_model_description: Optional[str] = None,
- response_model_field_name: Optional[str] = None,
- model: Optional[str] = None,
- **model_kwargs: Any,
- ) -> Self:
- metadata = getattr(cls, "__metadata__", {})
-
- get_logger("marvin.AIModel").debug_kv(
- f"Calling `ai_model` {cls.__name__!r}",
- f"with {text!r}",
- )
-
- # Set default values using a loop to reduce repetition
- default_keys = [
- "ctx",
- "instructions",
- "response_model_name",
- "response_model_description",
- "response_model_field_name",
- "model",
- "model_kwargs",
- ]
- local_vars = locals()
- for key in default_keys:
- if local_vars.get(key, None) is None:
- local_vars[key] = metadata.get(key, {})
-
- _model: Self = ( # type: ignore
- await cls.as_chat_completion(
- text,
- ctx=ctx,
- instructions=instructions,
- response_model_name=response_model_name,
- response_model_description=response_model_description,
- response_model_field_name=response_model_field_name,
- model=model,
- **model_kwargs,
- ).acreate() # type: ignore
- ).to_model(cls)
- return _model # type: ignore
-
- @classmethod
- def map(cls, *map_args: list[str], **map_kwargs: list[Any]):
- """
- Map the AI function over a sequence of arguments. Runs concurrently.
-
- Arguments should be provided as if calling the function normally, but
- each argument must be a list. The function is called once for each item
- in the list, and the results are returned in a list.
-
- This method should be called synchronously.
-
- For example, fn.map([1, 2]) is equivalent to [fn(1), fn(2)].
-
- fn.map([1, 2], x=['a', 'b']) is equivalent to [fn(1, x='a'), fn(2, x='b')].
- """
- return run_sync(cls.amap(*map_args, **map_kwargs))
-
- @classmethod
- async def amap(cls, *map_args: list[str], **map_kwargs: list[Any]) -> list[Any]:
- tasks: list[Any] = []
- if map_args:
- max_length = max(len(arg) for arg in map_args)
- else:
- max_length = max(len(v) for v in map_kwargs.values())
-
- for i in range(max_length):
- call_args: list[str] = [
- arg[i] if i < len(arg) else None for arg in map_args
- ] # type: ignore
-
- tasks.append(cls.acall(*call_args, **map_kwargs))
-
- return await asyncio.gather(*tasks)
-
- @classmethod
- def as_decorator(
- cls: type[Self],
- base_model: Optional[type[BaseModel]] = None,
- ctx: Optional[dict[str, Any]] = None,
- instructions: Optional[str] = None,
- response_model_name: Optional[str] = None,
- response_model_description: Optional[str] = None,
- response_model_field_name: Optional[str] = None,
- model: Optional[str] = None,
- **model_kwargs: Any,
- ) -> type[BaseModel]:
- if not base_model:
- return partial(
- cls.as_decorator,
- ctx=ctx,
- instructions=instructions,
- response_model_name=response_model_name,
- response_model_description=response_model_description,
- response_model_field_name=response_model_field_name,
- model=model,
- **model_kwargs,
- ) # type: ignore
-
- response = type(base_model.__name__, (cls, base_model), {})
- response.__doc__ = base_model.__doc__
- setattr(
- response,
- "__metadata__",
- {
- "ctx": ctx or {},
- "instructions": instructions,
- "response_model_name": response_model_name,
- "response_model_description": response_model_description,
- "response_model_field_name": response_model_field_name,
- "model": model,
- "model_kwargs": model_kwargs,
- },
- )
- return response # type: ignore
+ if _type is not None:
+ wrapper_with_map = partial(wrapper, _type)
+ wrapper_with_map.amap = amap
+ wrapper_with_map.map = map
+ return wrapper_with_map
+ def decorator(_type_: T) -> Callable[[str], T]:
+ decorated = partial(wrapper, _type_)
+ decorated.amap = amap
+ decorated.map = map
+ return decorated
-ai_model = AIModel.as_decorator
+ return decorator
diff --git a/src/marvin/components/ai_model_factory.py b/src/marvin/components/ai_model_factory.py
deleted file mode 100644
index 33a6112a3..000000000
--- a/src/marvin/components/ai_model_factory.py
+++ /dev/null
@@ -1,19 +0,0 @@
-from typing import Optional
-
-from marvin._compat import BaseModel
-from marvin.components.ai_model import ai_model
-
-
-class DataSchema(BaseModel):
- title: Optional[str] = None
- type: Optional[str] = None
- properties: Optional[dict] = {}
- required: Optional[list[str]] = []
- additionalProperties: bool = False
- definitions: dict = {}
- description: Optional[str] = None
-
-
-# If you're reading this and expected something fancier,
-# I'm sorry to disappoint you. It's this simple.
-AIModelFactory = ai_model(DataSchema)
diff --git a/src/marvin/components/library/ai_functions.py b/src/marvin/components/library/ai_functions.py
deleted file mode 100644
index 42cb6e616..000000000
--- a/src/marvin/components/library/ai_functions.py
+++ /dev/null
@@ -1,13 +0,0 @@
-from datetime import datetime
-
-from marvin.components.ai_function import ai_fn
-
-
-@ai_fn
-def summarize_text(text: str, specifications: str = "concise, comprehensive") -> str:
- """generates a summary of `text` according to the `specifications`"""
-
-
-@ai_fn
-def make_datetime(description: str, tz: str = "UTC") -> datetime:
- """generates a datetime from a description"""
diff --git a/src/marvin/components/library/ai_models.py b/src/marvin/components/library/ai_models.py
deleted file mode 100644
index cc149cab4..000000000
--- a/src/marvin/components/library/ai_models.py
+++ /dev/null
@@ -1,101 +0,0 @@
-import json
-from typing import Optional
-
-import httpx
-from typing_extensions import Self
-
-from marvin import ai_model
-from marvin._compat import BaseModel, Field, SecretStr, field_validator
-from marvin.settings import MarvinBaseSettings
-
-
-class DiscourseSettings(MarvinBaseSettings):
- class Config:
- env_prefix = "MARVIN_DISCOURSE_"
-
- help_category_id: Optional[int] = Field(
- None, env=["MARVIN_DISCOURSE_HELP_CATEGORY_ID"]
- )
- api_key: Optional[SecretStr] = Field(None, env=["MARVIN_DISCOURSE_API_KEY"])
- api_username: Optional[str] = Field(None, env=["MARVIN_DISCOURSE_API_USERNAME"])
- url: Optional[str] = Field(None, env=["MARVIN_DISCOURSE_URL"])
-
-
-discourse_settings = DiscourseSettings()
-
-
-@ai_model(instructions="Produce a comprehensive Discourse post from text.")
-class DiscoursePost(BaseModel):
- title: Optional[str] = Field(
- description="A fitting title for the post.",
- example="How to install Prefect",
- )
- question: Optional[str] = Field(
- description="The question that is posed in the text.",
- example="How do I install Prefect?",
- )
- answer: Optional[str] = Field(
- description=(
- "The complete answer to the question posed in the text."
- " This answer should comprehensively answer the question, "
- " explain any relevant concepts, and have a friendly, academic tone,"
- " and provide any links to relevant resources found in the thread."
- " This answer should be written in Markdown, with any code blocks"
- " formatted as `code` or ```\n```."
- )
- )
-
- topic_url: Optional[str] = Field(None)
-
- @field_validator("title", "question", "answer")
- def non_empty_string(cls, value):
- if not value:
- raise ValueError("this field cannot be empty")
- return value
-
- @classmethod
- def from_slack_thread(cls, messages: list[str]) -> Self:
- return cls("here is the transcript:\n" + "\n\n".join(messages))
-
- async def publish(
- self,
- topic: str = None,
- category: Optional[int] = None,
- url: str = discourse_settings.url,
- tags: list[str] = None,
- ) -> str:
- if not category:
- category = discourse_settings.help_category_id
-
- headers = {
- "Api-Key": discourse_settings.api_key.get_secret_value(),
- "Api-Username": discourse_settings.api_username,
- "Content-Type": "application/json",
- }
- data = {
- "title": self.title,
- "raw": (
- f"## **{self.question}**\n\n{self.answer}"
- "\n\n---\n\n*This topic was created by Marvin.*"
- ),
- "category": category,
- "tags": tags or ["marvin"],
- }
-
- if topic:
- data["tags"].append(topic)
-
- async with httpx.AsyncClient() as client:
- response = await client.post(
- url=f"{url}/posts.json", headers=headers, data=json.dumps(data)
- )
-
- response.raise_for_status()
-
- response_data = response.json()
- topic_id = response_data.get("topic_id")
- post_number = response_data.get("post_number")
-
- self.topic_url = f"{url}/t/{topic_id}/{post_number}"
-
- return self.topic_url
diff --git a/src/marvin/components/prompt.py b/src/marvin/components/prompt.py
new file mode 100644
index 000000000..3cf04c07c
--- /dev/null
+++ b/src/marvin/components/prompt.py
@@ -0,0 +1,242 @@
+import inspect
+import re
+from functools import partial, wraps
+from typing import (
+ Any,
+ Callable,
+ Optional,
+ TypeVar,
+ Union,
+ overload,
+)
+
+import pydantic
+from pydantic import BaseModel
+from typing_extensions import ParamSpec, Self
+
+from marvin.requests import BaseMessage as Message
+from marvin.requests import Prompt, Tool
+from marvin.serializers import (
+ create_grammar_from_vocabulary,
+ create_tool_from_type,
+ create_vocabulary_from_type,
+)
+from marvin.settings import settings
+from marvin.utilities.jinja import (
+ BaseEnvironment,
+ Transcript,
+)
+
+P = ParamSpec("P")
+T = TypeVar("T")
+U = TypeVar("U", bound=BaseModel)
+
+
+class PromptFunction(Prompt[U]):
+ messages: list[Message] = pydantic.Field(default_factory=list)
+
+ def serialize(self) -> dict[str, Any]:
+ return self.model_dump(exclude_unset=True, exclude_none=True)
+
+ @overload
+ @classmethod
+ def as_grammar(
+ cls: type[Self],
+ *,
+ environment: Optional[BaseEnvironment] = None,
+ prompt: Optional[str] = None,
+ enumerate: bool = True,
+ encoder: Callable[[str], list[int]] = settings.openai.chat.completions.encoder,
+ max_tokens: Optional[int] = 1,
+ ) -> Callable[[Callable[P, Any]], Callable[P, Self]]:
+ pass
+
+ @overload
+ @classmethod
+ def as_grammar(
+ cls: type[Self],
+ fn: Optional[Callable[P, Any]] = None,
+ *,
+ environment: Optional[BaseEnvironment] = None,
+ prompt: Optional[str] = None,
+ enumerate: bool = True,
+ encoder: Callable[[str], list[int]] = settings.openai.chat.completions.encoder,
+ max_tokens: Optional[int] = 1,
+ ) -> Callable[P, Self]:
+ pass
+
+ @classmethod
+ def as_grammar(
+ cls: type[Self],
+ fn: Optional[Callable[P, Any]] = None,
+ *,
+ environment: Optional[BaseEnvironment] = None,
+ prompt: Optional[str] = None,
+ enumerate: bool = True,
+ encoder: Callable[[str], list[int]] = settings.openai.chat.completions.encoder,
+ max_tokens: Optional[int] = 1,
+ **kwargs: Any,
+ ) -> Union[Callable[[Callable[P, Any]], Callable[P, Self]], Callable[P, Self],]:
+ def wrapper(func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> Self:
+ # Get the signature of the function
+ signature = inspect.signature(func)
+ params = signature.bind(*args, **kwargs)
+ params.apply_defaults()
+
+ vocabulary = create_vocabulary_from_type(
+ inspect.signature(func).return_annotation
+ )
+
+ grammar = create_grammar_from_vocabulary(
+ vocabulary=vocabulary,
+ encoder=encoder,
+ _enumerate=enumerate,
+ max_tokens=max_tokens,
+ )
+
+ messages = Transcript(
+ content=prompt or func.__doc__ or ""
+ ).render_to_messages(
+ **kwargs | params.arguments,
+ _arguments=params.arguments,
+ _options=vocabulary,
+ _doc=func.__doc__,
+ _source_code=(
+ "\ndef" + "def".join(re.split("def", inspect.getsource(func))[1:])
+ ),
+ )
+
+ return cls(
+ messages=messages,
+ **grammar.model_dump(exclude_unset=True, exclude_none=True),
+ )
+
+ if fn is not None:
+ return wraps(fn)(partial(wrapper, fn))
+
+ def decorator(fn: Callable[P, Any]) -> Callable[P, Self]:
+ return wraps(fn)(partial(wrapper, fn))
+
+ return decorator
+
+ @overload
+ @classmethod
+ def as_function_call(
+ cls: type[Self],
+ *,
+ environment: Optional[BaseEnvironment] = None,
+ prompt: Optional[str] = None,
+ model_name: str = "FormatResponse",
+ model_description: str = "Formats the response.",
+ field_name: str = "data",
+ field_description: str = "The data to format.",
+ ) -> Callable[[Callable[P, Any]], Callable[P, Self]]:
+ pass
+
+ @overload
+ @classmethod
+ def as_function_call(
+ cls: type[Self],
+ fn: Optional[Callable[P, Any]] = None,
+ *,
+ environment: Optional[BaseEnvironment] = None,
+ prompt: Optional[str] = None,
+ model_name: str = "FormatResponse",
+ model_description: str = "Formats the response.",
+ field_name: str = "data",
+ field_description: str = "The data to format.",
+ ) -> Callable[P, Self]:
+ pass
+
+ @classmethod
+ def as_function_call(
+ cls: type[Self],
+ fn: Optional[Callable[P, Any]] = None,
+ *,
+ environment: Optional[BaseEnvironment] = None,
+ prompt: Optional[str] = None,
+ model_name: str = "FormatResponse",
+ model_description: str = "Formats the response.",
+ field_name: str = "data",
+ field_description: str = "The data to format.",
+ **kwargs: Any,
+ ) -> Union[Callable[[Callable[P, Any]], Callable[P, Self]], Callable[P, Self],]:
+ def wrapper(func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> Self:
+ # Get the signature of the function
+ signature = inspect.signature(func)
+ params = signature.bind(*args, **kwargs)
+ params.apply_defaults()
+
+ tool = create_tool_from_type(
+ _type=inspect.signature(func).return_annotation,
+ model_name=model_name,
+ model_description=model_description,
+ field_name=field_name,
+ field_description=field_description,
+ )
+
+ messages = Transcript(
+ content=prompt or func.__doc__ or ""
+ ).render_to_messages(
+ **kwargs | params.arguments,
+ _doc=func.__doc__,
+ _arguments=params.arguments,
+ _response_model=tool,
+ _source_code=(
+ "\ndef" + "def".join(re.split("def", inspect.getsource(func))[1:])
+ ),
+ )
+
+ return cls(
+ messages=messages,
+ tool_choice={
+ "type": "function",
+ "function": {"name": getattr(tool.function, "name", model_name)},
+ },
+ tools=[Tool[BaseModel](**tool.model_dump())],
+ )
+
+ if fn is not None:
+ return wraps(fn)(partial(wrapper, fn))
+
+ def decorator(fn: Callable[P, Any]) -> Callable[P, Self]:
+ return wraps(fn)(partial(wrapper, fn))
+
+ return decorator
+
+
+def prompt_fn(
+ fn: Optional[Callable[P, T]] = None,
+ *,
+ environment: Optional[BaseEnvironment] = None,
+ prompt: Optional[str] = None,
+ model_name: str = "FormatResponse",
+ model_description: str = "Formats the response.",
+ field_name: str = "data",
+ field_description: str = "The data to format.",
+ **kwargs: Any,
+) -> Union[
+ Callable[[Callable[P, T]], Callable[P, dict[str, Any]]],
+ Callable[P, dict[str, Any]],
+]:
+ def wrapper(
+ func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs
+ ) -> dict[str, Any]:
+ return PromptFunction.as_function_call(
+ fn=func,
+ environment=environment,
+ prompt=prompt,
+ model_name=model_name,
+ model_description=model_description,
+ field_name=field_name,
+ field_description=field_description,
+ **kwargs,
+ )(*args, **kwargs).serialize()
+
+ if fn is not None:
+ return wraps(fn)(partial(wrapper, fn))
+
+ def decorator(fn: Callable[P, Any]) -> Callable[P, dict[str, Any]]:
+ return wraps(fn)(partial(wrapper, fn))
+
+ return decorator
diff --git a/src/marvin/components/speech.py b/src/marvin/components/speech.py
new file mode 100644
index 000000000..fbc45595d
--- /dev/null
+++ b/src/marvin/components/speech.py
@@ -0,0 +1,73 @@
+from pathlib import Path
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Coroutine,
+ Literal,
+ Optional,
+ TypeVar,
+)
+
+from typing_extensions import ParamSpec
+
+if TYPE_CHECKING:
+ from openai._base_client import HttpxBinaryResponseContent
+
+T = TypeVar("T")
+
+P = ParamSpec("P")
+
+
+def speak(
+ input: str,
+ *,
+ create: Optional[Callable[..., "HttpxBinaryResponseContent"]] = None,
+ model: Optional[str] = "tts-1-hd",
+ voice: Optional[
+ Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
+ ] = None,
+ response_format: Optional[Literal["mp3", "opus", "aac", "flac"]] = None,
+ speed: Optional[float] = None,
+ filepath: Path,
+) -> None:
+ if create is None:
+ from marvin.settings import settings
+
+ create = settings.openai.audio.speech.create
+ return create(
+ input=input,
+ **({"model": model} if model else {}),
+ **({"voice": voice} if voice else {}),
+ **({"response_format": response_format} if response_format else {}),
+ **({"speed": speed} if speed else {}),
+ ).stream_to_file(filepath)
+
+
+async def aspeak(
+ input: str,
+ *,
+ acreate: Optional[
+ Callable[..., Coroutine[Any, Any, "HttpxBinaryResponseContent"]]
+ ] = None,
+ model: Optional[str],
+ voice: Optional[
+ Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
+ ] = None,
+ response_format: Optional[Literal["mp3", "opus", "aac", "flac"]] = None,
+ speed: Optional[float] = None,
+ filepath: Path,
+) -> None:
+ if acreate is None:
+ from marvin.settings import settings
+
+ acreate = settings.openai.audio.speech.acreate
+ return (
+ await acreate(
+ input=input,
+ **({"model": model} if model else {}),
+ **({"voice": voice} if voice else {}),
+ **({"response_format": response_format} if response_format else {}),
+ **({"speed": speed} if speed else {}),
+ )
+ ).stream_to_file(filepath)
diff --git a/src/marvin/core/ChatCompletion/__init__.py b/src/marvin/core/ChatCompletion/__init__.py
deleted file mode 100644
index 068c171ce..000000000
--- a/src/marvin/core/ChatCompletion/__init__.py
+++ /dev/null
@@ -1,56 +0,0 @@
-from typing import Optional, Any, TypeVar
-from .abstract import AbstractChatCompletion
-
-from marvin._compat import BaseModel
-from marvin.settings import settings
-
-T = TypeVar(
- "T",
- bound=BaseModel,
-)
-
-PROVIDER_SHORTCUTS = {
- "gpt-3.5-turbo": "openai",
- "gpt-4": "openai",
- "claude-1": "anthropic",
- "claude-2": "anthropic",
-}
-
-
-def parse_model_shortcut(provider: Optional[str]) -> tuple[str, str]:
- """
- Parse a model string into a provider and a model name.
- - If the provider is None, use the default provider and model.
- - If the provider is a shortcut, use the shortcut to get the provider and model.
- """
- if provider is None:
- try:
- provider, model = settings.llm_model.split("/")
- except Exception:
- provider, model = (
- PROVIDER_SHORTCUTS[str(settings.llm_model)],
- settings.llm_model,
- )
-
- elif provider in PROVIDER_SHORTCUTS:
- provider, model = PROVIDER_SHORTCUTS[provider], provider
- else:
- provider, model = provider.split("/")
- return provider, model
-
-
-def ChatCompletion(
- model: Optional[str] = None,
- **kwargs: Any,
-) -> AbstractChatCompletion[T]: # type: ignore
- provider, model = parse_model_shortcut(model)
- if provider == "openai" or provider == "azure_openai":
- from .providers.openai import OpenAIChatCompletion
-
- return OpenAIChatCompletion(provider=provider, model=model, **kwargs)
- if provider == "anthropic":
- from .providers.anthropic import AnthropicChatCompletion
-
- return AnthropicChatCompletion(model=model, **kwargs)
- else:
- raise ValueError(f"Unknown provider: {provider}")
diff --git a/src/marvin/core/ChatCompletion/abstract.py b/src/marvin/core/ChatCompletion/abstract.py
deleted file mode 100644
index 059bbfa10..000000000
--- a/src/marvin/core/ChatCompletion/abstract.py
+++ /dev/null
@@ -1,224 +0,0 @@
-from abc import ABC, abstractmethod
-from typing import Any, Generic, Optional, TypeVar
-
-from marvin._compat import BaseModel, Field, model_copy, model_dump
-from marvin.utilities.messages import Message
-from typing_extensions import Self
-
-from .handlers import Request, Response, Turn
-
-T = TypeVar(
- "T",
- bound=BaseModel,
-)
-
-
-class Conversation(BaseModel, Generic[T], extra="allow", arbitrary_types_allowed=True):
- turns: list[Turn[T]]
- model: Any
-
- def __getitem__(self, key: int) -> Turn[T]:
- return self.turns[key]
-
- @property
- def last_turn(self) -> Turn[T]:
- return self.turns[-1]
-
- @property
- def last_request(self) -> Optional[Request[T]]:
- return self.turns[-1][0] if self.turns else None
-
- @property
- def last_response(self) -> Optional[Response[T]]:
- return self.turns[-1][1] if self.turns else None
-
- @property
- def history(self) -> list[Message]:
- response: list[Message] = []
- if not self.turns:
- return response
- if self.last_request:
- response = self.last_request.messages or []
- if self.last_response:
- response.append(self.last_response.choices[0].message)
- return response
-
- def send(self, messages: list[Message], **kwargs: Any) -> Turn[T]:
- params = kwargs
- if self.last_request:
- params = model_dump(self.last_request, exclude={"messages"}) | kwargs
-
- turn = self.model.create(
- **params,
- messages=[
- *self.history,
- *messages,
- ],
- )
- self.turns.append(turn)
- return turn
-
- async def asend(self, messages: list[Message], **kwargs: Any) -> Turn[T]:
- params = kwargs
- if self.last_request:
- params = model_dump(self.last_request, exclude={"messages"}) | kwargs
-
- turn = await self.model.acreate(
- **params,
- messages=[
- *self.history,
- *messages,
- ],
- )
- self.turns.append(turn)
- return turn
-
-
-class AbstractChatCompletion(
- BaseModel, Generic[T], ABC, extra="allow", arbitrary_types_allowed=True
-):
- """
- A ChatCompletion object is responsible for exposing a create and acreate method,
- and for merging default parameters with the parameters passed to these methods.
- """
-
- defaults: dict[str, Any] = Field(default_factory=dict, exclude=True)
-
- def __call__(self: Self, **kwargs: Any) -> Self:
- """
- Create a new ChatCompletion object with new defaults computed from
- merging the passed parameters with the default parameters.
- """
- copy = model_copy(self)
- copy.defaults = self.defaults | kwargs
- return copy
-
- @abstractmethod
- def _serialize_request(self, request: Optional[Request[T]]) -> dict[str, Any]:
- """
- Serialize the request.
- This should be implemented by derived classes based on their specific needs.
- """
- pass
-
- @abstractmethod
- def _create_request(self, **kwargs: Any) -> Request[T]:
- """
- Prepare and return a request object.
- This should be implemented by derived classes.
- """
- pass
-
- @abstractmethod
- def _parse_response(self, response: Any) -> Any:
- """
- Parse the response based on specific needs.
- """
- pass
-
- def merge_with_defaults(self, **kwargs: Any) -> dict[str, Any]:
- """
- Merge the passed parameters with the default parameters.
- """
- return self.defaults | kwargs
-
- @abstractmethod
- def _send_request(self, **serialized_request: Any) -> Any:
- """
- Send the serialized request to the appropriate endpoint/service.
- Derived classes should implement this.
- """
- pass
-
- @abstractmethod
- async def _send_request_async(
- self, **serialized_request: Any
- ) -> Response[T]: # noqa
- """
- Send the serialized request to the appropriate endpoint/service asynchronously.
- Derived classes should implement this.
- """
- pass
-
- def create(
- self, response_model: Optional[type[T]] = None, **kwargs: Any
- ) -> Turn[T]:
- """
- Create a completion synchronously.
- Derived classes can override this if they need to change the core logic.
- """
- request = self._create_request(**kwargs, response_model=response_model)
- serialized_request = self._serialize_request(request=request)
- response_data = self._send_request(**serialized_request)
- response = self._parse_response(response_data)
-
- return Turn(
- request=Request(
- **serialized_request
- | self.defaults
- | model_dump(request, exclude_none=True)
- | ({"response_model": response_model} if response_model else {})
- ),
- response=response,
- )
-
- async def acreate(
- self, response_model: Optional[type[T]] = None, **kwargs: Any
- ) -> Turn[T]:
- """
- Create a completion asynchronously.
- Similar to the synchronous version but for async implementations.
- """
- request = self._create_request(**kwargs, response_model=response_model)
- serialized_request = self._serialize_request(request=request)
- response_data = await self._send_request_async(**serialized_request)
- response = self._parse_response(response_data)
- return Turn(
- request=Request(
- **serialized_request
- | self.defaults
- | model_dump(request, exclude_none=True)
- | ({"response_model": response_model} if response_model else {})
- ),
- response=response,
- )
-
- def chain(self, **kwargs: Any) -> Conversation[T]:
- """
- Create a new Conversation object.
- """
- with self as conversation:
- conversation.send(**kwargs)
- while conversation.last_turn.has_function_call():
- message = conversation.last_turn.call_function()
- conversation.send(
- message if isinstance(message, list) else [message],
- )
-
- return conversation
-
- async def achain(self, **kwargs: Any) -> Conversation[T]:
- """
- Create a new Conversation object asynchronously.
- """
- with self as conversation:
- await conversation.asend(**kwargs)
- while conversation.last_turn.has_function_call():
- message = conversation.last_turn.call_function()
- await conversation.asend(
- message if isinstance(message, list) else [message],
- )
-
- return conversation
-
- def __enter__(self: Self) -> Conversation[T]:
- """
- Enter a context manager.
- """
- return Conversation(turns=[], model=self)
-
- def __exit__(self: Self, *args: Any) -> None:
- """
- Exit a context manager.
- """
- pass
diff --git a/src/marvin/core/ChatCompletion/handlers.py b/src/marvin/core/ChatCompletion/handlers.py
deleted file mode 100644
index 3904f3038..000000000
--- a/src/marvin/core/ChatCompletion/handlers.py
+++ /dev/null
@@ -1,215 +0,0 @@
-import inspect
-import json
-from types import FunctionType
-from typing import (
- Any,
- Callable,
- Generic,
- Literal,
- Optional,
- TypeVar,
- Union,
- overload,
-)
-
-from marvin._compat import BaseModel, Field, cast_to_json, model_dump
-from marvin.utilities.async_utils import run_sync
-from marvin.utilities.logging import get_logger
-from marvin.utilities.messages import Message, Role
-from typing_extensions import ParamSpec
-
-from .utils import parse_raw
-
-T = TypeVar(
- "T",
- bound=BaseModel,
-)
-
-P = ParamSpec("P")
-
-
-class Request(BaseModel, Generic[T], extra="allow", arbitrary_types_allowed=True):
- messages: Optional[list[Message]] = Field(default=None)
- functions: Optional[list[Union[Callable[..., Any], dict[str, Any]]]] = Field(
- default=None
- )
- function_call: Any = None
- response_model: Optional[type[T]] = Field(default=None, exclude=True)
-
- def serialize(
- self,
- functions_serializer: Callable[
- [Callable[..., Any]], dict[str, Any]
- ] = cast_to_json,
- ) -> dict[str, Any]:
- extras = model_dump(
- self, exclude={"functions", "function_call", "response_model"}
- )
- response_model: dict[str, Any] = {}
- functions: dict[str, Any] = {}
- function_call: dict[str, Any] = {}
- messages: dict[str, Any] = {}
-
- if self.response_model:
- response_model_schema: dict[str, Any] = functions_serializer(
- self.response_model
- )
- response_model = {
- "functions": [response_model_schema],
- "function_call": {"name": response_model_schema.get("name")},
- }
-
- elif self.functions:
- functions = {
- "functions": [
- functions_serializer(function) for function in self.functions
- ]
- }
- if self.function_call:
- functions["function_call"] = self.function_call
-
- return extras | functions | function_call | messages | response_model
-
- def function_registry(self) -> dict[str, FunctionType]:
- return {
- function.__name__: function
- for function in self.functions or []
- if callable(function)
- }
-
-
-class Choice(BaseModel):
- message: Message
- index: int
- finish_reason: str
-
- class Config:
- arbitrary_types_allowed = True
-
-
-class Usage(BaseModel):
- prompt_tokens: int
- completion_tokens: int
- total_tokens: int
-
-
-class Response(BaseModel, Generic[T], extra="allow", arbitrary_types_allowed=True):
- id: str
- object: str
- created: int
- model: str
- usage: Usage
- choices: list[Choice] = Field(default_factory=list)
-
-
-class Turn(BaseModel, Generic[T], extra="allow", arbitrary_types_allowed=True):
- request: Request[T]
- response: Response[T]
-
- @overload
- def __getitem__(self, key: Literal[0]) -> Request[T]:
- ...
-
- @overload
- def __getitem__(self, key: Literal[1]) -> Response[T]:
- ...
-
- def __getitem__(self, key: int) -> Union[Request[T], Response[T]]:
- if key == 0:
- return self.request
- elif key == 1:
- return self.response
- else:
- raise IndexError("Turn only has two items.")
-
- def has_function_call(self) -> bool:
- return any([choice.message.function_call for choice in self.response.choices])
-
- def get_function_call(self) -> list[tuple[str, dict[str, Any]]]:
- if not self.has_function_call():
- raise ValueError("No function call found.")
- pairs: list[tuple[str, dict[str, Any]]] = []
- for choice in self.response.choices:
- if choice.message.function_call:
- pairs.append(
- (
- choice.message.function_call.name,
- parse_raw(choice.message.function_call.arguments),
- )
- )
- return pairs
-
- def call_function(self) -> Union[Message, list[Message]]:
- if not self.has_function_call():
- raise ValueError("No function call found.")
-
- pairs: list[tuple[str, dict[str, Any]]] = self.get_function_call()
- function_registry: dict[str, FunctionType] = self.request.function_registry()
- evaluations: list[Any] = []
-
- logger = get_logger("ChatCompletion.handlers")
-
- for pair in pairs:
- name, argument = pair
- if name not in function_registry:
- raise ValueError(
- f"Function {name} not found in {function_registry=!r}."
- )
-
- logger.debug_kv(
- "Function call",
- (
- f"Calling function {name!r} with payload:"
- f" {json.dumps(argument, indent=2)}"
- ),
- key_style="green",
- )
-
- function_result = function_registry[name](**argument)
-
- if inspect.isawaitable(function_result):
- function_result = run_sync(function_result)
-
- logger.debug_kv(
- "Function call",
- f"Function {name!r} returned: {function_result}",
- key_style="green",
- )
-
- evaluations.append(function_result)
- if len(evaluations) != 1:
- return [
- Message(
- name=pairs[j][0],
- role=Role.FUNCTION_RESPONSE,
- content=str(evaluations[j]),
- function_call=None,
- )
- for j in range(len(evaluations))
- ]
- else:
- return Message(
- name=pairs[0][0],
- role=Role.FUNCTION_RESPONSE,
- content=str(evaluations[0]),
- function_call=None,
- )
-
- def to_model(self, model_cls: Optional[type[T]] = None) -> T:
- model = model_cls or self.request.response_model
-
- if not model:
- raise ValueError("No model found.")
-
- pairs = self.get_function_call()
- try:
- return model(**pairs[0][1])
- except ValueError: # ValidationError is a subclass of ValueError
- return model(output=pairs[0][1])
- except TypeError:
- pass
- try:
- return model.parse_raw(pairs[0][1])
- except TypeError:
- pass
- return model.construct(**pairs[0][1])
diff --git a/src/marvin/core/ChatCompletion/providers/anthropic/__init__.py b/src/marvin/core/ChatCompletion/providers/anthropic/__init__.py
deleted file mode 100644
index ed471255b..000000000
--- a/src/marvin/core/ChatCompletion/providers/anthropic/__init__.py
+++ /dev/null
@@ -1,178 +0,0 @@
-from typing import Any, TypeVar, Callable, Awaitable, Optional
-import inspect
-from marvin._compat import cast_to_json, model_dump
-from marvin.settings import settings
-from pydantic import BaseModel
-from .prompt import render_anthropic_functions_prompt, handle_anthropic_response
-from ...abstract import AbstractChatCompletion
-from ...handlers import Request, Response
-
-T = TypeVar(
- "T",
- bound=BaseModel,
-)
-
-
-def get_anthropic_create(**kwargs: Any) -> tuple[Callable[..., Any], dict[str, Any]]:
- """
- Get the Anthropic create function and the default parameters,
- pruned of parameters that are not accepted by the constructor.
- """
- import anthropic
-
- params = dict(inspect.signature(anthropic.Anthropic).parameters)
-
- return anthropic.Anthropic(
- **{k: v for k, v in kwargs.items() if k in params.keys()}
- ).completions.create, {k: v for k, v in kwargs.items() if k not in params.keys()}
-
-
-def get_anthropic_acreate(
- **kwargs: Any,
-) -> tuple[Callable[..., Awaitable[Any]], dict[str, Any]]: # noqa
- """
- Get the Anthropic acreate function and the default parameters,
- pruned of parameters that are not accepted by the constructor.
- """
- import anthropic
-
- params = dict(inspect.signature(anthropic.AsyncAnthropic).parameters)
- return anthropic.AsyncAnthropic(
- **{k: v for k, v in kwargs.items() if k in params.keys()}
- ).completions.create, {k: v for k, v in kwargs.items() if k not in params.keys()}
-
-
-class AnthropicChatCompletion(AbstractChatCompletion[T]):
- """
- Anthropic-specific implementation of the ChatCompletion.
- """
-
- def __init__(self, **kwargs: Any):
- """
- Filters out the parameters that are not accepted by the constructor.
- """
- import anthropic
-
- kwargs = {
- k: v
- for k, v in kwargs.items()
- if k not in dict(inspect.signature(anthropic.Anthropic).parameters).keys()
- }
- super().__init__(defaults=settings.get_defaults("anthropic") | kwargs)
-
- def _serialize_request(
- self, request: Optional[Request[T]] = None
- ) -> dict[str, Any]:
- """
- Serialize the request as per OpenAI's requirements.
- """
- request = request or Request()
- request = Request(
- **self.defaults
- | model_dump(
- request,
- exclude_none=True,
- )
- )
-
- extras = model_dump(
- request,
- exclude={"functions", "function_call", "response_model", "messages"},
- )
-
- functions: dict[str, Any] = {}
- function_call: Any = {}
-
- prompt = "\n\nHuman:"
- for message in request.messages or []:
- if message.role != "function" and message.content:
- prompt += f"\n\n{'Human' if message.role == 'user' else 'Assistant'}"
- prompt += f": {message.content}"
-
- if request.response_model:
- schema = cast_to_json(request.response_model)
- functions["functions"] = [schema]
- function_call["function_call"] = {"name": schema.get("name")}
-
- elif request.functions:
- functions["functions"] = [
- cast_to_json(function) if callable(function) else function
- for function in request.functions
- ]
- if request.function_call:
- function_call["function_call"] = request.function_call
-
- if functions:
- prompt += render_anthropic_functions_prompt(
- functions=functions.pop("functions", []),
- function_call=function_call.pop("function_call", None),
- )
-
- for message in request.messages or []:
- if message.role == "function":
- prompt += "\n\nAssistant"
- prompt += f": The result of {message.name} is {message.content}."
- if message.function_call:
- prompt += "\n\nAssistant"
- prompt += f": I will call the {message.function_call.name} function."
-
- prompt += "\n\nAssistant: "
- prompt.replace("\n\nHuman:\n\nHuman: ", "\n\nHuman: ")
- return extras | {"prompt": prompt}
-
- def _create_request(self, **kwargs: Any) -> Request[T]:
- """
- Prepare and return an OpenAI-specific request object.
- """
- return Request(**kwargs)
-
- def _parse_response(self, response: Any) -> Response[T]:
- """
- Parse the response received from OpenAI.
- """
- # Convert OpenAI's response into a standard format or object
-
- content, function_call = handle_anthropic_response(response.completion)
-
- return Response(
- **{
- "id": response.log_id,
- "model": response.model,
- "object": "text_completion",
- "created": 0,
- "choices": [
- {
- "index": 0,
- "finish_reason": "stop",
- "message": {
- "content": content, # type: ignore
- "role": "assistant",
- "function_call": function_call,
- },
- }
- ],
- "usage": {
- "prompt_tokens": 0,
- "completion_tokens": 0,
- "total_tokens": 0,
- },
- }
- )
-
- def _send_request(self, **serialized_request: Any) -> Any:
- """
- Send the serialized request to OpenAI's endpoint/service.
- """
- # Use openai's library functions to send the request and get a response
- # Example:
- create, params = get_anthropic_create(**serialized_request)
- response = create(**params)
- return response
-
- async def _send_request_async(self, **serialized_request: Any) -> Response[T]:
- """
- Send the serialized request to OpenAI's endpoint asynchronously.
- """
- create, params = get_anthropic_acreate(**serialized_request)
- response = await create(**params)
- return response
diff --git a/src/marvin/core/ChatCompletion/providers/anthropic/prompt.py b/src/marvin/core/ChatCompletion/providers/anthropic/prompt.py
deleted file mode 100644
index 45f9d6f28..000000000
--- a/src/marvin/core/ChatCompletion/providers/anthropic/prompt.py
+++ /dev/null
@@ -1,110 +0,0 @@
-import json
-import re
-from typing import Any, Literal, Optional, Union
-
-from jinja2 import Environment
-from marvin.utilities.strings import jinja_env
-
-from ...utils import parse_raw
-
-FUNCTION_PROMPT = """
-# Functions
-
-You can call various functions to perform tasks.
-
-Whenever you receive a message from the user, check to see if any of your
-functions would help you respond. For example, you might use a function to look
-up information, interact with a filesystem, call an API, or validate data. You
-might write code, update state, or cause a side effect. After indicating that
-you want to call a function, the user will execute the function and tell you its
-result so that you can use the information in your final response. Therefore,
-you must use your functions whenever they would be helpful.
-
-The user may also provide a `function_call` instruction which could be:
-
-- "auto": you may decide to call a function on your own (this is the
- default)
-- "none": do not call any function
-- {"name": ""}: you MUST call the function with the given
- name
-
-To call a function:
-
-- Your response must include a JSON payload with the below format, including the
- {"mode": "function_call"} key.
-- Do not put any other text in your response beside the JSON payload.
-- Do not describe your plan to call the function to the user; they will not see
- it.
-- Do not include more than one payload in your response.
-- Do not include function output or results in your response.
-
-# Available Functions
-
-You have access to the following functions. Each has a name (which must be part
-of your response), a description (which you should use to decide whether to call the
-function), and a parameter spec (which is a JSON Schema description of the
-arguments you should pass in your response)
-
-{% for function in functions -%}
-
-## {{ function.name }}
-
-- Name: {{ function.name }}
-- Description: {{ function.description }}
-- Parameters: {{ function.parameters }}
-
-{% endfor %}
-
-# Calling a Function
-
-To call a function, your response MUST include a JSON document with the
-following structure:
-
-{
- "mode": "function_call",
-
- "name": "",
-
- "arguments": ""
-}
-
-The user will execute the function and respond with its result verbatim.
-
-# function_call instruction
-
-The user provided the following `function_call` instruction: {{ function_call }}
-
-# final_response instruction
-
-When you choose to call no functions, your final answer should be summative.
-Do not acknowledge any internal functions you had access to or chose to call.
-
-"""
-
-
-def render_anthropic_functions_prompt(
- functions: list[dict[str, Any]],
- function_call: Union[dict[Literal["name"], str], Literal["auto"]],
- environment: Optional[Environment] = None,
-) -> str:
- env = environment or jinja_env
- template = env.from_string(FUNCTION_PROMPT)
- return template.render(functions=functions, function_call=function_call or "auto")
-
-
-def handle_anthropic_response(
- completion: str,
-) -> tuple[Optional[str], Optional[dict[str, Any]]]:
- try:
- response: dict[str, Any] = parse_raw(
- re.findall(r"\{.*\}", completion, re.DOTALL)[0]
- )
- if response.pop("mode", None) == "function_call":
- return None, {
- "name": response.pop("name", None),
- "arguments": json.dumps(response.pop("arguments", None)),
- }
- except Exception:
- pass
- return completion, None
diff --git a/src/marvin/core/ChatCompletion/providers/openai.py b/src/marvin/core/ChatCompletion/providers/openai.py
deleted file mode 100644
index 893456fb4..000000000
--- a/src/marvin/core/ChatCompletion/providers/openai.py
+++ /dev/null
@@ -1,218 +0,0 @@
-import inspect
-from typing import Any, AsyncGenerator, Callable, Optional, TypeVar, Union
-
-from marvin._compat import BaseModel, cast_to_json, model_dump
-from marvin.settings import settings
-from marvin.types import Function
-from marvin.utilities.async_utils import create_task, run_sync
-from marvin.utilities.messages import Message
-from marvin.utilities.streaming import StreamHandler
-from openai.openai_object import OpenAIObject
-
-from ..abstract import AbstractChatCompletion
-from ..handlers import Request, Response, Usage
-
-T = TypeVar(
- "T",
- bound=BaseModel,
-)
-
-CONTEXT_SIZES = {
- "gpt-3.5-turbo-16k-0613": 16384,
- "gpt-3.5-turbo-16k": 16384,
- "gpt-3.5-turbo-0613": 4096,
- "gpt-3.5-turbo": 4096,
- "gpt-4-32k-0613": 32768,
- "gpt-4-32k": 32768,
- "gpt-4-0613": 8192,
- "gpt-4": 8192,
-}
-
-
-def get_context_size(model: str) -> int:
- if "/" in model:
- model = model.split("/")[-1]
-
- return CONTEXT_SIZES.get(model, 2048)
-
-
-def serialize_function_or_callable(
- function_or_callable: Union[Function, Callable[..., Any]],
- name: Optional[str] = None,
- description: Optional[str] = None,
- field_name: Optional[str] = None,
-) -> dict[str, Any]:
- if isinstance(function_or_callable, Function):
- return {
- "name": function_or_callable.__name__,
- "description": function_or_callable.__doc__,
- "parameters": function_or_callable.schema,
- }
- else:
- return cast_to_json(
- function_or_callable,
- name=name,
- description=description,
- field_name=field_name,
- )
-
-
-class OpenAIStreamHandler(StreamHandler):
- async def handle_streaming_response(
- self,
- api_response: AsyncGenerator[OpenAIObject, None],
- ) -> OpenAIObject:
- final_chunk = {}
- accumulated_content = ""
-
- async for r in api_response:
- final_chunk.update(r.to_dict_recursive())
-
- delta = r.choices[0].delta if r.choices and r.choices[0] else None
-
- if delta is None:
- continue
-
- if "content" in delta:
- accumulated_content += delta.content or ""
-
- if self.callback:
- callback_result = self.callback(
- Message(
- content=accumulated_content, role="assistant", data=final_chunk
- )
- )
- if inspect.isawaitable(callback_result):
- create_task(callback_result)
-
- if "choices" in final_chunk and len(final_chunk["choices"]) > 0:
- final_chunk["choices"][0]["content"] = accumulated_content
-
- final_chunk["object"] = "chat.completion"
-
- return OpenAIObject.construct_from(
- {
- "id": final_chunk["id"],
- "object": "chat.completion",
- "created": final_chunk["created"],
- "model": final_chunk["model"],
- "choices": [
- {
- "index": 0,
- "message": {
- "role": "assistant",
- "content": accumulated_content,
- },
- "finish_reason": "stop",
- }
- ],
- # TODO: Figure out how to get the usage from the streaming response
- "usage": Usage.parse_obj(
- {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
- ),
- }
- )
-
-
-class OpenAIChatCompletion(AbstractChatCompletion[T]):
- """
- OpenAI-specific implementation of the ChatCompletion.
- """
-
- def __init__(self, provider: str, **kwargs: Any):
- super().__init__(defaults=settings.get_defaults(provider or "openai") | kwargs)
-
- def _serialize_request(
- self, request: Optional[Request[T]] = None
- ) -> dict[str, Any]:
- """
- Serialize the request as per OpenAI's requirements.
- """
- _request = request or Request()
- _request = Request(
- **self.defaults
- | (
- model_dump(
- request,
- exclude_none=True,
- )
- if request
- else {}
- )
- )
- _request.function_call = _request.function_call or (
- request and request.function_call
- )
- _request.functions = _request.functions or (request and request.functions)
- _request.response_model = _request.response_model or (
- request and request.response_model
- ) # noqa
-
- extras = model_dump(
- _request,
- exclude={"functions", "function_call", "response_model"},
- )
-
- functions: dict[str, Any] = {}
- function_call: Any = {}
- for message in extras.get("messages", []):
- if message.get("name", -1) is None:
- message.pop("name", None)
- if message.get("function_call", -1) is None:
- message.pop("function_call", None)
-
- if _request.response_model:
- schema = cast_to_json(_request.response_model)
- functions["functions"] = [schema]
- function_call["function_call"] = {"name": schema.get("name")}
-
- elif _request.functions:
- functions["functions"] = [
- serialize_function_or_callable(function)
- for function in _request.functions
- ]
- if _request.function_call:
- function_call["function_call"] = _request.function_call
- return extras | functions | function_call
-
- def _create_request(self, **kwargs: Any) -> Request[T]:
- """
- Prepare and return an OpenAI-specific request object.
- """
- return Request(**kwargs)
-
- def _parse_response(self, response: Any) -> Response[T]:
- """
- Parse the response received from OpenAI.
- """
- # Convert OpenAI's response into a standard format or object
- return Response(**response.to_dict_recursive()) # type: ignore
-
- def _send_request(self, **serialized_request: Any) -> Any:
- """
- Send the serialized request to OpenAI's endpoint/service.
- """
- # Use openai's library functions to send the request and get a response
- # Example:
-
- return run_sync(
- self._send_request_async(**serialized_request),
- )
-
- async def _send_request_async(self, **serialized_request: Any) -> Response[T]:
- """
- Send the serialized request to OpenAI's endpoint asynchronously.
- """
- import openai
-
- if handler_fn := serialized_request.pop("stream_handler", {}):
- serialized_request["stream"] = True
-
- response = await openai.ChatCompletion.acreate(**serialized_request) # type: ignore # noqa
-
- if handler_fn:
- response = await OpenAIStreamHandler(
- callback=handler_fn,
- ).handle_streaming_response(response)
-
- return response
diff --git a/src/marvin/core/ChatCompletion/utils.py b/src/marvin/core/ChatCompletion/utils.py
deleted file mode 100644
index 58a08c498..000000000
--- a/src/marvin/core/ChatCompletion/utils.py
+++ /dev/null
@@ -1,15 +0,0 @@
-import json
-from ast import literal_eval
-from typing import Any
-
-
-def parse_raw(raw: str) -> dict[str, Any]:
- try:
- return literal_eval(raw)
- except Exception:
- pass
- try:
- return json.loads(raw)
- except Exception:
- pass
- return {}
diff --git a/src/marvin/deployment/__init__.py b/src/marvin/deployment/__init__.py
deleted file mode 100644
index 60057dd90..000000000
--- a/src/marvin/deployment/__init__.py
+++ /dev/null
@@ -1,81 +0,0 @@
-import asyncio
-from typing import Union, Optional
-
-import uvicorn
-from fastapi import FastAPI, APIRouter
-from pydantic import BaseModel, Extra
-
-from marvin import AIApplication, AIModel, AIFunction
-from marvin.tools import Tool
-
-
-class Deployment(BaseModel):
- """
- Deployment class handles the deployment of AI applications, models or functions.
- """
-
- def __init__(
- self,
- component: Union[AIApplication, AIModel, AIFunction],
- *args,
- app_kwargs: Optional[dict] = None,
- router_kwargs: Optional[dict] = None,
- uvicorn_kwargs: Optional[dict] = None,
- **kwargs,
- ):
- super().__init__(**kwargs)
- self._app = FastAPI(**(app_kwargs or {}))
- self._router = APIRouter(**(router_kwargs or {}))
- self._controller = component
- self._mount_router()
- self._uvicorn_kwargs = {
- "app": self._app,
- "host": "0.0.0.0",
- "port": 8000,
- **(uvicorn_kwargs or {}),
- }
-
- def _mount_router(self):
- """
- Mounts a router to the FastAPI app for each tool in the AI application.
- """
-
- if isinstance(self._controller, AIApplication):
- name = self._controller.name
- base_path = f"/{name.lower()}" if name else "aiapp"
- self._router.get(base_path, tags=[name])(self._controller.entrypoint)
- for tool in self._controller.tools:
- name, fn = (
- (tool.name, tool.fn)
- if isinstance(tool, Tool)
- else (tool.__name__, tool)
- )
- tool_path = f"{base_path}/tools/{name}"
- self._router.post(tool_path, tags=[name])(fn)
-
- self._app.include_router(self._router)
- self._app.openapi_tags = self._app.openapi_tags or []
- self._app.openapi_tags.append(
- {"name": name, "description": self._controller.description}
- )
-
- if isinstance(self._controller, AIModel):
- raise NotImplementedError
-
- if isinstance(self._controller, AIFunction):
- raise NotImplementedError
-
- def serve(self):
- """
- Serves the FastAPI app.
- """
- try:
- config = uvicorn.Config(**(self._uvicorn_kwargs or {}))
- server = uvicorn.Server(config)
- loop = asyncio.get_event_loop()
- loop.run_until_complete(server.serve())
- except Exception as e:
- print(f"Error while serving the application: {e}")
-
- class Config:
- extra = Extra.allow
diff --git a/src/marvin/engine/__init__.py b/src/marvin/engine/__init__.py
deleted file mode 100644
index a097ae617..000000000
--- a/src/marvin/engine/__init__.py
+++ /dev/null
@@ -1,94 +0,0 @@
-"""
-The engine module is the interface to external LLM providers.
-"""
-from pydantic import BaseModel, Field
-from typing import Optional
-from marvin.utilities.module_loading import import_string
-import copy
-
-
-class ChatCompletionBase(BaseModel):
- """
- This class is used to create and handle chat completions from the API.
- It provides several utility functions to create the request, send it to the API,
- and handle the response.
- """
-
- _module: str = "openai.ChatCompletion" # the module used to interact with the API
- _request: str = "marvin.openai.ChatCompletion.Request"
- _response: str = "marvin.openai.ChatCompletion.Response"
- _create: str = "create" # the name of the create method in the API model
- _acreate: str = ( # the name of the asynchronous create method in the API model
- "acreate"
- )
- defaults: Optional[dict] = Field(None, repr=False) # default configuration values
-
- def __init__(self, module_path: str = None, config_path: str = None, **kwargs):
- super().__init__(_module=module_path, _config=config_path, defaults=kwargs)
-
- @property
- def module(self):
- """
- This property imports and returns the API model.
- """
- return import_string(self._module)
-
- @property
- def model(self):
- """
- This property imports and returns the API model.
- """
- return self.module
-
- def request(self, *args, **kwargs):
- """
- This method imports and returns a configuration object.
- """
- return import_string(self._request)(*args, **(kwargs or self.defaults or {}))
-
- @property
- def response_class(self, *args, **kwargs):
- """
- This method imports and returns a configuration object.
- """
- return import_string(self._response)
-
- def prepare_request(self, *args, **kwargs):
- """
- This method prepares a request and returns it.
- """
- request = self.request() | self.request(**kwargs)
- payload = request.dict(exclude_none=True, exclude_unset=True)
- return request, payload
-
- def create(self=None, *args, **kwargs):
- """
- This method creates a request and sends it to the API.
- It returns a Response object with the raw response and the request.
- """
- request, request_dict = self.prepare_request(*args, **kwargs)
- create = getattr(self.model, self._create)
- response = self.response_class(create(**request_dict), request=request)
- if request.evaluate_function_call and response.function_call:
- return response.call_function(as_message=True)
- return response
-
- async def acreate(self, *args, **kwargs):
- """
- This method is an asynchronous version of the create method.
- It creates a request and sends it to the API asynchronously.
- It returns a Response object with the raw response and the request.
- """
- request, request_dict = self.prepare_request(*args, **kwargs)
- acreate = getattr(self.model, self._acreate)
- response = self.response_class(await acreate(**request_dict), request=request)
- if request.evaluate_function_call and response.function_call:
- return await response.acall_function(as_message=True)
- return response
-
- def __call__(self, *args, **kwargs):
- self = copy.deepcopy(self)
- request = self.request()
- passed = self.__class__(**kwargs).request()
- self.defaults = (request | passed).dict(serialize_functions=False)
- return self
diff --git a/src/marvin/engine/anthropic.py b/src/marvin/engine/anthropic.py
deleted file mode 100644
index 734b38333..000000000
--- a/src/marvin/engine/anthropic.py
+++ /dev/null
@@ -1,195 +0,0 @@
-from operator import itemgetter
-from typing import Any, Callable, Optional
-
-from anthropic import AI_PROMPT, HUMAN_PROMPT
-from jinja2 import Template
-from pydantic import BaseModel, Extra, Field, root_validator
-
-from marvin import settings
-from marvin.engine import ChatCompletionBase
-from marvin.engine.language_models.anthropic import AnthropicFunctionCall
-from marvin.types.request import Request as BaseRequest
-from marvin.utilities.module_loading import import_string
-
-
-class Request(BaseRequest):
- """
- This is a class for creating Request objects to interact with the GPT-3 API.
- The class contains several configurations and validation functions to ensure
- the correct data is sent to the API.
-
- """
-
- model: str = "claude-2" # the model used by the GPT-3 API
- # temperature: float = 0.8 # the temperature parameter used by the GPT-3 API
- api_key: str = Field(default_factory=settings.anthropic.api_key.get_secret_value)
- max_tokens_to_sample: int = Field(default=1000)
- prompt: str = Field(default="")
-
- class Config:
- exclude = {"response_model", "messages"}
- exclude_none = True
- extra = Extra.allow
- functions_prompt = (
- "marvin.engine.language_models.anthropic.FUNCTIONS_INSTRUCTIONS"
- )
-
- @root_validator(pre=True)
- def to_anthropic(cls, values):
- values["prompt"] = ""
- for message in values.get("messages", []):
- if message.get("role") == "user":
- values["prompt"] += f'{HUMAN_PROMPT} {message.get("content")}'
- else:
- values["prompt"] += f'{AI_PROMPT} {message.get("content")}'
- values["prompt"] += f"{AI_PROMPT} "
- return values
-
- def dict(self, *args, serialize_functions=True, **kwargs):
- """
- This method returns a dictionary representation of the Request.
- If the functions attribute is present and serialize_functions is True,
- the functions' schemas are also included.
- """
-
- # This identity function is here for no reason except to show
- # readers that custom adapters need only override the dict method.
- return super().dict(*args, serialize_functions=serialize_functions, **kwargs)
-
-
-class Response(BaseModel):
- """
- This class is used to handle the response from the API.
- It includes several utility functions and properties to extract useful information
- from the raw response.
- """
-
- raw: Any # the raw response from the API
- request: Any # the request that generated the response
-
- def __init__(self, response, *args, request, **kwargs):
- super().__init__(raw=response, request=request)
-
- def __iter__(self):
- return self.raw.__iter__()
-
- def __next__(self):
- return self.raw.__next__()
-
- def __getattr__(self, name):
- """
- This method attempts to get the attribute from the raw response.
- If it doesn't exist, it falls back to the standard attribute access.
- """
- try:
- return self.raw.__getattr__(name)
- except AttributeError:
- return self.__getattribute__(name)
-
- @property
- def message(self):
- """
- This property extracts the message from the raw response.
- If there is only one choice, it returns the message from that choice.
- Otherwise, it returns a list of messages from all choices.
- """
- return self.raw.completion
-
- @property
- def function_call(self):
- """
- This property extracts the function call from the message.
- If the message is a list, it returns a list of function calls from all messages.
- Otherwise, it returns the function call from the message.
- """
-
- return AnthropicFunctionCall.parse_raw(self.message).dict(
- exclude={"function_call"}
- )
-
- @property
- def callables(self):
- """
- This property returns a list of all callable functions from the request.
- """
- return [x for x in self.request.functions if isinstance(x, Callable)]
-
- @property
- def callable_registry(self):
- """
- This property returns a dictionary mapping function names to functions for all
- callable functions from the request.
- """
- return {fn.__name__: fn for fn in self.callables}
-
- def call_function(self, as_message=True):
- """
- This method evaluates the function call in the response and returns the result.
- If as_message is True, it returns the result as a function message.
- Otherwise, it returns the result directly.
- """
- name, raw_arguments = itemgetter("name", "arguments")(self.function_call)
- function = self.callable_registry.get(name)
- arguments = function.model.parse_raw(raw_arguments)
- value = function(**arguments.dict(exclude_none=True))
- if as_message:
- return {"role": "function", "name": name, "content": value}
- else:
- return value
-
- def to_model(self):
- """
- This method parses the function call arguments into the response model and
- returns the result.
- """
- return self.request.response_model.parse_raw(self.function_call.arguments)
-
- def __repr__(self, *args, **kwargs):
- """
- This method returns a string representation of the raw response.
- """
- return self.raw.__repr__(*args, **kwargs)
-
-
-class AnthropicChatCompletion(ChatCompletionBase):
- """
- This class is used to create and handle chat completions from the API.
- It provides several utility functions to create the request, send it to the API,
- and handle the response.
- """
-
- _module: str = "anthropic.Anthropic" # the module used to interact with the API
- _request: str = "marvin.engine.anthropic.Request"
- _response: str = "marvin.engine.anthropic.Response"
- defaults: Optional[dict] = Field(None, repr=False) # default configuration values
-
- @property
- def model(self):
- """
- This property imports and returns the API model.
- """
- return self.module(
- api_key=self.request().api_key,
- ).completions
-
- def prepare_request(self, *args, **kwargs):
- request, payload = super().prepare_request(*args, **kwargs)
- payload.pop("messages", None)
- payload.pop("api_key", None)
- if payload.get("functions", None):
- functions_prompt = Template(
- import_string(request.Config.functions_prompt)
- ).render(
- functions=payload.get("functions"),
- function_call=payload.get("function_call"),
- )
- payload["prompt"] = f"{HUMAN_PROMPT} {functions_prompt}" + payload["prompt"]
- payload.pop("functions", None)
- return request, payload
-
-
-ChatCompletion = AnthropicChatCompletion()
-
-# This is a legacy class that is used to create a ChatCompletion object.
-# It is deprecated and will be removed in a future release.
-ChatCompletionConfig = Request
diff --git a/src/marvin/engine/executors/__init__.py b/src/marvin/engine/executors/__init__.py
deleted file mode 100644
index b938d374d..000000000
--- a/src/marvin/engine/executors/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from .base import Executor
-from .openai import OpenAIFunctionsExecutor
diff --git a/src/marvin/engine/executors/base.py b/src/marvin/engine/executors/base.py
deleted file mode 100644
index 318594af7..000000000
--- a/src/marvin/engine/executors/base.py
+++ /dev/null
@@ -1,67 +0,0 @@
-from typing import List, Union
-
-from pydantic import PrivateAttr
-
-from marvin.engine.language_models import ChatLLM
-from marvin.prompts.base import Prompt, render_prompts
-from marvin.utilities.messages import Message
-from marvin.utilities.types import LoggerMixin, MarvinBaseModel
-
-
-class Executor(LoggerMixin, MarvinBaseModel):
- model: ChatLLM
- _should_stop: bool = PrivateAttr(False)
-
- async def start(
- self,
- prompts: list[Union[Prompt, Message]],
- prompt_render_kwargs: dict = None,
- ) -> list[Message]:
- """
- Start the LLM loop
- """
- # reset stop criteria
- self._should_stop = False
-
- responses = []
- while not self._should_stop:
- # render the prompts, including any responses from the previous step
- messages = render_prompts(
- prompts + responses,
- render_kwargs=prompt_render_kwargs,
- max_tokens=self.model.context_size,
- )
- response = await self.step(messages)
- responses.append(response)
- if await self.stop_condition(messages, responses):
- self._should_stop = True
- return responses
-
- async def step(self, messages: list[Message]) -> Message:
- """
- Implements one step of the LLM loop
- """
- messages = await self.process_messages(messages)
- llm_response = await self.run_engine(messages=messages)
- response = await self.process_response(llm_response)
- return response
-
- async def run_engine(self, messages: list[Message]) -> Message:
- """
- Implements one step of the LLM loop
- """
- llm_response = await self.model.run(messages=messages)
- return llm_response
-
- async def process_messages(self, messages: list[Message]) -> list[Message]:
- """Called prior to sending messages to the LLM"""
- return messages
-
- async def stop_condition(
- self, messages: List[Message], responses: List[Message]
- ) -> bool:
- return True
-
- async def process_response(self, response: Message) -> Message:
- """Called after receiving a response from the LLM"""
- return response
diff --git a/src/marvin/engine/executors/openai.py b/src/marvin/engine/executors/openai.py
deleted file mode 100644
index e46f7d9d6..000000000
--- a/src/marvin/engine/executors/openai.py
+++ /dev/null
@@ -1,162 +0,0 @@
-import inspect
-import json
-from ast import literal_eval
-from typing import Any, Callable, List, Optional, Union
-
-from pydantic import Field
-
-import marvin
-from marvin.engine.language_models import OpenAIFunction
-from marvin.utilities.messages import Message
-
-from .base import Executor
-
-
-class OpenAIFunctionsExecutor(Executor):
- """
- An executor that understands how to pass functions to the LLM, interpret
- responses that request function calls, and iteratively continue to process
- functions until the LLM responds directly to the user. This uses the OpenAI
- Functions API, so provider LLMs must be compatible.
- """
-
- functions: Optional[List[OpenAIFunction]] = Field(default=None)
- function_call: Union[str, dict[str, str]] = Field(default=None)
- max_iterations: Optional[int] = Field(
- default_factory=lambda: marvin.settings.ai_application_max_iterations
- )
- stream_handler: Optional[Callable[[Message], None]] = Field(default=None)
-
- def __init__(
- self,
- functions: Optional[List[Union[OpenAIFunction, Callable[..., Any]]]] = None,
- **kwargs: Any,
- ):
- if functions is not None:
- functions = [
- (
- OpenAIFunction.from_function(i)
- if not isinstance(i, OpenAIFunction)
- else i
- )
- for i in functions
- ]
- super().__init__(
- functions=functions,
- **kwargs,
- )
-
- # @validator("functions", pre=True)
- # def validate_functions(cls, v):
- # if v is None:
- # return None
- # v = [
- # OpenAIFunction.from_function(i) if not isinstance(i, OpenAIFunction)else i
- # for i in v
- # ]
- # return v
-
- async def run_engine(self, messages: list[Message]) -> Message:
- """
- Implements one step of the LLM loop
- """
-
- kwargs = {}
-
- if self.functions:
- kwargs["functions"] = self.functions
- kwargs["function_call"] = self.function_call
-
- llm_response = await self.model.run(
- messages=messages,
- stream_handler=self.stream_handler,
- **kwargs,
- )
- return llm_response
-
- async def stop_condition(
- self, messages: List[Message], responses: List[Message]
- ) -> bool:
- # if the number of responses exceeds max iterations, stop
-
- if self.max_iterations is not None and len(responses) >= self.max_iterations:
- return True
-
- # if function calls are set to auto and the most recent call was a
- # function, continue
- if self.function_call == "auto":
- if responses and responses[-1].role == "function_response":
- return False
-
- # if a specific function call was requested but errored, continue
- elif self.function_call != "none":
- if responses and responses[-1].data.get("is_error"):
- return False
-
- # otherwise stop
- return True
-
- async def process_response(self, response: Message) -> Message:
- if response.role == "function_request":
- return await self.process_function_call(response)
- else:
- return response
-
- async def process_function_call(self, response: Message) -> Message:
- response_data = {}
-
- function_call = response.data["function_call"]
- fn_name = function_call.get("name")
- fn_args = function_call.get("arguments")
- response_data["name"] = fn_name
- try:
- try:
- fn_args = json.loads(function_call.get("arguments", "{}"))
- except json.JSONDecodeError:
- fn_args = literal_eval(function_call.get("arguments", "{}"))
- response_data["arguments"] = fn_args
-
- # retrieve the named function
- openai_fn = next((f for f in self.functions if f.name == fn_name), None)
- if openai_fn is None:
- raise ValueError(f'Function "{function_call["name"]}" not found.')
-
- if not isinstance(fn_args, dict):
- raise ValueError(
- "Expected a dictionary of arguments, got a"
- f" {type(fn_args).__name__}."
- )
-
- # call the function
- if openai_fn.fn is not None:
- self.logger.debug(
- f"Running function '{openai_fn.name}' with payload {fn_args}"
- )
- fn_result = openai_fn.fn(**fn_args)
- if inspect.isawaitable(fn_result):
- fn_result = await fn_result
-
- # if the function is undefined, return the arguments as its output
- else:
- fn_result = fn_args
- self.logger.debug(f"Result of function '{openai_fn.name}': {fn_result}")
- response_data["is_error"] = False
-
- except Exception as exc:
- fn_result = (
- f"The function '{fn_name}' encountered an error:"
- f" {str(exc)}\n\nThe payload you provided was: {fn_args}\n\nYou"
- " can try to fix the error and call the function again."
- )
- self.logger.debug_kv("Error", fn_result, key_style="red")
- response_data["is_error"] = True
-
- response_data["result"] = fn_result
-
- return Message(
- role="function_response",
- name=fn_name,
- content=str(fn_result),
- data=response_data,
- llm_response=response.llm_response,
- )
diff --git a/src/marvin/engine/language_models/__init__.py b/src/marvin/engine/language_models/__init__.py
deleted file mode 100644
index 6c952e794..000000000
--- a/src/marvin/engine/language_models/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .base import ChatLLM, OpenAIFunction, StreamHandler, chat_llm
diff --git a/src/marvin/engine/language_models/anthropic.py b/src/marvin/engine/language_models/anthropic.py
deleted file mode 100644
index 61bf410ed..000000000
--- a/src/marvin/engine/language_models/anthropic.py
+++ /dev/null
@@ -1,264 +0,0 @@
-import inspect
-import json
-import re
-from logging import Logger
-from typing import Callable, Union
-
-import anthropic
-import openai
-import openai.openai_object
-from pydantic import BaseModel
-
-import marvin
-import marvin.utilities.types
-from marvin.engine.language_models import ChatLLM, StreamHandler
-from marvin.engine.language_models.base import OpenAIFunction
-from marvin.utilities.async_utils import create_task
-from marvin.utilities.logging import get_logger
-from marvin.utilities.messages import Message, Role
-from marvin.utilities.strings import jinja_env
-
-CONTEXT_SIZES = {
- "claude-instant": 100_000,
- "claude-2": 100_000,
-}
-
-FUNCTION_CALL_REGEX = re.compile(
- r'{\s*"mode":\s*"function_call"\s*(.*)}',
- re.DOTALL,
-)
-FUNCTION_CALL_NAME = re.compile(r'"name":\s*"(.*)"')
-FUNCTION_CALL_ARGS = re.compile(r'"arguments":\s*(".*")', re.DOTALL)
-
-
-def extract_function_call(completion: str) -> Union[dict, None]:
- function_call = dict(name=None, arguments="{}")
- if match := FUNCTION_CALL_REGEX.search(completion):
- if name := FUNCTION_CALL_NAME.search(match.group(1)):
- function_call["name"] = name.group(1)
- if args := FUNCTION_CALL_ARGS.search(match.group(1)):
- function_call["arguments"] = args.group(1)
- try:
- function_call["arguments"] = json.loads(function_call["arguments"])
- except json.JSONDecodeError:
- pass
- if not function_call["name"]:
- return None
- return function_call
-
-
-def anthropic_role_map(marvin_role: Role):
- if marvin_role in [Role.USER, Role.SYSTEM, Role.FUNCTION_RESPONSE]:
- return anthropic.HUMAN_PROMPT
- else:
- return anthropic.AI_PROMPT
-
-
-class AnthropicFunctionCall(BaseModel):
- mode: str
- name: str
- arguments: str
-
- @classmethod
- def parse_raw(cls, raw: str):
- return super().parse_raw(re.sub("^[^{]*|[^}]*$", "", raw))
-
-
-class AnthropicStreamHandler(StreamHandler):
- async def handle_streaming_response(
- self,
- api_response: openai.openai_object.OpenAIObject,
- ) -> Message:
- """
- Accumulate chunk deltas into a full response. Returns the full message.
- Passes partial messages to the callback, if provided.
- """
- response = {
- "role": Role.ASSISTANT,
- "content": "",
- "data": {},
- "llm_response": None,
- }
-
- async for msg in api_response:
- response["llm_response"] = msg.dict()
- response["content"] += msg.completion
-
- if function_call := extract_function_call(response["content"]):
- response["role"] = Role.FUNCTION_REQUEST
- response["data"]["function_call"] = function_call
-
- if self.callback:
- callback_result = self.callback(Message(**response))
- if inspect.isawaitable(callback_result):
- create_task(callback_result)
-
- response["content"] = response["content"].strip()
- return Message(**response)
-
-
-class AnthropicChatLLM(ChatLLM):
- model: str = "claude-2"
-
- @property
- def context_size(self) -> int:
- if self.model in CONTEXT_SIZES:
- return CONTEXT_SIZES[self.model]
- else:
- for model_prefix, context in CONTEXT_SIZES:
- if self.model.startswith(model_prefix):
- return context
- return 100_000
-
- def format_messages(
- self, messages: list[Message]
- ) -> Union[str, dict, list[Union[str, dict]]]:
- formatted_messages = []
- for msg in messages:
- role = anthropic_role_map(msg.role)
- formatted_messages.append(f"{role}{msg.content}")
-
- return "".join(formatted_messages) + anthropic.AI_PROMPT
-
- async def run(
- self,
- messages: list[Message],
- *,
- functions: list[OpenAIFunction] = None,
- function_call: Union[str, dict[str, str]] = None,
- logger: Logger = None,
- stream_handler: Callable[[Message], None] = False,
- **kwargs,
- ) -> Message:
- """Calls an OpenAI LLM with a list of messages and returns the response."""
-
- if logger is None:
- logger = get_logger(self.name)
-
- # ----------------------------------
- # Prepare functions
- # ----------------------------------
- if functions:
- function_message = jinja_env.from_string(FUNCTIONS_INSTRUCTIONS).render(
- functions=functions, function_call=function_call
- )
- system_message = Message(role=Role.SYSTEM, content=function_message)
- messages = [system_message] + messages
-
- prompt = self.format_messages(messages)
-
- # ----------------------------------
- # Call Anthropic LLM
- # ----------------------------------
-
- if not marvin.settings.anthropic.api_key:
- raise ValueError(
- "Anthropic API key not found in settings. Please set it or use the"
- " MARVIN_ANTHROPIC_API_KEY environment variable."
- )
-
- client = anthropic.AsyncAnthropic(
- api_key=marvin.settings.anthropic.api_key.get_secret_value(),
- timeout=marvin.settings.llm_request_timeout_seconds,
- )
-
- kwargs.setdefault("temperature", self.temperature)
- kwargs.setdefault("max_tokens_to_sample", self.max_tokens)
-
- response = await client.completions.create(
- model=self.model,
- prompt=prompt,
- stream=True if stream_handler else False,
- **kwargs,
- )
-
- if stream_handler:
- handler = AnthropicStreamHandler(callback=stream_handler)
- msg = await handler.handle_streaming_response(response)
- return msg
-
- else:
- llm_response = response.dict()
- content = llm_response["completion"].strip()
- role = Role.ASSISTANT
- data = {}
- if function_call := extract_function_call(content):
- role = Role.FUNCTION_REQUEST
- data["function_call"] = function_call
- msg = Message(
- role=role,
- content=content,
- data=data,
- llm_response=llm_response,
- )
- return msg
-
-
-FUNCTIONS_INSTRUCTIONS = """
-# Functions
-
-You can call various functions to perform tasks.
-
-Whenever you receive a message from the user, check to see if any of your
-functions would help you respond. For example, you might use a function to look
-up information, interact with a filesystem, call an API, or validate data. You
-might write code, update state, or cause a side effect. After indicating that
-you want to call a function, the user will execute the function and tell you its
-result so that you can use the information in your final response. Therefore,
-you must use your functions whenever they would be helpful.
-
-The user may also provide a `function_call` instruction which could be:
-
-- "auto": you may decide to call a function on your own (this is the
- default)
-- "none": do not call any function
-- {"name": ""}: you MUST call the function with the given
- name
-
-To call a function:
-
-- Your response must include a JSON payload with the below format, including the
- {"mode": "function_call"} key.
-- Do not put any other text in your response beside the JSON payload.
-- Do not describe your plan to call the function to the user; they will not see
- it.
-- Do not include more than one payload in your response.
-- Do not include function output or results in your response.
-
-# Available Functions
-
-Your have access to the following functions. Each has a name (which must be part
-of your response), a description (which you should use to decide to call the
-function), and a parameter spec (which is a JSON Schema description of the
-arguments you should pass in your response)
-
-{% for function in functions -%}
-
-## {{ function.name }}
-
-- Name: {{ function.name }}
-- Description: {{ function.description }}
-- Parameters: {{ function.parameters }}
-
-{% endfor %}
-
-# Calling a Function
-
-To call a function, your response MUST include a JSON document with the
-following structure:
-
-{
- "mode": "function_call",
-
- "name": "",
-
- "arguments": ""
-}
-
-The user will execute the function and respond with its result verbatim.
-
-# function_call instruction
-
-The user provided the following `function_call` instruction: {{ function_call }}
-"""
diff --git a/src/marvin/engine/language_models/azure_openai.py b/src/marvin/engine/language_models/azure_openai.py
deleted file mode 100644
index b19626da7..000000000
--- a/src/marvin/engine/language_models/azure_openai.py
+++ /dev/null
@@ -1,54 +0,0 @@
-import marvin
-
-from .openai import OpenAIChatLLM
-
-CONTEXT_SIZES = {
- "gpt-35-turbo": 4096,
- "gpt-35-turbo-0613": 4096,
- "gpt-35-turbo-16k": 16384,
- "gpt-35-turbo-16k-0613": 16384,
- "gpt-4": 8192,
- "gpt-4-0613": 8192,
- "gpt-4-32k": 32768,
- "gpt-4-32k-0613": 32768,
-}
-
-
-class AzureOpenAIChatLLM(OpenAIChatLLM):
- model: str = "gpt-35-turbo-0613"
-
- @property
- def context_size(self) -> int:
- if self.model in CONTEXT_SIZES:
- return CONTEXT_SIZES[self.model]
- else:
- for model_prefix, context in CONTEXT_SIZES:
- if self.model.startswith(model_prefix):
- return context
- return 4096
-
- def _get_openai_settings(self) -> dict:
- # do not load the base openai settings; any azure settings must be set
- # explicitly
- openai_kwargs = {}
-
- if marvin.settings.azure_openai.api_key:
- openai_kwargs["api_key"] = (
- marvin.settings.azure_openai.api_key.get_secret_value()
- )
- else:
- raise ValueError(
- "Azure OpenAI API key not set. Please set it or use the"
- " MARVIN_AZURE_OPENAI_API_KEY environment variable."
- )
- if marvin.settings.azure_openai.deployment_name:
- openai_kwargs["deployment_name"] = (
- marvin.settings.azure_openai.deployment_name
- )
- if marvin.settings.azure_openai.api_type:
- openai_kwargs["api_type"] = marvin.settings.azure_openai.api_type
- if marvin.settings.azure_openai.api_base:
- openai_kwargs["api_base"] = marvin.settings.azure_openai.api_base
- if marvin.settings.azure_openai.api_version:
- openai_kwargs["api_version"] = marvin.settings.azure_openai.api_version
- return openai_kwargs
diff --git a/src/marvin/engine/language_models/base.py b/src/marvin/engine/language_models/base.py
deleted file mode 100644
index b258834a7..000000000
--- a/src/marvin/engine/language_models/base.py
+++ /dev/null
@@ -1,140 +0,0 @@
-import abc
-import json
-from logging import Logger
-from typing import Any, Callable, Optional, Union
-
-import tiktoken
-from pydantic import Field, validator
-
-import marvin
-import marvin.utilities.types
-from marvin.utilities.messages import Message
-from marvin.utilities.types import MarvinBaseModel
-
-
-class StreamHandler(MarvinBaseModel, abc.ABC):
- callback: Callable[[Message], None] = None
-
- @abc.abstractmethod
- def handle_streaming_response(self, api_response) -> Message:
- raise NotImplementedError()
-
-
-class OpenAIFunction(MarvinBaseModel):
- name: Optional[str] = None
- description: Optional[str] = None
- parameters: dict[str, Any] = {"type": "object", "properties": {}}
- fn: Optional[Callable] = Field(None, exclude=True)
- args: Optional[dict] = None
- """
- Base class for representing a function that can be called by an LLM. The
- format is identical to OpenAI's Functions API.
-
- Args:
- name (str): The name of the function. description (str): The description
- of the function. parameters (dict): The parameters of the function. fn
- (Callable): The function to be called. args (dict): The arguments to be
- passed to the function.
- """
-
- @classmethod
- def from_function(cls, fn: Callable, **kwargs):
- return cls(
- name=kwargs.get("name", fn.__name__),
- description=kwargs.get("description", fn.__doc__ or ""),
- parameters=marvin.utilities.types.function_to_schema(fn),
- fn=fn,
- )
-
- async def query(self, q: str, model: "ChatLLM" = None):
- if not model:
- model = chat_llm()
- self.args = json.loads(
- (
- await model.run(
- messages=[Message(role="USER", content=q)],
- functions=[self],
- function_call={"name": self.name},
- )
- )
- .data.get("function_call")
- .get("arguments")
- )
- return self
-
-
-class ChatLLM(MarvinBaseModel, abc.ABC):
- name: Optional[str] = None
- model: str
- max_tokens: int = Field(default_factory=lambda: marvin.settings.llm_max_tokens)
- temperature: float = Field(default_factory=lambda: marvin.settings.llm_temperature)
-
- @validator("name", always=True)
- def default_name(cls, v):
- if v is None:
- v = cls.__name__
- return v
-
- @property
- def context_size(self) -> int:
- return 4096
-
- def get_tokens(self, text: str, **kwargs) -> list[int]:
- try:
- enc = tiktoken.encoding_for_model(self.model)
- # fallback to the gpt-3.5-turbo tokenizer if the model is not found
- # note this will give the wrong answer for non-OpenAI models
- except KeyError:
- enc = tiktoken.encoding_for_model("gpt-3.5-turbo")
- return enc.encode(text)
-
- async def __call__(self, messages, *args, **kwargs):
- return await self.run(messages, *args, **kwargs)
-
- @abc.abstractmethod
- def format_messages(
- self, messages: list[Message]
- ) -> Union[str, dict, list[Union[str, dict]]]:
- """Format Marvin message objects into a prompt compatible with the LLM model"""
- return messages
-
- @abc.abstractmethod
- async def run(
- self,
- messages: list[Message],
- functions: list[OpenAIFunction] = None,
- *,
- logger: Logger = None,
- stream_handler: Callable[[Message], None] = False,
- **kwargs,
- ) -> Message:
- """Run the LLM model on a list of messages and optional list of functions"""
- raise NotImplementedError()
-
-
-def chat_llm(model: str = None, **kwargs) -> ChatLLM:
- """Dispatches to all supported LLM providers"""
- if model is None:
- model = marvin.settings.llm_model
-
- # automatically detect gpt-3.5 and gpt-4 for backwards compatibility
- if model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"):
- model = f"openai/{model}"
-
- # extract the provider and model name
- provider, model_name = model.split("/", 1)
-
- if provider == "openai":
- from .openai import OpenAIChatLLM
-
- return OpenAIChatLLM(model=model_name, **kwargs)
- elif provider == "anthropic":
- from .anthropic import AnthropicChatLLM
-
- return AnthropicChatLLM(model=model_name, **kwargs)
- elif provider == "azure_openai":
- from .azure_openai import AzureOpenAIChatLLM
-
- return AzureOpenAIChatLLM(model=model_name, **kwargs)
- else:
- raise ValueError(f"Unknown provider/model: {model}")
diff --git a/src/marvin/engine/language_models/openai.py b/src/marvin/engine/language_models/openai.py
deleted file mode 100644
index 3e880b3fc..000000000
--- a/src/marvin/engine/language_models/openai.py
+++ /dev/null
@@ -1,213 +0,0 @@
-import inspect
-from logging import Logger
-from typing import Callable, Optional, Union
-
-import openai
-import openai.openai_object
-
-import marvin
-import marvin.utilities.types
-from marvin.utilities.async_utils import create_task
-from marvin.utilities.logging import get_logger
-from marvin.utilities.messages import Message, Role
-
-from .base import ChatLLM, OpenAIFunction, StreamHandler
-
-CONTEXT_SIZES = {
- "gpt-3.5-turbo-16k-0613": 16384,
- "gpt-3.5-turbo-16k": 16384,
- "gpt-3.5-turbo-0613": 4096,
- "gpt-3.5-turbo": 4096,
- "gpt-4-32k-0613": 32768,
- "gpt-4-32k": 32768,
- "gpt-4-0613": 8192,
- "gpt-4": 8192,
-}
-
-
-def openai_role_map(marvin_role: Role) -> str:
- if marvin_role == Role.FUNCTION_RESPONSE or marvin_role == "function_request":
- return "function"
- elif marvin_role == Role.FUNCTION_REQUEST or marvin_role == "function_response":
- return "assistant"
- else:
- return getattr(marvin_role, "value", marvin_role).lower()
-
-
-class OpenAIStreamHandler(StreamHandler):
- async def handle_streaming_response(
- self,
- api_response: openai.openai_object.OpenAIObject,
- ) -> Message:
- """
- Accumulate chunk deltas into a full response. Returns the full message.
- Passes partial messages to the callback, if provided.
- """
- response = {"role": None, "content": "", "data": {}, "llm_response": None}
-
- async for r in api_response:
- response["llm_response"] = r.to_dict_recursive()
-
- delta = r.choices[0].delta if r.choices and r.choices[0] else None
-
- if delta is None:
- continue
-
- if "role" in delta:
- response["role"] = delta.role
-
- if fn_call := delta.get("function_call"):
- if "function_call" not in response["data"]:
- response["data"]["function_call"] = {"name": None, "arguments": ""}
- if "name" in fn_call:
- response["data"]["function_call"]["name"] = fn_call.name
- if "arguments" in fn_call:
- response["data"]["function_call"]["arguments"] += (
- fn_call.arguments or ""
- )
-
- if "content" in delta:
- response["content"] += delta.content or ""
-
- if self.callback:
- callback_result = self.callback(Message(**response))
- if inspect.isawaitable(callback_result):
- create_task(callback_result)
-
- return Message(**response)
-
-
-class OpenAIChatLLM(ChatLLM):
- model: Optional[str] = "gpt-3.5-turbo"
-
- @property
- def context_size(self) -> int:
- if self.model in CONTEXT_SIZES:
- return CONTEXT_SIZES[self.model]
- else:
- for model_prefix, context in CONTEXT_SIZES:
- if self.model.startswith(model_prefix):
- return context
- return 4096
-
- def _get_openai_settings(self) -> dict:
- openai_kwargs = {}
- if marvin.settings.openai.api_key:
- openai_kwargs["api_key"] = marvin.settings.openai.api_key.get_secret_value()
- else:
- raise ValueError(
- "OpenAI API key not set. Please set it or use the"
- " MARVIN_OPENAI_API_KEY environment variable."
- )
-
- if marvin.settings.openai.api_type:
- openai_kwargs["api_type"] = marvin.settings.openai.api_type
- if marvin.settings.openai.api_base:
- openai_kwargs["api_base"] = marvin.settings.openai.api_base
- if marvin.settings.openai.api_version:
- openai_kwargs["api_version"] = marvin.settings.openai.api_version
- if marvin.settings.openai.organization:
- openai_kwargs["organization"] = marvin.settings.openai.organization
- return openai_kwargs
-
- def format_messages(
- self, messages: list[Message]
- ) -> Union[str, dict, list[Union[str, dict]]]:
- """Format Marvin message objects into a prompt compatible with the LLM model"""
- formatted_messages = []
- for m in messages:
- role = openai_role_map(m.role)
- fmt = {"role": role, "content": m.content}
- if m.name:
- fmt["name"] = m.name
- formatted_messages.append(fmt)
- return formatted_messages
-
- async def run(
- self,
- messages: list[Message],
- *,
- functions: list[OpenAIFunction] = None,
- function_call: Union[str, dict[str, str]] = None,
- logger: Logger = None,
- stream_handler: Callable[[Message], None] = False,
- **kwargs,
- ) -> Message:
- """Calls an OpenAI LLM with a list of messages and returns the response."""
-
- # ----------------------------------
- # Validate arguments
- # ----------------------------------
-
- if functions is None:
- functions = []
- if function_call is None:
- function_call = "auto"
- elif function_call not in (
- ["auto", "none"] + [{"name": f.name} for f in functions]
- ):
- raise ValueError(f"Invalid function_call value: {function_call}")
- if logger is None:
- logger = get_logger(self.name)
-
- # ----------------------------------
- # Form OpenAI-specific arguments
- # ----------------------------------
-
- openai_kwargs = self._get_openai_settings()
- kwargs.update(openai_kwargs)
-
- prompt = self.format_messages(messages)
- llm_functions = [f.dict(exclude={"fn"}, exclude_none=True) for f in functions]
-
- # only add to kwargs if supplied, because empty parameters are not
- # allowed by OpenAI
- if functions:
- kwargs["functions"] = llm_functions
- kwargs["function_call"] = function_call
-
- # ----------------------------------
- # Call OpenAI LLM
- # ----------------------------------
-
- kwargs.setdefault("temperature", self.temperature)
- kwargs.setdefault("max_tokens", self.max_tokens)
-
- response = await openai.ChatCompletion.acreate(
- model=self.model,
- messages=prompt,
- stream=True if stream_handler else False,
- request_timeout=marvin.settings.llm_request_timeout_seconds,
- **kwargs,
- )
-
- if stream_handler:
- handler = OpenAIStreamHandler(callback=stream_handler)
- msg = await handler.handle_streaming_response(response)
- role = msg.role
-
- if role == Role.ASSISTANT and isinstance(
- msg.data.get("function_call"), dict
- ):
- role = Role.FUNCTION_REQUEST
-
- return Message(
- role=role,
- content=msg.content,
- data=msg.data,
- llm_response=msg.llm_response,
- )
-
- else:
- llm_response = response.to_dict_recursive()
- msg = llm_response["choices"][0]["message"].copy()
- role = msg.pop("role").upper()
- if role == "ASSISTANT" and isinstance(msg.get("function_call"), dict):
- role = Role.FUNCTION_REQUEST
- msg = Message(
- role=role,
- content=msg.pop("content", None),
- data=msg,
- llm_response=llm_response,
- )
- return msg
diff --git a/src/marvin/functions/__init__.py b/src/marvin/functions/__init__.py
deleted file mode 100644
index c0cd37359..000000000
--- a/src/marvin/functions/__init__.py
+++ /dev/null
@@ -1,133 +0,0 @@
-import functools
-import inspect
-import re
-from typing import Callable, TypeVar, Optional, Any, Type
-
-from fastapi.routing import APIRouter
-
-from marvin.utilities.types import function_to_model
-
-
-T = TypeVar("T")
-A = TypeVar("A")
-
-
-class Function:
- def __init__(
- self, *, fn: Callable[[A], T] = None, name: str = None, description: str = None
- ) -> None:
- self.fn = fn
- self.name = name or self.fn.__name__
- self.description = description or self.fn.__doc__
-
- super().__init__()
-
- @property
- def model(self):
- return function_to_model(self.fn, name=self.name, description=self.description)
-
- @property
- def signature(self):
- return inspect.signature(self.fn)
-
- @property
- def source_code(self):
- source_code = inspect.cleandoc(inspect.getsource(self.fn))
- if match := re.search(re.compile(r"(\bdef\b.*)", re.DOTALL), source_code):
- source_code = match.group(0)
- return source_code
-
- @property
- def return_annotation(self):
- return_annotation = self.signature.return_annotation
- if return_annotation is inspect._empty:
- return return_annotation, False
- return return_annotation
-
- def arguments(self, *args, **kwargs):
- bound_args = self.signature.bind(*args, **kwargs)
- bound_args.apply_defaults()
- return bound_args.arguments
-
- def schema(self, *args, name: str = None, description: str = None, **kwargs):
- schema = self.model.schema(*args, **kwargs)
- return {
- "name": name or schema.pop("title"),
- "description": description or self.fn.__doc__,
- "parameters": schema,
- }
-
-
-def FunctionDecoratorFactory(
- name: str = "marvin",
- func_class: Type[T] = None,
- in_place=True,
-) -> Callable[[A], T]:
- def decorator(fn: Callable[[A], T] = None) -> Callable[[A], T]:
- if fn is None:
- return functools.partial(decorator)
- elif in_place:
- fn = func_class(fn=fn)
- else:
- instance = func_class(fn=fn)
- setattr(fn, name, instance)
- for method in dir(instance):
- is_method_private = method.startswith("__")
- if not is_method_private:
- setattr(fn, method, getattr(instance, method))
- return fn
-
- return decorator
-
-
-marvin_fn = FunctionDecoratorFactory(name="openai", func_class=Function)
-
-
-class FunctionRegistry(APIRouter):
- def __init__(self, function_decorator=marvin_fn, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.function_decorator = function_decorator
-
- @property
- def endpoints(self):
- # Returns literal functions.
- return [route.endpoint for route in self.routes]
-
- @property
- def schema(self):
- # Returns JSON Schema of functions.
- return [self.function_decorator(fn=fn).schema() for fn in self.endpoints]
-
- @property
- def functions(self):
- # Returns function classes.
- return [self.function_decorator(fn=fn) for fn in self.endpoints]
-
- def include(self, registry: "FunctionRegistry", *args, **kwargs):
- super().include_router(registry, *args, **kwargs)
- # Add some 50-IQ idempotency.
- self.routes = list({x.name: x for x in self.routes}.values())
-
- def register(self, fn: Optional[Callable] = None, **kwargs: Any) -> Callable:
- def decorator(fn: Callable, *args) -> Callable:
- fn_class = self.function_decorator(fn=fn, **kwargs)
- self.add_api_route(
- **{
- **{
- "name": fn_class.name,
- "path": f"/{fn_class.name}",
- "endpoint": fn,
- "description": fn_class.description,
- "methods": ["POST"],
- },
- **kwargs,
- }
- )
- return fn
-
- if fn:
- # if the decorator was called with parentheses
- return decorator(fn)
- else:
- # else, return the decorator to be called later
- return decorator
diff --git a/src/marvin/openai/ChatCompletion/__init__.py b/src/marvin/openai/ChatCompletion/__init__.py
deleted file mode 100644
index bd41d1135..000000000
--- a/src/marvin/openai/ChatCompletion/__init__.py
+++ /dev/null
@@ -1,155 +0,0 @@
-from pydantic.main import ModelMetaclass
-
-from typing import Any, Callable, Optional
-from operator import itemgetter
-
-from marvin import settings
-from marvin._compat import BaseModel, Extra, Field
-from marvin.types.request import Request as BaseRequest
-from marvin.engine import ChatCompletionBase
-
-
-class Request(BaseRequest):
- """
- This is a class for creating Request objects to interact with the GPT-3 API.
- The class contains several configurations and validation functions to ensure
- the correct data is sent to the API.
-
- """
-
- model: str = "gpt-3.5-turbo" # the model used by the GPT-3 API
- temperature: float = 0.8 # the temperature parameter used by the GPT-3 API
- api_key: str = Field(default_factory=settings.openai.api_key.get_secret_value)
-
- class Config:
- exclude = {"response_model"}
- exclude_none = True
- extra = Extra.allow
-
- def dict(self, *args, serialize_functions=True, exclude=None, **kwargs):
- """
- This method returns a dictionary representation of the Request.
- If the functions attribute is present and serialize_functions is True,
- the functions' schemas are also included.
- """
-
- # This identity function is here for no reason except to show
- # readers that custom adapters need only override the dict method.
- return super().dict(
- *args, serialize_functions=serialize_functions, exclude=exclude, **kwargs
- )
-
-
-class Response(BaseModel):
- """
- This class is used to handle the response from the API.
- It includes several utility functions and properties to extract useful information
- from the raw response.
- """
-
- raw: Any # the raw response from the API
- request: Any # the request that generated the response
-
- def __init__(self, response, *args, request, **kwargs):
- super().__init__(raw=response, request=request)
-
- def __iter__(self):
- return self.raw.__iter__()
-
- def __next__(self):
- return self.raw.__next__()
-
- def __getattr__(self, name):
- """
- This method attempts to get the attribute from the raw response.
- If it doesn't exist, it falls back to the standard attribute access.
- """
- try:
- return self.raw.__getattr__(name)
- except AttributeError:
- return self.__getattribute__(name)
-
- @property
- def message(self):
- """
- This property extracts the message from the raw response.
- If there is only one choice, it returns the message from that choice.
- Otherwise, it returns a list of messages from all choices.
- """
- if len(self.raw.choices) == 1:
- return next(iter(self.raw.choices)).message
- return [x.message for x in self.raw.choices]
-
- @property
- def function_call(self):
- """
- This property extracts the function call from the message.
- If the message is a list, it returns a list of function calls from all messages.
- Otherwise, it returns the function call from the message.
- """
- if isinstance(self.message, list):
- return [x.function_call for x in self.message]
- return self.message.function_call
-
- @property
- def callables(self):
- """
- This property returns a list of all callable functions from the request.
- """
- return [x for x in self.request.functions if isinstance(x, Callable)]
-
- @property
- def callable_registry(self):
- """
- This property returns a dictionary mapping function names to functions for all
- callable functions from the request.
- """
- return {fn.__name__: fn for fn in self.callables}
-
- def call_function(self, as_message=True):
- """
- This method evaluates the function call in the response and returns the result.
- If as_message is True, it returns the result as a function message.
- Otherwise, it returns the result directly.
- """
- name, raw_arguments = itemgetter("name", "arguments")(self.function_call)
- function = self.callable_registry.get(name)
- arguments = function.model.parse_raw(raw_arguments)
- value = function(**arguments.dict(exclude_none=True))
- if as_message:
- return {"role": "function", "name": name, "content": value}
- else:
- return value
-
- def to_model(self):
- """
- This method parses the function call arguments into the response model and
- returns the result.
- """
- return self.request.response_model.parse_raw(self.function_call.arguments)
-
- def __repr__(self, *args, **kwargs):
- """
- This method returns a string representation of the raw response.
- """
- return self.raw.__repr__(*args, **kwargs)
-
-
-class OpenAIChatCompletion(ChatCompletionBase):
- """
- This class is used to create and handle chat completions from the API.
- It provides several utility functions to create the request, send it to the API,
- and handle the response.
- """
-
- _module: str = "openai.ChatCompletion" # the module used to interact with the API
- _request: str = "marvin.openai.ChatCompletion.Request"
- _response: str = "marvin.openai.ChatCompletion.Response"
- defaults: Optional[dict] = Field(None, repr=False) # default configuration values
-
-
-ChatCompletion = OpenAIChatCompletion()
-
-# This is a legacy class that is used to create a ChatCompletion object.
-# It is deprecated and will be removed in a future release.
-ChatCompletionConfig = Request
diff --git a/src/marvin/openai/Function/Registry/__init__.py b/src/marvin/openai/Function/Registry/__init__.py
deleted file mode 100644
index 2f74ea661..000000000
--- a/src/marvin/openai/Function/Registry/__init__.py
+++ /dev/null
@@ -1,28 +0,0 @@
-from typing import Any
-from marvin.openai.Function import openai_fn
-from openai.openai_object import OpenAIObject
-from marvin.functions import FunctionRegistry
-
-
-class OpenAIFunctionRegistry(FunctionRegistry):
- def __init__(self, function_decorator=openai_fn, *args, **kwargs):
- self.function_decorator = function_decorator
- super().__init__(function_decorator=function_decorator, *args, **kwargs)
-
- def from_response(self, response: OpenAIObject) -> Any:
- return next(
- iter(
- [
- {"name": k, "content": v}
- for k, v in self.dict_from_openai_response(response).items()
- if v is not None
- ]
- ),
- None,
- )
-
- def dict_from_openai_response(self, response: OpenAIObject) -> Any:
- return {
- fn.name: fn.from_response(response)
- for fn in map(lambda fn: self.function_decorator(fn=fn), self.endpoints)
- }
diff --git a/src/marvin/openai/Function/__init__.py b/src/marvin/openai/Function/__init__.py
deleted file mode 100644
index 35d69430f..000000000
--- a/src/marvin/openai/Function/__init__.py
+++ /dev/null
@@ -1,41 +0,0 @@
-import json
-from typing import TypeVar
-from openai.openai_object import OpenAIObject
-
-from pydantic import validate_arguments
-from marvin.functions import Function, FunctionDecoratorFactory
-
-
-T = TypeVar("T")
-A = TypeVar("A")
-
-
-class OpenAIFunction(Function):
- def __call__(self, response: OpenAIObject) -> T:
- return self.from_response(response)
-
- @validate_arguments
- def from_response(self, response: OpenAIObject) -> T:
- relevant_calls = [
- choice.message.function_call
- for choice in response.choices
- if hasattr(choice.message, "function_call")
- and self.name == choice.message.function_call.get("name", None)
- ]
-
- arguments = [
- json.loads(function_call.get("arguments"))
- for function_call in relevant_calls
- ]
-
- responses = [self.fn(**argument) for argument in arguments]
-
- if len(responses) == 0:
- return None
- elif len(responses) == 1:
- return responses[0]
- else:
- return responses
-
-
-openai_fn = FunctionDecoratorFactory(name="openai", func_class=OpenAIFunction)
diff --git a/src/marvin/openai/__init__.py b/src/marvin/openai/__init__.py
deleted file mode 100644
index bebb3c2df..000000000
--- a/src/marvin/openai/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-import sys
-from openai import * # noqa: F403
-from marvin.core.ChatCompletion import ChatCompletion
-
-ChatCompletion = ChatCompletion
diff --git a/src/marvin/prompts/__init__.py b/src/marvin/prompts/__init__.py
deleted file mode 100644
index 674bf5c53..000000000
--- a/src/marvin/prompts/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from .base import Prompt, render_prompts, prompt_fn
-from . import library
diff --git a/src/marvin/prompts/base.py b/src/marvin/prompts/base.py
deleted file mode 100644
index c4dd9c27c..000000000
--- a/src/marvin/prompts/base.py
+++ /dev/null
@@ -1,377 +0,0 @@
-import abc
-import inspect
-from functools import partial, wraps
-from typing import (
- Any,
- Callable,
- Dict,
- Generic,
- List,
- Literal,
- Optional,
- TypeVar,
- Union,
-)
-
-from jinja2 import Environment
-from pydantic import BaseModel, Field
-from typing_extensions import ParamSpec, Self
-
-import marvin
-from marvin._compat import cast_to_json, cast_to_model, model_dump, model_json_schema
-from marvin.core.ChatCompletion import ChatCompletion
-from marvin.core.ChatCompletion.abstract import AbstractChatCompletion
-from marvin.utilities.messages import Message, Role
-from marvin.utilities.strings import count_tokens, jinja_env
-
-T = TypeVar("T")
-P = ParamSpec("P")
-
-
-class MessageList(list[Message]):
- def render(
- self: Self,
- **kwargs: Any,
- ) -> Self:
- return render_prompts(self, render_kwargs=kwargs)
-
- def serialize(
- self: Self,
- **kwargs: Any,
- ) -> list[dict[str, Any]]:
- return [model_dump(message) for message in self.render(**kwargs)]
-
-
-class PromptList(list[Union["Prompt", Message]]):
- def __init__(self, prompts: list[Union["Prompt", Message]]):
- super().__init__(prompts)
-
- def render(
- self: Self,
- content: Optional[str] = None,
- render_kwargs: Optional[dict[str, Any]] = None,
- ) -> list[Message]:
- return render_prompts(self, render_kwargs=render_kwargs)
-
- def dict(self, **kwargs: Any):
- return [model_dump(message) for message in self.render(**kwargs)]
-
- def __call__(self, **kwargs: Any):
- return self.render(**kwargs)
-
-
-class BasePrompt(BaseModel, abc.ABC):
- """
- Base class for prompt templates.
- """
-
- functions: Optional[
- Union[
- List[Union[Dict[str, Any], Callable[..., Any], type[BaseModel]]],
- Callable[
- ..., List[Union[Dict[str, Any], Callable[..., Any], type[BaseModel]]]
- ],
- ]
- ] = Field(default=None)
-
- function_call: Optional[
- Union[
- Literal["auto"],
- Dict[Literal["name"], str],
- ]
- ] = Field(default=None)
-
- response_model: Optional[
- Union[
- type,
- type[BaseModel],
- Any,
- Callable[..., Union[type, type[BaseModel], Any]],
- ]
- ] = Field(default=None)
-
- response_model_name: Optional[str] = Field(
- default=None,
- exclude=True,
- repr=False,
- )
- response_model_description: Optional[str] = Field(
- default=None,
- exclude=True,
- repr=False,
- )
- response_model_field_name: Optional[str] = Field(
- default=None,
- exclude=True,
- repr=False,
- )
-
- position: Optional[int] = Field(
- default=None,
- repr=False,
- exclude=True,
- description=(
- "Position indicates the desired index for this prompt's messages. 0"
- " indicates they should be first; 1 indicates they should be second; -1"
- " indicates they should be last; None indicates they should be between any"
- " prompts that do request a position."
- ),
- )
- priority: float = Field(
- default=10,
- repr=False,
- exclude=True,
- description=(
- "Priority indicates the weight given when trimming messages to satisfy"
- " context limitations. Lower numbers indicate higher priority e.g. the"
- " highest priority is 0. The default is 10. Ties will be broken by message"
- " timestamp and role."
- ),
- )
-
- @abc.abstractmethod
- def generate(self, **kwargs: Any) -> list["Message"]:
- """
- Abstract method that generates a list of messages from the prompt template
- """
- pass
-
- def render(
- self: Self, content: str, render_kwargs: Optional[dict[str, Any]] = None
- ) -> str:
- """
- Helper function for rendering any jinja2 template with runtime render kwargs
- """
- return jinja_env.from_string(inspect.cleandoc(content)).render(
- **(render_kwargs or {})
- )
-
- def __or__(self: Self, other: Union[Self, list[Self]]) -> PromptList:
- """
- Supports pipe syntax:
- prompt = (
- Prompt()
- | Prompt()
- | Prompt()
- )
- """
- # when the right operand is a Prompt object
- if isinstance(other, Prompt):
- return PromptList([self, other])
- # when the right operand is a list
- elif isinstance(other, list[Prompt]):
- return PromptList([self, *other])
- else:
- raise TypeError(
- f"unsupported operand type(s) for |: '{type(self).__name__}' and"
- f" '{type(other).__name__}'"
- )
-
- def __ror__(self, other):
- """
- Supports pipe syntax:
- prompt = (
- Prompt()
- | Prompt()
- | Prompt()
- )
- """
- # when the left operand is a Prompt object
- if isinstance(other, Prompt):
- return PromptList([other, self])
- # when the left operand is a list
- elif isinstance(other, list):
- return PromptList(other + [self])
- else:
- raise TypeError(
- f"unsupported operand type(s) for |: '{type(other).__name__}' and"
- f" '{type(self).__name__}'"
- )
-
-
-class Prompt(BasePrompt, Generic[P], extra="allow", arbitrary_types_allowed=True):
- def generate(self, **kwargs: Any) -> list[Message]:
- response = Message.from_transcript(
- self.render(content=self.__doc__ or "", render_kwargs=kwargs)
- )
- return response
-
- def to_dict(self, **kwargs: Any) -> dict[str, Any]:
- extras = model_dump(
- self,
- exclude=set(self.__fields__.keys()),
- exclude_none=True,
- )
- return {
- "messages": [
- model_dump(message, include={"content", "role"})
- for message in render_prompts(
- self.generate(
- **extras | kwargs | {"response_model": self.response_model}
- )
- )
- ],
- "functions": self.functions,
- "function_call": self.function_call,
- "response_model": self.response_model,
- }
-
- def to_chat_completion(
- self, model: Optional[str] = None, **model_kwargs: Any
- ) -> AbstractChatCompletion[T]:
- return ChatCompletion(model=model, **model_kwargs)(**self.to_dict())
-
- def serialize(self, model: Any = None, **kwargs: Any) -> dict[str, Any]:
- if model:
- return model(**self.to_dict(**kwargs))._serialize_request() # type: ignore
-
- _dict = self.to_dict(**kwargs)
-
- response: dict[str, Any] = {}
- response["messages"] = _dict["messages"]
-
- if _dict.get("response_model", None):
- response["functions"] = [
- model_json_schema(cast_to_model(_dict["response_model"]))
- ]
- response["function_call"] = {"name": response["functions"][0]["name"]}
- elif _dict.get("functions", None):
- response["functions"] = [
- cast_to_json(function) if callable(function) else function
- for function in _dict["functions"]
- ]
- if _dict["function_call"]:
- response["function_call"] = _dict["function_call"]
-
- return response
-
- @classmethod
- def as_decorator(
- cls: type[Self],
- func: Optional[Callable[P, Any]] = None,
- *,
- environment: Optional[Environment] = None,
- ctx: Optional[dict[str, Any]] = None,
- role: Optional[Role] = None,
- functions: Optional[
- list[Union[Callable[..., Any], type[BaseModel], dict[str, Any]]]
- ] = None, # noqa
- function_call: Optional[
- Union[Literal["auto"], dict[Literal["name"], str]]
- ] = None, # noqa
- response_model: Optional[type[BaseModel]] = None,
- response_model_name: Optional[str] = None,
- response_model_description: Optional[str] = None,
- response_model_field_name: Optional[str] = None,
- serialize_on_call: bool = True,
- ) -> Union[
- Callable[[Callable[P, None]], Callable[P, None]],
- Callable[[Callable[P, None]], Callable[P, Self]],
- Callable[P, Self],
- ]:
- def wrapper(func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> Self:
- signature = inspect.signature(func)
- params = signature.bind(*args, **kwargs)
- params.apply_defaults()
- response = type(getattr(cls, "__name__", ""), (cls,), {})(
- __params__=params.arguments,
- **params.arguments,
- **ctx or {},
- functions=functions,
- function_call=function_call,
- response_model=cast_to_model(
- response_model or signature.return_annotation,
- name=response_model_name,
- description=response_model_description,
- field_name=response_model_field_name,
- ),
- response_model_name=response_model_name,
- response_model_description=response_model_description,
- response_model_field_name=response_model_field_name,
- )
- response.__doc__ = func.__doc__
- if serialize_on_call:
- return response.serialize()
- return response
-
- if func is not None:
- return wraps(func)(partial(wrapper, func))
-
- def decorator(func: Callable[P, None]) -> Callable[P, Self]:
- return wraps(func)(partial(wrapper, func))
-
- return decorator
-
-
-prompt_fn = Prompt.as_decorator
-
-
-class MessageWrapper(BasePrompt):
- """
- A Prompt class that stores and returns a specific Message
- """
-
- message: Message
-
- def generate(self, **kwargs: Any) -> list[Message]:
- return [self.message]
-
-
-def render_prompts(
- prompts: Union[list[Message], list[Union[Prompt, Message]]],
- render_kwargs: Optional[dict[str, Any]] = None,
- max_tokens: Optional[int] = None,
-) -> MessageList:
- max_tokens = max_tokens or marvin.settings.llm_max_context_tokens
-
- all_messages = []
-
- # if the user supplied any messages, wrap them in a MessageWrapper so we can
- # treat them as prompts for sorting and filtering
- prompts = [
- MessageWrapper(message=p) if isinstance(p, Message) else p for p in prompts
- ]
-
- # Separate prompts by positive, none and negative position
- pos_prompts = [p for p in prompts if p.position is not None and p.position >= 0]
- none_prompts = [p for p in prompts if p.position is None]
- neg_prompts = [p for p in prompts if p.position is not None and p.position < 0]
-
- # Sort the positive prompts in ascending order and negative prompts in
- # descending order, but both with timestamp ascending
- pos_prompts = sorted(pos_prompts, key=lambda c: c.position)
- neg_prompts = sorted(neg_prompts, key=lambda c: c.position, reverse=True)
-
- # generate messages from all prompts
- for i, prompt in enumerate(pos_prompts + none_prompts + neg_prompts):
- prompt_messages = prompt.generate(**(render_kwargs or {})) or []
- all_messages.extend((prompt.priority, i, m) for m in prompt_messages)
-
- # sort all messages by (priority asc, position desc) and stop when the
- # token limit is reached. This will prefer high-priority messages that are
- # later in the message chain.
- current_tokens = 0
- allowed_messages = []
- for _, position, msg in sorted(all_messages, key=lambda m: (m[0], -1 * m[1])):
- if current_tokens >= max_tokens:
- break
- allowed_messages.append((position, msg))
- current_tokens += count_tokens(msg.content)
-
- # sort allowed messages by position to restore original order
- messages = [msg for _, msg in sorted(allowed_messages, key=lambda m: m[0])]
-
- # Combine all system messages into one and insert at the index of the first
- # system message
- system_messages = [m for m in messages if m.role == Role.SYSTEM.value]
- if len(system_messages) > 1:
- system_message = Message(
- role=Role.SYSTEM,
- content="\n\n".join([m.content for m in system_messages]),
- )
- system_message_index = messages.index(system_messages[0])
- messages = [m for m in messages if m.role != Role.SYSTEM.value]
- messages.insert(system_message_index, system_message)
-
- # return all messages
- return messages
diff --git a/src/marvin/prompts/library.py b/src/marvin/prompts/library.py
deleted file mode 100644
index 5d923cf22..000000000
--- a/src/marvin/prompts/library.py
+++ /dev/null
@@ -1,168 +0,0 @@
-import inspect
-from typing import Callable, Literal, Optional
-
-from pydantic import Field
-
-from marvin.prompts.base import Prompt
-from marvin.utilities.history import History, HistoryFilter
-from marvin.utilities.messages import Message, Role
-
-
-class MessagePrompt(Prompt):
- role: Role
- content: str = Field(
- ..., description="The message content, which can be a Jinja2 template"
- )
- name: str = None
- priority: int = 2
-
- def get_content(self) -> str:
- """
- Override this method to easily customize behavior
- """
- return self.content
-
- def generate(self, **kwargs) -> list[Message]:
- return [
- Message(
- role=self.role,
- content=self.render(
- self.get_content(),
- render_kwargs={
- **self.dict(exclude={"role", "content", "name", "priority"}),
- **kwargs,
- },
- ),
- name=self.name,
- )
- ]
-
- def read(self, **kwargs) -> str:
- return self.render(
- self.get_content(),
- render_kwargs={
- **self.dict(exclude={"role", "content", "name", "priority"}),
- **kwargs,
- },
- )
-
- def __init__(self, content: str = None, *args, **kwargs):
- content = kwargs.get("content", content)
- super().__init__(
- *args, **{**kwargs, **({"content": content} if content else {})}
- )
-
-
-class System(MessagePrompt):
- position: int = 0
- priority: int = 1
- role: Literal[Role.SYSTEM] = Role.SYSTEM
-
-
-class Assistant(MessagePrompt):
- role: Literal[Role.ASSISTANT] = Role.ASSISTANT
-
-
-class User(MessagePrompt):
- role: Literal[Role.USER] = Role.USER
-
-
-class MessageHistory(Prompt):
- history: History
- n: Optional[int] = 100
- skip: Optional[int] = None
- filter: HistoryFilter = None
-
- def generate(self, **kwargs) -> list[Message]:
- return self.history.get_messages(n=self.n, skip=self.skip, filter=self.filter)
-
-
-class Tagged(MessagePrompt):
- """
- Surround content with a tag, e.g. bold
- """
-
- tag: str
- role: Role = Role.USER
-
- def get_content(self) -> str:
- return f"<{self.tag}>{self.content}{self.tag}>"
-
-
-class Conditional(Prompt):
- if_: Callable = Field(
- ...,
- description=(
- "A function that returns a boolean. It will be called when the prompt is"
- " generated and provided all the variables that are passed to the render"
- " function."
- ),
- )
- if_content: str
- else_content: str
- role: Role = Role.USER
- name: str = None
-
- def generate(self, **kwargs) -> list[Message]:
- if self.if_(**kwargs):
- return [
- Message(
- role=self.role,
- content=self.render(self.if_content, render_kwargs=kwargs),
- name=self.name,
- )
- ]
- elif self.else_content:
- return [
- Message(
- role=self.role,
- content=self.render(self.else_content, render_kwargs=kwargs),
- name=self.name,
- )
- ]
- else:
- return []
-
-
-class JinjaConditional(Prompt):
- if_: str = Field(
- ...,
- description=(
- "A Jinja2-compatible expression that evaluates to a boolean e.g."
- " `truthy_var` or `counter > 10`. It will automatically be templated, do"
- " not include the `{{ }}` braces."
- ),
- )
- if_content: str
- else_content: str = None
- role: Role = Role.USER
- name: str = None
-
- def generate(self, **kwargs) -> list[Message]:
- if_content = inspect.cleandoc(self.if_content)
- if self.else_content:
- content = (
- f"{{% if {self.if_} %}}{if_content}{{% else"
- f" %}}{inspect.cleandoc(self.else_content)}{{% endif %}}"
- )
-
- else:
- content = (f"{{% if {self.if_} %}}{if_content}{{% endif %}}",)
- return [
- Message(
- role=self.role,
- content=self.render(content, render_kwargs=kwargs),
- name=self.name,
- )
- ]
-
-
-class ChainOfThought(Prompt):
- position: int = -1
-
- def generate(self, **kwargs) -> list[Message]:
- return [Message(role=Role.ASSISTANT, content="Let's think step by step.")]
-
-
-class Now(System):
- content: str = "It is {{ now().strftime('%A, %d %B %Y at %I:%M:%S %p %Z') }}."
diff --git a/src/marvin/pydantic.py b/src/marvin/pydantic.py
deleted file mode 100644
index 38396a61e..000000000
--- a/src/marvin/pydantic.py
+++ /dev/null
@@ -1,26 +0,0 @@
-from pydantic.version import VERSION as PYDANTIC_VERSION
-
-PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
-
-if not PYDANTIC_V2:
- from pydantic import (
- BaseModel,
- BaseSettings,
- Extra,
- Field,
- PrivateAttr,
- PyObject,
- SecretStr,
- validate_arguments,
- )
- from pydantic.main import ModelMetaclass
-
- ModelMetaclass = ModelMetaclass
- BaseModel = BaseModel
- BaseSettings = BaseSettings
- Field = Field
- SecretStr = SecretStr
- Extra = Extra
- ImportString = PyObject
- PrivateAttr = PrivateAttr
- validate_arguments = validate_arguments
diff --git a/src/marvin/requests.py b/src/marvin/requests.py
new file mode 100644
index 000000000..f7d01d88a
--- /dev/null
+++ b/src/marvin/requests.py
@@ -0,0 +1,115 @@
+from typing import Any, Callable, Generic, Literal, Optional, TypeVar, Union
+
+from pydantic import BaseModel, Field
+from typing_extensions import Annotated, Self
+
+from marvin.settings import settings
+
+T = TypeVar("T", bound=BaseModel)
+
+
+class ResponseFormat(BaseModel):
+ type: str
+
+
+LogitBias = dict[str, float]
+
+
+class Function(BaseModel, Generic[T]):
+ name: str
+ description: Optional[str]
+ parameters: dict[str, Any]
+
+ model: Optional[type[T]] = Field(default=None, exclude=True, repr=False)
+ python_fn: Optional[Callable[..., Any]] = Field(
+ default=None,
+ description="Private field that holds the executable function, if available",
+ exclude=True,
+ repr=False,
+ )
+
+ def validate_json(self: Self, json_data: Union[str, bytes, bytearray]) -> T:
+ if self.model is None:
+ raise ValueError("This Function was not initialized with a model.")
+ return self.model.model_validate_json(json_data)
+
+
+class Tool(BaseModel, Generic[T]):
+ type: str
+ function: Optional[Function[T]] = None
+
+
+class ToolSet(BaseModel, Generic[T]):
+ tools: Optional[list[Tool[T]]] = None
+ tool_choice: Optional[Union[Literal["auto"], dict[str, Any]]] = None
+
+
+class RetrievalTool(Tool[T]):
+ type: str = Field(default="retrieval")
+
+
+class CodeInterpreterTool(Tool[T]):
+ type: str = Field(default="code_interpreter")
+
+
+class FunctionCall(BaseModel):
+ name: str
+
+
+class BaseMessage(BaseModel):
+ content: str
+ role: str
+
+
+class Grammar(BaseModel):
+ logit_bias: Optional[LogitBias] = None
+ max_tokens: Optional[Annotated[int, Field(strict=True, ge=1)]] = None
+ response_format: Optional[ResponseFormat] = None
+
+
+class Prompt(Grammar, ToolSet[T], Generic[T]):
+ messages: list[BaseMessage] = Field(default_factory=list)
+
+
+class ResponseModel(BaseModel):
+ model: type
+ name: str = Field(default="FormatResponse")
+ description: str = Field(default="Response format")
+
+
+class ChatRequest(Prompt[T]):
+ model: str = Field(default=settings.openai.chat.completions.model)
+ frequency_penalty: Optional[
+ Annotated[float, Field(strict=True, ge=-2.0, le=2.0)]
+ ] = 0
+ n: Optional[Annotated[int, Field(strict=True, ge=1)]] = 1
+ presence_penalty: Optional[
+ Annotated[float, Field(strict=True, ge=-2.0, le=2.0)]
+ ] = 0
+ seed: Optional[int] = None
+ stop: Optional[Union[str, list[str]]] = None
+ stream: Optional[bool] = False
+ temperature: Optional[Annotated[float, Field(strict=True, ge=0, le=2)]] = 1
+ top_p: Optional[Annotated[float, Field(strict=True, ge=0, le=1)]] = 1
+ user: Optional[str] = None
+
+
+class AssistantMessage(BaseMessage):
+ id: str
+ thread_id: str
+ created_at: int
+ assistant_id: Optional[str] = None
+ run_id: Optional[str] = None
+ file_ids: list[str] = []
+ metadata: dict[str, Any] = {}
+
+
+class Run(BaseModel, Generic[T]):
+ id: str
+ thread_id: str
+ created_at: int
+ status: str
+ model: str
+ instructions: Optional[str]
+ tools: Optional[list[Tool[T]]] = None
+ metadata: dict[str, str]
diff --git a/src/marvin/serializers.py b/src/marvin/serializers.py
new file mode 100644
index 000000000..8a7c7422f
--- /dev/null
+++ b/src/marvin/serializers.py
@@ -0,0 +1,107 @@
+from enum import Enum
+from types import GenericAlias
+from typing import (
+ Any,
+ Callable,
+ Literal,
+ Optional,
+ TypeVar,
+ Union,
+ get_args,
+ get_origin,
+)
+
+from pydantic import BaseModel, create_model
+from pydantic.fields import FieldInfo
+from pydantic.json_schema import GenerateJsonSchema, JsonSchemaMode
+
+from marvin import settings
+from marvin.requests import Function, Grammar, Tool
+
+U = TypeVar("U", bound=BaseModel)
+
+
+class FunctionSchema(GenerateJsonSchema):
+ def generate(self, schema: Any, mode: JsonSchemaMode = "validation"):
+ json_schema = super().generate(schema, mode=mode)
+ json_schema.pop("title", None)
+ return json_schema
+
+
+def create_tool_from_type(
+ _type: Union[type, GenericAlias],
+ model_name: str,
+ model_description: str,
+ field_name: str,
+ field_description: str,
+ python_function: Optional[Callable[..., Any]] = None,
+ **kwargs: Any,
+) -> Tool[BaseModel]:
+ annotated_metadata = getattr(_type, "__metadata__", [])
+ if isinstance(next(iter(annotated_metadata), None), FieldInfo):
+ metadata = next(iter(annotated_metadata))
+ else:
+ metadata = FieldInfo(description=field_description)
+
+ model: type[BaseModel] = create_model(
+ model_name,
+ __config__=None,
+ __base__=None,
+ __module__=__name__,
+ __validators__=None,
+ __cls_kwargs__=None,
+ **{field_name: (_type, metadata)},
+ )
+ return Tool[BaseModel](
+ type="function",
+ function=Function[BaseModel](
+ name=model_name,
+ description=model_description,
+ parameters=model.model_json_schema(schema_generator=FunctionSchema),
+ model=model,
+ ),
+ )
+
+
+def create_tool_from_model(
+ model: type[BaseModel],
+) -> Tool[BaseModel]:
+ return Tool[BaseModel](
+ type="function",
+ function=Function[BaseModel](
+ name=model.__name__,
+ description=model.__doc__,
+ parameters=model.model_json_schema(schema_generator=FunctionSchema),
+ model=model,
+ ),
+ )
+
+
+def create_vocabulary_from_type(
+ vocabulary: Union[GenericAlias, type],
+) -> list[str]:
+ if get_origin(vocabulary) == Literal:
+ return [str(token) for token in get_args(vocabulary)]
+ elif isinstance(vocabulary, type) and issubclass(vocabulary, Enum):
+ return [str(token) for token in list(vocabulary.__members__.keys())]
+ else:
+ raise TypeError(
+ f"Expected Literal or Enum, got {type(vocabulary)} with value {vocabulary}"
+ )
+
+
+def create_grammar_from_vocabulary(
+ vocabulary: list[str],
+ encoder: Callable[[str], list[int]] = settings.openai.chat.completions.encoder,
+ max_tokens: Optional[int] = None,
+ _enumerate: bool = True,
+ **kwargs: Any,
+) -> Grammar:
+ return Grammar(
+ max_tokens=max_tokens,
+ logit_bias={
+ str(encoding): 100
+ for i, token in enumerate(vocabulary)
+ for encoding in encoder(str(i) if _enumerate else token)
+ },
+ )
diff --git a/src/marvin/settings.py b/src/marvin/settings.py
index 42372eeed..096c6fe7c 100644
--- a/src/marvin/settings.py
+++ b/src/marvin/settings.py
@@ -1,183 +1,216 @@
import os
from contextlib import contextmanager
-from pathlib import Path
-from typing import Any, Literal, Optional, Union
+from typing import TYPE_CHECKING, Any, Literal, Optional, Union
+
+from pydantic import Field, SecretStr
+from pydantic_settings import BaseSettings, SettingsConfigDict
+
+if TYPE_CHECKING:
+ from openai import AsyncClient, Client
+ from openai._base_client import HttpxBinaryResponseContent
+ from openai.types.chat import ChatCompletion
+ from openai.types.images_response import ImagesResponse
+
+
+class MarvinSettings(BaseSettings):
+ model_config = SettingsConfigDict(
+ env_prefix="marvin_",
+ env_file="~/.marvin/.env",
+ extra="allow",
+ arbitrary_types_allowed=True,
+ )
+
+ def __setattr__(self, name: str, value: Any) -> None:
+ """Preserve SecretStr type when setting values."""
+ field = self.model_fields.get(name)
+ if field:
+ annotation = field.annotation
+ base_types = (
+ annotation.__args__
+ if getattr(annotation, "__origin__", None) is Union
+ else (annotation,)
+ )
+ if SecretStr in base_types and not isinstance(value, SecretStr):
+ value = SecretStr(value)
+ super().__setattr__(name, value)
+
+
+class MarvinModelSettings(MarvinSettings):
+ model: str
+
+ @property
+ def encoder(self):
+ import tiktoken
+
+ return tiktoken.encoding_for_model(self.model).encode
+
+
+class ChatCompletionSettings(MarvinModelSettings):
+ model: str = Field(
+ default="gpt-3.5-turbo-1106",
+ description="The default chat model to use.",
+ )
+
+ async def acreate(self, **kwargs: Any) -> "ChatCompletion":
+ from marvin.settings import settings
+
+ return await settings.openai.async_client.chat.completions.create(
+ model=self.model, **kwargs
+ )
+
+ def create(self, **kwargs: Any) -> "ChatCompletion":
+ from marvin.settings import settings
+
+ return settings.openai.client.chat.completions.create(
+ model=self.model, **kwargs
+ )
+
+
+class ImageSettings(MarvinModelSettings):
+ model: str = Field(
+ default="dall-e-3",
+ description="The default image model to use.",
+ )
+ size: Literal["1024x1024", "1792x1024", "1024x1792"] = Field(
+ default="1024x1024",
+ )
+ response_format: Literal["url", "b64_json"] = Field(default="url")
+ style: Literal["vivid", "natural"] = Field(default="vivid")
+
+ async def agenerate(self, prompt: str, **kwargs: Any) -> "ImagesResponse":
+ from marvin.settings import settings
+
+ return await settings.openai.async_client.images.generate(
+ model=self.model,
+ prompt=prompt,
+ size=self.size,
+ response_format=self.response_format,
+ style=self.style,
+ **kwargs,
+ )
+
+ def generate(self, prompt: str, **kwargs: Any) -> "ImagesResponse":
+ from marvin.settings import settings
+
+ return settings.openai.client.images.generate(
+ model=self.model,
+ prompt=prompt,
+ size=self.size,
+ response_format=self.response_format,
+ style=self.style,
+ **kwargs,
+ )
+
+
+class SpeechSettings(MarvinModelSettings):
+ model: str = Field(
+ default="tts-1-hd",
+ description="The default image model to use.",
+ )
+ voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = Field(
+ default="alloy",
+ )
+ response_format: Literal["mp3", "opus", "aac", "flac"] = Field(default="mp3")
+ speed: float = Field(default=1.0)
+
+ async def acreate(self, input: str, **kwargs: Any) -> "HttpxBinaryResponseContent":
+ from marvin.settings import settings
+
+ return await settings.openai.async_client.audio.speech.create(
+ model=kwargs.get("model", self.model),
+ input=input,
+ voice=kwargs.get("voice", self.voice),
+ response_format=kwargs.get("response_format", self.response_format),
+ speed=kwargs.get("speed", self.speed),
+ )
+
+ def create(self, input: str, **kwargs: Any) -> "HttpxBinaryResponseContent":
+ from marvin.settings import settings
+
+ return settings.openai.client.audio.speech.create(
+ model=kwargs.get("model", self.model),
+ input=input,
+ voice=kwargs.get("voice", self.voice),
+ response_format=kwargs.get("response_format", self.response_format),
+ speed=kwargs.get("speed", self.speed),
+ )
+
-from ._compat import (
- BaseSettings,
- SecretStr,
- model_dump,
-)
+class AssistantSettings(MarvinModelSettings):
+ model: str = Field(
+ default="gpt-4-1106-preview",
+ description="The default assistant model to use.",
+ )
-DEFAULT_ENV_PATH = Path(os.getenv("MARVIN_ENV_FILE", "~/.marvin/.env")).expanduser()
+class ChatSettings(MarvinSettings):
+ completions: ChatCompletionSettings = Field(default_factory=ChatCompletionSettings)
-class MarvinBaseSettings(BaseSettings):
- class Config:
- env_file = (
- ".env",
- str(DEFAULT_ENV_PATH),
+
+class AudioSettings(MarvinSettings):
+ speech: SpeechSettings = Field(default_factory=SpeechSettings)
+
+
+class OpenAISettings(MarvinSettings):
+ model_config = SettingsConfigDict(env_prefix="marvin_openai_")
+
+ api_key: Optional[SecretStr] = Field(
+ default=None,
+ description="Your OpenAI API key.",
+ )
+
+ organization: Optional[str] = Field(
+ default=None,
+ description="Your OpenAI organization ID.",
+ )
+
+ chat: ChatSettings = Field(default_factory=ChatSettings)
+ images: ImageSettings = Field(default_factory=ImageSettings)
+ audio: AudioSettings = Field(default_factory=AudioSettings)
+ assistants: AssistantSettings = Field(default_factory=AssistantSettings)
+
+ @property
+ def async_client(
+ self, api_key: Optional[str] = None, **kwargs: Any
+ ) -> "AsyncClient":
+ from openai import AsyncClient
+
+ if not (api_key or self.api_key):
+ raise ValueError("No API key provided.")
+ elif not api_key and self.api_key:
+ api_key = self.api_key.get_secret_value()
+
+ return AsyncClient(
+ api_key=api_key,
+ organization=self.organization,
+ **kwargs,
)
- env_prefix = "MARVIN_"
- validate_assignment = True
-
-
-class OpenAISettings(MarvinBaseSettings):
- """Provider-specific settings. Only some of these will be relevant to users."""
-
- class Config:
- env_prefix = "MARVIN_OPENAI_"
-
- api_key: Optional[SecretStr] = None
- organization: Optional[str] = None
- embedding_engine: str = "text-embedding-ada-002"
- api_type: Optional[str] = None
- api_base: Optional[str] = None
- api_version: Optional[str] = None
-
- def get_defaults(self, settings: "Settings") -> dict[str, Any]:
- import os
-
- import openai
-
- from marvin import openai as marvin_openai
-
- EXCLUDE_KEYS = {"stream_handler"}
-
- response: dict[str, Any] = {}
- if settings.llm_max_context_tokens > 0:
- response["max_tokens"] = settings.llm_max_tokens
- response["api_key"] = self.api_key and self.api_key.get_secret_value()
- if os.environ.get("MARVIN_OPENAI_API_KEY"):
- response["api_key"] = os.environ["MARVIN_OPENAI_API_KEY"]
- if os.environ.get("OPENAI_API_KEY"):
- response["api_key"] = os.environ["OPENAI_API_KEY"]
- if openai.api_key:
- response["api_key"] = openai.api_key
- if marvin_openai.api_key:
- response["api_key"] = marvin_openai.api_key
- response["temperature"] = settings.llm_temperature
- response["request_timeout"] = settings.llm_request_timeout_seconds
- return {
- k: v for k, v in response.items() if v is not None and k not in EXCLUDE_KEYS
- }
-
-
-class AnthropicSettings(MarvinBaseSettings):
- class Config:
- env_prefix = "MARVIN_ANTHROPIC_"
-
- api_key: Optional[SecretStr] = None
-
- def get_defaults(self, settings: "Settings") -> dict[str, Any]:
- response: dict[str, Any] = {}
- if settings.llm_max_context_tokens > 0:
- response["max_tokens_to_sample"] = settings.llm_max_tokens
- response["api_key"] = self.api_key and self.api_key.get_secret_value()
- response["temperature"] = settings.llm_temperature
- response["timeout"] = settings.llm_request_timeout_seconds
- if os.environ.get("MARVIN_ANTHROPIC_API_KEY"):
- response["api_key"] = os.environ["MARVIN_ANTHROPIC_API_KEY"]
- if os.environ.get("ANTHROPIC_API_KEY"):
- response["api_key"] = os.environ["ANTHROPIC_API_KEY"]
- return {k: v for k, v in response.items() if v is not None}
-
-
-class AzureOpenAI(MarvinBaseSettings):
- class Config:
- env_prefix = "MARVIN_AZURE_OPENAI_"
-
- api_key: Optional[SecretStr] = None
- api_type: Literal["azure", "azure_ad"] = "azure"
- # "The endpoint of the Azure OpenAI API. This should have the form https://YOUR_RESOURCE_NAME.openai.azure.com" # noqa
- api_base: Optional[str] = None
- api_version: Optional[str] = "2023-07-01-preview"
- # `deployment_name` will correspond to the custom name you chose for your deployment when # noqa
- # you deployed a model.
- deployment_name: Optional[str] = None
-
- def get_defaults(self, settings: "Settings") -> dict[str, Any]:
- import os
-
- import openai
-
- from marvin import openai as marvin_openai
-
- response: dict[str, Any] = {}
- if settings.llm_max_context_tokens > 0:
- response["max_tokens"] = settings.llm_max_tokens
- response["temperature"] = settings.llm_temperature
- response["request_timeout"] = settings.llm_request_timeout_seconds
- response["api_key"] = self.api_key and self.api_key.get_secret_value()
- if os.environ.get("MARVIN_AZURE_OPENAI_API_KEY"):
- response["api_key"] = os.environ["MARVIN_AZURE_OPENAI_API_KEY"]
- if openai.api_key:
- response["api_key"] = openai.api_key
- if marvin_openai.api_key:
- response["api_key"] = marvin_openai.api_key
-
- return model_dump(self, exclude_unset=True) | {
- k: v for k, v in response.items() if v is not None
- }
-
-
-def initial_setup(home: Union[Path, None] = None) -> Path:
- if not home:
- home = Path.home() / ".marvin"
- home.mkdir(parents=True, exist_ok=True)
- return home
-
-
-class Settings(MarvinBaseSettings):
- """Marvin settings"""
-
- home: Path = initial_setup()
- test_mode: bool = False
-
- # LOGGING
- log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
- verbose: bool = False
-
- # LLMS
- llm_model: str = "openai/gpt-3.5-turbo"
- llm_max_tokens: int = 1500
- llm_max_context_tokens: int = 3500
- llm_temperature: float = 0.8
- llm_request_timeout_seconds: Union[float, list[float]] = 600.0
-
- # AI APPLICATIONS
- ai_application_max_iterations: Optional[int] = None
-
- # providers
- openai: OpenAISettings = OpenAISettings()
- anthropic: AnthropicSettings = AnthropicSettings()
- azure_openai: AzureOpenAI = AzureOpenAI()
-
- # SLACK
- slack_api_token: Optional[SecretStr] = None
-
- # TOOLS
-
- # chroma
- chroma_server_host: Optional[str] = None
- chroma_server_http_port: Optional[int] = None
- # github
- github_token: Optional[SecretStr] = None
-
- # wolfram
- wolfram_app_id: Optional[SecretStr] = None
-
- def get_defaults(self, provider: Optional[str] = None) -> dict[str, Any]:
- response: dict[str, Any] = {}
- if provider == "openai":
- return self.openai.get_defaults(self)
- elif provider == "anthropic":
- return self.anthropic.get_defaults(self)
- elif provider == "azure_openai":
- return self.azure_openai.get_defaults(self)
- else:
- return response
+ @property
+ def client(self, api_key: Optional[str] = None, **kwargs: Any) -> "Client":
+ from openai import Client
+
+ if not (api_key or self.api_key):
+ raise ValueError("No API key provided.")
+ elif not api_key and self.api_key:
+ api_key = self.api_key.get_secret_value()
+
+ return Client(
+ api_key=api_key,
+ organization=self.organization,
+ **kwargs,
+ )
+
+
+class Settings(MarvinSettings):
+ model_config = SettingsConfigDict(env_prefix="marvin_")
+
+ openai: OpenAISettings = Field(default_factory=OpenAISettings)
+
+ log_level: str = Field(
+ default="DEBUG",
+ description="The log level to use.",
+ )
settings = Settings()
@@ -190,15 +223,9 @@ def temporary_settings(**kwargs: Any):
been already been accessed at module load time.
This function should only be used for testing.
-
- Example:
- >>> from marvin.settings import settings
- >>> with temporary_settings(MARVIN_LLM_MAX_TOKENS=100):
- >>> assert settings.llm_max_tokens == 100
- >>> assert settings.llm_max_tokens == 1500
"""
old_env = os.environ.copy()
- old_settings = settings.copy()
+ old_settings = settings.model_copy()
try:
for setting in kwargs:
@@ -210,7 +237,7 @@ def temporary_settings(**kwargs: Any):
new_settings = Settings()
- for field in settings.__fields__:
+ for field in settings.model_fields:
object.__setattr__(settings, field, getattr(new_settings, field))
yield settings
@@ -222,5 +249,5 @@ def temporary_settings(**kwargs: Any):
else:
os.environ.pop(setting, None)
- for field in settings.__fields__:
+ for field in settings.model_fields:
object.__setattr__(settings, field, getattr(old_settings, field))
diff --git a/src/marvin/tools/__init__.py b/src/marvin/tools/__init__.py
index 0be514778..e69de29bb 100644
--- a/src/marvin/tools/__init__.py
+++ b/src/marvin/tools/__init__.py
@@ -1,2 +0,0 @@
-from .base import Tool, tool
-from . import format_response
diff --git a/src/marvin/tools/assistants.py b/src/marvin/tools/assistants.py
new file mode 100644
index 000000000..bd07e4a36
--- /dev/null
+++ b/src/marvin/tools/assistants.py
@@ -0,0 +1,17 @@
+from typing import Any, Union
+
+from marvin.requests import CodeInterpreterTool, RetrievalTool, Tool
+
+Retrieval = RetrievalTool()
+CodeInterpreter = CodeInterpreterTool()
+
+AssistantTools = Union[RetrievalTool, CodeInterpreterTool, Tool]
+
+
+class CancelRun(Exception):
+ """
+ A special exception that can be raised in a tool to end the run immediately.
+ """
+
+ def __init__(self, data: Any = None):
+ self.data = data
diff --git a/src/marvin/tools/base.py b/src/marvin/tools/base.py
deleted file mode 100644
index d3c455084..000000000
--- a/src/marvin/tools/base.py
+++ /dev/null
@@ -1,76 +0,0 @@
-import inspect
-from functools import partial
-from typing import Callable, Optional
-
-from pydantic import BaseModel
-
-from marvin._compat import field_validator
-from marvin.types import Function
-from marvin.utilities.strings import jinja_env
-from marvin.utilities.types import LoggerMixin, function_to_schema
-
-
-class Tool(LoggerMixin, BaseModel):
- name: Optional[str] = None
- description: Optional[str] = None
- fn: Optional[Callable] = None
-
- @classmethod
- def from_function(cls, fn, name: str = None, description: str = None):
- # assuming fn has a name and a description
- name = name or fn.__name__
- description = description or fn.__doc__
- return cls(name=name, description=description, fn=fn)
-
- @field_validator("name")
- def default_name(cls, v: Optional[str]) -> str:
- if v is None:
- return cls.__name__
- else:
- return v
-
- def run(self, *args, **kwargs):
- if not self.fn:
- raise NotImplementedError()
- else:
- return self.fn(*args, **kwargs)
-
- def __call__(self, *args, **kwargs):
- return self.run(*args, **kwargs)
-
- def argument_schema(self) -> dict:
- schema = function_to_schema(self.fn or self.run, name=self.name)
- schema.pop("title", None)
- return schema
-
- def as_function(self, description: Optional[str] = None) -> Function:
- if not description:
- description = jinja_env.from_string(
- inspect.cleandoc(self.description or "")
- )
- description = description.render(**self.dict(), TOOL=self)
-
- def fn(*args, **kwargs):
- return self.run(*args, **kwargs)
-
- fn.__name__ = self.__class__.__name__
- fn.__doc__ = self.run.__doc__
-
- schema = self.argument_schema()
-
- return Function(
- name=fn.__name__,
- description=description,
- parameters=schema,
- fn=fn,
- signature=inspect.signature(self.run),
- )
-
-
-def tool(arg=None, *, name: str = None, description: str = None):
- if callable(arg): # Direct function decoration
- return Tool.from_function(arg, name=name, description=description)
- elif arg is None: # Partial application
- return partial(tool, name=name, description=description)
- else:
- raise TypeError("Invalid argument passed to decorator.")
diff --git a/src/marvin/tools/chroma.py b/src/marvin/tools/chroma.py
deleted file mode 100644
index 7e8f60b38..000000000
--- a/src/marvin/tools/chroma.py
+++ /dev/null
@@ -1,129 +0,0 @@
-import asyncio
-import json
-from typing import Optional
-
-import httpx
-from typing_extensions import Literal
-
-import marvin
-from marvin.tools import Tool
-from marvin.utilities.embeddings import create_openai_embeddings
-
-QueryResultType = Literal["documents", "distances", "metadatas"]
-
-
-async def list_collections() -> list[dict]:
- async with httpx.AsyncClient() as client:
- chroma_api_url = f"http://{marvin.settings.chroma_server_host}:{marvin.settings.chroma_server_http_port}"
- response = await client.get(
- f"{chroma_api_url}/api/v1/collections",
- )
-
- response.raise_for_status()
- return response.json()
-
-
-async def query_chroma(
- query: str,
- collection: str = "marvin",
- n_results: int = 5,
- where: Optional[dict] = None,
- where_document: Optional[dict] = None,
- include: Optional[list[QueryResultType]] = None,
- max_characters: int = 2000,
-) -> str:
- query_embedding = (await create_openai_embeddings([query]))[0]
-
- collection_ids = [
- c["id"] for c in await list_collections() if c["name"] == collection
- ]
-
- if len(collection_ids) == 0:
- return f"Collection {collection} not found."
-
- collection_id = collection_ids[0]
-
- async with httpx.AsyncClient() as client:
- chroma_api_url = f"http://{marvin.settings.chroma_server_host}:{marvin.settings.chroma_server_http_port}"
-
- response = await client.post(
- f"{chroma_api_url}/api/v1/collections/{collection_id}/query",
- data=json.dumps(
- {
- "query_embeddings": [query_embedding],
- "n_results": n_results,
- "where": where or {},
- "where_document": where_document or {},
- "include": include or ["documents"],
- }
- ),
- headers={"Content-Type": "application/json"},
- )
-
- response.raise_for_status()
-
- return "\n".join(
- [
- f"{i+1}. {', '.join(excerpt)}"
- for i, excerpt in enumerate(response.json()["documents"])
- ]
- )[:max_characters]
-
-
-class QueryChroma(Tool):
- """Tool for querying a Chroma index."""
-
- description: str = """
- Retrieve document excerpts from a knowledge-base given a query.
- """
-
- async def run(
- self,
- query: str,
- collection: str = "marvin",
- n_results: int = 5,
- where: Optional[dict] = None,
- where_document: Optional[dict] = None,
- include: Optional[list[QueryResultType]] = None,
- max_characters: int = 2000,
- ) -> str:
- return await query_chroma(
- query, collection, n_results, where, where_document, include, max_characters
- )
-
-
-class MultiQueryChroma(Tool):
- """Tool for querying a Chroma index."""
-
- description: str = """
- Retrieve document excerpts from a knowledge-base given a query.
- """
-
- async def run(
- self,
- queries: list[str],
- collection: str = "marvin",
- n_results: int = 5,
- where: Optional[dict] = None,
- where_document: Optional[dict] = None,
- include: Optional[list[QueryResultType]] = None,
- max_characters: int = 2000,
- max_queries: int = 5,
- ) -> str:
- if len(queries) > max_queries:
- # make sure excerpts are not too short
- queries = queries[:max_queries]
-
- coros = [
- query_chroma(
- query,
- collection,
- n_results,
- where,
- where_document,
- include,
- max_characters // len(queries),
- )
- for query in queries
- ]
- return "\n\n".join(await asyncio.gather(*coros, return_exceptions=True))
diff --git a/src/marvin/tools/code.py b/src/marvin/tools/code.py
new file mode 100644
index 000000000..89f2d7950
--- /dev/null
+++ b/src/marvin/tools/code.py
@@ -0,0 +1,26 @@
+# 🚨 WARNING 🚨
+# These functions allow ARBITRARY code execution and should be used with caution.
+
+import json
+import subprocess
+
+
+def shell(command: str) -> str:
+ """executes a shell command on your local machine and returns the output"""
+
+ result = subprocess.run(command, shell=True, text=True, capture_output=True)
+
+ # Output and error
+ output = result.stdout
+ error = result.stderr
+
+ return json.dumps(dict(command_output=output, command_error=error))
+
+
+def python(code: str) -> str:
+ """
+ Executes Python code on your local machine and returns the output. You can
+ use this to run code that isn't compatible with the code interpreter tool,
+ for example if it requires internet access or other packages.
+ """
+ return str(eval(code))
diff --git a/src/marvin/tools/filesystem.py b/src/marvin/tools/filesystem.py
index 44d7cfbf7..31386124f 100644
--- a/src/marvin/tools/filesystem.py
+++ b/src/marvin/tools/filesystem.py
@@ -1,186 +1,151 @@
-import json
-from pathlib import Path
-from typing import Literal
-
-from pydantic import BaseModel, Field, root_validator, validate_arguments, validator
-
-from marvin.tools import Tool
-
-
-class FileSystemTool(Tool):
- root_dir: Path = Field(
- None,
- description=(
- "Root directory for files. If provided, only files nested in or below this"
- " directory can be read. "
- ),
- )
-
- def validate_paths(self, paths: list[str]) -> list[Path]:
- """
- If `root_dir` is set, ensures that all paths are children of `root_dir`.
- """
- if self.root_dir:
- for path in paths:
- if ".." in path:
- raise ValueError(f"Do not use `..` in paths. Got {path}")
- if not (self.root_dir / path).is_relative_to(self.root_dir):
- raise ValueError(f"Path {path} is not relative to {self.root_dir}")
- return [self.root_dir / path for path in paths]
- return paths
-
-
-class ListFiles(FileSystemTool):
- description: str = """
- Lists all files at or optionally under a provided path. {%- if root_dir
- %} Paths must be relative to {{ root_dir }}. Provide '.' instead of '/'
- to read root. {%- endif %}}
- """
-
- root_dir: Path = Field(
- None,
- description=(
- "Root directory for files. If provided, only files nested in or below this"
- " directory can be read."
- ),
- )
-
- def run(self, path: str, include_nested: bool = True) -> list[str]:
- """List all files in `root_dir`, optionally including nested files."""
- [path] = self.validate_paths([path])
- if include_nested:
- files = [str(p) for p in path.rglob("*") if p.is_file()]
- else:
- files = [str(p) for p in path.glob("*") if p.is_file()]
-
- # filter out certain files
- files = [
- file
- for file in files
- if not (
- "__pycache__" in file
- or "/.git/" in file
- or file.endswith("/.gitignore")
- )
- ]
+import os
+import pathlib
+import shutil
- return files
+def _safe_create_file(path: str) -> None:
+ path = os.path.expanduser(path)
+ file_path = pathlib.Path(path)
+ file_path.parent.mkdir(parents=True, exist_ok=True)
+ file_path.touch(exist_ok=True)
-class ReadFile(FileSystemTool):
- description: str = """
- Read the content of a specific file, optionally providing start and end
- rows.{% if root_dir %} Paths must be relative to {{ root_dir }}. Provide '.'
- instead of '/' to read root.{%- endif %}}
- """
- def run(self, path: str, start_row: int = 1, end_row: int = -1) -> str:
- [path] = self.validate_paths([path])
- with open(path, "r") as f:
- content = f.readlines()
+def getcwd() -> str:
+ """Returns the current working directory"""
+ return os.getcwd()
- if start_row == 0:
- start_row = 1
- if start_row > 0:
- start_row -= 1
- if end_row < 0:
- end_row += 1
- if end_row == 0:
- content = content[start_row:]
- else:
- content = content[start_row:end_row]
+def write(path: str, contents: str) -> str:
+ """Creates or overwrites a file with the given contents"""
+ path = os.path.expanduser(path)
+ _safe_create_file(path)
+ with open(path, "w") as f:
+ f.write(contents)
+ return f'Successfully wrote "{path}"'
- return "\n".join(content)
+def write_lines(
+ path: str, contents: str, insert_line: int = -1, mode: str = "insert"
+) -> str:
+ """Writes content to a specific line in the file.
-class ReadFiles(FileSystemTool):
- description: str = """
- Read the entire content of multiple files at once. Due to context size
- limitations, reading too many files at once may cause truncated responses.
- {% if root_dir %} Paths must be relative to {{ root_dir }}. Provide '.'
- instead of '/' to read root.{%- endif %}}
+ Args:
+ path (str): The name of the file to write to.
+ contents (str): The content to write to the file.
+ insert_line (int, optional): The line number to insert the content at.
+ Negative values count from the end of the file. Defaults to -1.
+ mode (str, optional): The mode to use when writing the content. Can be
+ "insert" or "overwrite". Defaults to "insert".
+
+ Returns:
+ str: A message indicating whether the write was successful.
"""
+ path = os.path.expanduser(path)
+ _safe_create_file(path)
+ with open(path, "r") as f:
+ lines = f.readlines()
+ if insert_line < 0:
+ insert_line = len(lines) + insert_line + 1
+ if mode == "insert":
+ lines[insert_line:insert_line] = contents.splitlines(True)
+ elif mode == "overwrite":
+ lines[insert_line : insert_line + len(contents.splitlines())] = (
+ contents.splitlines(True)
+ )
+ else:
+ raise ValueError(f"Invalid mode: {mode}")
+ with open(path, "w") as f:
+ f.writelines(lines)
+ return f'Successfully wrote to "{path}"'
- def run(self, paths: list[str]) -> dict[str, str]:
- """Load content of each file into a dictionary of path: content."""
- content = {}
- for path in self.validate_paths(paths):
- with open(path) as f:
- content[path] = f.read()
- return content
-
-
-class WriteContent(BaseModel):
- path: str
- content: str
- write_mode: Literal["overwrite", "append", "insert"] = "append"
- insert_at_row: int = None
-
- @validator("content", pre=True)
- def content_must_be_string(cls, v):
- if v and not isinstance(v, str):
- try:
- v = json.dumps(v)
- except json.JSONDecodeError:
- raise ValueError("Content must be a string or JSON-serializable.")
- return v
-
- @root_validator
- def check_insert_model(cls, values):
- if values["insert_at_row"] is None and values["write_mode"] == "insert":
- raise ValueError("Must provide `insert_at_row` when using `insert` mode.")
- return values
-
-
-class WriteFile(FileSystemTool):
- description: str = """
- Write content to a file.
-
- {%if root_dir %} Paths must be relative to {{ root_dir }}.{% endif %}}
-
- {%if require_confirmation %} You MUST ask the user to confirm writes by
- showing them details. {% endif %}
- """
- require_confirmation: bool = True
-
- def run(self, write_content: WriteContent) -> str:
- [path] = self.validate_paths([write_content.path])
-
- # ensure the parent directory exists
- path.parent.mkdir(parents=True, exist_ok=True)
-
- if write_content.write_mode == "overwrite":
- with open(path, "w") as f:
- f.write(write_content.content)
- elif write_content.write_mode == "append":
- with open(path, "a") as f:
- f.write(write_content.content)
- elif write_content.write_mode == "insert":
- with open(path, "r") as f:
- contents = f.readlines()
- contents[write_content.insert_at_row] = write_content.content
-
- with open(path, "w") as f:
- f.writelines(contents)
-
- return f"Files {write_content.path} written successfully."
-
-
-class WriteFiles(WriteFile):
- description: str = """
- Write content to multiple files. Each `WriteContent` object in the
- `contents` argument is an instruction to write to a specific file.
-
- {%if root_dir %} Paths must be relative to {{ root_dir }}.{% endif %}}
-
- {%if require_confirmation %} You MUST ask the user to confirm writes by
- showing them details. {% endif %}
- """
- require_confirmation: bool = True
-
- @validate_arguments
- def run(self, contents: list[WriteContent]) -> str:
- for wc in contents:
- super().run(write_content=wc)
- return f"Files {[c.path for c in contents]} written successfully."
+
+def read(path: str, include_line_numbers: bool = False) -> str:
+ """Reads a file and returns the contents.
+
+ Args:
+ path (str): The path to the file.
+ include_line_numbers (bool, optional): Whether to include line numbers
+ in the returned contents. Defaults to False.
+
+ Returns:
+ str: The contents of the file.
+ """
+ path = os.path.expanduser(path)
+ with open(path, "r") as f:
+ if include_line_numbers:
+ lines = f.readlines()
+ lines_with_numbers = [f"{i+1}: {line}" for i, line in enumerate(lines)]
+ return "".join(lines_with_numbers)
+ else:
+ return f.read()
+
+
+def read_lines(
+ path: str,
+ start_line: int = 0,
+ end_line: int = -1,
+ include_line_numbers: bool = False,
+) -> str:
+ """Reads a partial file and returns the contents with optional line numbers.
+
+ Args:
+ path (str): The path to the file.
+ start_line (int, optional): The starting line number to read. Defaults
+ to 0.
+ end_line (int, optional): The ending line number to read. Defaults to
+ -1, which means read until the end of the file.
+ include_line_numbers (bool, optional): Whether to include line numbers
+ in the returned contents. Defaults to False.
+
+ Returns:
+ str: The contents of the file.
+ """
+ path = os.path.expanduser(path)
+ with open(path, "r") as f:
+ lines = f.readlines()
+ if start_line < 0:
+ start_line = len(lines) + start_line
+ if end_line < 0:
+ end_line = len(lines) + end_line
+ if include_line_numbers:
+ lines_with_numbers = [
+ f"{i+1}: {line}" for i, line in enumerate(lines[start_line:end_line])
+ ]
+ return "".join(lines_with_numbers)
+ else:
+ return "".join(lines[start_line:end_line])
+
+
+def mkdir(path: str) -> str:
+ """Creates a directory (and any parent directories))"""
+ path = os.path.expanduser(path)
+ path = pathlib.Path(path)
+ path.mkdir(parents=True, exist_ok=True)
+ return f'Successfully created directory "{path}"'
+
+
+def mv(src: str, dest: str) -> str:
+ """Moves a file or directory"""
+ src = os.path.expanduser(src)
+ dest = os.path.expanduser(dest)
+ src = pathlib.Path(src)
+ dest = pathlib.Path(dest)
+ src.rename(dest)
+ return f'Successfully moved "{src}" to "{dest}"'
+
+
+def cp(src: str, dest: str) -> str:
+ """Copies a file or directory"""
+ src = os.path.expanduser(src)
+ dest = os.path.expanduser(dest)
+ src = pathlib.Path(src)
+ dest = pathlib.Path(dest)
+ shutil.copytree(src, dest)
+ return f'Successfully copied "{src}" to "{dest}"'
+
+
+def ls(path: str) -> str:
+ """Lists the contents of a directory"""
+ path = os.path.expanduser(path)
+ path = pathlib.Path(path)
+ return "\n".join(str(p) for p in path.iterdir())
diff --git a/src/marvin/tools/format_response.py b/src/marvin/tools/format_response.py
deleted file mode 100644
index 41fe2acb7..000000000
--- a/src/marvin/tools/format_response.py
+++ /dev/null
@@ -1,80 +0,0 @@
-import warnings
-from types import GenericAlias
-from typing import Any, Union
-
-import pydantic
-from pydantic import BaseModel, Field, PrivateAttr
-
-import marvin
-import marvin.utilities.types
-from marvin.tools import Tool
-from marvin.utilities.types import (
- genericalias_contains,
- safe_issubclass,
-)
-
-SENTINEL = "__SENTINEL__"
-
-
-class FormatResponse(Tool):
- _cached_type: Union[type, GenericAlias] = PrivateAttr(SENTINEL)
- type_schema: dict[str, Any] = Field(
- ..., description="The OpenAPI schema for the type"
- )
- description: str = (
- "You MUST always call this function before responding to the user to ensure"
- " that your final response is formatted correctly and complies with the output"
- " format requirements."
- )
-
- def __init__(self, type_: Union[type, GenericAlias] = SENTINEL, **kwargs):
- if type_ is not SENTINEL:
- if not isinstance(type_, (type, GenericAlias)):
- raise ValueError(f"Expected a type or GenericAlias, got {type_}")
-
- # warn if the type is a set or tuple with GPT 3.5
- if (
- "gpt-3.5" in marvin.settings.llm_model
- or "gpt-35" in marvin.settings.llm_model
- ):
- if safe_issubclass(type_, (set, tuple)) or genericalias_contains(
- type_, (set, tuple)
- ):
- warnings.warn(
- (
- "GPT-3.5 often fails with `set` or `tuple` types. Consider"
- " using `list` instead."
- ),
- UserWarning,
- )
-
- type_schema = marvin.utilities.types.type_to_schema(
- type_, set_root_type=False
- )
- type_schema.pop("title", None)
- kwargs["type_schema"] = type_schema
-
- super().__init__(**kwargs)
- if type_ is not SENTINEL:
- if type_schema.get("description"):
- self.description += f"\n\n {type_schema['description']}"
-
- if type_ is not SENTINEL:
- self._cached_type = type_
-
- def get_type(self) -> Union[type, GenericAlias]:
- if self._cached_type is not SENTINEL:
- return self._cached_type
- model = marvin.utilities.types.schema_to_type(self.type_schema)
- type_ = model.__fields__["__root__"].outer_type_
- self._cached_type = type_
- return type_
-
- def run(self, **kwargs) -> Any:
- type_ = self.get_type()
- if not safe_issubclass(type_, BaseModel):
- kwargs = kwargs["data"]
- return pydantic.parse_obj_as(type_, kwargs)
-
- def argument_schema(self) -> dict:
- return self.type_schema
diff --git a/src/marvin/tools/github.py b/src/marvin/tools/github.py
index 0969c826d..0a3e110b0 100644
--- a/src/marvin/tools/github.py
+++ b/src/marvin/tools/github.py
@@ -1,14 +1,40 @@
+import os
from datetime import datetime
from typing import List, Optional
import httpx
-from pydantic import BaseModel, Field, validator
+from pydantic import BaseModel, Field, field_validator
import marvin
-from marvin.tools import Tool
+from marvin.utilities.logging import get_logger
from marvin.utilities.strings import slice_tokens
+async def get_token() -> str:
+ try:
+ from prefect.blocks.system import Secret
+
+ return (await Secret.load("github-token")).get()
+ except (ImportError, ValueError) as exc:
+ get_logger("marvin").debug_kv(
+ (
+ "Prefect Secret for GitHub token not retrieved. "
+ f"{exc.__class__.__name__}: {exc}"
+ "red"
+ ),
+ )
+
+ try:
+ return marvin.settings.github_token
+ except AttributeError:
+ pass
+
+ if token := os.environ.get("MARVIN_GITHUB_TOKEN", ""):
+ return token
+
+ raise RuntimeError("GitHub token not found")
+
+
class GitHubUser(BaseModel):
"""GitHub user."""
@@ -39,59 +65,55 @@ class GitHubIssue(BaseModel):
labels: List[GitHubLabel] = Field(default_factory=GitHubLabel)
user: GitHubUser = Field(default_factory=GitHubUser)
- @validator("body", always=True)
+ @field_validator("body")
def validate_body(cls, v):
if not v:
return ""
return v
-class SearchGitHubIssues(Tool):
- """Tool for searching GitHub issues."""
-
- description: str = "Use the GitHub API to search for issues in a given repository."
-
- async def run(self, query: str, repo: str = "prefecthq/prefect", n: int = 3) -> str:
- """
- Use the GitHub API to search for issues in a given repository. Do
- not alter the default value for `n` unless specifically requested by
- a user.
-
- For example, to search for open issues about AttributeErrors with the
- label "bug" in PrefectHQ/prefect:
- - repo: prefecthq/prefect
- - query: label:bug is:open AttributeError
- """
- headers = {"Accept": "application/vnd.github.v3+json"}
-
- if token := marvin.settings.github_token:
- headers["Authorization"] = f"Bearer {token.get_secret_value()}"
-
- async with httpx.AsyncClient() as client:
- response = await client.get(
- "https://api.github.com/search/issues",
- headers=headers,
- params={
- "q": query if "repo:" in query else f"repo:{repo} {query}",
- "order": "desc",
- "per_page": n,
- },
- )
- response.raise_for_status()
+async def search_github_issues(
+ query: str, repo: str = "prefecthq/prefect", n: int = 3
+) -> str:
+ """
+ Use the GitHub API to search for issues in a given repository. Do
+ not alter the default value for `n` unless specifically requested by
+ a user.
+
+ For example, to search for open issues about AttributeErrors with the
+ label "bug" in PrefectHQ/prefect:
+ - repo: prefecthq/prefect
+ - query: label:bug is:open AttributeError
+ """
+ headers = {"Accept": "application/vnd.github.v3+json"}
+
+ headers["Authorization"] = f"Bearer {await get_token()}"
+
+ async with httpx.AsyncClient() as client:
+ response = await client.get(
+ "https://api.github.com/search/issues",
+ headers=headers,
+ params={
+ "q": query if "repo:" in query else f"repo:{repo} {query}",
+ "order": "desc",
+ "per_page": n,
+ },
+ )
+ response.raise_for_status()
- issues_data = response.json()["items"]
+ issues_data = response.json()["items"]
- # enforce 1000 token limit per body
- for issue in issues_data:
- if not issue["body"]:
- continue
- issue["body"] = slice_tokens(issue["body"], 1000)
+ # enforce 1000 token limit per body
+ for issue in issues_data:
+ if not issue["body"]:
+ continue
+ issue["body"] = slice_tokens(issue["body"], 1000)
- issues = [GitHubIssue(**issue) for issue in issues_data]
+ issues = [GitHubIssue(**issue) for issue in issues_data]
- summary = "\n\n".join(
- f"{issue.title} ({issue.html_url}):\n{issue.body}" for issue in issues
- )
- if not summary.strip():
- raise ValueError("No issues found.")
- return summary
+ summary = "\n\n".join(
+ f"{issue.title} ({issue.html_url}):\n{issue.body}" for issue in issues
+ )
+ if not summary.strip():
+ raise ValueError("No issues found.")
+ return summary
diff --git a/src/marvin/tools/mathematics.py b/src/marvin/tools/mathematics.py
deleted file mode 100644
index 1f28fde22..000000000
--- a/src/marvin/tools/mathematics.py
+++ /dev/null
@@ -1,51 +0,0 @@
-import httpx
-from typing_extensions import Literal
-
-import marvin
-from marvin.tools import Tool
-
-ResultType = Literal["DecimalApproximation"]
-
-
-class WolframCalculator(Tool):
- """Evaluate mathematical expressions using Wolfram Alpha."""
-
- description: str = """
- Evaluate mathematical expressions using Wolfram Alpha.
-
- Always append "to decimal" to your expression unless asked for something else.
- """
-
- async def run(
- self, expression: str, result_type: ResultType = "DecimalApproximation"
- ) -> str:
- async with httpx.AsyncClient() as client:
- response = await client.get(
- "https://api.wolframalpha.com/v2/query",
- params={
- "appid": marvin.settings.wolfram_app_id.get_secret_value(),
- "input": expression,
- "output": "json",
- },
- )
- try:
- response.raise_for_status()
- except httpx.HTTPStatusError as e:
- if e.response.status_code == 403:
- raise ValueError(
- "Invalid Wolfram Alpha App ID - get one at"
- " https://developer.wolframalpha.com/portal/myapps/"
- )
- raise
-
- data = response.json()
-
- pods = [
- pod
- for pod in data.get("queryresult", {}).get("pods", [])
- if pod.get("id") == result_type
- ]
-
- if not pods:
- return "No result found."
- return pods[0].get("subpods", [{}])[0].get("plaintext", "No result found.")
diff --git a/src/marvin/tools/python.py b/src/marvin/tools/python.py
deleted file mode 100644
index 290be1103..000000000
--- a/src/marvin/tools/python.py
+++ /dev/null
@@ -1,31 +0,0 @@
-import sys
-from io import StringIO
-
-from marvin.tools import Tool
-
-
-def run_python(code: str, globals: dict = None, locals: dict = None):
- old_stdout = sys.stdout
- new_stdout = StringIO()
- sys.stdout = new_stdout
- try:
- exec(code, globals or {}, locals or {})
- result = new_stdout.getvalue()
- except Exception as e:
- result = repr(e)
- finally:
- sys.stdout = old_stdout
- return result
-
-
-class Python(Tool):
- description: str = """
- Runs arbitrary Python code.
-
- {%if require_confirmation %} You MUST ask the user to confirm execution by
- showing them the code. {% endif %}
- """
- require_confirmation: bool = True
-
- def run(self, code: str) -> str:
- return run_python(code)
diff --git a/src/marvin/tools/retrieval.py b/src/marvin/tools/retrieval.py
new file mode 100644
index 000000000..a32ca6172
--- /dev/null
+++ b/src/marvin/tools/retrieval.py
@@ -0,0 +1,141 @@
+import asyncio
+import json
+import os
+from typing import Optional
+
+import httpx
+from typing_extensions import Literal
+
+import marvin
+
+try:
+ HOST, PORT = (
+ marvin.settings.chroma_server_host,
+ marvin.settings.chroma_server_http_port,
+ )
+except AttributeError:
+ HOST = os.environ.get("MARVIN_CHROMA_SERVER_HOST", "localhost")
+ PORT = os.environ.get("MARVIN_CHROMA_SERVER_HTTP_PORT", 8000)
+
+QueryResultType = Literal["documents", "distances", "metadatas"]
+
+
+async def create_openai_embeddings(texts: list[str]) -> list[list[float]]:
+ """Create OpenAI embeddings for a list of texts."""
+
+ try:
+ import numpy # noqa F401
+ except ImportError:
+ raise ImportError(
+ "The numpy package is required to create OpenAI embeddings. Please install"
+ " it with `pip install numpy`."
+ )
+ from openai import AsyncOpenAI
+
+ return (
+ (
+ await AsyncOpenAI(
+ api_key=marvin.settings.openai.api_key.get_secret_value()
+ ).embeddings.create(
+ input=[text.replace("\n", " ") for text in texts],
+ model="text-embedding-ada-002",
+ )
+ )
+ .data[0]
+ .embedding
+ )
+
+
+async def list_collections() -> list[dict]:
+ async with httpx.AsyncClient() as client:
+ chroma_api_url = f"http://{HOST}:{PORT}"
+ response = await client.get(
+ f"{chroma_api_url}/api/v1/collections",
+ )
+
+ response.raise_for_status()
+ return response.json()
+
+
+async def query_chroma(
+ query: str,
+ collection: str = "marvin",
+ n_results: int = 5,
+ where: Optional[dict] = None,
+ where_document: Optional[dict] = None,
+ include: Optional[list[QueryResultType]] = None,
+ max_characters: int = 2000,
+) -> str:
+ """Query Chroma.
+
+ Example:
+ User: "What are prefect blocks?"
+ Assistant: >>> query_chroma("What are prefect blocks?")
+ """
+ query_embedding = await create_openai_embeddings([query])
+
+ collection_ids = [
+ c["id"] for c in await list_collections() if c["name"] == collection
+ ]
+
+ if len(collection_ids) == 0:
+ return f"Collection {collection} not found."
+
+ collection_id = collection_ids[0]
+
+ async with httpx.AsyncClient() as client:
+ chroma_api_url = f"http://{HOST}:{PORT}"
+
+ response = await client.post(
+ f"{chroma_api_url}/api/v1/collections/{collection_id}/query",
+ data=json.dumps(
+ {
+ "query_embeddings": [query_embedding],
+ "n_results": n_results,
+ "where": where or {},
+ "where_document": where_document or {},
+ "include": include or ["documents"],
+ }
+ ),
+ headers={"Content-Type": "application/json"},
+ )
+
+ response.raise_for_status()
+
+ return "\n".join(
+ f"{i+1}. {', '.join(excerpt)}"
+ for i, excerpt in enumerate(response.json()["documents"])
+ )[:max_characters]
+
+
+async def multi_query_chroma(
+ queries: list[str],
+ collection: str = "marvin",
+ n_results: int = 5,
+ where: Optional[dict] = None,
+ where_document: Optional[dict] = None,
+ include: Optional[list[QueryResultType]] = None,
+ max_characters: int = 2000,
+) -> str:
+ """Query Chroma with multiple queries.
+
+ Example:
+ User: "What are prefect blocks and tasks?"
+ Assistant: >>> multi_query_chroma(
+ ["What are prefect blocks?", "What are prefect tasks?"]
+ )
+ """
+
+ coros = [
+ query_chroma(
+ query,
+ collection,
+ n_results,
+ where,
+ where_document,
+ include,
+ max_characters // len(queries),
+ )
+ for query in queries
+ ]
+ return "\n".join(await asyncio.gather(*coros))[:max_characters]
diff --git a/src/marvin/tools/shell.py b/src/marvin/tools/shell.py
deleted file mode 100644
index 9988299ca..000000000
--- a/src/marvin/tools/shell.py
+++ /dev/null
@@ -1,52 +0,0 @@
-import subprocess
-from pathlib import Path
-
-from marvin.tools import Tool
-
-
-class Shell(Tool):
- description: str = """
- Runs arbitrary shell code.
-
- {% if working_directory %} The working directory will be {{
- working_directory }}. {% endif %}.
-
- {%if require_confirmation %} You MUST ask the user to confirm execution by
- showing them the code. {% endif %}
- """
- require_confirmation: bool = True
- working_directory: Path = None
-
- def run(self, cmd: str, working_directory: str = None) -> str:
- if working_directory and self.working_directory:
- raise ValueError(
- f"The working directoy is {self.working_directory}; do not provide"
- " another one."
- )
-
- wd = working_directory or (
- str(self.working_directory) if self.working_directory else None
- )
-
- try:
- result = subprocess.run(
- cmd,
- shell=True,
- cwd=wd,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- )
-
- if result.returncode != 0:
- return (
- f"Command failed with code {result.returncode} and error:"
- f" {result.stderr.decode() or ''}"
- )
- else:
- return (
- "Command succeeded with output:"
- f" { result.stdout.decode() or '' }"
- )
-
- except Exception as e:
- return f"Execution failed: {str(e)}"
diff --git a/src/marvin/tools/web.py b/src/marvin/tools/web.py
deleted file mode 100644
index 2e65ab72f..000000000
--- a/src/marvin/tools/web.py
+++ /dev/null
@@ -1,98 +0,0 @@
-import json
-from itertools import islice
-from typing import Dict
-
-import httpx
-from typing_extensions import Literal
-
-from marvin._compat import BaseSettings, Field, SecretStr
-from marvin.tools import Tool
-from marvin.utilities.strings import html_to_content, slice_tokens
-
-
-class SerpApiSettings(BaseSettings):
- api_key: SecretStr = Field(None, env="MARVIN_SERPAPI_API_KEY")
-
-
-class VisitUrl(Tool):
- """Tool for visiting a URL - only to be used in special cases."""
-
- description: str = "Visit a valid URL and return its contents."
-
- async def run(self, url: str) -> str:
- if not url.startswith("http"):
- url = f"http://{url}"
- async with httpx.AsyncClient(follow_redirects=True, timeout=2) as client:
- try:
- response = await client.get(url)
- except httpx.ConnectTimeout:
- return "Failed to load URL: Connection timed out"
- if response.status_code == 200:
- text = response.text
-
- # try to parse as JSON in case the URL is an API
- try:
- content = str(json.loads(text))
- # otherwise parse as HTML
- except json.JSONDecodeError:
- content = html_to_content(text)
- return slice_tokens(content, 1000)
- else:
- return f"Failed to load URL: {response.status_code}"
-
-
-class DuckDuckGoSearch(Tool):
- """Tool for searching the web with DuckDuckGo."""
-
- description: str = "Search the web with DuckDuckGo."
- backend: Literal["api", "html", "lite"] = "lite"
-
- async def run(self, query: str, n_results: int = 3) -> str:
- try:
- from duckduckgo_search import DDGS
- except ImportError:
- raise RuntimeError(
- "You must install the duckduckgo-search library to use this tool. "
- "You can do so by running `pip install 'marvin[ddg]'`."
- )
-
- with DDGS() as ddgs:
- return [
- r for r in islice(ddgs.text(query, backend=self.backend), n_results)
- ]
-
-
-class GoogleSearch(Tool):
- description: str = """
- For performing a Google search and retrieving the results.
-
- Provide the search query to get answers.
- """
-
- async def run(self, query: str, n_results: int = 3) -> Dict:
- try:
- from serpapi import GoogleSearch as google_search
- except ImportError:
- raise RuntimeError(
- "You must install the serpapi library to use this tool. "
- "You can do so by running `pip install 'marvin[serpapi]'`."
- )
-
- if (api_key := SerpApiSettings().api_key) is None:
- raise RuntimeError(
- "You must provide a SerpApi API key to use this tool. You can do so by"
- " setting the MARVIN_SERPAPI_API_KEY environment variable."
- )
-
- search_params = {
- "q": query,
- "api_key": api_key.get_secret_value(),
- }
- results = google_search(search_params).get_dict()
-
- if "error" in results:
- raise RuntimeError(results["error"])
- return [
- {"title": r.get("title"), "href": r.get("link"), "body": r.get("snippet")}
- for r in results.get("organic_results", [])[:n_results]
- ]
diff --git a/src/marvin/types/__init__.py b/src/marvin/types/__init__.py
deleted file mode 100644
index 578e8a155..000000000
--- a/src/marvin/types/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .function import Function
diff --git a/src/marvin/types/function.py b/src/marvin/types/function.py
deleted file mode 100644
index a9da6cd25..000000000
--- a/src/marvin/types/function.py
+++ /dev/null
@@ -1,137 +0,0 @@
-import copy
-import inspect
-import re
-from typing import Callable, Optional, Type
-
-from marvin._compat import BaseModel, validate_arguments
-
-extraneous_fields = [
- "args",
- "kwargs",
- "ALT_V_ARGS",
- "ALT_V_KWARGS",
- "V_POSITIONAL_ONLY_NAME",
- "V_DUPLICATE_KWARGS",
-]
-
-
-def get_openai_function_schema(schema, model):
- # Make a copy of the schema.
- _schema = copy.deepcopy(schema)
-
- # Prune the schema of 'titles'.
- _schema.pop("title", None)
- for key, value in schema["properties"].items():
- if key in extraneous_fields:
- _schema["properties"].pop(key, None)
- else:
- _schema["properties"][key].pop("title", None)
-
- # Clear the existing schema.
- schema.clear()
-
- # Reconstruct the schema.
- schema["name"] = getattr(model, "name", model.Config.fn.__name__)
- schema["description"] = getattr(model, "description", model.Config.fn.__doc__)
- schema["parameters"] = {
- k: v for (k, v) in _schema.items() if k not in extraneous_fields
- }
-
-
-class FunctionConfig(BaseModel):
- fn: Callable
- name: str
- description: str = ""
- schema_extra: Optional[Callable] = get_openai_function_schema
-
- def __init__(self, fn, **kwargs):
- kwargs.setdefault("name", fn.__name__ or "")
- kwargs.setdefault("description", fn.__doc__ or "")
- super().__init__(fn=fn, **kwargs)
-
- def getsource(self):
- try:
- return re.search("def.*", inspect.getsource(self.fn), re.DOTALL).group()
- except Exception:
- return None
-
- def bind_arguments(self, *args, **kwargs):
- bound_arguments = inspect.signature(self.fn).bind(*args, **kwargs)
- bound_arguments.apply_defaults()
- return bound_arguments.arguments
-
- def response_model(self, *args, **kwargs):
- def format_response(data: inspect.signature(self.fn).return_annotation):
- """Function to format the final response to the user"""
- return None
-
- format_response.__name__ = kwargs.get("name", format_response.__name__)
- format_response.__doc__ = kwargs.get("description", format_response.__doc__)
- response_model = Function(format_response).model
- response_model.__signature__ = inspect.signature(format_response)
- return response_model
-
-
-class Function:
- """
- A wrapper class to add additional functionality to a function,
- such as a schema, response model, and more.
- """
-
- def __new__(cls, fn: Callable, parameters: dict = None, signature=None, **kwargs):
- config = FunctionConfig(fn, **kwargs)
- instance = super().__new__(cls)
- instance.instance = validate_arguments(fn, config=config.dict())
- instance.schema = parameters or instance.instance.model.schema
- instance.response_model = config.response_model
- instance.bind_arguments = config.bind_arguments
- instance.getsource = config.getsource
- instance.__name__ = config.name or fn.__name__
- instance.__doc__ = config.description or fn.__doc__
- instance.signature = signature or inspect.signature(fn)
- return instance
-
- def __call__(self, *args, **kwargs):
- return self.evaluate_raw(*args, **kwargs)
-
- def evaluate_raw(self, *args, **kwargs):
- return self.instance(*args, **kwargs)
-
- @classmethod
- def from_model(cls, model: Type[BaseModel], **kwargs):
- model.__signature__ = inspect.Signature(
- list(model.__signature__.parameters.values()), return_annotation=model
- )
-
- return cls(
- model, name="format_response", description="Format the response", **kwargs
- )
-
- @classmethod
- def from_return_annotation(
- cls, fn: Callable, *args, name: str = None, description: str = None
- ):
- def format_final_response(data: inspect.signature(fn).return_annotation):
- """Function to format the final response to the user"""
- return None
-
- format_final_response.__name__ = name or format_final_response.__name__
- format_final_response.__doc__ = description or format_final_response.__doc__
- return cls(format_final_response)
-
- def __repr__(self):
- parameters = []
- for _, param in self.signature.parameters.items():
- param_repr = str(param)
- if param.annotation is not param.empty:
- param_repr = param_repr.replace(
- str(param.annotation),
- (
- param.annotation.__name__
- if isinstance(param.annotation, type)
- else str(param.annotation)
- ),
- )
- parameters.append(param_repr)
- param_str = ", ".join(parameters)
- return f"marvin.functions.{self.__name__}({param_str})"
diff --git a/src/marvin/types/request.py b/src/marvin/types/request.py
deleted file mode 100644
index 754eabaf3..000000000
--- a/src/marvin/types/request.py
+++ /dev/null
@@ -1,117 +0,0 @@
-import warnings
-from typing import Callable, List, Literal, Optional, Type, Union
-
-from pydantic import BaseModel, BaseSettings, Extra, Field, root_validator, validator
-
-from marvin.types import Function
-
-
-class Request(BaseSettings):
- """
- This is a class for creating Request objects to interact with the GPT-3 API.
- The class contains several configurations and validation functions to ensure
- the correct data is sent to the API.
-
- """
-
- messages: Optional[List[dict[str, str]]] = [] # messages to send to the API
- functions: List[Union[dict, Callable]] = None # functions to be used in the request
- function_call: Optional[Union[dict[Literal["name"], str], Literal["auto"]]] = None
-
- # Internal Marvin Attributes to be excluded from the data sent to the API
- response_model: Optional[Type[BaseModel]] = Field(default=None)
- evaluate_function_call: bool = Field(default=False)
-
- class Config:
- exclude = {"response_model"}
- exclude_none = True
- extra = Extra.allow
-
- @root_validator(pre=True)
- def handle_response_model(cls, values):
- """
- This function validates and handles the response_model attribute.
- If a response_model is provided, it creates a function from the model
- and sets it as the function to call.
- """
- response_model = values.get("response_model")
- if response_model:
- fn = Function.from_model(response_model)
- values["functions"] = [fn]
- values["function_call"] = {"name": fn.__name__}
- return values
-
- @validator("functions", each_item=True)
- def validate_function(cls, fn):
- """
- This function validates the functions attribute.
- If a Callable is provided, it wraps it with the Function class.
- """
- if isinstance(fn, Callable):
- fn = Function(fn)
- return fn
-
- def __or__(self, config):
- """
- This method is used to merge two Request objects.
- If the attribute is a list, the lists are concatenated.
- Otherwise, the attribute from the provided config is used.
- """
-
- touched = config.dict(exclude_unset=True, serialize_functions=False)
-
- fields = list(
- set(
- [
- # We exclude none fields from defaults.
- *self.dict(exclude_none=True, serialize_functions=False).keys(),
- # We exclude unset fields from the provided config.
- *config.dict(exclude_unset=True, serialize_functions=False).keys(),
- ]
- )
- )
-
- for field in fields:
- if isinstance(getattr(self, field, None), list):
- merged = (getattr(self, field, []) or []) + (
- getattr(config, field, []) or []
- )
- setattr(self, field, merged)
- else:
- setattr(self, field, touched.get(field, getattr(self, field, None)))
- return self
-
- def merge(self, **kwargs):
- warnings.warn(
- "This is deprecated. Use the | operator instead.", DeprecationWarning
- )
- return self | self.__class__(**kwargs)
-
- def functions_schema(self, *args, **kwargs):
- """
- This method generates a list of schemas for all functions in the request.
- If a function is callable, its model's schema is returned.
- Otherwise, the function itself is returned.
- """
- return [
- fn.model.schema() if isinstance(fn, Callable) else fn
- for fn in self.functions or []
- ]
-
- def _dict(self, *args, serialize_functions=True, exclude=None, **kwargs):
- """
- This method returns a dictionary representation of the Request.
- If the functions attribute is present and serialize_functions is True,
- the functions' schemas are also included.
- """
- exclude = exclude or {}
- if serialize_functions:
- exclude["evaluate_function_call"] = True
- exclude["response_model"] = True
- response = super().dict(*args, **kwargs, exclude=exclude)
- if response.get("functions") and serialize_functions:
- response.update({"functions": self.functions_schema()})
- return response
-
- def dict(self, *args, **kwargs):
- return self._dict(*args, **kwargs)
diff --git a/src/marvin/utilities/async_utils.py b/src/marvin/utilities/async_utils.py
deleted file mode 100644
index 3bfde4271..000000000
--- a/src/marvin/utilities/async_utils.py
+++ /dev/null
@@ -1,69 +0,0 @@
-import asyncio
-import functools
-from concurrent.futures import ThreadPoolExecutor
-from typing import Awaitable, TypeVar
-
-T = TypeVar("T")
-
-BACKGROUND_TASKS = set()
-
-
-def create_task(coro):
- """
- Creates async background tasks in a way that is safe from garbage
- collection.
-
- See
- https://textual.textualize.io/blog/2023/02/11/the-heisenbug-lurking-in-your-async-code/
-
- Example:
-
- async def my_coro(x: int) -> int:
- return x + 1
-
- # safely submits my_coro for background execution
- create_task(my_coro(1))
- """ # noqa: E501
- task = asyncio.create_task(coro)
- BACKGROUND_TASKS.add(task)
- task.add_done_callback(BACKGROUND_TASKS.discard)
- return task
-
-
-async def run_async(func, *args, **kwargs) -> T:
- """
- Runs a synchronous function in an asynchronous manner.
- """
-
- async def wrapper() -> T:
- try:
- return await loop.run_in_executor(
- None, functools.partial(func, *args, **kwargs)
- )
- except Exception as e:
- # propagate the exception to the caller
- raise e
-
- loop = asyncio.get_event_loop()
- return await wrapper()
-
-
-def run_sync(coroutine: Awaitable[T]) -> T:
- """
- Runs a coroutine from a synchronous context, either in the current event
- loop or in a new one if there is no event loop running. The coroutine will
- block until it is done. A thread will be spawned to run the event loop if
- necessary, which allows coroutines to run in environments like Jupyter
- notebooks where the event loop runs on the main thread.
-
- """
- try:
- loop = asyncio.get_running_loop()
- if loop.is_running():
- with ThreadPoolExecutor() as executor:
- future = executor.submit(asyncio.run, coroutine)
- return future.result()
- else:
- return asyncio.run(coroutine)
- except RuntimeError:
- return asyncio.run(coroutine)
diff --git a/src/marvin/utilities/asyncio.py b/src/marvin/utilities/asyncio.py
new file mode 100644
index 000000000..47d032d40
--- /dev/null
+++ b/src/marvin/utilities/asyncio.py
@@ -0,0 +1,121 @@
+import asyncio
+import functools
+from concurrent.futures import ThreadPoolExecutor
+from typing import Any, Callable, Coroutine, TypeVar, cast
+
+T = TypeVar("T")
+
+
+async def run_async(fn: Callable[..., T], *args: Any, **kwargs: Any) -> T:
+ """
+ Runs a synchronous function in an asynchronous manner.
+
+ Args:
+ fn: The function to run.
+ *args: Positional arguments to pass to the function.
+ **kwargs: Keyword arguments to pass to the function.
+
+ Returns:
+ The return value of the function.
+ """
+
+ async def wrapper() -> T:
+ try:
+ return await loop.run_in_executor(
+ None, functools.partial(fn, *args, **kwargs)
+ )
+ except Exception as e:
+ # propagate the exception to the caller
+ raise e
+
+ loop = asyncio.get_event_loop()
+ return await wrapper()
+
+
+def run_sync(coroutine: Coroutine[Any, Any, T]) -> T:
+ """
+ Runs a coroutine from a synchronous context, either in the current event
+ loop or in a new one if there is no event loop running. The coroutine will
+ block until it is done. A thread will be spawned to run the event loop if
+ necessary, which allows coroutines to run in environments like Jupyter
+ notebooks where the event loop runs on the main thread.
+
+ """
+ try:
+ loop = asyncio.get_running_loop()
+ if loop.is_running():
+ with ThreadPoolExecutor() as executor:
+ future = executor.submit(asyncio.run, coroutine)
+ return future.result()
+ else:
+ return asyncio.run(coroutine)
+ except RuntimeError:
+ return asyncio.run(coroutine)
+
+
+class ExposeSyncMethodsMixin:
+ """
+ A mixin class that can take functions decorated with `expose_sync_method` and
+ automatically create synchronous versions.
+
+
+ Example:
+
+ class MyClass(ExposeSyncMethodsMixin):
+
+ @expose_sync_method("my_method")
+ async def my_method_async(self):
+ return 42
+
+ my_instance = MyClass()
+ await my_instance.my_method_async() # returns 42
+ my_instance.my_method() # returns 42
+ """
+
+ def __init_subclass__(cls, **kwargs: Any) -> None:
+ super().__init_subclass__(**kwargs)
+ for method in list(cls.__dict__.values()):
+ if callable(method) and hasattr(method, "_sync_name"):
+ sync_method_name = method._sync_name
+ setattr(cls, sync_method_name, method._sync_wrapper)
+
+
+def expose_sync_method(name: str) -> Callable[..., Any]:
+ """
+ Decorator that automatically exposes synchronous versions of async methods.
+ Note it doesn't work with classmethods.
+
+ Example:
+
+ class MyClass(ExposeSyncMethodsMixin):
+
+ @expose_sync_method("my_method")
+ async def my_method_async(self):
+ return 42
+
+ my_instance = MyClass()
+ await my_instance.my_method_async() # returns 42
+ my_instance.my_method() # returns 42
+ """
+
+ def decorator(
+ async_method: Callable[..., Coroutine[Any, Any, T]]
+ ) -> Callable[..., T]:
+ @functools.wraps(async_method)
+ def sync_wrapper(*args: Any, **kwargs: Any) -> T:
+ coro = async_method(*args, **kwargs)
+ return run_sync(coro)
+
+ # Cast the sync_wrapper to the same type as the async_method to give the
+ # type checker the needed information.
+ casted_sync_wrapper = cast(Callable[..., T], sync_wrapper)
+
+ # Attach attributes to the async wrapper
+ setattr(async_method, "_sync_wrapper", casted_sync_wrapper)
+ setattr(async_method, "_sync_name", name)
+
+ # return the original async method; the sync wrapper will be added to
+ # the class by the init hook
+ return async_method
+
+ return decorator
diff --git a/src/marvin/utilities/collections.py b/src/marvin/utilities/collections.py
deleted file mode 100644
index b41810aae..000000000
--- a/src/marvin/utilities/collections.py
+++ /dev/null
@@ -1,94 +0,0 @@
-import itertools
-from pathlib import Path
-from typing import Any, Callable, Iterable, Optional, TypeVar
-
-T = TypeVar("T")
-
-
-def batched(
- iterable: Iterable[T], size: int, size_fn: Callable[[Any], int] = None
-) -> Iterable[T]:
- """
- If size_fn is not provided, then the batch size will be determined by the
- number of items in the batch.
-
- If size_fn is provided, then it will be used
- to compute the batch size. Note that if a single item is larger than the
- batch size, it will be returned as a batch of its own.
-
- Args:
- iterable: The iterable to batch.
- size: The size of each batch.
- size_fn: A function that takes an item from the iterable and returns its size.
-
- Returns:
- An iterable of batches.
-
- Example:
- Batch a list of integers into batches of size 2:
- ```python
- batched([1, 2, 3, 4, 5], 2)
- # [[1, 2], [3, 4], [5]]
- ```
- """
- if size_fn is None:
- it = iter(iterable)
- while True:
- batch = tuple(itertools.islice(it, size))
- if not batch:
- break
- yield batch
- else:
- batch = []
- batch_size = 0
- for item in iter(iterable):
- batch.append(item)
- batch_size += size_fn(item)
- if batch_size > size:
- yield batch
- batch = []
- batch_size = 0
- if batch:
- yield batch
-
-
-def multi_glob(
- directory: Optional[str] = None,
- keep_globs: Optional[list[str]] = None,
- drop_globs: Optional[list[str]] = None,
-) -> list[Path]:
- """Return a list of files in a directory that match the given globs.
-
- Args:
- directory: The directory to search. Defaults to the current working directory.
- keep_globs: A list of globs to keep. Defaults to ["**/*"].
- drop_globs: A list of globs to drop. Defaults to [".git/**/*"].
-
- Returns:
- A list of `Path` objects in the directory that match the given globs.
-
- Example:
- Recursively find all Python files in the `src` directory:
- ```python
- all_python_files = multi_glob(directory="src", keep_globs=["**/*.py"])
- ```
- """
- keep_globs = keep_globs or ["**/*"]
- drop_globs = drop_globs or [".git/**/*"]
-
- directory_path = Path(directory) if directory else Path.cwd()
-
- if not directory_path.is_dir():
- raise ValueError(f"'{directory}' is not a directory.")
-
- def files_from_globs(globs):
- return {
- file
- for pattern in globs
- for file in directory_path.glob(pattern)
- if file.is_file()
- }
-
- matching_files = files_from_globs(keep_globs) - files_from_globs(drop_globs)
-
- return [file.relative_to(directory_path) for file in matching_files]
diff --git a/src/marvin/utilities/context.py b/src/marvin/utilities/context.py
new file mode 100644
index 000000000..a3778fbbd
--- /dev/null
+++ b/src/marvin/utilities/context.py
@@ -0,0 +1,26 @@
+import contextvars
+from contextlib import contextmanager
+
+
+class ScopedContext:
+ def __init__(self):
+ self._context_storage = contextvars.ContextVar(
+ "scoped_context_storage", default={}
+ )
+
+ def get(self, key, default=None):
+ return self._context_storage.get().get(key, default)
+
+ def set(self, **kwargs):
+ ctx = self._context_storage.get()
+ updated_ctx = {**ctx, **kwargs}
+ self._context_storage.set(updated_ctx)
+
+ @contextmanager
+ def __call__(self, **kwargs):
+ current_context = self._context_storage.get().copy()
+ self.set(**kwargs)
+ try:
+ yield
+ finally:
+ self._context_storage.set(current_context)
diff --git a/src/marvin/utilities/embeddings.py b/src/marvin/utilities/embeddings.py
deleted file mode 100644
index b88625eef..000000000
--- a/src/marvin/utilities/embeddings.py
+++ /dev/null
@@ -1,26 +0,0 @@
-from typing import List
-
-import openai
-
-import marvin
-
-
-async def create_openai_embeddings(texts: List[str]) -> List[List[float]]:
- """Create OpenAI embeddings for a list of texts."""
-
- try:
- import numpy # noqa F401
- except ImportError:
- raise ImportError(
- "The numpy package is required to create OpenAI embeddings. Please install"
- " it with `pip install numpy` or `pip install 'marvin[slackbot]'`."
- )
-
- embeddings = await openai.Embedding.acreate(
- input=[text.replace("\n", " ") for text in texts],
- engine=marvin.settings.openai.embedding_engine,
- )
-
- return [
- r["embedding"] for r in sorted(embeddings["data"], key=lambda x: x["index"])
- ]
diff --git a/src/marvin/utilities/history.py b/src/marvin/utilities/history.py
deleted file mode 100644
index b35dd399e..000000000
--- a/src/marvin/utilities/history.py
+++ /dev/null
@@ -1,48 +0,0 @@
-import datetime
-from typing import Optional
-
-from pydantic import Field
-
-from marvin._compat import BaseModel
-from marvin.utilities.messages import Message, Role
-
-
-class HistoryFilter(BaseModel):
- role_in: list[Role] = Field(default_factory=list)
- timestamp_ge: Optional[datetime.datetime] = None
- timestamp_le: Optional[datetime.datetime] = None
-
-
-class History(BaseModel, arbitrary_types_allowed=True):
- messages: list[Message] = Field(default_factory=list)
- max_messages: int = None
-
- def add_message(self, message: Message):
- self.messages.append(message)
-
- if self.max_messages is not None:
- self.messages = self.messages[-self.max_messages :]
-
- def get_messages(
- self, n: int = None, skip: int = None, filter: HistoryFilter = None
- ) -> list[Message]:
- messages = self.messages.copy()
-
- if filter is not None:
- if filter.timestamp_ge:
- messages = [m for m in messages if m.timestamp >= filter.timestamp_ge]
- if filter.timestamp_le:
- messages = [m for m in messages if m.timestamp <= filter.timestamp_le]
- if filter.role_in:
- messages = [m for m in messages if m.role in filter.role_in]
-
- if skip:
- messages = messages[:-skip]
-
- if n is not None:
- messages = messages[-n:]
-
- return messages
-
- def clear(self):
- self.messages.clear()
diff --git a/src/marvin/utilities/jinja.py b/src/marvin/utilities/jinja.py
new file mode 100644
index 000000000..0cb54052c
--- /dev/null
+++ b/src/marvin/utilities/jinja.py
@@ -0,0 +1,109 @@
+import inspect
+import re
+from datetime import datetime
+from typing import Any, ClassVar, Pattern, Union
+from zoneinfo import ZoneInfo
+
+import pydantic
+from jinja2 import Environment as JinjaEnvironment
+from jinja2 import StrictUndefined, select_autoescape
+from jinja2 import Template as BaseTemplate
+from typing_extensions import Self
+
+from marvin.requests import BaseMessage as Message
+
+
+class BaseEnvironment(pydantic.BaseModel):
+ model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
+
+ environment: JinjaEnvironment = pydantic.Field(
+ default=JinjaEnvironment(
+ autoescape=select_autoescape(default_for_string=False),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ auto_reload=True,
+ undefined=StrictUndefined,
+ )
+ )
+
+ globals: dict[str, Any] = pydantic.Field(
+ default_factory=lambda: {
+ "now": lambda: datetime.now(ZoneInfo("UTC")),
+ "inspect": inspect,
+ }
+ )
+
+ @pydantic.model_validator(mode="after")
+ def setup_globals(self: Self) -> Self:
+ self.environment.globals.update(self.globals) # type: ignore
+ return self
+
+ def render(self, template: Union[str, BaseTemplate], **kwargs: Any) -> str:
+ if isinstance(template, str):
+ return self.environment.from_string(template).render(**kwargs)
+ return template.render(**kwargs)
+
+
+Environment = BaseEnvironment()
+
+
+def split_text_by_tokens(
+ text: str,
+ split_tokens: list[str],
+ environment: JinjaEnvironment = Environment.environment,
+) -> list[tuple[str, str]]:
+ cleaned_text = inspect.cleandoc(text)
+
+ # Find all positions of tokens in the text
+ positions = [
+ (match.start(), match.end(), match.group().rstrip(":").strip())
+ for token in split_tokens
+ for match in re.finditer(re.escape(token) + r"(?::\s*)?", cleaned_text)
+ ]
+
+ # Sort positions by their start index
+ positions.sort(key=lambda x: x[0])
+
+ paired: list[tuple[str, str]] = []
+ prev_end = 0
+ prev_token = split_tokens[0]
+ for start, end, token in positions:
+ paired.append((prev_token, cleaned_text[prev_end:start].strip()))
+ prev_end = end
+ prev_token = token
+
+ paired.append((prev_token, cleaned_text[prev_end:].strip()))
+
+ # Remove pairs where the text is empty
+ paired = [(token.replace(":", ""), text) for token, text in paired if text]
+
+ return paired
+
+
+class Transcript(pydantic.BaseModel):
+ content: str
+ roles: list[str] = pydantic.Field(default=["system", "user"])
+ environment: ClassVar[BaseEnvironment] = Environment
+
+ @property
+ def role_regex(self) -> Pattern[str]:
+ return re.compile("|".join([f"\n\n{role}:" for role in self.roles]))
+
+ def render(self: Self, **kwargs: Any) -> str:
+ return self.environment.render(self.content, **kwargs)
+
+ def render_to_messages(
+ self: Self,
+ **kwargs: Any,
+ ) -> list[Message]:
+ pairs = split_text_by_tokens(
+ text=self.render(**kwargs),
+ split_tokens=[f"\n{role}" for role in self.roles],
+ )
+ return [
+ Message(
+ role=pair[0].strip(),
+ content=pair[1],
+ )
+ for pair in pairs
+ ]
diff --git a/src/marvin/utilities/logging.py b/src/marvin/utilities/logging.py
index cf46df54f..fef473b5f 100644
--- a/src/marvin/utilities/logging.py
+++ b/src/marvin/utilities/logging.py
@@ -1,14 +1,17 @@
import logging
from functools import lru_cache, partial
+from typing import Optional
-from rich.logging import RichHandler
-from rich.markup import escape
+from rich.logging import RichHandler # type: ignore
+from rich.markup import escape # type: ignore
import marvin
@lru_cache()
-def get_logger(name: str = None) -> logging.Logger:
+def get_logger(
+ name: Optional[str] = None,
+) -> logging.Logger:
parent_logger = logging.getLogger("marvin")
if name:
@@ -25,7 +28,9 @@ def get_logger(name: str = None) -> logging.Logger:
return logger
-def setup_logging(level: str = None):
+def setup_logging(
+ level: Optional[str] = None,
+) -> None:
logger = get_logger()
if level is not None:
@@ -47,8 +52,8 @@ def setup_logging(level: str = None):
logger.propagate = False
-def add_logging_methods(logger):
- def log_style(level: int, message: str, style: str = None):
+def add_logging_methods(logger: logging.Logger) -> None:
+ def log_style(level: int, message: str, style: Optional[str] = None):
if not style:
style = "default on default"
message = f"[{style}]{escape(str(message))}[/]"
@@ -68,17 +73,17 @@ def log_kv(
extra={"markup": True},
)
- logger.debug_style = partial(log_style, logging.DEBUG)
- logger.info_style = partial(log_style, logging.INFO)
- logger.warning_style = partial(log_style, logging.WARNING)
- logger.error_style = partial(log_style, logging.ERROR)
- logger.critical_style = partial(log_style, logging.CRITICAL)
-
- logger.debug_kv = partial(log_kv, logging.DEBUG)
- logger.info_kv = partial(log_kv, logging.INFO)
- logger.warning_kv = partial(log_kv, logging.WARNING)
- logger.error_kv = partial(log_kv, logging.ERROR)
- logger.critical_kv = partial(log_kv, logging.CRITICAL)
+ setattr(logger, "debug_style", partial(log_style, logging.DEBUG))
+ setattr(logger, "info_style", partial(log_style, logging.INFO))
+ setattr(logger, "warning_style", partial(log_style, logging.WARNING))
+ setattr(logger, "error_style", partial(log_style, logging.ERROR))
+ setattr(logger, "critical_style", partial(log_style, logging.CRITICAL))
+
+ setattr(logger, "debug_kv", partial(log_kv, logging.DEBUG))
+ setattr(logger, "info_kv", partial(log_kv, logging.INFO))
+ setattr(logger, "warning_kv", partial(log_kv, logging.WARNING))
+ setattr(logger, "error_kv", partial(log_kv, logging.ERROR))
+ setattr(logger, "critical_kv", partial(log_kv, logging.CRITICAL))
setup_logging(level=marvin.settings.log_level)
diff --git a/src/marvin/utilities/messages.py b/src/marvin/utilities/messages.py
deleted file mode 100644
index ef70824cc..000000000
--- a/src/marvin/utilities/messages.py
+++ /dev/null
@@ -1,82 +0,0 @@
-import inspect
-import uuid
-from datetime import datetime
-from enum import Enum
-from typing import Any, Optional
-from zoneinfo import ZoneInfo
-
-from typing_extensions import Self
-
-from marvin._compat import BaseModel, Field, field_validator
-from marvin.utilities.strings import split_text_by_tokens
-from marvin.utilities.types import MarvinBaseModel
-
-
-class Role(Enum):
- SYSTEM = "system"
- ASSISTANT = "assistant"
- USER = "user"
- FUNCTION_REQUEST = "function"
- FUNCTION_RESPONSE = "function"
-
- @classmethod
- def _missing_(cls: type[Self], value: object) -> Optional[Self]:
- lower_value = str(value).lower()
- matching_member = next(
- (member for member in cls if member.value.lower() == lower_value), None
- )
- return matching_member
-
-
-class FunctionCall(BaseModel):
- name: str
- arguments: str
-
-
-def utcnow():
- return datetime.now(ZoneInfo("UTC"))
-
-
-class Message(MarvinBaseModel):
- role: Role
- content: Optional[str] = Field(default=None, description="The message content")
-
- name: Optional[str] = Field(
- default=None,
- description="The name of the message",
- )
-
- function_call: Optional[FunctionCall] = Field(default=None)
-
- # convenience fields, excluded from serialization
- id: uuid.UUID = Field(default_factory=uuid.uuid4, exclude=True)
- data: Optional[dict[str, Any]] = Field(default_factory=dict, exclude=True)
- timestamp: datetime = Field(
- default_factory=utcnow,
- exclude=True,
- )
-
- @field_validator("content")
- def clean_content(cls, v: Optional[str]) -> Optional[str]:
- if v is not None:
- v = inspect.cleandoc(v)
- return v
-
- class Config(MarvinBaseModel.Config):
- use_enum_values = True
-
- @classmethod
- def from_transcript(
- cls: type[Self],
- text: str,
- ) -> list[Self]:
- pairs = split_text_by_tokens(
- text=text, split_tokens=[role.value.capitalize() for role in Role]
- )
- return [
- cls(
- role=Role(pair[0]),
- content=pair[1],
- )
- for pair in pairs
- ]
diff --git a/src/marvin/utilities/module_loading.py b/src/marvin/utilities/module_loading.py
deleted file mode 100644
index f9a2fad8f..000000000
--- a/src/marvin/utilities/module_loading.py
+++ /dev/null
@@ -1,34 +0,0 @@
-import sys
-from importlib import import_module
-
-
-def cached_import(module_path, class_name):
- """
- This function checks if the module is already imported and if it is,
- it returns the class. Otherwise, it imports the module and returns the class.
- """
-
- # Check if the module is already imported
-
- if not (
- (module := sys.modules.get(module_path))
- and (spec := getattr(module, "__spec__", None))
- and getattr(spec, "_initializing", False) is False
- ):
- module = import_module(module_path)
- return getattr(module, class_name)
-
-
-def import_string(path: str):
- """
- Import a dotted module path and return the attribute/class designated by the
- last name in the path. Raise ImportError if the import failed.
- """
- try:
- path_, class_ = path.rsplit(".", 1)
- except ValueError as err:
- raise ImportError("%s doesn't look like a module path" % path) from err
- try:
- return cached_import(path_, class_)
- except AttributeError as err:
- raise ImportError(f"Module '{path_}' isn't a '{class_}' attr/class") from err
diff --git a/src/marvin/utilities/openai.py b/src/marvin/utilities/openai.py
new file mode 100644
index 000000000..20b7faa1c
--- /dev/null
+++ b/src/marvin/utilities/openai.py
@@ -0,0 +1,38 @@
+import asyncio
+from functools import lru_cache
+from typing import Optional
+
+from openai import AsyncClient
+
+
+def get_client() -> AsyncClient:
+ from marvin import settings
+
+ api_key: Optional[str] = (
+ settings.openai.api_key.get_secret_value() if settings.openai.api_key else None
+ )
+ organization: Optional[str] = settings.openai.organization
+ return _get_client_memoized(
+ api_key=api_key, organization=organization, loop=asyncio.get_event_loop()
+ )
+
+
+@lru_cache
+def _get_client_memoized(
+ api_key: Optional[str],
+ organization: Optional[str],
+ loop: Optional[asyncio.AbstractEventLoop] = None,
+) -> AsyncClient:
+ """
+ This function is memoized to ensure that only one instance of the client is
+ created for a given api key / organization / loop tuple.
+
+ The `loop` is an important key to ensure that the client is not re-used
+ across multiple event loops (which can happen when using the `run_sync`
+ function). Attempting to re-use the client across multiple event loops
+ can result in a `RuntimeError: Event loop is closed` error or infinite hangs.
+ """
+ return AsyncClient(
+ api_key=api_key,
+ organization=organization,
+ )
diff --git a/src/marvin/_compat.py b/src/marvin/utilities/pydantic.py
similarity index 51%
rename from src/marvin/_compat.py
rename to src/marvin/utilities/pydantic.py
index bd5669b37..0c95d3975 100644
--- a/src/marvin/_compat.py
+++ b/src/marvin/utilities/pydantic.py
@@ -1,95 +1,9 @@
from types import FunctionType, GenericAlias
-from typing import (
- Annotated,
- Any,
- Callable,
- Optional,
- TypeVar,
- Union,
- cast,
- get_origin,
-)
-
-from pydantic.version import VERSION as PYDANTIC_VERSION
-
-PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
-
-if PYDANTIC_V2:
- from pydantic.v1 import (
- BaseSettings,
- PrivateAttr,
- SecretStr,
- validate_arguments,
- )
-
- SettingsConfigDict = BaseSettings.Config
-
- from pydantic import (
- BaseModel,
- Field,
- create_model,
- field_validator,
- )
-
-else:
- from pydantic import ( # noqa # type: ignore
- BaseSettings,
- BaseModel,
- create_model,
- Field,
- SecretStr,
- validate_arguments,
- validator as field_validator,
- PrivateAttr,
- )
-
- SettingsConfigDict = BaseSettings.Config
-
-_ModelT = TypeVar("_ModelT", bound=BaseModel)
-
-
-def model_dump(model: _ModelT, **kwargs: Any) -> dict[str, Any]:
- if PYDANTIC_V2 and hasattr(model, "model_dump"):
- return model.model_dump(**kwargs) # type: ignore
- return model.dict(**kwargs) # type: ignore
-
-
-def model_dump_json(model: type[BaseModel], **kwargs: Any) -> dict[str, Any]:
- if PYDANTIC_V2 and hasattr(model, "model_dump_json"):
- return model.model_dump_json(**kwargs) # type: ignore
- return model.json(**kwargs) # type: ignore
-
-
-def model_json_schema(
- model: type[BaseModel],
- name: Optional[str] = None,
- description: Optional[str] = None,
-) -> dict[str, Any]:
- # Get the schema from the model.
- schema = {"parameters": {**model_schema(model)}}
-
- # Mutate the schema to match the OpenAPI spec.
- schema["parameters"]["title"] = name or schema["parameters"].pop("title", None)
- schema["parameters"]["description"] = description or schema["parameters"].pop(
- "description", ""
- ) # noqa
-
- # Move the properties to the root of the schema.
- schema["name"] = schema["parameters"].pop("title")
- schema["description"] = schema["parameters"].pop("description")
- return schema
-
+from typing import Annotated, Any, Callable, Optional, Union, cast, get_origin
-def model_schema(model: type[BaseModel], **kwargs: Any) -> dict[str, Any]:
- if PYDANTIC_V2 and hasattr(model, "model_json_schema"):
- return model.model_json_schema(**kwargs) # type: ignore
- return model.schema(**kwargs) # type: ignore
-
-
-def model_copy(model: _ModelT, **kwargs: Any) -> _ModelT:
- if PYDANTIC_V2 and hasattr(model, "model_copy"):
- return model.model_copy(**kwargs) # type: ignore
- return model.copy(**kwargs) # type: ignore
+from pydantic import BaseModel, TypeAdapter, create_model
+from pydantic.v1 import validate_arguments
+from typing_extensions import Literal
def cast_callable_to_model(
@@ -97,14 +11,14 @@ def cast_callable_to_model(
name: Optional[str] = None,
description: Optional[str] = None,
) -> type[BaseModel]:
- response = validate_arguments(function).model # type: ignore
+ response = validate_arguments(function).model
for field in ["args", "kwargs", "v__duplicate_kwargs"]:
- fields = cast(dict[str, Any], response.__fields__) # type: ignore
+ fields = cast(dict[str, Any], response.__fields__)
fields.pop(field, None)
response.__title__ = name or function.__name__
response.__name__ = name or function.__name__
response.__doc__ = description or function.__doc__
- return response # type: ignore
+ return response
def cast_type_or_alias_to_model(
@@ -137,34 +51,38 @@ def cast_to_model(
response = BaseModel
if origin is Annotated:
- metadata: Any = next(iter(function_or_type.__metadata__), None) # type: ignore
+ metadata: Any = next(iter(function_or_type.__metadata__), None)
annotated_field_name: Optional[str] = field_name
if hasattr(metadata, "extra") and isinstance(metadata.extra, dict):
- annotated_field_name: Optional[str] = metadata.extra.get("name", "") # type: ignore # noqa
+ annotated_field_name: Optional[str] = metadata.extra.get("name", "") # noqa
elif hasattr(metadata, "json_schema_extra") and isinstance(
metadata.json_schema_extra, dict
): # noqa
- annotated_field_name: Optional[str] = metadata.json_schema_extra.get("name", "") # type: ignore # noqa
+ annotated_field_name: Optional[str] = metadata.json_schema_extra.get(
+ "name", ""
+ ) # noqa
elif isinstance(metadata, dict):
- annotated_field_name: Optional[str] = metadata.get("name", "") # type: ignore # noqa
+ annotated_field_name: Optional[str] = metadata.get("name", "") # noqa
elif isinstance(metadata, str):
annotated_field_name: Optional[str] = metadata
else:
pass
annotated_field_description: Optional[str] = description or ""
if hasattr(metadata, "description") and isinstance(metadata.description, str):
- annotated_field_description: Optional[str] = metadata.description # type: ignore # noqa
+ annotated_field_description: Optional[str] = metadata.description # noqa
elif isinstance(metadata, dict):
- annotated_field_description: Optional[str] = metadata.get("description", "") # type: ignore # noqa
+ annotated_field_description: Optional[str] = metadata.get(
+ "description", ""
+ ) # noqa
else:
pass
response = cast_to_model(
- function_or_type.__origin__, # type: ignore
+ function_or_type.__origin__,
name=name,
description=annotated_field_description,
- field_name=annotated_field_name, # type: ignore
+ field_name=annotated_field_name,
)
response.__doc__ = annotated_field_description or ""
elif origin in {dict, list, tuple, set, frozenset}:
@@ -193,12 +111,17 @@ def cast_to_model(
return response
-def cast_to_json(
- function_or_type: Union[type, type[BaseModel], GenericAlias, Callable[..., Any]],
- name: Optional[str] = None,
- description: Optional[str] = None,
- field_name: Optional[str] = None,
-) -> dict[str, Any]:
- return model_json_schema(
- cast_to_model(function_or_type, name, description, field_name)
- )
+def parse_as(
+ type_: Any,
+ data: Any,
+ mode: Literal["python", "json", "strings"] = "python",
+) -> BaseModel:
+ """Parse a json string to a Pydantic model."""
+ adapter = TypeAdapter(type_)
+
+ if get_origin(type_) is list and isinstance(data, dict):
+ data = next(iter(data.values()))
+
+ parser = getattr(adapter, f"validate_{mode}")
+
+ return parser(data)
diff --git a/src/marvin/utilities/slack.py b/src/marvin/utilities/slack.py
new file mode 100644
index 000000000..b7a331c10
--- /dev/null
+++ b/src/marvin/utilities/slack.py
@@ -0,0 +1,286 @@
+import os
+import re
+from typing import List, Optional, Union
+
+import httpx
+from pydantic import BaseModel, field_validator
+
+import marvin
+
+
+class EventBlockElement(BaseModel):
+ type: str
+ text: Optional[str] = None
+ user_id: Optional[str] = None
+
+
+class EventBlockElementGroup(BaseModel):
+ type: str
+ elements: List[EventBlockElement]
+
+
+class EventBlock(BaseModel):
+ type: str
+ block_id: str
+ elements: List[Union[EventBlockElement, EventBlockElementGroup]]
+
+
+class SlackEvent(BaseModel):
+ client_msg_id: Optional[str] = None
+ type: str
+ text: str
+ user: str
+ ts: str
+ team: str
+ channel: str
+ event_ts: str
+ thread_ts: Optional[str] = None
+ parent_user_id: Optional[str] = None
+ blocks: Optional[List[EventBlock]] = None
+
+
+class EventAuthorization(BaseModel):
+ enterprise_id: Optional[str] = None
+ team_id: str
+ user_id: str
+ is_bot: bool
+ is_enterprise_install: bool
+
+
+class SlackPayload(BaseModel):
+ token: str
+ type: str
+ team_id: Optional[str] = None
+ api_app_id: Optional[str] = None
+ event: Optional[SlackEvent] = None
+ event_id: Optional[str] = None
+ event_time: Optional[int] = None
+ authorizations: Optional[List[EventAuthorization]] = None
+ is_ext_shared_channel: Optional[bool] = None
+ event_context: Optional[str] = None
+ challenge: Optional[str] = None
+
+ @field_validator("event")
+ def validate_event(cls, v: Optional[SlackEvent]) -> Optional[SlackEvent]:
+ if v.type != "url_verification" and v is None:
+ raise ValueError("event is required")
+ return v
+
+
+async def get_token() -> str:
+ """Get the Slack bot token from the environment."""
+ try:
+ token = marvin.settings.slack_api_token
+ except AttributeError:
+ token = os.getenv("MARVIN_SLACK_API_TOKEN")
+ if not token:
+ raise ValueError(
+ "`MARVIN_SLACK_API_TOKEN` not found in environment."
+ " Please set it in `~/.marvin/.env` or as an environment variable."
+ )
+ return token
+
+
+def convert_md_links_to_slack(text) -> str:
+ md_link_pattern = r"\[(?P[^\]]+)]\((?P[^\)]+)\)"
+
+ # converting Markdown links to Slack-style links
+ def to_slack_link(match):
+ return f'<{match.group("url")}|{match.group("text")}>'
+
+ # Replace Markdown links with Slack-style links
+ slack_text = re.sub(md_link_pattern, to_slack_link, text)
+
+ return slack_text
+
+
+async def post_slack_message(
+ message: str,
+ channel_id: str,
+ thread_ts: Union[str, None] = None,
+ auth_token: Union[str, None] = None,
+) -> httpx.Response:
+ if not auth_token:
+ auth_token = await get_token()
+
+ async with httpx.AsyncClient() as client:
+ response = await client.post(
+ "https://slack.com/api/chat.postMessage",
+ headers={"Authorization": f"Bearer {auth_token}"},
+ json={
+ "channel": channel_id,
+ "text": convert_md_links_to_slack(message),
+ "thread_ts": thread_ts,
+ },
+ )
+
+ response.raise_for_status()
+ return response
+
+
+async def get_thread_messages(channel: str, thread_ts: str) -> list:
+ """Get all messages from a slack thread."""
+ async with httpx.AsyncClient() as client:
+ response = await client.get(
+ "https://slack.com/api/conversations.replies",
+ headers={"Authorization": f"Bearer {await get_token()}"},
+ params={"channel": channel, "ts": thread_ts},
+ )
+ response.raise_for_status()
+ return response.json().get("messages", [])
+
+
+async def get_user_name(user_id: str) -> str:
+ async with httpx.AsyncClient() as client:
+ response = await client.get(
+ "https://slack.com/api/users.info",
+ params={"user": user_id},
+ headers={"Authorization": f"Bearer {await get_token()}"}, # noqa: E501
+ )
+ return (
+ response.json().get("user", {}).get("name", user_id)
+ if response.status_code == 200
+ else user_id
+ )
+
+
+async def get_channel_name(channel_id: str) -> str:
+ async with httpx.AsyncClient() as client:
+ response = await client.get(
+ "https://slack.com/api/conversations.info",
+ params={"channel": channel_id},
+ headers={"Authorization": f"Bearer {await get_token()}"}, # noqa: E501
+ )
+ return (
+ response.json().get("channel", {}).get("name", channel_id)
+ if response.status_code == 200
+ else channel_id
+ )
+
+
+async def fetch_current_message_text(channel: str, ts: str) -> str:
+ """Fetch the current text of a specific Slack message using its timestamp."""
+ async with httpx.AsyncClient() as client:
+ response = await client.get(
+ "https://slack.com/api/conversations.replies",
+ params={"channel": channel, "ts": ts},
+ headers={"Authorization": f"Bearer {await get_token()}"}, # noqa: E501
+ )
+ response.raise_for_status()
+ messages = response.json().get("messages", [])
+ if not messages:
+ raise ValueError("Message not found")
+
+ return messages[0]["text"]
+
+
+async def edit_slack_message(
+ new_text: str,
+ channel_id: str,
+ thread_ts: str,
+ mode: str = "append",
+ delimiter: Union[str, None] = None,
+) -> httpx.Response:
+ """Edit an existing Slack message by appending new text or replacing it.
+
+ Args:
+ channel (str): The Slack channel ID.
+ ts (str): The timestamp of the message to edit.
+ new_text (str): The new text to append or replace in the message.
+ mode (str): The mode of text editing, 'append' (default) or 'replace'.
+
+ Returns:
+ httpx.Response: The response from the Slack API.
+ """
+ if mode == "append":
+ current_text = await fetch_current_message_text(channel_id, thread_ts)
+ delimiter = "\n\n" if delimiter is None else delimiter
+ updated_text = f"{current_text}{delimiter}{convert_md_links_to_slack(new_text)}"
+ elif mode == "replace":
+ updated_text = convert_md_links_to_slack(new_text)
+ else:
+ raise ValueError("Invalid mode. Use 'append' or 'replace'.")
+
+ async with httpx.AsyncClient() as client:
+ response = await client.post(
+ "https://slack.com/api/chat.update",
+ headers={"Authorization": f"Bearer {await get_token()}"},
+ json={"channel": channel_id, "ts": thread_ts, "text": updated_text},
+ )
+
+ response.raise_for_status()
+ return response
+
+
+async def search_slack_messages(
+ query: str,
+ max_messages: int = 3,
+ channel: Union[str, None] = None,
+ user_auth_token: Union[str, None] = None,
+) -> list:
+ """
+ Search for messages in Slack workspace based on a query.
+
+ Args:
+ query (str): The search query.
+ max_messages (int): The maximum number of messages to retrieve.
+ channel (str, optional): The specific channel to search in. Defaults to None,
+ which searches all channels.
+
+ Returns:
+ list: A list of message contents and permalinks matching the query.
+ """
+ all_messages = []
+ next_cursor = None
+
+ if not user_auth_token:
+ user_auth_token = await get_token()
+
+ async with httpx.AsyncClient() as client:
+ while len(all_messages) < max_messages:
+ params = {
+ "query": query,
+ "limit": min(max_messages - len(all_messages), 10),
+ }
+ if channel:
+ params["channel"] = channel
+ if next_cursor:
+ params["cursor"] = next_cursor
+
+ response = await client.get(
+ "https://slack.com/api/search.messages",
+ headers={"Authorization": f"Bearer {user_auth_token}"},
+ params=params,
+ )
+
+ response.raise_for_status()
+ data = response.json().get("messages", {}).get("matches", [])
+ for message in data:
+ all_messages.append(
+ {
+ "content": message.get("text", ""),
+ "permalink": message.get("permalink", ""),
+ }
+ )
+
+ next_cursor = (
+ response.json().get("response_metadata", {}).get("next_cursor")
+ )
+
+ if not next_cursor:
+ break
+
+ return all_messages[:max_messages]
+
+
+async def get_workspace_info(slack_bot_token: Union[str, None] = None) -> dict:
+ if not slack_bot_token:
+ slack_bot_token = await get_token()
+
+ async with httpx.AsyncClient() as client:
+ response = await client.get(
+ "https://slack.com/api/team.info",
+ headers={"Authorization": f"Bearer {slack_bot_token}"},
+ )
+ response.raise_for_status()
+ return response.json().get("team", {})
diff --git a/src/marvin/utilities/streaming.py b/src/marvin/utilities/streaming.py
deleted file mode 100644
index 15302a024..000000000
--- a/src/marvin/utilities/streaming.py
+++ /dev/null
@@ -1,13 +0,0 @@
-import abc
-from typing import Callable, Optional
-
-from marvin.utilities.messages import Message
-from marvin.utilities.types import MarvinBaseModel
-
-
-class StreamHandler(MarvinBaseModel, abc.ABC):
- callback: Optional[Callable] = None
-
- @abc.abstractmethod
- def handle_streaming_response(self, api_response) -> Message:
- raise NotImplementedError()
diff --git a/src/marvin/utilities/strings.py b/src/marvin/utilities/strings.py
index e05ee52d2..46a140f07 100644
--- a/src/marvin/utilities/strings.py
+++ b/src/marvin/utilities/strings.py
@@ -1,66 +1,13 @@
-import inspect
-import re
-from datetime import datetime
-from zoneinfo import ZoneInfo
-
import tiktoken
-from jinja2 import (
- ChoiceLoader,
- Environment,
- StrictUndefined,
- pass_context,
- select_autoescape,
-)
-from markupsafe import Markup
-
-import marvin.utilities.async_utils
-
-NEWLINES_REGEX = re.compile(r"(\s*\n\s*)")
-MD_LINK_REGEX = r"\[(?P[^\]]+)]\((?P[^\)]+)\)"
-
-jinja_env = Environment(
- loader=ChoiceLoader(
- [
- # PackageLoader("marvin", "prompts")
- ]
- ),
- autoescape=select_autoescape(default_for_string=False),
- trim_blocks=True,
- lstrip_blocks=True,
- auto_reload=True,
- undefined=StrictUndefined,
-)
-
-jinja_env.globals.update(
- zip=zip,
- arun=marvin.utilities.async_utils.run_sync,
- now=lambda: datetime.now(ZoneInfo("UTC")),
-)
-
-
-@pass_context
-def render_filter(context, value):
- """
- Allows nested rendering of variables that may contain variables themselves
- e.g. {{ description | render }}
- """
- _template = context.eval_ctx.environment.from_string(value)
- result = _template.render(**context)
- if context.eval_ctx.autoescape:
- result = Markup(result)
- return result
-
-
-jinja_env.filters["render"] = render_filter
-def tokenize(text: str) -> list[int]:
- tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
+def tokenize(text: str, model: str = "gpt-3.5-turbo-1106") -> list[int]:
+ tokenizer = tiktoken.encoding_for_model(model)
return tokenizer.encode(text)
-def detokenize(tokens: list[int]) -> str:
- tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
+def detokenize(tokens: list[int], model: str = "gpt-3.5-turbo-1106") -> str:
+ tokenizer = tiktoken.encoding_for_model(model)
return tokenizer.decode(tokens)
@@ -71,78 +18,3 @@ def count_tokens(text: str) -> int:
def slice_tokens(text: str, n_tokens: int) -> str:
tokens = tokenize(text)
return detokenize(tokens[:n_tokens])
-
-
-def split_tokens(text: str, n_tokens: int) -> list[str]:
- tokens = tokenize(text)
- return [
- detokenize(tokens[i : i + n_tokens]) for i in range(0, len(tokens), n_tokens)
- ]
-
-
-def condense_newlines(text: str) -> str:
- def replace_whitespace(match):
- newlines_count = match.group().count("\n")
- if newlines_count <= 1:
- return " "
- else:
- return "\n" * newlines_count
-
- text = inspect.cleandoc(text)
- text = NEWLINES_REGEX.sub(replace_whitespace, text)
- return text.strip()
-
-
-def html_to_content(html: str) -> str:
- from bs4 import BeautifulSoup
-
- soup = BeautifulSoup(html, "html.parser")
-
- # Remove script and style elements
- for script in soup(["script", "style"]):
- script.extract()
-
- # Get text
- text = soup.get_text()
-
- # Condense newlines
- return condense_newlines(text)
-
-
-def convert_md_links_to_slack(text: str) -> str:
- # Convert Markdown links to Slack-style links
- md_link_regex = re.compile(r"\[(?P[^\]]+)\]\((?P[^\)]+)\)")
- text = md_link_regex.sub(r"<\g|\g>", text)
-
- text = re.sub(r"\*\*(.+?)\*\*", r"*\1*", text)
-
- return text
-
-
-def split_text_by_tokens(text: str, split_tokens: list[str]) -> list[tuple[str, str]]:
- cleaned_text = inspect.cleandoc(text)
-
- # Find all positions of tokens in the text
- positions = [
- (match.start(), match.end(), match.group().rstrip(":").strip())
- for token in split_tokens
- for match in re.finditer(re.escape(token) + r"(?::\s*)?", cleaned_text)
- ]
-
- # Sort positions by their start index
- positions.sort(key=lambda x: x[0])
-
- paired: list[tuple[str, str]] = []
- prev_end = 0
- prev_token = split_tokens[0]
- for start, end, token in positions:
- paired.append((prev_token, cleaned_text[prev_end:start].strip()))
- prev_end = end
- prev_token = token
-
- paired.append((prev_token, cleaned_text[prev_end:].strip()))
-
- # Remove pairs where the text is empty
- paired = [(token.replace(":", ""), text) for token, text in paired if text]
-
- return paired
diff --git a/src/marvin/utilities/tools.py b/src/marvin/utilities/tools.py
new file mode 100644
index 000000000..0f051f778
--- /dev/null
+++ b/src/marvin/utilities/tools.py
@@ -0,0 +1,54 @@
+import inspect
+import json
+
+from marvin.requests import Tool
+from marvin.utilities.asyncio import run_sync
+from marvin.utilities.logging import get_logger
+from marvin.utilities.pydantic import cast_callable_to_model
+
+logger = get_logger("Tools")
+
+
+def tool_from_function(fn: callable, name: str = None, description: str = None):
+ model = cast_callable_to_model(fn)
+ return Tool(
+ type="function",
+ function=dict(
+ name=name or fn.__name__,
+ description=description or fn.__doc__,
+ # use deprecated schema because this is based on a pydantic v1
+ # validate_arguments
+ parameters=model.schema(),
+ python_fn=fn,
+ ),
+ )
+
+
+def call_function_tool(
+ tools: list[Tool], function_name: str, function_arguments_json: str
+):
+ tool = next(
+ (
+ tool
+ for tool in tools
+ if isinstance(tool, Tool)
+ and tool.function
+ and tool.function.name == function_name
+ ),
+ None,
+ )
+ if not tool:
+ raise ValueError(f"Could not find function '{function_name}'")
+
+ arguments = json.loads(function_arguments_json)
+ logger.debug(f"Calling {tool.function.name} with arguments: {arguments}")
+ output = tool.function.python_fn(**arguments)
+ if inspect.isawaitable(output):
+ output = run_sync(output)
+ truncated_output = str(output)[:100]
+ if len(truncated_output) < len(str(output)):
+ truncated_output += "..."
+ logger.debug(f"{tool.function.name} returned: {truncated_output}")
+ if not isinstance(output, str):
+ output = json.dumps(output)
+ return output
diff --git a/src/marvin/utilities/types.py b/src/marvin/utilities/types.py
deleted file mode 100644
index dbd7f1456..000000000
--- a/src/marvin/utilities/types.py
+++ /dev/null
@@ -1,128 +0,0 @@
-import inspect
-import logging
-from types import GenericAlias
-from typing import Any, Callable, _SpecialForm
-
-from marvin._compat import BaseModel, PrivateAttr, create_model
-from marvin.utilities.logging import get_logger
-
-
-class MarvinBaseModel(BaseModel):
- class Config:
- extra = "forbid"
-
-
-class LoggerMixin(BaseModel):
- """
- BaseModel mixin that adds a private `logger` attribute
- """
-
- _logger: logging.Logger = PrivateAttr()
-
- def __init__(self, **data):
- super().__init__(**data)
- self._logger = get_logger(type(self).__name__)
-
- @property
- def logger(self):
- return self._logger
-
-
-def function_to_model(
- function: Callable[..., Any], name: str = None, description: str = None
-) -> dict:
- """
- Converts a function's arguments into an OpenAPI schema by parsing it into a
- Pydantic model. To work, all arguments must have valid type annotations.
- """
- signature = inspect.signature(function)
-
- fields = {
- p: (
- signature.parameters[p].annotation,
- (
- signature.parameters[p].default
- if signature.parameters[p].default != inspect._empty
- else ...
- ),
- )
- for p in signature.parameters
- if p != getattr(function, "__self__", None)
- }
-
- # Create Pydantic model
- try:
- Model = create_model(name or function.__name__, **fields)
- except RuntimeError as exc:
- if "see `arbitrary_types_allowed` " in str(exc):
- raise ValueError(
- f"Error while inspecting {function.__name__} with signature"
- f" {signature}: {exc}"
- )
- else:
- raise
-
- return Model
-
-
-def function_to_schema(function: Callable[..., Any], name: str = None) -> dict:
- """
- Converts a function's arguments into an OpenAPI schema by parsing it into a
- Pydantic model. To work, all arguments must have valid type annotations.
- """
- Model = function_to_model(function, name=name)
-
- return Model.schema()
-
-
-def safe_issubclass(type_, classes):
- if isinstance(type_, type) and not isinstance(type_, GenericAlias):
- return issubclass(type_, classes)
- else:
- return False
-
-
-def type_to_schema(type_, set_root_type: bool = True) -> dict:
- if safe_issubclass(type_, BaseModel):
- schema = type_.schema()
- # if the docstring was updated at runtime, make it the description
- if type_.__doc__ and type_.__doc__ != schema.get("description"):
- schema["description"] = type_.__doc__
- return schema
-
- elif set_root_type:
-
- class Model(BaseModel):
- __root__: type_
-
- return Model.schema()
- else:
-
- class Model(BaseModel):
- data: type_
-
- return Model.schema()
-
-
-def genericalias_contains(genericalias, target_type):
- """
- Explore whether a type or generic alias contains a target type. The target
- types can be a single type or a tuple of types.
-
- Useful for seeing if a type contains a pydantic model, for example.
- """
- if isinstance(target_type, tuple):
- return any(genericalias_contains(genericalias, t) for t in target_type)
-
- if isinstance(genericalias, GenericAlias):
- if safe_issubclass(genericalias.__origin__, target_type):
- return True
- for arg in genericalias.__args__:
- if genericalias_contains(arg, target_type):
- return True
- elif isinstance(genericalias, _SpecialForm):
- return False
- else:
- return safe_issubclass(genericalias, target_type)
-
- return False
diff --git a/src/marvin/_framework/app/__init__.py b/tests/beta/__init__.py
similarity index 100%
rename from src/marvin/_framework/app/__init__.py
rename to tests/beta/__init__.py
diff --git a/src/marvin/_framework/config/__init__.py b/tests/beta/assistants/__init__.py
similarity index 100%
rename from src/marvin/_framework/config/__init__.py
rename to tests/beta/assistants/__init__.py
diff --git a/src/marvin/_framework/static/__init__.py b/tests/beta/assistants/test_assistants.py
similarity index 100%
rename from src/marvin/_framework/static/__init__.py
rename to tests/beta/assistants/test_assistants.py
diff --git a/src/marvin/cli/admin/scripts/__init__.py b/tests/cli/__init__.py
similarity index 100%
rename from src/marvin/cli/admin/scripts/__init__.py
rename to tests/cli/__init__.py
diff --git a/tests/cli/test_version.py b/tests/cli/test_version.py
new file mode 100644
index 000000000..61769f429
--- /dev/null
+++ b/tests/cli/test_version.py
@@ -0,0 +1,14 @@
+from typer.testing import CliRunner
+
+
+def test_marvin_version_command():
+ """Test the marvin version command."""
+ from marvin.cli import app
+
+ runner = CliRunner()
+ result = runner.invoke(app, ["version"])
+
+ assert result.exit_code == 0
+ assert "Version:" in result.output
+ assert "Python version:" in result.output
+ assert "OS/Arch:" in result.output
diff --git a/src/marvin/components/library/__init__.py b/tests/components/__init__.py
similarity index 100%
rename from src/marvin/components/library/__init__.py
rename to tests/components/__init__.py
diff --git a/tests/components/test_ai_classifier.py b/tests/components/test_ai_classifier.py
new file mode 100644
index 000000000..034f53bf0
--- /dev/null
+++ b/tests/components/test_ai_classifier.py
@@ -0,0 +1,47 @@
+from enum import Enum
+
+from marvin import ai_classifier
+from typing_extensions import Literal
+
+from tests.utils import pytest_mark_class
+
+Sentiment = Literal["Positive", "Negative"]
+
+
+class GitHubIssueTag(Enum):
+ BUG = "bug"
+ FEATURE = "feature"
+ ENHANCEMENT = "enhancement"
+ DOCS = "docs"
+
+
+@pytest_mark_class("llm")
+class TestAIClassifer:
+ class TestLiteral:
+ def test_ai_classifier_literal_return_type(self):
+ @ai_classifier
+ def sentiment(text: str) -> Sentiment:
+ """Classify sentiment"""
+
+ result = sentiment("Great!")
+
+ assert result == "Positive"
+
+ def test_ai_classifier_literal_return_type_with_docstring(self):
+ @ai_classifier
+ def sentiment(text: str) -> Sentiment:
+ """Classify sentiment - also its opposite day"""
+
+ result = sentiment("Great!")
+
+ assert result == "Negative"
+
+ class TestEnum:
+ def test_ai_classifier_enum_return_type(self):
+ @ai_classifier
+ def labeler(text: str) -> GitHubIssueTag:
+ """Classify GitHub issue tags"""
+
+ result = labeler("improve the docs you slugs")
+
+ assert result == GitHubIssueTag.DOCS
diff --git a/tests/components/test_ai_functions.py b/tests/components/test_ai_functions.py
new file mode 100644
index 000000000..070a0b97d
--- /dev/null
+++ b/tests/components/test_ai_functions.py
@@ -0,0 +1,205 @@
+import inspect
+from typing import Dict, List
+
+import marvin
+import pytest
+from marvin import ai_fn
+from pydantic import BaseModel
+
+from tests.utils import pytest_mark_class
+
+
+@ai_fn
+def list_fruit(n: int = 2) -> list[str]:
+ """Returns a list of `n` fruit"""
+
+
+@ai_fn
+def list_fruit_color(n: int, color: str = None) -> list[str]:
+ """Returns a list of `n` fruit that all have the provided `color`"""
+
+
+@pytest_mark_class("llm")
+class TestAIFunctions:
+ class TestBasics:
+ def test_list_fruit(self):
+ result = list_fruit()
+ assert len(result) == 2
+
+ def test_list_fruit_argument(self):
+ result = list_fruit(5)
+ assert len(result) == 5
+
+ async def test_list_fruit_async(self):
+ @ai_fn
+ async def list_fruit(n: int) -> list[str]:
+ """Returns a list of `n` fruit"""
+
+ coro = list_fruit(3)
+ assert inspect.iscoroutine(coro)
+ result = await coro
+ assert len(result) == 3
+
+ class TestAnnotations:
+ def test_list_fruit_with_generic_type_hints(self):
+ @ai_fn
+ def list_fruit(n: int) -> List[str]:
+ """Returns a list of `n` fruit"""
+
+ result = list_fruit(3)
+ assert len(result) == 3
+
+ def test_basemodel_return_annotation(self):
+ class Fruit(BaseModel):
+ name: str
+ color: str
+
+ @ai_fn
+ def get_fruit(description: str) -> Fruit:
+ """Returns a fruit with the provided description"""
+
+ fruit = get_fruit("loved by monkeys")
+ assert fruit.name.lower() == "banana"
+ assert fruit.color.lower() == "yellow"
+
+ @pytest.mark.parametrize("name,expected", [("banana", True), ("car", False)])
+ def test_bool_return_annotation(self, name, expected):
+ @ai_fn
+ def is_fruit(name: str) -> bool:
+ """Returns True if the provided name is a fruit"""
+
+ assert is_fruit(name) == expected
+
+ @pytest.mark.skipif(
+ marvin.settings.openai.chat.completions.model == "gpt-3.5-turbo-1106",
+ reason="3.5 turbo doesn't do well with unknown schemas",
+ )
+ def test_plain_dict_return_type(self):
+ @ai_fn
+ def describe_fruit(description: str) -> dict:
+ """guess the fruit and return the name and color"""
+
+ fruit = describe_fruit("the one thats loved by monkeys")
+ assert fruit["name"].lower() == "banana"
+ assert fruit["color"].lower() == "yellow"
+
+ @pytest.mark.skipif(
+ marvin.settings.openai.chat.completions.model == "gpt-3.5-turbo-1106",
+ reason="3.5 turbo doesn't do well with unknown schemas",
+ )
+ def test_annotated_dict_return_type(self):
+ @ai_fn
+ def describe_fruit(description: str) -> dict[str, str]:
+ """guess the fruit and return the name and color"""
+
+ fruit = describe_fruit("the one thats loved by monkeys")
+ assert fruit["name"].lower() == "banana"
+ assert fruit["color"].lower() == "yellow"
+
+ @pytest.mark.skipif(
+ marvin.settings.openai.chat.completions.model == "gpt-3.5-turbo-1106",
+ reason="3.5 turbo doesn't do well with unknown schemas",
+ )
+ def test_generic_dict_return_type(self):
+ @ai_fn
+ def describe_fruit(description: str) -> Dict[str, str]:
+ """guess the fruit and return the name and color"""
+
+ fruit = describe_fruit("the one thats loved by monkeys")
+ assert fruit["name"].lower() == "banana"
+ assert fruit["color"].lower() == "yellow"
+
+ def test_typed_dict_return_type(self):
+ from typing_extensions import TypedDict
+
+ class Fruit(TypedDict):
+ name: str
+ color: str
+
+ @ai_fn
+ def describe_fruit(description: str) -> Fruit:
+ """guess the fruit and return the name and color"""
+
+ fruit = describe_fruit("the one thats loved by monkeys")
+ assert fruit["name"].lower() == "banana"
+ assert fruit["color"].lower() == "yellow"
+
+ def test_int_return_type(self):
+ @ai_fn
+ def get_fruit(name: str) -> int:
+ """Returns the number of letters in the alluded fruit name"""
+
+ assert get_fruit("banana") == 6
+
+ def test_float_return_type(self):
+ @ai_fn
+ def get_pi(n: int) -> float:
+ """Return the first n digits of pi"""
+
+ assert get_pi(5) == 3.14159
+
+ def test_tuple_return_type(self):
+ @ai_fn
+ def get_fruit(name: str) -> tuple:
+ """Returns a tuple of fruit"""
+
+ assert get_fruit("alphabet of fruit, first 3, singular") == (
+ "apple",
+ "banana",
+ "cherry",
+ )
+
+ def test_set_return_type(self):
+ @ai_fn
+ def get_fruit_letters(name: str) -> set:
+ """Returns the letters in the provided fruit name"""
+
+ assert get_fruit_letters("banana") == {"a", "b", "n"}
+
+ def test_frozenset_return_type(self):
+ @ai_fn
+ def get_fruit_letters(name: str) -> frozenset:
+ """Returns the letters in the provided fruit name"""
+
+ assert get_fruit_letters("orange") == frozenset(
+ {"a", "e", "g", "n", "o", "r"}
+ )
+
+
+@pytest_mark_class("llm")
+class TestAIFunctionsMap:
+ def test_map(self):
+ result = list_fruit.map([2, 3])
+ assert len(result) == 2
+ assert len(result[0]) == 2
+ assert len(result[1]) == 3
+
+ async def test_amap(self):
+ result = await list_fruit.amap([2, 3])
+ assert len(result) == 2
+ assert len(result[0]) == 2
+ assert len(result[1]) == 3
+
+ def test_map_kwargs(self):
+ result = list_fruit.map(n=[2, 3])
+ assert len(result) == 2
+ assert len(result[0]) == 2
+ assert len(result[1]) == 3
+
+ def test_map_kwargs_and_args(self):
+ result = list_fruit_color.map([2, 3], color=["green", "red"])
+ assert len(result) == 2
+ assert len(result[0]) == 2
+ assert len(result[1]) == 3
+
+ def test_invalid_args(self):
+ with pytest.raises(TypeError):
+ list_fruit_color.map(2, color=["orange", "red"])
+
+ def test_invalid_kwargs(self):
+ with pytest.raises(TypeError):
+ list_fruit_color.map([2, 3], color=None)
+
+ async def test_invalid_async_map(self):
+ with pytest.raises(TypeError, match="can't be used in 'await' expression"):
+ await list_fruit_color.map(n=[2], color=["orange", "red"])
diff --git a/tests/test_components/test_ai_model.py b/tests/components/test_ai_model.py
similarity index 80%
rename from tests/test_components/test_ai_model.py
rename to tests/components/test_ai_model.py
index 24d37c06d..0fb9efa38 100644
--- a/tests/test_components/test_ai_model.py
+++ b/tests/components/test_ai_model.py
@@ -2,10 +2,9 @@
import pytest
from marvin import ai_model
-from marvin.utilities.messages import Message, Role
from pydantic import BaseModel, Field
-from tests.utils.mark import pytest_mark_class
+from tests.utils import pytest_mark_class
@pytest_mark_class("llm")
@@ -29,16 +28,16 @@ class Location(BaseModel):
longitude: float
city: str
state: str
- country: str
+ country: str = Field(..., description="The abbreviated country name")
x = Location("The capital city of the Cornhusker State.")
assert x.city == "Lincoln"
assert x.state == "Nebraska"
- assert "United" in x.country
- assert "States" in x.country
+ assert x.country in {"US", "USA", "U.S.", "U.S.A."}
assert x.latitude // 1 == 40
assert x.longitude // 1 == -97
+ @pytest.mark.xfail(reason="TODO: flaky on 3.5")
def test_depth(self):
from typing import List
@@ -61,6 +60,7 @@ class RentalHistory(BaseModel):
I lived in Palms, then Mar Vista, then Pico Robertson.
""")
+ @pytest.mark.flaky(max_runs=3)
def test_resume(self):
class Experience(BaseModel):
technology: str
@@ -86,23 +86,17 @@ class Resume(BaseModel):
assert not x.greater_than_ten_years_management_experience
assert len(x.technologies) == 2
- @pytest.mark.flaky(reruns=2)
def test_literal(self):
- class CertainPerson(BaseModel):
- name: Literal["Adam", "Nate", "Jeremiah"]
-
@ai_model
class LLMConference(BaseModel):
- speakers: List[CertainPerson]
+ speakers: list[Literal["Adam", "Nate", "Jeremiah"]]
x = LLMConference("""
The conference for best LLM framework will feature talks by
Adam, Nate, Jeremiah, Marvin, and Billy Bob Thornton.
""")
- assert len(set([speaker.name for speaker in x.speakers])) == 3
- assert set([speaker.name for speaker in x.speakers]) == set(
- ["Adam", "Nate", "Jeremiah"]
- )
+ assert len(set(x.speakers)) == 3
+ assert set(x.speakers) == set(["Adam", "Nate", "Jeremiah"])
@pytest.mark.xfail(reason="regression in OpenAI function-using models")
def test_history(self):
@@ -143,6 +137,7 @@ class Election(BaseModel):
)
)
+ @pytest.mark.skip(reason="old behavior, may revisit")
def test_correct_class_is_returned(self):
@ai_model
class Fruit(BaseModel):
@@ -153,36 +148,10 @@ class Fruit(BaseModel):
assert isinstance(fruit, Fruit)
- async def test_correct_class_is_returned_via_acall(self):
- @ai_model
- class Fruit(BaseModel):
- color: str
- name: str
-
- fruit = await Fruit.acall("loved by monkeys")
-
- assert isinstance(fruit, Fruit)
-
-
-@pytest_mark_class("llm")
-class TestAIModelsMessage:
- @pytest.mark.skip(reason="old behavior, may revisit")
- def test_arithmetic_message(self):
- @ai_model
- class Arithmetic(BaseModel):
- sum: float = Field(
- ..., description="The resolved sum of provided arguments"
- )
-
- x = Arithmetic("One plus six")
- assert x.sum == 7
- assert isinstance(x._message, Message)
- assert x._message.role == Role.FUNCTION_RESPONSE
-
+@pytest.mark.skip(reason="old behavior, may revisit")
@pytest_mark_class("llm")
class TestInstructions:
- @pytest.mark.skip(reason="old behavior, may revisit")
def test_instructions_error(self):
@ai_model
class Test(BaseModel):
@@ -280,31 +249,30 @@ class Arithmetic(BaseModel):
assert x[0].sum == 7
assert x[1].sum == 101
- def test_location(self):
+ @pytest.mark.flaky(max_runs=3)
+ def test_fix_misspellings(self):
@ai_model
class City(BaseModel):
- name: str = Field(description="The correct city name, e.g. Omaha") # noqa
+ """fix any misspellings of a city attributes"""
+
+ name: str = Field(
+ description=(
+ "The OFFICIAL, correctly-spelled name of a city - must be"
+ " capitalized. Do not include the state or country, or use any"
+ " abbreviations."
+ )
+ )
results = City.map(
[
"the windy city",
"chicago IL",
"Chicago",
- "Chcago",
+ "America's third-largest city",
"chicago, Illinois, USA",
- "chi-town",
+ "colloquially known as 'chi-town'",
]
)
assert len(results) == 6
for result in results:
assert result.name == "Chicago"
-
- def test_instructions(self):
- @ai_model
- class Translate(BaseModel):
- text: str
-
- result = Translate.map(["Hello", "Goodbye"], instructions="Translate to French")
- assert len(result) == 2
- assert result[0].text == "Bonjour"
- assert result[1].text == "Au revoir"
diff --git a/tests/conftest.py b/tests/conftest.py
index 96c498fdf..498585601 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,5 +1,6 @@
import asyncio
import logging
+import os
import sys
import pytest
@@ -40,3 +41,49 @@ def event_loop(request):
if tasks and loop.is_running():
loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
loop.close()
+
+
+class SetEnv:
+ def __init__(self):
+ self.envars = set()
+
+ def set(self, name, value):
+ self.envars.add(name)
+ os.environ[name] = value
+
+ def pop(self, name):
+ self.envars.remove(name)
+ os.environ.pop(name)
+
+ def clear(self):
+ for n in self.envars:
+ os.environ.pop(n)
+
+
+@pytest.fixture
+def env():
+ setenv = SetEnv()
+
+ yield setenv
+
+ setenv.clear()
+
+
+@pytest.fixture
+def docs_test_env():
+ setenv = SetEnv()
+
+ # # envs for basic usage example
+ # setenv.set('my_auth_key', 'xxx')
+ # setenv.set('my_api_key', 'xxx')
+
+ # envs for parsing environment variable values example
+ setenv.set("V0", "0")
+ setenv.set("SUB_MODEL", '{"v1": "json-1", "v2": "json-2"}')
+ setenv.set("SUB_MODEL__V2", "nested-2")
+ setenv.set("SUB_MODEL__V3", "3")
+ setenv.set("SUB_MODEL__DEEP__V4", "v4")
+
+ yield setenv
+
+ setenv.clear()
diff --git a/tests/utils/package_size.sh b/tests/package_size.sh
similarity index 100%
rename from tests/utils/package_size.sh
rename to tests/package_size.sh
diff --git a/tests/test_chat_completion/test_sdk.py b/tests/test_chat_completion/test_sdk.py
deleted file mode 100644
index 0c0fcbabf..000000000
--- a/tests/test_chat_completion/test_sdk.py
+++ /dev/null
@@ -1,103 +0,0 @@
-import pytest
-from marvin.core.ChatCompletion.abstract import Conversation
-
-from tests.utils.mark import pytest_mark_class
-
-
-class TestRegressions:
- def test_key_set_via_attr(self, monkeypatch):
- from marvin import openai
-
- monkeypatch.setattr(openai, "api_key", "test")
- v = openai.ChatCompletion().defaults.get("api_key")
- assert v == "test"
-
- @pytest.mark.parametrize("valid_env_var", ["MARVIN_OPENAI_API_KEY"])
- def test_key_set_via_env(self, monkeypatch, valid_env_var):
- monkeypatch.setenv(valid_env_var, "test")
- from marvin import openai
-
- v = openai.ChatCompletion().defaults.get("api_key")
- assert v == "test"
-
- def facet(self):
- messages = [{"role": "user", "content": "hey"}]
- from marvin import openai
-
- faceted = openai.ChatCompletion(messages=messages)
- faceted_request = faceted.prepare_request(messages=messages)
- assert faceted_request.messages == 2 * messages
-
-
-@pytest_mark_class("llm")
-class TestChatCompletion:
- def test_response_model(self):
- import pydantic
- from marvin import openai
-
- class Person(pydantic.BaseModel):
- name: str
- age: int
-
- response = openai.ChatCompletion().create(
- messages=[{"role": "user", "content": "Billy is 10 years old"}],
- response_model=Person,
- )
-
- model = response.to_model()
- assert model.name == "Billy"
- assert model.age == 10
-
- def test_streaming(self):
- from marvin import openai
-
- streamed_data = []
-
- def handler(message):
- streamed_data.append(message.content)
-
- completion = openai.ChatCompletion(stream_handler=handler).create(
- messages=[{"role": "user", "content": "say exactly 'hello'"}],
- )
-
- assert completion.response.choices[0].message.content == streamed_data[-1]
- assert "hello" in streamed_data[-1].lower()
- assert len(streamed_data) > 1
-
- async def test_streaming_async(self):
- from marvin import openai
-
- streamed_data = []
-
- async def handler(message):
- streamed_data.append(message.content)
-
- completion = await openai.ChatCompletion(stream_handler=handler).acreate(
- messages=[{"role": "user", "content": "say only 'hello'"}],
- )
- assert completion.response.choices[0].message.content == streamed_data[-1]
- assert "hello" in streamed_data[-1].lower()
- assert len(streamed_data) > 1
-
-
-@pytest_mark_class("llm")
-class TestChatCompletionChain:
- def test_chain(self):
- from marvin import openai
-
- convo = openai.ChatCompletion().chain(
- messages=[{"role": "user", "content": "Hello"}],
- )
-
- assert isinstance(convo, Conversation)
- assert len(convo.turns) == 1
-
- async def test_achain(self):
- from marvin import openai
-
- convo = await openai.ChatCompletion().achain(
- messages=[{"role": "user", "content": "Hello"}],
- )
-
- assert isinstance(convo, Conversation)
- assert len(convo.turns) == 1
diff --git a/tests/test_components/__init__.py b/tests/test_components/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/tests/test_components/test_ai_app.py b/tests/test_components/test_ai_app.py
deleted file mode 100644
index e5ae37dab..000000000
--- a/tests/test_components/test_ai_app.py
+++ /dev/null
@@ -1,291 +0,0 @@
-import jsonpatch
-import pytest
-from marvin._compat import model_dump
-from marvin.components.ai_application import (
- AIApplication,
- AppPlan,
- FreeformState,
- TaskState,
- UpdatePlan,
- UpdateState,
-)
-from marvin.tools import Tool
-from marvin.utilities.messages import Message
-
-from tests.utils.mark import pytest_mark_class
-
-
-class GetSchleeb(Tool):
- name: str = "get_schleeb"
-
- async def run(self):
- """Get the value of schleeb"""
- return 42
-
-
-class TestStateJSONPatch:
- def test_update_app_state_valid_patch(self):
- app = AIApplication(
- state=FreeformState(state={"foo": "bar"}), description="test app"
- )
- tool = UpdateState(app=app)
- tool.run([{"op": "replace", "path": "/state/foo", "value": "baz"}])
- assert model_dump(app.state) == {"state": {"foo": "baz"}}
-
- def test_update_app_state_invalid_patch(self):
- app = AIApplication(
- state=FreeformState(state={"foo": "bar"}), description="test app"
- )
- tool = UpdateState(app=app)
- with pytest.raises(jsonpatch.InvalidJsonPatch):
- tool.run([{"op": "invalid_op", "path": "/state/foo", "value": "baz"}])
- assert model_dump(app.state) == {"state": {"foo": "bar"}}
-
- def test_update_app_state_non_existent_path(self):
- app = AIApplication(
- state=FreeformState(state={"foo": "bar"}), description="test app"
- )
- tool = UpdateState(app=app)
- with pytest.raises(jsonpatch.JsonPatchConflict):
- tool.run([{"op": "replace", "path": "/state/baz", "value": "qux"}])
- assert model_dump(app.state) == {"state": {"foo": "bar"}}
-
-
-@pytest_mark_class("llm")
-class TestUpdateState:
- def test_keep_app_state(self):
- app = AIApplication(
- name="location tracker app",
- state=FreeformState(state={"San Francisco": {"visited": False}}),
- plan_enabled=False,
- description="keep track of where I've visited",
- )
-
- app("I just visited to San Francisco")
- assert bool(app.state.state.get("San Francisco", {}).get("visited"))
-
- app("oh also I visited San Jose!")
-
- assert bool(app.state.state.get("San Jose", {}).get("visited"))
-
- @pytest.mark.flaky(max_runs=3)
- def test_keep_app_state_undo_previous_patch(self):
- app = AIApplication(
- name="location tracker app",
- state=FreeformState(state={"San Francisco": {"visited": False}}),
- plan_enabled=False,
- description="keep track of where I've visited",
- )
-
- app("I just visited San Francisco")
- assert bool(app.state.state.get("San Francisco", {}).get("visited"))
-
- app(
- "sorry, scratch that, I did not visit San Francisco - but I did visit San"
- " Jose"
- )
-
- assert not bool(app.state.state.get("San Francisco", {}).get("visited"))
- assert bool(app.state.state.get("San Jose", {}).get("visited"))
-
-
-class TestPlanJSONPatch:
- def test_update_app_plan_valid_patch(self):
- app = AIApplication(
- plan=AppPlan(
- tasks=[{"id": 1, "description": "test task", "state": "IN_PROGRESS"}]
- ),
- description="test app",
- )
- tool = UpdatePlan(app=app)
- tool.run([{"op": "replace", "path": "/tasks/0/state", "value": "COMPLETED"}])
- assert model_dump(app.plan) == {
- "tasks": [
- {
- "id": 1,
- "description": "test task",
- "state": TaskState.COMPLETED,
- "upstream_task_ids": None,
- "parent_task_id": None,
- }
- ],
- "notes": [],
- }
-
- def test_update_app_plan_invalid_patch(self):
- app = AIApplication(
- plan=AppPlan(
- tasks=[{"id": 1, "description": "test task", "state": "IN_PROGRESS"}]
- ),
- description="test app",
- )
- tool = UpdatePlan(app=app)
- with pytest.raises(jsonpatch.JsonPatchException):
- tool.run(
- [{"op": "invalid_op", "path": "/tasks/0/state", "value": "COMPLETED"}]
- )
- assert model_dump(app.plan) == {
- "tasks": [
- {
- "id": 1,
- "description": "test task",
- "state": TaskState.IN_PROGRESS,
- "upstream_task_ids": None,
- "parent_task_id": None,
- }
- ],
- "notes": [],
- }
-
- def test_update_app_plan_non_existent_path(self):
- app = AIApplication(
- plan=AppPlan(
- tasks=[{"id": 1, "description": "test task", "state": "IN_PROGRESS"}]
- ),
- description="test app",
- )
- tool = UpdatePlan(app=app)
- with pytest.raises(jsonpatch.JsonPointerException):
- tool.run(
- [{"op": "replace", "path": "/tasks/1/state", "value": "COMPLETED"}]
- )
- assert model_dump(app.plan) == {
- "tasks": [
- {
- "id": 1,
- "description": "test task",
- "state": TaskState.IN_PROGRESS,
- "upstream_task_ids": None,
- "parent_task_id": None,
- }
- ],
- "notes": [],
- }
-
-
-@pytest_mark_class("llm")
-class TestUpdatePlan:
- @pytest.mark.flaky(max_runs=3)
- def test_keep_app_plan(self):
- app = AIApplication(
- name="Zoo planner app",
- plan=AppPlan(
- tasks=[
- {
- "id": 1,
- "description": "Visit tigers",
- "state": TaskState.IN_PROGRESS,
- },
- {
- "id": 2,
- "description": "Visit giraffes",
- "state": TaskState.PENDING,
- },
- ]
- ),
- state_enabled=False,
- description="plan and track my visit to the zoo",
- )
-
- app(
- "Actually I heard the tigers ate Carol Baskin's husband - I think I'll skip"
- " visiting them."
- )
-
- assert [task["state"] for task in app.plan.dict()["tasks"]] == [
- TaskState.SKIPPED,
- TaskState.PENDING,
- ]
-
- app("Dude i just visited the giraffes!")
-
- assert [task["state"] for task in app.plan.dict()["tasks"]] == [
- TaskState.SKIPPED,
- TaskState.COMPLETED,
- ]
-
-
-@pytest_mark_class("llm")
-class TestUseCallable:
- def test_use_sync_fn(self):
- def get_schleeb():
- return 42
-
- app = AIApplication(
- name="Schleeb app",
- tools=[get_schleeb],
- state_enabled=False,
- plan_enabled=False,
- description="answer user questions",
- )
-
- assert "42" in app("what is the value of schleeb?").content
-
- def test_use_async_fn(self):
- async def get_schleeb():
- return 42
-
- app = AIApplication(
- name="Schleeb app",
- tools=[get_schleeb],
- state_enabled=False,
- plan_enabled=False,
- description="answer user questions",
- )
-
- assert "42" in app("what is the value of schleeb?").content
-
-
-@pytest_mark_class("llm")
-class TestUseTool:
- def test_use_tool(self):
- app = AIApplication(
- name="Schleeb app",
- tools=[GetSchleeb()],
- state_enabled=False,
- plan_enabled=False,
- description="answer user questions",
- )
-
- assert "42" in app("what is the value of schleeb?").content
-
-
-@pytest_mark_class("llm")
-class TestStreaming:
- def test_streaming(self):
- external_state = {"content": []}
-
- app = AIApplication(
- name="streaming app",
- stream_handler=lambda m: external_state["content"].append(m.content),
- state_enabled=False,
- plan_enabled=False,
- )
-
- response = app(
- "say the words 'Hello world' EXACTLY as i have written them."
- " no other characters should be included, do not add any punctuation."
- )
-
- assert isinstance(response, Message)
- assert response.content == "Hello world"
-
- assert external_state["content"] == ["", "Hello", "Hello world", "Hello world"]
-
-
-@pytest_mark_class("llm")
-class TestMemory:
- def test_recall(self):
- app = AIApplication(
- name="memory app",
- state_enabled=False,
- plan_enabled=False,
- )
-
- app("I like pistachio ice cream")
-
- response = app(
- "reply only with the type of ice cream i like, it should be one word"
- )
-
- assert "pistachio" in response.content.lower()
diff --git a/tests/test_components/test_ai_classifier.py b/tests/test_components/test_ai_classifier.py
deleted file mode 100644
index ad0ab89a0..000000000
--- a/tests/test_components/test_ai_classifier.py
+++ /dev/null
@@ -1,109 +0,0 @@
-from enum import Enum
-
-import pytest
-from marvin import ai_classifier
-
-from tests.utils.mark import pytest_mark_class
-
-
-class TestAIClassifiersInitialization:
- def test_model(self):
- @ai_classifier(model="openai/gpt-4-test-model")
- class Sentiment(Enum):
- POSITIVE = "Positive"
- NEGATIVE = "Negative"
-
- assert (
- Sentiment.as_chat_completion("test").defaults.get("model")
- == "gpt-4-test-model"
- )
-
- def test_invalid_model(self):
- @ai_classifier(model="anthropic/claude-2")
- class Sentiment(Enum):
- POSITIVE = "Positive"
- NEGATIVE = "Negative"
-
- assert Sentiment.as_chat_completion("test").defaults.get("model") == "claude-2"
-
-
-@pytest_mark_class("llm")
-class TestAIClassifiers:
- def test_sentiment(self):
- @ai_classifier
- class Sentiment(Enum):
- POSITIVE = "Positive"
- NEGATIVE = "Negative"
-
- assert Sentiment("Great!") == Sentiment.POSITIVE
-
- def test_keys_are_passed_to_llm(self):
- @ai_classifier
- class Sentiment(Enum):
- POSITIVE = "option - 1"
- NEGATIVE = "option - 2"
-
- assert Sentiment("Great!") == Sentiment.POSITIVE
-
- def test_values_are_passed_to_llm(self):
- @ai_classifier
- class Sentiment(Enum):
- OPTION_1 = "Positive"
- OPITION_2 = "Negative"
-
- assert Sentiment("Great!") == Sentiment.OPTION_1
-
- def test_docstring_is_passed_to_llm(self):
- @ai_classifier
- class Sentiment(Enum):
- """It's opposite day"""
-
- POSITIVE = "Positive"
- NEGATIVE = "Negative"
-
- assert Sentiment("Great!") == Sentiment.NEGATIVE
-
- def test_instructions_are_passed_to_llm(self):
- @ai_classifier
- class Sentiment(Enum):
- POSITIVE = "Positive"
- NEGATIVE = "Negative"
-
- assert (
- Sentiment("Great!", instructions="today is opposite day")
- == Sentiment.NEGATIVE
- )
-
- def test_recover_complex_values(self):
- @ai_classifier
- class Sentiment(Enum):
- POSITIVE = {"value": "Positive"}
- NEGATIVE = {"value": "Negative"}
-
- result = Sentiment("Great!")
-
- assert result.value["value"] == "Positive"
-
-
-@pytest_mark_class("llm")
-class TestMapping:
- def test_mapping(self):
- @ai_classifier
- class Sentiment(Enum):
- POSITIVE = "Positive"
- NEGATIVE = "Negative"
-
- result = Sentiment.map(["good", "bad"])
- assert result == [Sentiment.POSITIVE, Sentiment.NEGATIVE]
-
- @pytest.mark.xfail(reason="Flaky with 3.5 turbo")
- def test_mapping_with_instructions(self):
- @ai_classifier
- class Sentiment(Enum):
- POSITIVE = "Positive"
- NEGATIVE = "Negative"
-
- result = Sentiment.map(
- ["good", "bad"], instructions="I want the opposite of the right answer"
- )
- assert result == [Sentiment.NEGATIVE, Sentiment.POSITIVE]
diff --git a/tests/test_components/test_ai_functions.py b/tests/test_components/test_ai_functions.py
deleted file mode 100644
index 5b7860202..000000000
--- a/tests/test_components/test_ai_functions.py
+++ /dev/null
@@ -1,169 +0,0 @@
-import inspect
-from typing import Dict, List
-
-import pytest
-from marvin import ai_fn
-from pydantic import BaseModel
-
-from tests.utils.mark import pytest_mark_class
-
-
-@ai_fn
-def list_fruit(n: int = 2) -> list[str]:
- """Returns a list of `n` fruit"""
-
-
-@ai_fn
-def list_fruit_color(n: int, color: str = None) -> list[str]:
- """Returns a list of `n` fruit that all have the provided `color`"""
-
-
-@pytest_mark_class("llm")
-class TestAIFunctions:
- def test_list_fruit(self):
- result = list_fruit()
- assert len(result) == 2
-
- def test_list_fruit_argument(self):
- result = list_fruit(5)
- assert len(result) == 5
-
- async def test_list_fruit_async(self):
- @ai_fn
- async def list_fruit(n: int) -> list[str]:
- """Returns a list of `n` fruit"""
-
- coro = list_fruit(3)
- assert inspect.iscoroutine(coro)
- result = await coro
- assert len(result) == 3
-
- def test_list_fruit_with_generic_type_hints(self):
- @ai_fn
- def list_fruit(n: int) -> List[str]:
- """Returns a list of `n` fruit"""
-
- result = list_fruit(3)
- assert len(result) == 3
-
- def test_basemodel_return_annotation(self):
- class Fruit(BaseModel):
- name: str
- color: str
-
- @ai_fn
- def get_fruit(description: str) -> Fruit:
- """Returns a fruit with the provided description"""
-
- fruit = get_fruit("loved by monkeys")
- assert fruit.name.lower() == "banana"
- assert fruit.color.lower() == "yellow"
-
- @pytest.mark.parametrize("name,expected", [("banana", True), ("car", False)])
- def test_bool_return_annotation(self, name, expected):
- @ai_fn
- def is_fruit(name: str) -> bool:
- """Returns True if the provided name is a fruit"""
-
- assert is_fruit(name) == expected
-
- def test_plain_dict_return_type(self):
- @ai_fn
- def get_fruit(name: str) -> dict:
- """Returns a fruit with the provided name and color"""
-
- fruit = get_fruit("banana")
- assert fruit["name"].lower() == "banana"
- assert fruit["color"].lower() == "yellow"
-
- def test_annotated_dict_return_type(self):
- @ai_fn
- def get_fruit(name: str) -> dict[str, str]:
- """Returns a fruit with the provided name and color"""
-
- fruit = get_fruit("banana")
- assert fruit["name"].lower() == "banana"
- assert fruit["color"].lower() == "yellow"
-
- def test_generic_dict_return_type(self):
- @ai_fn
- def get_fruit(name: str) -> Dict[str, str]:
- """Returns a fruit with the provided name and color"""
-
- fruit = get_fruit("banana")
- assert fruit["name"].lower() == "banana"
- assert fruit["color"].lower() == "yellow"
-
- def test_int_return_type(self):
- @ai_fn
- def get_fruit(name: str) -> int:
- """Returns the number of letters in the provided fruit name"""
-
- assert get_fruit("banana") == 6
-
- def test_float_return_type(self):
- @ai_fn
- def get_fruit(name: str) -> float:
- """Returns the number of letters in the provided fruit name"""
-
- assert get_fruit("banana") == 6.0
-
- def test_tuple_return_type(self):
- @ai_fn
- def get_fruit(name: str) -> tuple:
- """Returns the number of letters in the provided fruit name"""
-
- assert get_fruit("banana") == (6,)
-
- def test_set_return_type(self):
- @ai_fn
- def get_fruit(name: str) -> set:
- """Returns the letters in the provided fruit name"""
-
- assert get_fruit("banana") == {"a", "b", "n"}
-
- def test_frozenset_return_type(self):
- @ai_fn
- def get_fruit(name: str) -> frozenset:
- """Returns the letters in the provided fruit name"""
-
- assert get_fruit("banana") == frozenset({"a", "b", "n"})
-
-
-@pytest_mark_class("llm")
-class TestAIFunctionsMap:
- def test_map(self):
- result = list_fruit_color.map([2, 3])
- assert len(result) == 2
- assert len(result[0]) == 2
- assert len(result[1]) == 3
-
- async def test_amap(self):
- result = await list_fruit_color.amap([2, 3])
- assert len(result) == 2
- assert len(result[0]) == 2
- assert len(result[1]) == 3
-
- def test_map_kwargs(self):
- result = list_fruit_color.map(n=[2, 3])
- assert len(result) == 2
- assert len(result[0]) == 2
- assert len(result[1]) == 3
-
- def test_map_kwargs_and_args(self):
- result = list_fruit_color.map([2, 3], color=[None, "red"])
- assert len(result) == 2
- assert len(result[0]) == 2
- assert len(result[1]) == 3
-
- def test_invalid_args(self):
- with pytest.raises(TypeError):
- list_fruit_color.map(2, color=["orange", "red"])
-
- def test_invalid_kwargs(self):
- with pytest.raises(TypeError):
- list_fruit_color.map([2, 3], color=None)
-
- async def test_invalid_async_map(self):
- with pytest.raises(TypeError, match="can't be used in 'await' expression"):
- await list_fruit_color.map(n=[2], color=["orange", "red"])
diff --git a/tests/test_models/__init__.py b/tests/test_models/__init__.py
deleted file mode 100644
index cd6cb144a..000000000
--- a/tests/test_models/__init__.py
+++ /dev/null
@@ -1,12 +0,0 @@
-import pytest
-import inspect
-from marvin.functions import Function
-
-
-class TestFunctions:
- def test_signature(self):
- def fn(x, y=10):
- return x + y
-
- f = Function(fn=fn)
- assert f.signature == inspect.signature(fn)
diff --git a/tests/test_models/test_functions.py b/tests/test_models/test_functions.py
deleted file mode 100644
index 741f8ab93..000000000
--- a/tests/test_models/test_functions.py
+++ /dev/null
@@ -1,114 +0,0 @@
-class TestFunctions:
- def test_signature(self):
- import inspect
-
- from marvin.functions import Function
-
- def fn(x: int, y: int = 10) -> int:
- return x + y
-
- f = Function(fn=fn)
- assert f.signature == inspect.signature(fn)
-
- def test_name(self):
- from marvin.functions import Function
-
- def fn(x: int, y: int = 10) -> int:
- return x + y
-
- f = Function(fn=fn)
- assert f.name == fn.__name__
-
- def test_name_custom(self):
- from marvin.functions import Function
-
- def fn(x: int, y: int = 10) -> int:
- return x + y
-
- f = Function(fn=fn, name="custom")
- assert f.name == "custom"
-
- def test_description(self):
- from marvin.functions import Function
-
- def fn(x: int, y: int = 10) -> int:
- """This is a description"""
- return x + y
-
- f = Function(fn=fn)
- assert f.description == fn.__doc__
-
- def test_description_custom(self):
- from marvin.functions import Function
-
- def fn(x: int, y: int = 10) -> int:
- """This is a description"""
- return x + y
-
- f = Function(fn=fn, description="custom")
- assert f.description == "custom"
-
- def test_source_code(self):
- import inspect
-
- from marvin.functions import Function
-
- def fn(x: int, y: int = 10) -> int:
- return x + y
-
- f = Function(fn=fn)
- assert f.source_code == inspect.cleandoc(inspect.getsource(fn))
-
- def test_return_annotation_native(self):
- import inspect
-
- from marvin.functions import Function
-
- def fn(x: int, y: int = 10) -> int:
- return x + y
-
- f = Function(fn=fn)
- assert f.return_annotation == inspect.signature(fn).return_annotation
-
- def test_return_annotation_pydantic(self):
- from marvin.functions import Function
- from pydantic import BaseModel
-
- class Foo(BaseModel):
- foo: int
- bar: int
-
- def fn(x: int, y: int = 10) -> Foo:
- return Foo(foo=x, bar=y)
-
- f = Function(fn=fn)
- assert f.return_annotation == Foo
-
- def test_arguments(self):
- from marvin.functions import Function
-
- def fn(x: int, y: int = 10) -> int:
- return x + y
-
- f = Function(fn=fn)
- assert f.arguments(1) == {"x": 1, "y": 10}
-
- def test_schema(self):
- from marvin.functions import Function
-
- def fn(x: int, y: int = 10) -> int:
- return x + y
-
- f = Function(fn=fn)
- assert f.schema() == {
- "name": "fn",
- "description": fn.__doc__,
- "parameters": {
- "type": "object",
- "properties": {
- "x": {"type": "integer", "title": "X"},
- "y": {"type": "integer", "title": "Y", "default": 10},
- },
- "required": ["x"],
- },
- }
diff --git a/tests/test_settings.py b/tests/test_settings.py
new file mode 100644
index 000000000..d8084aa89
--- /dev/null
+++ b/tests/test_settings.py
@@ -0,0 +1,41 @@
+from marvin.settings import AssistantSettings, Settings, SpeechSettings
+from pydantic_settings import SettingsConfigDict
+
+
+def test_api_key_initialization_from_env(env):
+ test_api_key = "test_api_key_123"
+ env.set("MARVIN_OPENAI_API_KEY", test_api_key)
+
+ temp_model_config = SettingsConfigDict(env_prefix="marvin_")
+ settings = Settings(model_config=temp_model_config)
+
+ assert settings.openai.api_key.get_secret_value() == test_api_key
+
+
+def test_runtime_api_key_override(env):
+ override_api_key = "test_api_key_456"
+ env.set("MARVIN_OPENAI_API_KEY", override_api_key)
+
+ temp_model_config = SettingsConfigDict(env_prefix="marvin_")
+ settings = Settings(model_config=temp_model_config)
+
+ assert settings.openai.api_key.get_secret_value() == override_api_key
+
+ settings.openai.api_key = "test_api_key_789"
+
+ assert settings.openai.api_key.get_secret_value() == "test_api_key_789"
+
+
+class TestSpeechSettings:
+ def test_speech_settings_default(self):
+ settings = SpeechSettings()
+ assert settings.model == "tts-1-hd"
+ assert settings.voice == "alloy"
+ assert settings.response_format == "mp3"
+ assert settings.speed == 1.0
+
+
+class TestAssistantSettings:
+ def test_assistant_settings_default(self):
+ settings = AssistantSettings()
+ assert settings.model == "gpt-4-1106-preview"
diff --git a/tests/utils.py b/tests/utils.py
new file mode 100644
index 000000000..01c9673ad
--- /dev/null
+++ b/tests/utils.py
@@ -0,0 +1,28 @@
+import pytest
+
+
+def pytest_mark_class(*markers: str):
+ """Mark all test methods in a class with the provided markers
+
+ Only the outermost class should be marked, which will mark all nested classes
+ recursively.
+ """
+
+ def mark_test_methods(cls):
+ for attr_name, attr_value in cls.__dict__.items():
+ # mark all test methods with the provided markers
+ if callable(attr_value) and attr_name.startswith("test"):
+ for marker in markers:
+ marked_func = getattr(pytest.mark, marker)(attr_value)
+ setattr(cls, attr_name, marked_func)
+ # recursively mark nested classes
+ elif isinstance(attr_value, type) and attr_value.__name__.startswith(
+ "Test"
+ ):
+ mark_test_methods(attr_value)
+
+ def decorator(cls):
+ mark_test_methods(cls)
+ return cls
+
+ return decorator
diff --git a/tests/utils/mark.py b/tests/utils/mark.py
deleted file mode 100644
index 5043a7b35..000000000
--- a/tests/utils/mark.py
+++ /dev/null
@@ -1,11 +0,0 @@
-import pytest
-
-
-def pytest_mark_class(marker):
- def decorator(cls):
- for attr_name, attr_value in cls.__dict__.items():
- if callable(attr_value) and attr_name.startswith("test"):
- setattr(cls, attr_name, pytest.mark.llm(attr_value))
- return cls
-
- return decorator