Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: python lint errors patched #561

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
7 changes: 3 additions & 4 deletions src/cohere/aws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ def stream_generator(response: httpx.Response, endpoint: str) -> typing.Iterator
regex = r"{[^\}]*}"

for _text in response.iter_lines():
match = re.search(regex, _text)
if match:
if match := re.search(regex, _text):
obj = json.loads(match.group())
if "bytes" in obj:
base64_payload = base64.b64decode(obj["bytes"]).decode("utf-8")
Expand Down Expand Up @@ -256,9 +255,9 @@ def get_url(
stream: bool,
) -> str:
if platform == "bedrock":
endpoint = "invoke" if not stream else "invoke-with-response-stream"
endpoint = "invoke-with-response-stream" if stream else "invoke"
return f"https://{platform}-runtime.{aws_region}.amazonaws.com/model/{model}/{endpoint}"
elif platform == "sagemaker":
endpoint = "invocations" if not stream else "invocations-response-stream"
endpoint = "invocations-response-stream" if stream else "invocations"
return f"https://runtime.sagemaker.{aws_region}.amazonaws.com/endpoints/{model}/{endpoint}"
return ""
12 changes: 7 additions & 5 deletions src/cohere/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def get_success_states():


def get_failed_states():
return {"unknown", "failed", "skipped", "cancelled", "failed"}
return {"unknown", "skipped", "cancelled", "failed"}


def get_id(
Expand All @@ -37,12 +37,14 @@ def get_validation_status(awaitable: typing.Union[EmbedJob, DatasetsGetResponse]


def get_job(cohere: typing.Any,
awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse, EmbedJob, DatasetsGetResponse]) -> \
typing.Union[
awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse, EmbedJob, DatasetsGetResponse]) -> typing.Union[
EmbedJob, DatasetsGetResponse]:
if awaitable.__class__.__name__ == "EmbedJob" or awaitable.__class__.__name__ == "CreateEmbedJobResponse":
if awaitable.__class__.__name__ in ["EmbedJob", "CreateEmbedJobResponse"]:
return cohere.embed_jobs.get(id=get_id(awaitable))
elif awaitable.__class__.__name__ == "DatasetsGetResponse" or awaitable.__class__.__name__ == "DatasetsCreateResponse":
elif awaitable.__class__.__name__ in [
"DatasetsGetResponse",
"DatasetsCreateResponse",
]:
return cohere.datasets.get(id=get_id(awaitable))
else:
raise ValueError(f"Unexpected awaitable type {awaitable}")
Expand Down
17 changes: 8 additions & 9 deletions tests/test_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ async def test_chat_stream(self) -> None:
if chat_event.event_type == "text-generation":
print(chat_event.text)

self.assertTrue("text-generation" in events)
self.assertTrue("stream-start" in events)
self.assertTrue("stream-end" in events)
self.assertIn("text-generation" in events)
self.assertIn("stream-start" in events)
self.assertIn("stream-end" in events)

async def test_stream_equals_true(self) -> None:
with self.assertRaises(ValueError):
Expand Down Expand Up @@ -343,14 +343,13 @@ async def test_tool_use(self) -> None:
"""
)

if tool_parameters_response.tool_calls is not None:
self.assertEqual(
tool_parameters_response.tool_calls[0].name, "sales_database")
self.assertEqual(tool_parameters_response.tool_calls[0].parameters, {
"day": "2023-09-29"})
else:
if tool_parameters_response.tool_calls is None:
raise ValueError("Expected tool calls to be present")

self.assertEqual(
tool_parameters_response.tool_calls[0].name, "sales_database")
self.assertEqual(tool_parameters_response.tool_calls[0].parameters, {
"day": "2023-09-29"})
local_tools = {
"sales_database": lambda day: {
"number_of_sales": 120,
Expand Down
11 changes: 5 additions & 6 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,14 +405,13 @@ def test_tool_use(self) -> None:
"""
)

if tool_parameters_response.tool_calls is not None:
self.assertEqual(
tool_parameters_response.tool_calls[0].name, "sales_database")
self.assertEqual(tool_parameters_response.tool_calls[0].parameters, {
"day": "2023-09-29"})
else:
if tool_parameters_response.tool_calls is None:
raise ValueError("Expected tool calls to be present")

self.assertEqual(
tool_parameters_response.tool_calls[0].name, "sales_database")
self.assertEqual(tool_parameters_response.tool_calls[0].parameters, {
"day": "2023-09-29"})
local_tools = {
"sales_database": lambda day: {
"number_of_sales": 120,
Expand Down
10 changes: 5 additions & 5 deletions tests/test_client_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ def test_chat_stream(self) -> None:
if chat_event.type == "content-delta":
print(chat_event.delta.message)

self.assertTrue("message-start" in events)
self.assertTrue("content-start" in events)
self.assertTrue("content-delta" in events)
self.assertTrue("content-end" in events)
self.assertTrue("message-end" in events)
self.assertIn("message-start" in events)
self.assertIn("content-start" in events)
self.assertIn("content-delta" in events)
self.assertIn("content-end" in events)
self.assertIn("message-end" in events)

@unittest.skip("Skip v2 test for now")
def test_chat_documents(self) -> None:
Expand Down