diff --git a/src/cohere/aws_client.py b/src/cohere/aws_client.py index f34d60a54..52755e036 100644 --- a/src/cohere/aws_client.py +++ b/src/cohere/aws_client.py @@ -118,8 +118,7 @@ def stream_generator(response: httpx.Response, endpoint: str) -> typing.Iterator regex = r"{[^\}]*}" for _text in response.iter_lines(): - match = re.search(regex, _text) - if match: + if match := re.search(regex, _text): obj = json.loads(match.group()) if "bytes" in obj: base64_payload = base64.b64decode(obj["bytes"]).decode("utf-8") @@ -256,9 +255,9 @@ def get_url( stream: bool, ) -> str: if platform == "bedrock": - endpoint = "invoke" if not stream else "invoke-with-response-stream" + endpoint = "invoke-with-response-stream" if stream else "invoke" return f"https://{platform}-runtime.{aws_region}.amazonaws.com/model/{model}/{endpoint}" elif platform == "sagemaker": - endpoint = "invocations" if not stream else "invocations-response-stream" + endpoint = "invocations-response-stream" if stream else "invocations" return f"https://runtime.sagemaker.{aws_region}.amazonaws.com/endpoints/{model}/{endpoint}" return "" diff --git a/src/cohere/utils.py b/src/cohere/utils.py index de7ab65a3..e38f57ea3 100644 --- a/src/cohere/utils.py +++ b/src/cohere/utils.py @@ -23,7 +23,7 @@ def get_success_states(): def get_failed_states(): - return {"unknown", "failed", "skipped", "cancelled", "failed"} + return {"unknown", "skipped", "cancelled", "failed"} def get_id( @@ -37,12 +37,14 @@ def get_validation_status(awaitable: typing.Union[EmbedJob, DatasetsGetResponse] def get_job(cohere: typing.Any, - awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse, EmbedJob, DatasetsGetResponse]) -> \ - typing.Union[ + awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse, EmbedJob, DatasetsGetResponse]) -> typing.Union[ EmbedJob, DatasetsGetResponse]: - if awaitable.__class__.__name__ == "EmbedJob" or awaitable.__class__.__name__ == "CreateEmbedJobResponse": + if awaitable.__class__.__name__ in ["EmbedJob", "CreateEmbedJobResponse"]: return cohere.embed_jobs.get(id=get_id(awaitable)) - elif awaitable.__class__.__name__ == "DatasetsGetResponse" or awaitable.__class__.__name__ == "DatasetsCreateResponse": + elif awaitable.__class__.__name__ in [ + "DatasetsGetResponse", + "DatasetsCreateResponse", + ]: return cohere.datasets.get(id=get_id(awaitable)) else: raise ValueError(f"Unexpected awaitable type {awaitable}") diff --git a/tests/test_async_client.py b/tests/test_async_client.py index 187d8d53b..5c488f462 100644 --- a/tests/test_async_client.py +++ b/tests/test_async_client.py @@ -56,9 +56,9 @@ async def test_chat_stream(self) -> None: if chat_event.event_type == "text-generation": print(chat_event.text) - self.assertTrue("text-generation" in events) - self.assertTrue("stream-start" in events) - self.assertTrue("stream-end" in events) + self.assertIn("text-generation" in events) + self.assertIn("stream-start" in events) + self.assertIn("stream-end" in events) async def test_stream_equals_true(self) -> None: with self.assertRaises(ValueError): @@ -343,14 +343,13 @@ async def test_tool_use(self) -> None: """ ) - if tool_parameters_response.tool_calls is not None: - self.assertEqual( - tool_parameters_response.tool_calls[0].name, "sales_database") - self.assertEqual(tool_parameters_response.tool_calls[0].parameters, { - "day": "2023-09-29"}) - else: + if tool_parameters_response.tool_calls is None: raise ValueError("Expected tool calls to be present") + self.assertEqual( + tool_parameters_response.tool_calls[0].name, "sales_database") + self.assertEqual(tool_parameters_response.tool_calls[0].parameters, { + "day": "2023-09-29"}) local_tools = { "sales_database": lambda day: { "number_of_sales": 120, diff --git a/tests/test_client.py b/tests/test_client.py index 9a839877e..1427bedfa 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -405,14 +405,13 @@ def test_tool_use(self) -> None: """ ) - if tool_parameters_response.tool_calls is not None: - self.assertEqual( - tool_parameters_response.tool_calls[0].name, "sales_database") - self.assertEqual(tool_parameters_response.tool_calls[0].parameters, { - "day": "2023-09-29"}) - else: + if tool_parameters_response.tool_calls is None: raise ValueError("Expected tool calls to be present") + self.assertEqual( + tool_parameters_response.tool_calls[0].name, "sales_database") + self.assertEqual(tool_parameters_response.tool_calls[0].parameters, { + "day": "2023-09-29"}) local_tools = { "sales_database": lambda day: { "number_of_sales": 120, diff --git a/tests/test_client_v2.py b/tests/test_client_v2.py index 235fc9436..0cb03bb32 100644 --- a/tests/test_client_v2.py +++ b/tests/test_client_v2.py @@ -29,11 +29,11 @@ def test_chat_stream(self) -> None: if chat_event.type == "content-delta": print(chat_event.delta.message) - self.assertTrue("message-start" in events) - self.assertTrue("content-start" in events) - self.assertTrue("content-delta" in events) - self.assertTrue("content-end" in events) - self.assertTrue("message-end" in events) + self.assertIn("message-start" in events) + self.assertIn("content-start" in events) + self.assertIn("content-delta" in events) + self.assertIn("content-end" in events) + self.assertIn("message-end" in events) @unittest.skip("Skip v2 test for now") def test_chat_documents(self) -> None: