From af142a8add92f544daa4c9f56767084963bb7835 Mon Sep 17 00:00:00 2001 From: Vincent Koc Date: Mon, 12 Aug 2024 18:27:56 -0700 Subject: [PATCH 1/8] Update aws_client.py --- src/cohere/aws_client.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/cohere/aws_client.py b/src/cohere/aws_client.py index ee64266b4..5c1dc8ce3 100644 --- a/src/cohere/aws_client.py +++ b/src/cohere/aws_client.py @@ -111,8 +111,7 @@ def stream_generator(response: httpx.Response, endpoint: str) -> typing.Iterator regex = r"{[^\}]*}" for _text in response.iter_lines(): - match = re.search(regex, _text) - if match: + if match := re.search(regex, _text): obj = json.loads(match.group()) if "bytes" in obj: base64_payload = base64.b64decode(obj["bytes"]).decode("utf-8") From aa33eb95f83d1fe3328ae6773f9d551be0a8dd71 Mon Sep 17 00:00:00 2001 From: Vincent Koc Date: Mon, 12 Aug 2024 18:28:03 -0700 Subject: [PATCH 2/8] Update aws_client.py --- src/cohere/aws_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cohere/aws_client.py b/src/cohere/aws_client.py index 5c1dc8ce3..52a8d4a6c 100644 --- a/src/cohere/aws_client.py +++ b/src/cohere/aws_client.py @@ -247,9 +247,9 @@ def get_url( stream: bool, ) -> str: if platform == "bedrock": - endpoint = "invoke" if not stream else "invoke-with-response-stream" + endpoint = "invoke-with-response-stream" if stream else "invoke" return f"https://{platform}-runtime.{aws_region}.amazonaws.com/model/{model}/{endpoint}" elif platform == "sagemaker": - endpoint = "invocations" if not stream else "invocations-response-stream" + endpoint = "invocations-response-stream" if stream else "invocations" return f"https://runtime.sagemaker.{aws_region}.amazonaws.com/endpoints/{model}/{endpoint}" return "" From 725770398ff5fd2f470cfc980fdba5d375947285 Mon Sep 17 00:00:00 2001 From: Vincent Koc Date: Mon, 12 Aug 2024 18:28:11 -0700 Subject: [PATCH 3/8] duplicate list item --- src/cohere/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cohere/utils.py b/src/cohere/utils.py index 9d2ed5c09..1c81ec368 100644 --- a/src/cohere/utils.py +++ b/src/cohere/utils.py @@ -22,7 +22,7 @@ def get_success_states(): def get_failed_states(): - return {"unknown", "failed", "skipped", "cancelled", "failed"} + return {"unknown", "skipped", "cancelled", "failed"} def get_id( From 4a2b80399c8dd988f3c5b2c25845efb5dfc3dc0d Mon Sep 17 00:00:00 2001 From: Vincent Koc Date: Mon, 12 Aug 2024 18:28:16 -0700 Subject: [PATCH 4/8] Update utils.py --- src/cohere/utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/cohere/utils.py b/src/cohere/utils.py index 1c81ec368..b79a34bbd 100644 --- a/src/cohere/utils.py +++ b/src/cohere/utils.py @@ -36,12 +36,14 @@ def get_validation_status(awaitable: typing.Union[EmbedJob, DatasetsGetResponse] def get_job(cohere: typing.Any, - awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse, EmbedJob, DatasetsGetResponse]) -> \ - typing.Union[ + awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse, EmbedJob, DatasetsGetResponse]) -> typing.Union[ EmbedJob, DatasetsGetResponse]: - if awaitable.__class__.__name__ == "EmbedJob" or awaitable.__class__.__name__ == "CreateEmbedJobResponse": + if awaitable.__class__.__name__ in ["EmbedJob", "CreateEmbedJobResponse"]: return cohere.embed_jobs.get(id=get_id(awaitable)) - elif awaitable.__class__.__name__ == "DatasetsGetResponse" or awaitable.__class__.__name__ == "DatasetsCreateResponse": + elif awaitable.__class__.__name__ in [ + "DatasetsGetResponse", + "DatasetsCreateResponse", + ]: return cohere.datasets.get(id=get_id(awaitable)) else: raise ValueError(f"Unexpected awaitable type {awaitable}") From 11bb46b9d12c77146669a5e1d1331e81480155fb Mon Sep 17 00:00:00 2001 From: Vincent Koc Date: Mon, 12 Aug 2024 18:28:20 -0700 Subject: [PATCH 5/8] Update test_async_client.py --- tests/test_async_client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_async_client.py b/tests/test_async_client.py index 63ecb086c..80da59cf8 100644 --- a/tests/test_async_client.py +++ b/tests/test_async_client.py @@ -56,9 +56,9 @@ async def test_chat_stream(self) -> None: if chat_event.event_type == "text-generation": print(chat_event.text) - self.assertTrue("text-generation" in events) - self.assertTrue("stream-start" in events) - self.assertTrue("stream-end" in events) + self.assertIn("text-generation" in events) + self.assertIn("stream-start" in events) + self.assertIn("stream-end" in events) async def test_stream_equals_true(self) -> None: with self.assertRaises(ValueError): From bd3933a0f7d9fe52a36cec224e03c24eecaccf8d Mon Sep 17 00:00:00 2001 From: Vincent Koc Date: Mon, 12 Aug 2024 18:28:22 -0700 Subject: [PATCH 6/8] Update test_async_client.py --- tests/test_async_client.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/test_async_client.py b/tests/test_async_client.py index 80da59cf8..f45ad347e 100644 --- a/tests/test_async_client.py +++ b/tests/test_async_client.py @@ -343,14 +343,13 @@ async def test_tool_use(self) -> None: """ ) - if tool_parameters_response.tool_calls is not None: - self.assertEqual( - tool_parameters_response.tool_calls[0].name, "sales_database") - self.assertEqual(tool_parameters_response.tool_calls[0].parameters, { - "day": "2023-09-29"}) - else: + if tool_parameters_response.tool_calls is None: raise ValueError("Expected tool calls to be present") + self.assertEqual( + tool_parameters_response.tool_calls[0].name, "sales_database") + self.assertEqual(tool_parameters_response.tool_calls[0].parameters, { + "day": "2023-09-29"}) local_tools = { "sales_database": lambda day: { "number_of_sales": 120, From 66264e8c3a83d4af06f32af8af3a2a517af1bc0f Mon Sep 17 00:00:00 2001 From: Vincent Koc Date: Mon, 12 Aug 2024 18:28:26 -0700 Subject: [PATCH 7/8] Update test_client_v2.py --- tests/test_client_v2.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_client_v2.py b/tests/test_client_v2.py index 66f84d522..da8a84a8f 100644 --- a/tests/test_client_v2.py +++ b/tests/test_client_v2.py @@ -44,11 +44,11 @@ def test_chat_stream(self) -> None: if chat_event.type == "content-delta": print(chat_event.delta.message) - self.assertTrue("message-start" in events) - self.assertTrue("content-start" in events) - self.assertTrue("content-delta" in events) - self.assertTrue("content-end" in events) - self.assertTrue("message-end" in events) + self.assertIn("message-start" in events) + self.assertIn("content-start" in events) + self.assertIn("content-delta" in events) + self.assertIn("content-end" in events) + self.assertIn("message-end" in events) @unittest.skip("Skip v2 test for now") def test_chat_documents(self) -> None: From b366ff9fc61bc692533ab0f507665c6c9bbc2adb Mon Sep 17 00:00:00 2001 From: Vincent Koc Date: Mon, 12 Aug 2024 18:28:29 -0700 Subject: [PATCH 8/8] Update test_client.py --- tests/test_client.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/test_client.py b/tests/test_client.py index bf0581529..cbd592217 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -380,14 +380,13 @@ def test_tool_use(self) -> None: """ ) - if tool_parameters_response.tool_calls is not None: - self.assertEqual( - tool_parameters_response.tool_calls[0].name, "sales_database") - self.assertEqual(tool_parameters_response.tool_calls[0].parameters, { - "day": "2023-09-29"}) - else: + if tool_parameters_response.tool_calls is None: raise ValueError("Expected tool calls to be present") + self.assertEqual( + tool_parameters_response.tool_calls[0].name, "sales_database") + self.assertEqual(tool_parameters_response.tool_calls[0].parameters, { + "day": "2023-09-29"}) local_tools = { "sales_database": lambda day: { "number_of_sales": 120,