Skip to content

Commit

Permalink
🔧💥 #15 Apply API breaking changes from Glide v0.0.4-rc.1 (#16)
Browse files Browse the repository at this point in the history
Applying breaking changes from EinStack/glide#236:

- changed the request/response field name format to snake_case
- renamed router_id, model_name, provider_id fields
- introduced error name in the error responses
  • Loading branch information
roma-glushko authored May 13, 2024
1 parent fcef3b9 commit de79eae
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 39 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/activity-notifications.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name: Pull Request Activity Notifications

on:
pull_request:
pull_request_target:
types: [opened, closed, reopened]

jobs:
Expand Down
4 changes: 2 additions & 2 deletions examples/lang/chat_stream_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def chat_stream() -> None:
continue

if err := message.error:
print(f"💥ERR: {err.message} (code: {err.err_code})")
print(f"💥ERR ({err.name}): {err.message}")
print("🧹 Restarting the stream")
continue

Expand All @@ -50,7 +50,7 @@ async def chat_stream() -> None:

if last_msg and last_msg.chunk and last_msg.finish_reason:
# LLM gen context
provider_name = last_msg.chunk.provider_name
provider_name = last_msg.chunk.provider_id
model_name = last_msg.chunk.model_name
finish_reason = last_msg.finish_reason

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ requires-python = ">=3.8"
[project.urls]
Homepage = "https://glide.einstack.ai/"
Documentation = "https://glide.einstack.ai/"
Repository = "https://github.com/me/spam.git"
Issues = "https://github.com/EinStack/glide-python"
Repository = "https://github.com/EinStack/glide-py.git"
Issues = "https://github.com/EinStack/glide-py/issues/"

[tool.pdm.version]
source = "scm"
Expand Down
20 changes: 18 additions & 2 deletions src/glide/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,22 @@ class GlideClientError(GlideError):
Occurs when there is an issue with sending a Glide request
"""

def __init__(self, message: str, err_name: str) -> None:
super().__init__(message)

self.err_name = err_name


class GlideServerError(GlideError):
"""
Occurs when there is an issue with sending a Glide request related to Glide server issues
"""

def __init__(self, message: str, err_name: str) -> None:
super().__init__(message)

self.err_name = err_name


class GlideClientMismatch(GlideError):
"""
Expand All @@ -29,7 +45,7 @@ class GlideChatStreamError(GlideError):
Occurs when chat stream ends with an error
"""

def __init__(self, message: str, err_code: str) -> None:
def __init__(self, message: str, err_name: str) -> None:
super().__init__(message)

self.err_code = err_code
self.err_name = err_name
46 changes: 28 additions & 18 deletions src/glide/lang/router_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,15 @@
GlideClientError,
GlideClientMismatch,
GlideChatStreamError,
GlideServerError,
)
from glide.lang import schemas
from glide.lang.schemas import ChatStreamRequest, ChatStreamMessage, ChatRequestId
from glide.lang.schemas import (
ChatStreamRequest,
ChatStreamMessage,
ChatRequestId,
ChatError,
)
from glide.logging import logger
from glide.typing import RouterId

Expand Down Expand Up @@ -82,8 +88,8 @@ async def chat_stream(
if err := message.ended_with_err:
# fail only on fatal errors that indicate stream stop
raise GlideChatStreamError(
f"Chat stream {req.id} ended with an error: {err.message} (code: {err.err_code})",
err.err_code,
f"Chat stream {req.id} ended with an error ({err.name}): {err.message}",
err.name,
)

yield message # returns content chunk and some error messages
Expand Down Expand Up @@ -113,7 +119,7 @@ async def _sender(self) -> None:

await self._ws_client.send(chat_request.json())
except asyncio.CancelledError:
# TODO: log
logger.debug("chat stream sender task is canceled")
break

async def _receiver(self) -> None:
Expand All @@ -136,6 +142,7 @@ async def _receiver(self) -> None:
exc_info=True,
)
except asyncio.CancelledError:
logger.debug("chat stream receiver task is canceled")
break
except Exception as e:
logger.exception(e)
Expand Down Expand Up @@ -213,33 +220,36 @@ async def chat(
"""
Send a chat request to a specified language router
"""
try:
headers = {}
headers = {}

if self._user_agent:
headers["User-Agent"] = self._user_agent
if self._user_agent:
headers["User-Agent"] = self._user_agent

try:
resp = await self._http_client.post(
f"/language/{router_id}/chat",
headers=headers,
json=request.dict(by_alias=True),
)

except httpx.NetworkError as e:
raise GlideUnavailable() from e
if resp.is_error:
err_data = ChatError(**resp.json())

if not resp.is_success:
raise GlideClientError(
f"Failed to send a chat request: {resp.text} (status_code: {resp.status_code})"
)
if resp.is_client_error:
raise GlideClientError(err_data.message, err_data.name)

try:
raw_response = resp.json()
if resp.is_server_error:
raise GlideServerError(err_data.message, err_data.name)

raw_resp = resp.json()

return schemas.ChatResponse(**raw_response)
return schemas.ChatResponse(**raw_resp)
except httpx.NetworkError as e:
raise GlideUnavailable() from e
except pydantic.ValidationError as err:
raise GlideClientMismatch(
"Failed to validate Glide API response. Please make sure Glide API and client versions are compatible"
"Failed to validate Glide API response. "
"Please make sure Glide API and client versions are compatible"
) from err

def stream_client(self, router_id: RouterId) -> AsyncStreamChatClient:
Expand Down
27 changes: 16 additions & 11 deletions src/glide/lang/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@
from pydantic import Field

from glide.schames import Schema
from glide.typing import RouterId, ProviderName, ModelName
from glide.typing import RouterId, ProviderId, ModelName

ChatRequestId = str
Metadata = Dict[str, Any]


class ChatError(Schema):
name: str
message: str


class FinishReason(str, Enum):
# generation is finished successfully without interruptions
COMPLETE = "complete"
Expand Down Expand Up @@ -45,7 +50,7 @@ class ModelMessageOverride(Schema):
class ChatRequest(Schema):
message: ChatMessage
message_history: List[ChatMessage] = Field(default_factory=list)
override: Optional[ModelMessageOverride] = None
override_params: Optional[ModelMessageOverride] = None


class TokenUsage(Schema):
Expand All @@ -55,26 +60,26 @@ class TokenUsage(Schema):


class ModelResponse(Schema):
response_id: Dict[str, str]
metadata: Dict[str, str]
message: ChatMessage
token_count: TokenUsage
token_usage: TokenUsage


class ChatResponse(Schema):
id: ChatRequestId
created: datetime
provider: ProviderName
router: RouterId
created_at: datetime
provider_id: ProviderId
router_id: RouterId
model_id: str
model: ModelName
model_name: ModelName
model_response: ModelResponse


class ChatStreamRequest(Schema):
id: ChatRequestId = Field(default_factory=lambda: str(uuid.uuid4()))
message: ChatMessage
message_history: List[ChatMessage] = Field(default_factory=list)
override: Optional[ModelMessageOverride] = None
override_params: Optional[ModelMessageOverride] = None
metadata: Optional[Metadata] = None


Expand All @@ -90,7 +95,7 @@ class ChatStreamChunk(Schema):

model_id: str

provider_name: ProviderName
provider_id: ProviderId
model_name: ModelName

model_response: ModelChunkResponse
Expand All @@ -99,7 +104,7 @@ class ChatStreamChunk(Schema):

class ChatStreamError(Schema):
id: ChatRequestId
err_code: str
name: str
message: str
finish_reason: Optional[FinishReason] = None

Expand Down
2 changes: 0 additions & 2 deletions src/glide/schames.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright EinStack
# SPDX-License-Identifier: APACHE-2.0
from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel


class Schema(BaseModel):
Expand All @@ -10,7 +9,6 @@ class Schema(BaseModel):
"""

model_config = ConfigDict(
alias_generator=to_camel,
populate_by_name=True,
from_attributes=True,
protected_namespaces=(),
Expand Down
2 changes: 1 addition & 1 deletion src/glide/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
# SPDX-License-Identifier: APACHE-2.0

RouterId = str
ProviderName = str
ProviderId = str
ModelName = str

0 comments on commit de79eae

Please sign in to comment.