Skip to content

Commit

Permalink
feat(llamabot)🚀: Enhance llamabot with new message handling and API f…
Browse files Browse the repository at this point in the history
…unctionalities

- Add new message creation functions and bot classes to handle different types of interactions.
- Implement high-level API functions for user and system message creation.
- Introduce processing of multiple message types including text, images, and URLs.
- Update SimpleBot and StructuredBot to handle multiple messages and improve interaction flow.
- Add new notebook examples to demonstrate usage of vision models and structured data extraction.
  • Loading branch information
ericmjl committed Dec 6, 2024
1 parent ee7ef35 commit 820cd8f
Show file tree
Hide file tree
Showing 7 changed files with 446 additions and 18 deletions.
125 changes: 125 additions & 0 deletions llamabot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,22 @@
from llamabot import some_function
Use it to control the top-level API of your Python data science project.
The module provides several high-level functions and classes for working with LLMs:
- Message creation functions: `user()` and `system()`
- Bot classes: SimpleBot, StructuredBot, ChatBot, ImageBot, QueryBot
- Prompt management: `prompt` decorator
- Experimentation: `Experiment` and `metric`
- Recording: `PromptRecorder`
"""

import os
from pathlib import Path
from typing import Union
import mimetypes

import httpx

from loguru import logger

Expand All @@ -20,6 +32,7 @@
from .experiments import Experiment, metric
from .prompt_manager import prompt
from .recorder import PromptRecorder
from .components.messages import HumanMessage, ImageMessage, SystemMessage

# Configure logger
log_level = os.getenv("LOG_LEVEL", "WARNING").upper()
Expand Down Expand Up @@ -49,3 +62,115 @@

# Ensure ~/.llamabot directory exists
(Path.home() / ".llamabot").mkdir(parents=True, exist_ok=True)


# High-level API
def user(
*content: Union[str, Path]
) -> Union[HumanMessage, ImageMessage, list[Union[HumanMessage, ImageMessage]]]:
"""Create one or more user messages from the given content.
This function provides a flexible way to create user messages from various types of content:
- Plain text strings become HumanMessages
- Image file paths become ImageMessages
- URLs to images become ImageMessages
- Text file paths become HumanMessages with the file contents
- Multiple inputs return a list of messages
Examples:
>>> user("Hello, world!") # Simple text message
HumanMessage(content="Hello, world!")
>>> user("image.png") # Local image file
ImageMessage(content="<base64-encoded-content>")
>>> user("https://example.com/image.jpg") # Image URL
ImageMessage(content="<base64-encoded-content>")
>>> user("text.txt") # Text file
HumanMessage(content="<file-contents>")
>>> user("msg1", "msg2") # Multiple messages
[HumanMessage(content="msg1"), HumanMessage(content="msg2")]
:param content: One or more pieces of content to convert into messages.
Can be strings (text/URLs) or Paths to files.
:return: Either a single message or list of messages depending on input type
:raises FileNotFoundError: If a specified file path doesn't exist
:raises ValueError: If an image file is invalid
:raises httpx.HTTPError: If an image URL can't be accessed
"""

def _handle_path(path: Path) -> Union[HumanMessage, ImageMessage]:
"""Handle Path objects by checking if they are images or text files."""
if not path.exists():
raise FileNotFoundError(f"File not found: {path}")

mime_type, _ = mimetypes.guess_type(str(path))
if mime_type and mime_type.startswith("image/"):
return ImageMessage(content=path)
return HumanMessage(content=path.read_text())

def _handle_url(url: str) -> Union[HumanMessage, ImageMessage]:
"""Handle URL strings by attempting to load as image, falling back to text."""
try:
return ImageMessage(content=url)
except (httpx.HTTPError, ValueError):
return HumanMessage(content=url)

def _handle_single_content(
item: Union[str, Path]
) -> Union[HumanMessage, ImageMessage]:
"""Handle a single content item and convert it to an appropriate message type.
This helper function processes a single piece of content and determines whether it should
be treated as a Path, URL, or plain text content.
:param item: The content item to process, either a string or Path object
:return: Either a HumanMessage or ImageMessage depending on the content type
:raises FileNotFoundError: If a specified file path doesn't exist
:raises ValueError: If an image file is invalid
:raises httpx.HTTPError: If an image URL can't be accessed
"""
# Handle Path objects directly
if isinstance(item, Path):
return _handle_path(item)

# Handle string content
if isinstance(item, str):
# Check if string is a URL
if item.startswith(("http://", "https://")):
return _handle_url(item)

# Check if string is a path that exists
path = Path(item)
if path.exists():
return _handle_path(path)

return HumanMessage(content=item)

# Handle single input
if len(content) == 1:
return _handle_single_content(content[0])

# Handle multiple inputs
return [_handle_single_content(item) for item in content]


def system(content: str) -> SystemMessage:
"""Create a system message for instructing the LLM.
System messages are used to set the behavior, role, or context for the LLM.
They act as high-level instructions that guide the model's responses.
Examples:
>>> system("You are a helpful assistant.")
SystemMessage(content="You are a helpful assistant.")
>>> system("Respond in the style of Shakespeare.")
SystemMessage(content="Respond in the style of Shakespeare.")
:param content: The instruction or context to give to the LLM
:return: A SystemMessage containing the provided content
"""
return SystemMessage(content=content)
13 changes: 6 additions & 7 deletions llamabot/bot/simplebot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
HumanMessage,
SystemMessage,
BaseMessage,
process_messages,
)
from llamabot.recorder import autorecord, sqlite_log
from llamabot.config import default_language_model
Expand Down Expand Up @@ -73,17 +74,15 @@ def __init__(
self.stream_target = "none"

def __call__(
self, human_message: Union[str, BaseMessage]
self, *human_messages: Union[str, BaseMessage, list[Union[str, BaseMessage]]]
) -> Union[AIMessage, Generator]:
"""Call the SimpleBot.
:param human_message: The human message to use.
:return: The response to the human message, primed by the system prompt.
:param human_messages: One or more human messages to use, or lists of messages.
:return: The response to the human messages, primed by the system prompt.
"""
if isinstance(human_message, str):
human_message = HumanMessage(content=human_message)

messages = [self.system_prompt, human_message]
processed_messages = process_messages(human_messages)
messages = [self.system_prompt] + processed_messages
match self.stream_target:
case "stdout":
return self.stream_stdout(messages)
Expand Down
26 changes: 15 additions & 11 deletions llamabot/bot/structuredbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
AIMessage,
BaseMessage,
HumanMessage,
process_messages,
)
from pydantic import BaseModel, ValidationError
from llamabot.config import default_language_model
Expand Down Expand Up @@ -92,31 +93,34 @@ def _extract_json_from_response(self, response: AIMessage):

def __call__(
self,
message: Union[str, BaseMessage],
*messages: Union[str, BaseMessage, list[Union[str, BaseMessage]]],
num_attempts: int = 10,
verbose: bool = False,
) -> BaseModel | None:
"""Process the input message and return an instance of the Pydantic model.
"""Process the input messages and return an instance of the Pydantic model.
:param message: The text on which to parse to generate the structured response.
:param messages: One or more messages to process. Can be strings or BaseMessage objects.
:param num_attempts: Number of attempts to try getting a valid response.
:param verbose: Whether to show verbose output.
:return: An instance of the specified Pydantic model.
"""
if isinstance(message, str):
message = HumanMessage(content=message)
processed_messages = process_messages(messages)

messages = [
# Compose the full message list
full_messages = [
self.system_prompt,
self.task_message(),
message,
*processed_messages,
]

# we'll attempt to get the response from the model and validate it
for attempt in range(num_attempts):
try:
match self.stream_target:
case "stdout":
response = self.stream_stdout(messages)
response = self.stream_stdout(full_messages)
case "none":
response = self.stream_none(messages)
response = self.stream_none(full_messages)

# parse the response, and validate it against the pydantic model
codeblock = self._extract_json_from_response(response)
Expand All @@ -125,10 +129,10 @@ def __call__(

except ValidationError as e:
# we're on our last try, so we raise the error
if attempt == num_attempts:
if attempt == num_attempts - 1:
raise e

# Otherwise, if we failed, give the LLM the validation error and try again.
if verbose:
logger.info(e)
messages.extend([response, self.get_validation_error_message(e)])
full_messages.extend([response, self.get_validation_error_message(e)])
95 changes: 95 additions & 0 deletions llamabot/components/messages.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
"""Definitions for the different types of messages that can be sent."""

import base64
import mimetypes
from pathlib import Path
from typing import Union

import httpx
from pydantic import BaseModel, Field


Expand Down Expand Up @@ -73,6 +79,61 @@ class RetrievedMessage(BaseMessage):
role: str = "system"


class ImageMessage(BaseMessage):
"""A message containing an image.
:param content: Path to image file or URL of image
:param role: Role of the message sender, defaults to "user"
"""

content: str
role: str = "user"

def __init__(self, content: Union[str, Path], role: str = "user"):
if isinstance(content, Path):
path = content
elif content.startswith(("http://", "https://")):
# Download image from URL to temporary bytes

response = httpx.get(content)
image_bytes = response.content
mime_type = response.headers["content-type"]
encoded = base64.b64encode(image_bytes).decode("utf-8")
super().__init__(content=encoded, role=role)
self._mime_type = mime_type
return
else:
path = Path(content)

# Handle local file
if not path.exists():
raise FileNotFoundError(f"Image file not found: {path}")

mime_type = mimetypes.guess_type(path)[0]
if not mime_type or not mime_type.startswith("image/"):
raise ValueError(f"Not a valid image file: {path}")

with open(path, "rb") as image_file:
encoded = base64.b64encode(image_file.read()).decode("utf-8")

super().__init__(content=encoded, role=role)
self._mime_type = mime_type

def model_dump(self):
"""Convert message to format expected by LiteLLM and OpenAI."""
return {
"role": self.role,
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:{self._mime_type};base64,{self.content}"
},
}
],
}


def retrieve_messages_up_to_budget(
messages: list[BaseMessage], character_budget: int
) -> list[BaseMessage]:
Expand All @@ -96,3 +157,37 @@ def retrieve_messages_up_to_budget(
break
retrieved_messages.append(message)
return retrieved_messages


def process_messages(
messages: tuple[Union[str, BaseMessage, list[Union[str, BaseMessage]], ...]]
) -> list[BaseMessage]:
"""Process a tuple of messages into a list of BaseMessage objects.
Handles nested lists and converts strings to HumanMessages.
:param messages: Tuple of messages to process
:return: List of BaseMessage objects
"""
processed_messages = []

def process_message(msg: Union[str, BaseMessage, list]) -> None:
"""Process a single message or list of messages into BaseMessage objects.
Recursively processes nested lists and converts strings to HumanMessages.
Appends processed messages to the outer scope processed_messages list.
:param msg: Message to process - can be a string, BaseMessage, or list of messages
"""
if isinstance(msg, list):
for m in msg:
process_message(m)
elif isinstance(msg, str):
processed_messages.append(HumanMessage(content=msg))
else:
processed_messages.append(msg)

for msg in messages:
process_message(msg)

return processed_messages
Loading

0 comments on commit 820cd8f

Please sign in to comment.