diff --git a/council/llm/anthropic.py b/council/llm/anthropic.py new file mode 100644 index 00000000..39f73733 --- /dev/null +++ b/council/llm/anthropic.py @@ -0,0 +1,12 @@ +import abc +from abc import ABC +from typing import Sequence, List + +from council.llm import LLMMessage + + +class AnthropicAPIClientWrapper(ABC): + + @abc.abstractmethod + def post_chat_request(self, messages: Sequence[LLMMessage]) -> List[str]: + pass diff --git a/council/llm/anthropic_completion_llm.py b/council/llm/anthropic_completion_llm.py new file mode 100644 index 00000000..05e65016 --- /dev/null +++ b/council/llm/anthropic_completion_llm.py @@ -0,0 +1,58 @@ +from typing import Sequence, List + +from anthropic import Anthropic +from anthropic._types import NOT_GIVEN + +from council.llm import AnthropicLLMConfiguration, LLMMessage, LLMMessageRole +from council.llm.anthropic import AnthropicAPIClientWrapper + +_HUMAN_TURN = Anthropic.HUMAN_PROMPT +_ASSISTANT_TURN = Anthropic.AI_PROMPT + + +class AnthropicCompletionLLM(AnthropicAPIClientWrapper): + """ + Implementation for an Anthropic LLM with LEGACY completion. + + Notes: + More details: https://docs.anthropic.com/claude/docs + and https://docs.anthropic.com/claude/reference/complete_post + """ + + def __init__(self, config: AnthropicLLMConfiguration, client: Anthropic) -> None: + self._config = config + self._client = client + + def post_chat_request(self, messages: Sequence[LLMMessage]) -> List[str]: + prompt = self._to_anthropic_messages(messages) + result = self._client.completions.create( + prompt=prompt, + model=self._config.model.unwrap(), + max_tokens_to_sample=self._config.max_tokens.unwrap(), + timeout=self._config.timeout.value, + temperature=self._config.temperature.unwrap_or(NOT_GIVEN), + top_k=self._config.top_k.unwrap_or(NOT_GIVEN), + top_p=self._config.top_p.unwrap_or(NOT_GIVEN), + ) + return [result.completion] + + @staticmethod + def _to_anthropic_messages(messages: Sequence[LLMMessage]) -> str: + messages_count = len(messages) + if messages_count == 0: + raise Exception("No message to process.") + + result = [] + if messages[0].is_of_role(LLMMessageRole.System) and messages_count > 1: + result.append(f"{_HUMAN_TURN} {messages[0].content}\n{messages[1].content}") + remaining = messages[2:] + else: + result.append(f"{_HUMAN_TURN} {messages[0].content}") + remaining = messages[1:] + + for item in remaining: + prefix = _HUMAN_TURN if item.is_of_role(LLMMessageRole.User) else _ASSISTANT_TURN + result.append(f"{prefix} {item.content}") + result.append(_ASSISTANT_TURN) + + return "".join(result) diff --git a/council/llm/anthropic_llm.py b/council/llm/anthropic_llm.py index 556cb71c..d71b88f1 100644 --- a/council/llm/anthropic_llm.py +++ b/council/llm/anthropic_llm.py @@ -1,15 +1,13 @@ from __future__ import annotations -from typing import Any, Sequence, Optional, List, Dict +from typing import Any, Sequence, Optional, List from anthropic import Anthropic, APITimeoutError, APIStatusError -from anthropic._types import NOT_GIVEN from council.contexts import LLMContext, Consumption from council.llm import ( LLMBase, LLMMessage, - LLMMessageRole, LLMResult, LLMCallTimeoutException, LLMCallException, @@ -18,9 +16,10 @@ LLMConfigObject, LLMProviders, ) +from .anthropic import AnthropicAPIClientWrapper -_HUMAN_TURN = Anthropic.HUMAN_PROMPT -_ASSISTANT_TURN = Anthropic.AI_PROMPT +from .anthropic_completion_llm import AnthropicCompletionLLM +from .anthropic_messages_llm import AnthropicMessagesLLM class AnthropicTokenCounter(LLMessageTokenCounterBase): @@ -35,14 +34,6 @@ def count_messages_token(self, messages: Sequence[LLMMessage]) -> int: class AnthropicLLM(LLMBase): - """ - Implementation for an Anthropic LLM. - - Notes: - More details: https://docs.anthropic.com/claude/docs - and https://docs.anthropic.com/claude/reference/complete_post - """ - def __init__(self, config: AnthropicLLMConfiguration, name: Optional[str] = None) -> None: """ Initialize a new instance. @@ -53,31 +44,22 @@ def __init__(self, config: AnthropicLLMConfiguration, name: Optional[str] = None super().__init__(name=name or f"{self.__class__.__name__}") self.config = config self._client = Anthropic(api_key=config.api_key.value, max_retries=0) + self._api = self._get_api_wrapper() def _post_chat_request(self, context: LLMContext, messages: Sequence[LLMMessage], **kwargs: Any) -> LLMResult: try: - messages_formatted = self._to_anthropic_messages_v2(messages) - completion = self._client.messages.create( - messages=messages_formatted, - model=self.config.model.unwrap(), - max_tokens=self.config.max_tokens.unwrap(), - timeout=self.config.timeout.value, - temperature=self.config.temperature.unwrap_or(NOT_GIVEN), - top_k=self.config.top_k.unwrap_or(NOT_GIVEN), - top_p=self.config.top_p.unwrap_or(NOT_GIVEN), - ) - response = completion.content[0].text + response = self._api.post_chat_request(messages=messages) prompt_text = "\n".join([msg.content for msg in messages]) - return LLMResult(choices=[response], consumptions=self.to_consumptions(prompt_text, response)) + return LLMResult(choices=response, consumptions=self.to_consumptions(prompt_text, response)) except APITimeoutError as e: raise LLMCallTimeoutException(self.config.timeout.value, self._name) from e except APIStatusError as e: raise LLMCallException(code=e.status_code, error=e.message, llm_name=self._name) from e - def to_consumptions(self, prompt: str, response: str) -> Sequence[Consumption]: + def to_consumptions(self, prompt: str, responses: List[str]) -> Sequence[Consumption]: model = self.config.model.unwrap() prompt_tokens = self._client.count_tokens(prompt) - completion_tokens = self._client.count_tokens(response) + completion_tokens = sum(self._client.count_tokens(r) for r in responses) return [ Consumption(1, "call", f"{model}"), Consumption(prompt_tokens, "token", f"{model}:prompt_tokens"), @@ -85,46 +67,10 @@ def to_consumptions(self, prompt: str, response: str) -> Sequence[Consumption]: Consumption(prompt_tokens + completion_tokens, "token", f"{model}:total_tokens"), ] - @staticmethod - def _to_anthropic_messages_v2(messages: Sequence[LLMMessage]) -> List[Dict[str, str]]: - result = [] - temp_content = "" - role = "user" - - for message in messages: - if message.is_of_role(LLMMessageRole.System): - temp_content += message.content - else: - temp_content += message.content - result.append({"role": role, "content": temp_content}) - temp_content = "" - role = "assistant" if role == "user" else "user" - - if temp_content: - result.append({"role": role, "content": temp_content}) - - return result - - @staticmethod - def _to_anthropic_messages(messages: Sequence[LLMMessage]) -> str: - messages_count = len(messages) - if messages_count == 0: - raise Exception("No message to process.") - - result = [] - if messages[0].is_of_role(LLMMessageRole.System) and messages_count > 1: - result.append(f"{_HUMAN_TURN} {messages[0].content}\n{messages[1].content}") - remaining = messages[2:] - else: - result.append(f"{_HUMAN_TURN} {messages[0].content}") - remaining = messages[1:] - - for item in remaining: - prefix = _HUMAN_TURN if item.is_of_role(LLMMessageRole.User) else _ASSISTANT_TURN - result.append(f"{prefix} {item.content}") - result.append(_ASSISTANT_TURN) - - return "".join(result) + def _get_api_wrapper(self) -> AnthropicAPIClientWrapper: + if self.config.model.value == "claude-2": + return AnthropicCompletionLLM(client=self._client, config=self.config) + return AnthropicMessagesLLM(client=self._client, config=self.config) @staticmethod def from_env() -> AnthropicLLM: diff --git a/council/llm/anthropic_messages_llm.py b/council/llm/anthropic_messages_llm.py new file mode 100644 index 00000000..e4c1fef7 --- /dev/null +++ b/council/llm/anthropic_messages_llm.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from typing import Sequence, List, Iterable, Literal + +from anthropic import Anthropic +from anthropic._types import NOT_GIVEN +from anthropic.types import MessageParam + +from council.llm import ( + LLMMessage, + LLMMessageRole, + AnthropicLLMConfiguration, +) +from council.llm.anthropic import AnthropicAPIClientWrapper + + +class AnthropicMessagesLLM(AnthropicAPIClientWrapper): + """ + Implementation for an Anthropic LLM. + + Notes: + More details: https://docs.anthropic.com/claude/docs + and https://docs.anthropic.com/claude/reference/messages_post + """ + + def __init__(self, config: AnthropicLLMConfiguration, client: Anthropic) -> None: + self._config = config + self._client = client + + def post_chat_request(self, messages: Sequence[LLMMessage]) -> List[str]: + messages_formatted = self._to_anthropic_messages(messages) + completion = self._client.messages.create( + messages=messages_formatted, + model=self._config.model.unwrap(), + max_tokens=self._config.max_tokens.unwrap(), + timeout=self._config.timeout.value, + temperature=self._config.temperature.unwrap_or(NOT_GIVEN), + top_k=self._config.top_k.unwrap_or(NOT_GIVEN), + top_p=self._config.top_p.unwrap_or(NOT_GIVEN), + ) + return [content.text for content in completion.content] + + @staticmethod + def _to_anthropic_messages(messages: Sequence[LLMMessage]) -> Iterable[MessageParam]: + result: List[MessageParam] = [] + temp_content = "" + role: Literal["user", "assistant"] = "user" + + for message in messages: + if message.is_of_role(LLMMessageRole.System): + temp_content += message.content + else: + temp_content += message.content + result.append(MessageParam(role=role, content=temp_content)) + temp_content = "" + role = "assistant" if role == "user" else "user" + + if temp_content: + result.append(MessageParam(role=role, content=temp_content)) + + return result diff --git a/requirements.txt b/requirements.txt index 846a633c..173b7613 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ progressbar==2.5 tiktoken==0.5.1 # LLMs -anthropic>=0.5.0 +anthropic>=0.20.0 # Skills ## Google diff --git a/tests/data/anthropic-llmodel.yaml b/tests/data/anthropic-llmodel.yaml index 8c46b4a4..075dc4fe 100644 --- a/tests/data/anthropic-llmodel.yaml +++ b/tests/data/anthropic-llmodel.yaml @@ -9,7 +9,7 @@ spec: provider: name: CML-Anthropic anthropicSpec: - model: claude-2 + model: claude-2.1 timeout: 60 maxTokens: 8192 apiKey: diff --git a/tests/integration/llm/test_anthropic_llm.py b/tests/integration/llm/test_anthropic_llm.py index 4e39dfe9..de78d474 100644 --- a/tests/integration/llm/test_anthropic_llm.py +++ b/tests/integration/llm/test_anthropic_llm.py @@ -3,14 +3,33 @@ import dotenv from council import LLMContext from council.llm import LLMMessage, AnthropicLLM +from council.utils import OsEnviron class TestAnthropicLLM(unittest.TestCase): def test_completion(self): messages = [LLMMessage.user_message("what is the capital of France?")] dotenv.load_dotenv() - instance = AnthropicLLM.from_env() - context = LLMContext.empty() - result = instance.post_chat_request(context, messages) + with OsEnviron("ANTHROPIC_LLM_MODEL", "claude-2"): + instance = AnthropicLLM.from_env() + context = LLMContext.empty() + result = instance.post_chat_request(context, messages) - assert "Paris" in result.choices[0] + assert "Paris" in result.choices[0] + + def test_message(self): + messages = [LLMMessage.user_message("what is the capital of France?")] + dotenv.load_dotenv() + with OsEnviron("ANTHROPIC_LLM_MODEL", "claude-2.1"): + instance = AnthropicLLM.from_env() + context = LLMContext.empty() + result = instance.post_chat_request(context, messages) + + assert "Paris" in result.choices[0] + + with OsEnviron("ANTHROPIC_LLM_MODEL", "claude-3-haiku-20240307"): + instance = AnthropicLLM.from_env() + context = LLMContext.empty() + result = instance.post_chat_request(context, messages) + + assert "Paris" in result.choices[0]