Skip to content

Commit

Permalink
Manage the 2 APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
aflament committed Mar 14, 2024
1 parent 1811a09 commit f71a8b5
Show file tree
Hide file tree
Showing 7 changed files with 169 additions and 73 deletions.
12 changes: 12 additions & 0 deletions council/llm/anthropic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import abc
from abc import ABC
from typing import Sequence, List

from council.llm import LLMMessage


class AnthropicAPIClientWrapper(ABC):

@abc.abstractmethod
def post_chat_request(self, messages: Sequence[LLMMessage]) -> List[str]:
pass
58 changes: 58 additions & 0 deletions council/llm/anthropic_completion_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Sequence, List

from anthropic import Anthropic
from anthropic._types import NOT_GIVEN

from council.llm import AnthropicLLMConfiguration, LLMMessage, LLMMessageRole
from council.llm.anthropic import AnthropicAPIClientWrapper

_HUMAN_TURN = Anthropic.HUMAN_PROMPT
_ASSISTANT_TURN = Anthropic.AI_PROMPT


class AnthropicCompletionLLM(AnthropicAPIClientWrapper):
"""
Implementation for an Anthropic LLM with LEGACY completion.
Notes:
More details: https://docs.anthropic.com/claude/docs
and https://docs.anthropic.com/claude/reference/complete_post
"""

def __init__(self, config: AnthropicLLMConfiguration, client: Anthropic) -> None:
self._config = config
self._client = client

def post_chat_request(self, messages: Sequence[LLMMessage]) -> List[str]:
prompt = self._to_anthropic_messages(messages)
result = self._client.completions.create(
prompt=prompt,
model=self._config.model.unwrap(),
max_tokens_to_sample=self._config.max_tokens.unwrap(),
timeout=self._config.timeout.value,
temperature=self._config.temperature.unwrap_or(NOT_GIVEN),
top_k=self._config.top_k.unwrap_or(NOT_GIVEN),
top_p=self._config.top_p.unwrap_or(NOT_GIVEN),
)
return [result.completion]

@staticmethod
def _to_anthropic_messages(messages: Sequence[LLMMessage]) -> str:
messages_count = len(messages)
if messages_count == 0:
raise Exception("No message to process.")

result = []
if messages[0].is_of_role(LLMMessageRole.System) and messages_count > 1:
result.append(f"{_HUMAN_TURN} {messages[0].content}\n{messages[1].content}")
remaining = messages[2:]
else:
result.append(f"{_HUMAN_TURN} {messages[0].content}")
remaining = messages[1:]

for item in remaining:
prefix = _HUMAN_TURN if item.is_of_role(LLMMessageRole.User) else _ASSISTANT_TURN
result.append(f"{prefix} {item.content}")
result.append(_ASSISTANT_TURN)

return "".join(result)
80 changes: 13 additions & 67 deletions council/llm/anthropic_llm.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from __future__ import annotations

from typing import Any, Sequence, Optional, List, Dict
from typing import Any, Sequence, Optional, List

from anthropic import Anthropic, APITimeoutError, APIStatusError
from anthropic._types import NOT_GIVEN

from council.contexts import LLMContext, Consumption
from council.llm import (
LLMBase,
LLMMessage,
LLMMessageRole,
LLMResult,
LLMCallTimeoutException,
LLMCallException,
Expand All @@ -18,9 +16,10 @@
LLMConfigObject,
LLMProviders,
)
from .anthropic import AnthropicAPIClientWrapper

_HUMAN_TURN = Anthropic.HUMAN_PROMPT
_ASSISTANT_TURN = Anthropic.AI_PROMPT
from .anthropic_completion_llm import AnthropicCompletionLLM
from .anthropic_messages_llm import AnthropicMessagesLLM


class AnthropicTokenCounter(LLMessageTokenCounterBase):
Expand All @@ -35,14 +34,6 @@ def count_messages_token(self, messages: Sequence[LLMMessage]) -> int:


class AnthropicLLM(LLMBase):
"""
Implementation for an Anthropic LLM.
Notes:
More details: https://docs.anthropic.com/claude/docs
and https://docs.anthropic.com/claude/reference/complete_post
"""

def __init__(self, config: AnthropicLLMConfiguration, name: Optional[str] = None) -> None:
"""
Initialize a new instance.
Expand All @@ -53,78 +44,33 @@ def __init__(self, config: AnthropicLLMConfiguration, name: Optional[str] = None
super().__init__(name=name or f"{self.__class__.__name__}")
self.config = config
self._client = Anthropic(api_key=config.api_key.value, max_retries=0)
self._api = self._get_api_wrapper()

def _post_chat_request(self, context: LLMContext, messages: Sequence[LLMMessage], **kwargs: Any) -> LLMResult:
try:
messages_formatted = self._to_anthropic_messages_v2(messages)
completion = self._client.messages.create(
messages=messages_formatted,
model=self.config.model.unwrap(),
max_tokens=self.config.max_tokens.unwrap(),
timeout=self.config.timeout.value,
temperature=self.config.temperature.unwrap_or(NOT_GIVEN),
top_k=self.config.top_k.unwrap_or(NOT_GIVEN),
top_p=self.config.top_p.unwrap_or(NOT_GIVEN),
)
response = completion.content[0].text
response = self._api.post_chat_request(messages=messages)
prompt_text = "\n".join([msg.content for msg in messages])
return LLMResult(choices=[response], consumptions=self.to_consumptions(prompt_text, response))
return LLMResult(choices=response, consumptions=self.to_consumptions(prompt_text, response))
except APITimeoutError as e:
raise LLMCallTimeoutException(self.config.timeout.value, self._name) from e
except APIStatusError as e:
raise LLMCallException(code=e.status_code, error=e.message, llm_name=self._name) from e

def to_consumptions(self, prompt: str, response: str) -> Sequence[Consumption]:
def to_consumptions(self, prompt: str, responses: List[str]) -> Sequence[Consumption]:
model = self.config.model.unwrap()
prompt_tokens = self._client.count_tokens(prompt)
completion_tokens = self._client.count_tokens(response)
completion_tokens = sum(self._client.count_tokens(r) for r in responses)
return [
Consumption(1, "call", f"{model}"),
Consumption(prompt_tokens, "token", f"{model}:prompt_tokens"),
Consumption(completion_tokens, "token", f"{model}:completion_tokens"),
Consumption(prompt_tokens + completion_tokens, "token", f"{model}:total_tokens"),
]

@staticmethod
def _to_anthropic_messages_v2(messages: Sequence[LLMMessage]) -> List[Dict[str, str]]:
result = []
temp_content = ""
role = "user"

for message in messages:
if message.is_of_role(LLMMessageRole.System):
temp_content += message.content
else:
temp_content += message.content
result.append({"role": role, "content": temp_content})
temp_content = ""
role = "assistant" if role == "user" else "user"

if temp_content:
result.append({"role": role, "content": temp_content})

return result

@staticmethod
def _to_anthropic_messages(messages: Sequence[LLMMessage]) -> str:
messages_count = len(messages)
if messages_count == 0:
raise Exception("No message to process.")

result = []
if messages[0].is_of_role(LLMMessageRole.System) and messages_count > 1:
result.append(f"{_HUMAN_TURN} {messages[0].content}\n{messages[1].content}")
remaining = messages[2:]
else:
result.append(f"{_HUMAN_TURN} {messages[0].content}")
remaining = messages[1:]

for item in remaining:
prefix = _HUMAN_TURN if item.is_of_role(LLMMessageRole.User) else _ASSISTANT_TURN
result.append(f"{prefix} {item.content}")
result.append(_ASSISTANT_TURN)

return "".join(result)
def _get_api_wrapper(self) -> AnthropicAPIClientWrapper:
if self.config.model.value == "claude-2":
return AnthropicCompletionLLM(client=self._client, config=self.config)
return AnthropicMessagesLLM(client=self._client, config=self.config)

@staticmethod
def from_env() -> AnthropicLLM:
Expand Down
61 changes: 61 additions & 0 deletions council/llm/anthropic_messages_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from __future__ import annotations

from typing import Sequence, List, Iterable, Literal

from anthropic import Anthropic
from anthropic._types import NOT_GIVEN
from anthropic.types import MessageParam

from council.llm import (
LLMMessage,
LLMMessageRole,
AnthropicLLMConfiguration,
)
from council.llm.anthropic import AnthropicAPIClientWrapper


class AnthropicMessagesLLM(AnthropicAPIClientWrapper):
"""
Implementation for an Anthropic LLM.
Notes:
More details: https://docs.anthropic.com/claude/docs
and https://docs.anthropic.com/claude/reference/messages_post
"""

def __init__(self, config: AnthropicLLMConfiguration, client: Anthropic) -> None:
self._config = config
self._client = client

def post_chat_request(self, messages: Sequence[LLMMessage]) -> List[str]:
messages_formatted = self._to_anthropic_messages(messages)
completion = self._client.messages.create(
messages=messages_formatted,
model=self._config.model.unwrap(),
max_tokens=self._config.max_tokens.unwrap(),
timeout=self._config.timeout.value,
temperature=self._config.temperature.unwrap_or(NOT_GIVEN),
top_k=self._config.top_k.unwrap_or(NOT_GIVEN),
top_p=self._config.top_p.unwrap_or(NOT_GIVEN),
)
return [content.text for content in completion.content]

@staticmethod
def _to_anthropic_messages(messages: Sequence[LLMMessage]) -> Iterable[MessageParam]:
result: List[MessageParam] = []
temp_content = ""
role: Literal["user", "assistant"] = "user"

for message in messages:
if message.is_of_role(LLMMessageRole.System):
temp_content += message.content
else:
temp_content += message.content
result.append(MessageParam(role=role, content=temp_content))
temp_content = ""
role = "assistant" if role == "user" else "user"

if temp_content:
result.append(MessageParam(role=role, content=temp_content))

return result
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ progressbar==2.5
tiktoken==0.5.1

# LLMs
anthropic>=0.5.0
anthropic>=0.20.0

# Skills
## Google
Expand Down
2 changes: 1 addition & 1 deletion tests/data/anthropic-llmodel.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ spec:
provider:
name: CML-Anthropic
anthropicSpec:
model: claude-2
model: claude-2.1
timeout: 60
maxTokens: 8192
apiKey:
Expand Down
27 changes: 23 additions & 4 deletions tests/integration/llm/test_anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,33 @@
import dotenv
from council import LLMContext
from council.llm import LLMMessage, AnthropicLLM
from council.utils import OsEnviron


class TestAnthropicLLM(unittest.TestCase):
def test_completion(self):
messages = [LLMMessage.user_message("what is the capital of France?")]
dotenv.load_dotenv()
instance = AnthropicLLM.from_env()
context = LLMContext.empty()
result = instance.post_chat_request(context, messages)
with OsEnviron("ANTHROPIC_LLM_MODEL", "claude-2"):
instance = AnthropicLLM.from_env()
context = LLMContext.empty()
result = instance.post_chat_request(context, messages)

assert "Paris" in result.choices[0]
assert "Paris" in result.choices[0]

def test_message(self):
messages = [LLMMessage.user_message("what is the capital of France?")]
dotenv.load_dotenv()
with OsEnviron("ANTHROPIC_LLM_MODEL", "claude-2.1"):
instance = AnthropicLLM.from_env()
context = LLMContext.empty()
result = instance.post_chat_request(context, messages)

assert "Paris" in result.choices[0]

with OsEnviron("ANTHROPIC_LLM_MODEL", "claude-3-haiku-20240307"):
instance = AnthropicLLM.from_env()
context = LLMContext.empty()
result = instance.post_chat_request(context, messages)

assert "Paris" in result.choices[0]

0 comments on commit f71a8b5

Please sign in to comment.