diff --git a/src/cohere/types/__init__.py b/src/cohere/types/__init__.py index 7e2c8a12c..106547ce1 100644 --- a/src/cohere/types/__init__.py +++ b/src/cohere/types/__init__.py @@ -110,7 +110,6 @@ from .detokenize_response import DetokenizeResponse from .document import Document from .document_content import DocumentContent -from .document_source import DocumentSource from .embed_by_type_response import EmbedByTypeResponse from .embed_by_type_response_embeddings import EmbedByTypeResponseEmbeddings from .embed_floats_response import EmbedFloatsResponse @@ -206,12 +205,8 @@ from .summarize_request_format import SummarizeRequestFormat from .summarize_request_length import SummarizeRequestLength from .summarize_response import SummarizeResponse -from .system_message import SystemMessage from .system_message_content import SystemMessageContent from .system_message_content_item import SystemMessageContentItem, TextSystemMessageContentItem -from .text_content import TextContent -from .text_response_format import TextResponseFormat -from .text_response_format_v2 import TextResponseFormatV2 from .texts import Texts from .texts_truncate import TextsTruncate from .tokenize_response import TokenizeResponse @@ -222,12 +217,10 @@ from .tool_call_v2 import ToolCallV2 from .tool_call_v2function import ToolCallV2Function from .tool_content import DocumentToolContent, TextToolContent, ToolContent -from .tool_message import ToolMessage from .tool_message_v2 import ToolMessageV2 from .tool_message_v2tool_content import ToolMessageV2ToolContent from .tool_parameter_definitions_value import ToolParameterDefinitionsValue from .tool_result import ToolResult -from .tool_source import ToolSource from .tool_v2 import ToolV2 from .tool_v2function import ToolV2Function from .unprocessable_entity_error_body import UnprocessableEntityErrorBody @@ -235,7 +228,6 @@ from .usage import Usage from .usage_billed_units import UsageBilledUnits from .usage_tokens import UsageTokens -from .user_message import UserMessage from .user_message_content import UserMessageContent __all__ = [ diff --git a/src/cohere/utils.py b/src/cohere/utils.py index 6d33559ae..de7ab65a3 100644 --- a/src/cohere/utils.py +++ b/src/cohere/utils.py @@ -8,7 +8,7 @@ import requests from fastavro import parse_schema, reader, writer -from . import EmbedResponse, EmbedResponse_EmbeddingsFloats, EmbedResponse_EmbeddingsByType, ApiMeta, \ +from . import EmbedResponse, EmbeddingsFloatsEmbedResponse, EmbeddingsByTypeEmbedResponse, ApiMeta, \ EmbedByTypeResponseEmbeddings, ApiMetaBilledUnits, EmbedJob, CreateEmbedJobResponse, Dataset from .datasets import DatasetsCreateResponse, DatasetsGetResponse from .overrides import get_fields @@ -194,7 +194,7 @@ def merge_embed_responses(responses: typing.List[EmbedResponse]) -> EmbedRespons ] if responses[0].response_type == "embeddings_floats": - embeddings_floats = typing.cast(typing.List[EmbedResponse_EmbeddingsFloats], responses) + embeddings_floats = typing.cast(typing.List[EmbeddingsFloatsEmbedResponse], responses) embeddings = [ embedding @@ -202,7 +202,7 @@ def merge_embed_responses(responses: typing.List[EmbedResponse]) -> EmbedRespons for embedding in embeddings_floats.embeddings ] - return EmbedResponse_EmbeddingsFloats( + return EmbeddingsFloatsEmbedResponse( response_type="embeddings_floats", id=response_id, texts=texts, @@ -210,7 +210,7 @@ def merge_embed_responses(responses: typing.List[EmbedResponse]) -> EmbedRespons meta=meta ) else: - embeddings_type = typing.cast(typing.List[EmbedResponse_EmbeddingsByType], responses) + embeddings_type = typing.cast(typing.List[EmbeddingsByTypeEmbedResponse], responses) embeddings_by_type = [ response.embeddings @@ -231,7 +231,7 @@ def merge_embed_responses(responses: typing.List[EmbedResponse]) -> EmbedRespons embeddings_by_type_merged = EmbedByTypeResponseEmbeddings.parse_obj(merged_dicts) - return EmbedResponse_EmbeddingsByType( + return EmbeddingsByTypeEmbedResponse( response_type="embeddings_by_type", id=response_id, embeddings=embeddings_by_type_merged, diff --git a/tests/test_async_client.py b/tests/test_async_client.py index 63ecb086c..187d8d53b 100644 --- a/tests/test_async_client.py +++ b/tests/test_async_client.py @@ -3,7 +3,7 @@ import cohere from cohere import ChatConnector, ClassifyExample, CreateConnectorServiceAuth, Tool, \ - ToolParameterDefinitionsValue, ToolResult, Message_User, Message_Chatbot + ToolParameterDefinitionsValue, ToolResult, UserMessage, ChatbotMessage package_dir = os.path.dirname(os.path.abspath(__file__)) embed_job = os.path.join(package_dir, 'embed_job.jsonl') @@ -26,9 +26,9 @@ async def test_context_manager(self) -> None: async def test_chat(self) -> None: chat = await self.co.chat( chat_history=[ - Message_User( + UserMessage( message="Who discovered gravity?"), - Message_Chatbot(message="The man who is widely credited with discovering " + ChatbotMessage(message="The man who is widely credited with discovering " "gravity is Sir Isaac Newton") ], message="What year was he born?", @@ -40,9 +40,9 @@ async def test_chat(self) -> None: async def test_chat_stream(self) -> None: stream = self.co.chat_stream( chat_history=[ - Message_User( + UserMessage( message="Who discovered gravity?"), - Message_Chatbot(message="The man who is widely credited with discovering " + ChatbotMessage(message="The man who is widely credited with discovering " "gravity is Sir Isaac Newton") ], message="What year was he born?", diff --git a/tests/test_client.py b/tests/test_client.py index aaee3923e..9a839877e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,7 +4,7 @@ import cohere from cohere import ChatConnector, ClassifyExample, CreateConnectorServiceAuth, Tool, \ - ToolParameterDefinitionsValue, ToolResult, Message_Chatbot, Message_User, ResponseFormat_JsonObject + ToolParameterDefinitionsValue, ToolResult, ChatbotMessage, UserMessage, JsonObjectResponseFormat co = cohere.Client(timeout=10000) @@ -25,9 +25,9 @@ def test_context_manager(self) -> None: def test_chat(self) -> None: chat = co.chat( chat_history=[ - Message_User( + UserMessage( message="Who discovered gravity?"), - Message_Chatbot(message="The man who is widely credited with discovering " + ChatbotMessage(message="The man who is widely credited with discovering " "gravity is Sir Isaac Newton") ], message="What year was he born?", @@ -40,7 +40,7 @@ def test_chat(self) -> None: def test_response_format(self) -> None: chat = co.chat( message="imagine a character from the tv show severance", - response_format=ResponseFormat_JsonObject( + response_format=JsonObjectResponseFormat( schema={ "type": "object", "properties": { @@ -61,9 +61,9 @@ def test_response_format(self) -> None: def test_chat_stream(self) -> None: stream = co.chat_stream( chat_history=[ - Message_User( + UserMessage( message="Who discovered gravity?"), - Message_Chatbot(message="The man who is widely credited with discovering " + ChatbotMessage(message="The man who is widely credited with discovering " "gravity is Sir Isaac Newton") ], message="What year was he born?", diff --git a/tests/test_client_v2.py b/tests/test_client_v2.py index 06f654a43..ffb5825b7 100644 --- a/tests/test_client_v2.py +++ b/tests/test_client_v2.py @@ -3,7 +3,7 @@ import unittest import cohere -from cohere import ToolMessage2, UserMessage, AssistantMessage +from cohere import ToolMessage, UserMessage, AssistantMessage co = cohere.ClientV2(timeout=10000) @@ -14,12 +14,12 @@ class TestClientV2(unittest.TestCase): def test_chat(self) -> None: - response = co.chat(model="command-r-plus", messages=[cohere.v2.ChatMessage2_User(content="hello world!")]) + response = co.chat(model="command-r-plus", messages=[cohere.UserMessage(message="hello world!")]) print(response.message) def test_chat_stream(self) -> None: - stream = co.chat_stream(model="command-r-plus", messages=[cohere.v2.ChatMessage2_User(content="hello world!")]) + stream = co.chat_stream(model="command-r-plus", messages=[cohere.UserMessage(message="hello world!")]) events = set() @@ -43,8 +43,8 @@ def test_chat_documents(self) -> None: {"title": "widget sales 2021", "text": "4 million"}, ] response = co.chat( - messages=cohere.v2.UserMessage( - content=cohere.v2.TextContent(text="how many widges were sold in 2020?"), + messages=cohere.UserChatMessageV2( + content=cohere.TextContent(text="how many widges were sold in 2020?"), documents=documents, ), ) @@ -67,17 +67,17 @@ def test_chat_tools(self) -> None: "required": ["location"], }, } - tools = [cohere.v2.Tool2(type="function", function=get_weather_tool)] - messages: typing.List[typing.Union[UserMessage, AssistantMessage, None, ToolMessage2]] = [ - cohere.v2.UserMessage(content="what is the weather in Toronto?") + tools = [cohere.ToolV2(type="function", function=get_weather_tool)] + messages: cohere.ChatMessages = [ + cohere.UserChatMessageV2(content="what is the weather in Toronto?") ] res = co.chat(model="command-r-plus", tools=tools, messages=messages) # call the get_weather tool tool_result = {"temperature": "30C"} - tool_content = [cohere.v2.Content(output=tool_result, text="The weather in Toronto is 30C")] + tool_content = [cohere.Content(output=tool_result, text="The weather in Toronto is 30C")] messages.append(res.message) - messages.append(cohere.v2.ToolMessage2(tool_call_id=res.message.tool_calls[0].id, tool_content=tool_content)) + messages.append(cohere.ToolChatMessageV2(tool_call_id=res.message.tool_calls[0].id, tool_content=tool_content)) res = co.chat(tools=tools, messages=messages) print(res.message) diff --git a/tests/test_embed_utils.py b/tests/test_embed_utils.py index 17813658a..40c712177 100644 --- a/tests/test_embed_utils.py +++ b/tests/test_embed_utils.py @@ -1,10 +1,10 @@ import unittest -from cohere import EmbedResponse_EmbeddingsByType, EmbedByTypeResponseEmbeddings, ApiMeta, ApiMetaBilledUnits, \ - ApiMetaApiVersion, EmbedResponse_EmbeddingsFloats +from cohere import EmbeddingsByTypeEmbedResponse, EmbedByTypeResponseEmbeddings, ApiMeta, ApiMetaBilledUnits, \ + ApiMetaApiVersion, EmbeddingsFloatsEmbedResponse from cohere.utils import merge_embed_responses -ebt_1 = EmbedResponse_EmbeddingsByType( +ebt_1 = EmbeddingsByTypeEmbedResponse( response_type="embeddings_by_type", id="1", embeddings=EmbedByTypeResponseEmbeddings( @@ -27,7 +27,7 @@ ) ) -ebt_2 = EmbedResponse_EmbeddingsByType( +ebt_2 = EmbeddingsByTypeEmbedResponse( response_type="embeddings_by_type", id="2", embeddings=EmbedByTypeResponseEmbeddings( @@ -50,7 +50,7 @@ ) ) -ebt_partial_1 = EmbedResponse_EmbeddingsByType( +ebt_partial_1 = EmbeddingsByTypeEmbedResponse( response_type="embeddings_by_type", id="1", embeddings=EmbedByTypeResponseEmbeddings( @@ -71,7 +71,7 @@ ) ) -ebt_partial_2 = EmbedResponse_EmbeddingsByType( +ebt_partial_2 = EmbeddingsByTypeEmbedResponse( response_type="embeddings_by_type", id="2", embeddings=EmbedByTypeResponseEmbeddings( @@ -92,7 +92,7 @@ ) ) -ebf_1 = EmbedResponse_EmbeddingsFloats( +ebf_1 = EmbeddingsFloatsEmbedResponse( response_type="embeddings_floats", id="1", texts=["hello", "goodbye"], @@ -109,7 +109,7 @@ ) ) -ebf_2 = EmbedResponse_EmbeddingsFloats( +ebf_2 = EmbeddingsFloatsEmbedResponse( response_type="embeddings_floats", id="2", texts=["bye", "seeya"], @@ -139,7 +139,7 @@ def test_merge_embeddings_by_type(self) -> None: raise Exception("this is just for mpy") self.assertEqual(set(resp.meta.warnings or []), {"test_warning_1", "test_warning_2"}) - self.assertEqual(resp, EmbedResponse_EmbeddingsByType( + self.assertEqual(resp, EmbeddingsByTypeEmbedResponse( response_type="embeddings_by_type", id="1, 2", embeddings=EmbedByTypeResponseEmbeddings( @@ -172,7 +172,7 @@ def test_merge_embeddings_floats(self) -> None: raise Exception("this is just for mpy") self.assertEqual(set(resp.meta.warnings or []), {"test_warning_1", "test_warning_2"}) - self.assertEqual(resp, EmbedResponse_EmbeddingsFloats( + self.assertEqual(resp, EmbeddingsFloatsEmbedResponse( response_type="embeddings_floats", id="1, 2", texts=["hello", "goodbye", "bye", "seeya"], @@ -199,7 +199,7 @@ def test_merge_partial_embeddings_floats(self) -> None: raise Exception("this is just for mpy") self.assertEqual(set(resp.meta.warnings or []), {"test_warning_1", "test_warning_2"}) - self.assertEqual(resp, EmbedResponse_EmbeddingsByType( + self.assertEqual(resp, EmbeddingsByTypeEmbedResponse( response_type="embeddings_by_type", id="1, 2", embeddings=EmbedByTypeResponseEmbeddings(