Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature add initial support for images #150

Merged
merged 4 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions council/llm/anthropic_messages_llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Iterable, List, Literal, Sequence
from typing import Any, Iterable, List, Literal, Sequence

from anthropic import Anthropic
from anthropic._types import NOT_GIVEN
Expand Down Expand Up @@ -39,16 +39,25 @@ def post_chat_request(self, messages: Sequence[LLMMessage]) -> List[str]:
@staticmethod
def _to_anthropic_messages(messages: Sequence[LLMMessage]) -> Iterable[MessageParam]:
result: List[MessageParam] = []
temp_content = ""
temp_content: List[Any] = []
role: Literal["user", "assistant"] = "user"

for message in messages:
if message.is_of_role(LLMMessageRole.System):
temp_content += message.content
temp_content.append({"type": "text", "text": message.content})
else:
temp_content += message.content
temp_content.append({"type": "text", "text": message.content})
for data in message.data:
if data.is_image:
temp_content.append(
{
"type": "image",
"source": {"type": "base64", "media_type": data.mime_type, "data": data.content},
}
)

result.append(MessageParam(role=role, content=temp_content))
temp_content = ""
temp_content = []
role = "assistant" if role == "user" else "user"

if temp_content:
Expand Down
28 changes: 24 additions & 4 deletions council/llm/gemini_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
LLMProviders,
LLMResult,
)
from google.ai.generativelanguage import FileData
from google.ai.generativelanguage_v1 import HarmCategory # type: ignore
from google.generativeai.types import HarmBlockThreshold # type: ignore

Expand Down Expand Up @@ -71,11 +72,30 @@ def _to_chat_history(messages: Sequence[LLMMessage]) -> Tuple[List[Any], Any]:
history = []
for message in messages[:-1]:
if message.is_of_role(LLMMessageRole.System):
history.append({"role": "user", "parts": [{"text": f"System Prompt: {message.content}"}]})
history.append({"role": "user", "parts": GeminiLLM._get_parts(message)})
history.append({"role": "model", "parts": [{"text": "Understood"}]})
elif message.is_of_role(LLMMessageRole.User):
history.append({"role": "user", "parts": [{"text": message.content}]})
history.append({"role": "user", "parts": GeminiLLM._get_parts(message)})
elif message.is_of_role(LLMMessageRole.Assistant):
history.append({"role": "model", "parts": [{"text": message.content}]})
last = messages[-1].content
return history, last

last_msg = messages[-1]
return history, {"role": "user", "parts": GeminiLLM._get_parts(last_msg)}

@staticmethod
def _get_parts(message: LLMMessage) -> List[Any]:
parts: List[Any] = []
if message.is_of_role(LLMMessageRole.System):
parts.append({"text": f"System Prompt: {message.content}"})
elif message.is_of_role(LLMMessageRole.User):
parts.append({"text": message.content})
elif message.is_of_role(LLMMessageRole.Assistant):
parts.append({"text": message.content})

for data in message.data:
if data.is_url:
fd = FileData({"mime_type": data.mime_type, "file_uri": data.content})
parts.append({"file_data": fd})
else:
parts.append({"inline_data": {"mime_type": data.mime_type, "data": data.content}})
return parts
82 changes: 77 additions & 5 deletions council/llm/llm_message.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import abc
import base64
import mimetypes
from enum import Enum
from typing import Iterable, List, Optional, Sequence

Expand Down Expand Up @@ -28,6 +30,56 @@ class LLMMessageRole(str, Enum):
"""


class LLMMessageData:
"""
Represents the data of a message.
"""

def __init__(self, content: str, mime_type: str) -> None:
self._content = content
self._mime_type = mime_type

@property
def content(self) -> str:
return self._content

@property
def mime_type(self) -> str:
result = self._mime_type.split(":")[-1]
return result

@property
def is_image(self) -> bool:
return self._mime_type.startswith("image/")

@property
def is_url(self) -> bool:
return self._mime_type.startswith("text/url")

def __str__(self):
return f"content length={len(self.content)}, mime_type={self.mime_type})"

@classmethod
def from_file(cls, path: str) -> LLMMessageData:
"""
Add data from file to the message.
"""
mime_type, _ = mimetypes.guess_type(path)
if mime_type is None:
mime_type = "image/unknown"

with open(path, "rb") as f:
return cls(content=base64.b64encode(f.read()).decode("utf-8"), mime_type=mime_type)

@classmethod
def from_uri(cls, uri: str) -> LLMMessageData:
"""
Add an uri to the message.
"""
mime_type, _ = mimetypes.guess_type(uri)
return cls(content=uri, mime_type=f"text/url:{mime_type}")


class LLMMessage:
"""
Represents chat messages. Used in the payload
Expand All @@ -45,6 +97,7 @@ def __init__(self, role: LLMMessageRole, content: str, name: Optional[str] = Non
self._role = role
self._content = content
self._name = name
self._data: List[LLMMessageData] = []

@staticmethod
def system_message(content: str, name: Optional[str] = None) -> LLMMessage:
Expand Down Expand Up @@ -79,11 +132,25 @@ def assistant_message(content: str, name: Optional[str] = None) -> LLMMessage:
"""
return LLMMessage(role=LLMMessageRole.Assistant, content=content, name=name)

def dict(self) -> dict[str, str]:
result = {"role": self._role.value, "content": self._content}
if self._name is not None:
result["name"] = self._name
return result
@property
def data(self) -> Sequence[LLMMessageData]:
"""
Get the list of data associated with this message
"""
return self._data

def add_content(self, *, path: Optional[str] = None, url: Optional[str] = None) -> None:
"""
Add an image to the message.
"""
data: Optional[LLMMessageData] = None
if path is not None:
data = LLMMessageData.from_file(path=path)
elif url is not None:
data = LLMMessageData.from_uri(uri=url)

if data is not None:
self._data.append(data)

@property
def content(self) -> str:
Expand All @@ -100,6 +167,11 @@ def role(self) -> LLMMessageRole:
"""Retrieve the role of this instance"""
return self._role

@property
def has_data(self) -> bool:
"""Check if this message has data associated with it"""
return bool(self._data)

def is_of_role(self, role: LLMMessageRole) -> bool:
"""Check the role of this instance"""
return self._role == role
Expand Down
31 changes: 27 additions & 4 deletions council/llm/openai_chat_completions_llm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

from typing import Any, List, Optional, Protocol, Sequence
from typing import Any, Dict, List, Optional, Protocol, Sequence

import httpx
from council.contexts import Consumption, LLMContext

from ..utils import truncate_dict_values_to_str
from . import ChatGPTConfigurationBase
from .llm_base import LLMBase, LLMResult
from .llm_exception import LLMCallException
Expand Down Expand Up @@ -140,12 +141,14 @@ def __init__(
self._provider = provider

def _post_chat_request(self, context: LLMContext, messages: Sequence[LLMMessage], **kwargs: Any) -> LLMResult:
payload = self._configuration.build_default_payload()
payload["messages"] = [message.dict() for message in messages]

payload = self._build_payload(messages)
for key, value in kwargs.items():
payload[key] = value

context.logger.debug(f'message="Sending chat GPT completions request to {self._name}" payload="{payload}"')
context.logger.debug(
f'message="Sending chat GPT completions request to {self._name}" payload="{truncate_dict_values_to_str(payload, 100)}"'
)
r = self._post_request(payload)
context.logger.debug(
f'message="Got chat GPT completions result from {self._name}" id="{r.id}" model="{r.model}" {r.usage}'
Expand All @@ -158,3 +161,23 @@ def _post_request(self, payload) -> OpenAIChatCompletionsResult:
raise LLMCallException(response.status_code, response.text, self._name)

return OpenAIChatCompletionsResult.from_dict(response.json())

def _build_payload(self, messages: Sequence[LLMMessage]):
payload = self._configuration.build_default_payload()
msgs = []
for message in messages:
content: List[Dict[str, Any]] = [{"type": "text", "text": message.content}]
result: Dict[str, Any] = {"role": message.role.value}
if message.name is not None:
result["name"] = message.name
for data in message.data:
if data.is_image:
content.append(
{"type": "image_url", "image_url": {"url": f"data:{data.mime_type};base64,{data.content}"}}
)
elif data.is_url:
content.append({"type": "image_url", "image_url": {"url": f"{data.content}"}})
result["content"] = content
msgs.append(result)
payload["messages"] = msgs
return payload
1 change: 1 addition & 0 deletions council/prompt/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .prompt_builder import PromptBuilder
from .llm_prompt_config_object import LLMPromptConfigObject, LLMPromptConfigSpec
1 change: 1 addition & 0 deletions council/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
from .data_object import DataObject, DataObjectSpecBase
from .code_parser import CodeParser
from .env import OsEnviron
from .utils import truncate_dict_values_to_str
34 changes: 34 additions & 0 deletions council/utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Dict


def truncate_dict_values_to_str(data: Dict, max_length: int = 20):
"""
Truncates dictionary values that are longer than max_length and returns a string representation.
The truncated value shows both the start and end of the original value. Handles nested dictionaries recursively.

Parameters:
data (dict): The dictionary with values to be truncated.
max_length (int): The maximum length of each value before truncation. Default is 20.

Returns:
str: A string representation of the dictionary with truncated values.
"""

def truncate_value(value):
if isinstance(value, dict):
return truncate_dict_values_to_str(value, max_length)
elif isinstance(value, list):
return [truncate_value(item) for item in value]
elif isinstance(value, str) and len(value) > max_length:
half_length = (max_length - 3) // 2
return value[:half_length] + "..." + value[-half_length:]
else:
return value

truncated_items = []

for key, value in data.items():
truncated_value = truncate_value(value)
truncated_items.append(f"{key}: {truncated_value}")

return "{ " + ", ".join(truncated_items) + " }"
6 changes: 6 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
"""Init file."""

import os


def get_data_filename(filename: str):
return os.path.join(os.path.dirname(__file__), "data", filename)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
26 changes: 22 additions & 4 deletions tests/integration/llm/test_anthropic_llm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import unittest

import dotenv

from council import LLMContext
from council.llm import LLMMessage, AnthropicLLM
from council.utils import OsEnviron

from tests import get_data_filename


class TestAnthropicLLM(unittest.TestCase):
def test_completion(self):
Expand All @@ -15,7 +17,7 @@ def test_completion(self):
context = LLMContext.empty()
result = instance.post_chat_request(context, messages)

assert "Paris" in result.choices[0]
assert "Paris" in result.first_choice

def test_message(self):
messages = [LLMMessage.user_message("what is the capital of France?")]
Expand All @@ -25,11 +27,27 @@ def test_message(self):
context = LLMContext.empty()
result = instance.post_chat_request(context, messages)

assert "Paris" in result.choices[0]
assert "Paris" in result.first_choice

with OsEnviron("ANTHROPIC_LLM_MODEL", "claude-3-haiku-20240307"):
instance = AnthropicLLM.from_env()
context = LLMContext.empty()
result = instance.post_chat_request(context, messages)

assert "Paris" in result.choices[0]
assert "Paris" in result.first_choice

with OsEnviron("ANTHROPIC_LLM_MODEL", "claude-3-5-sonnet-20240620"):
instance = AnthropicLLM.from_env()
context = LLMContext.empty()
result = instance.post_chat_request(context, messages)

assert "Paris" in result.first_choice

def test_with_png_image(self):
dotenv.load_dotenv()
with OsEnviron("ANTHROPIC_LLM_MODEL", "claude-3-5-sonnet-20240620"):
instance = AnthropicLLM.from_env()
message = LLMMessage.user_message("What is in the image?")
message.add_content(path=get_data_filename("Gfp-wisconsin-madison-the-nature-boardwalk.png"))
result = instance.post_chat_request(LLMContext.empty(), [message])
print(result.first_choice)
Loading
Loading