Skip to content

Commit

Permalink
Add muxing models (#948)
Browse files Browse the repository at this point in the history
Small refactor to be able to have typing and make the code
more readable
  • Loading branch information
aponcedeleonch authored Feb 6, 2025
1 parent 0d9396c commit 8da7955
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 56 deletions.
5 changes: 3 additions & 2 deletions src/codegate/api/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError

import codegate.muxing.models as mux_models
from codegate import __version__
from codegate.api import v1_models, v1_processing
from codegate.db.connection import AlreadyExistsError, DbReader
Expand Down Expand Up @@ -477,7 +478,7 @@ async def delete_workspace_custom_instructions(workspace_name: str):
)
async def get_workspace_muxes(
workspace_name: str,
) -> List[v1_models.MuxRule]:
) -> List[mux_models.MuxRule]:
"""Get the mux rules of a workspace.
The list is ordered in order of priority. That is, the first rule in the list
Expand All @@ -501,7 +502,7 @@ async def get_workspace_muxes(
)
async def set_workspace_muxes(
workspace_name: str,
request: List[v1_models.MuxRule],
request: List[mux_models.MuxRule],
):
"""Set the mux rules of a workspace."""
try:
Expand Down
23 changes: 0 additions & 23 deletions src/codegate/api/v1_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,26 +267,3 @@ class ModelByProvider(pydantic.BaseModel):

def __str__(self):
return f"{self.provider_name} / {self.name}"


class MuxMatcherType(str, Enum):
"""
Represents the different types of matchers we support.
"""

# Always match this prompt
catch_all = "catch_all"


class MuxRule(pydantic.BaseModel):
"""
Represents a mux rule for a provider.
"""

provider_id: str
model: str
# The type of matcher to use
matcher_type: MuxMatcherType
# The actual matcher to use. Note that
# this depends on the matcher type.
matcher: Optional[str] = None
27 changes: 27 additions & 0 deletions src/codegate/muxing/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from enum import Enum
from typing import Optional

import pydantic


class MuxMatcherType(str, Enum):
"""
Represents the different types of matchers we support.
"""

# Always match this prompt
catch_all = "catch_all"


class MuxRule(pydantic.BaseModel):
"""
Represents a mux rule for a provider.
"""

provider_id: str
model: str
# The type of matcher to use
matcher_type: MuxMatcherType
# The actual matcher to use. Note that
# this depends on the matcher type.
matcher: Optional[str] = None
50 changes: 19 additions & 31 deletions src/codegate/workspaces/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,9 @@
from typing import List, Optional, Tuple
from uuid import uuid4 as uuid

from codegate.db import models as db_models
from codegate.db.connection import DbReader, DbRecorder
from codegate.db.models import (
ActiveWorkspace,
MuxRule,
Session,
WorkspaceRow,
WorkspaceWithSessionInfo,
)
from codegate.muxing import models as mux_models
from codegate.muxing import rulematcher


Expand Down Expand Up @@ -40,7 +35,7 @@ class WorkspaceCrud:
def __init__(self):
self._db_reader = DbReader()

async def add_workspace(self, new_workspace_name: str) -> WorkspaceRow:
async def add_workspace(self, new_workspace_name: str) -> db_models.WorkspaceRow:
"""
Add a workspace
Expand All @@ -57,7 +52,7 @@ async def add_workspace(self, new_workspace_name: str) -> WorkspaceRow:

async def rename_workspace(
self, old_workspace_name: str, new_workspace_name: str
) -> WorkspaceRow:
) -> db_models.WorkspaceRow:
"""
Rename a workspace
Expand All @@ -79,33 +74,33 @@ async def rename_workspace(
if not ws:
raise WorkspaceDoesNotExistError(f"Workspace {old_workspace_name} does not exist.")
db_recorder = DbRecorder()
new_ws = WorkspaceRow(
new_ws = db_models.WorkspaceRow(
id=ws.id, name=new_workspace_name, custom_instructions=ws.custom_instructions
)
workspace_renamed = await db_recorder.update_workspace(new_ws)
return workspace_renamed

async def get_workspaces(self) -> List[WorkspaceWithSessionInfo]:
async def get_workspaces(self) -> List[db_models.WorkspaceWithSessionInfo]:
"""
Get all workspaces
"""
return await self._db_reader.get_workspaces()

async def get_archived_workspaces(self) -> List[WorkspaceRow]:
async def get_archived_workspaces(self) -> List[db_models.WorkspaceRow]:
"""
Get all archived workspaces
"""
return await self._db_reader.get_archived_workspaces()

async def get_active_workspace(self) -> Optional[ActiveWorkspace]:
async def get_active_workspace(self) -> Optional[db_models.ActiveWorkspace]:
"""
Get the active workspace
"""
return await self._db_reader.get_active_workspace()

async def _is_workspace_active(
self, workspace_name: str
) -> Tuple[bool, Optional[Session], Optional[WorkspaceRow]]:
) -> Tuple[bool, Optional[db_models.Session], Optional[db_models.WorkspaceRow]]:
"""
Check if the workspace is active alongside the session and workspace objects
"""
Expand Down Expand Up @@ -155,13 +150,13 @@ async def recover_workspace(self, workspace_name: str):

async def update_workspace_custom_instructions(
self, workspace_name: str, custom_instr_lst: List[str]
) -> WorkspaceRow:
) -> db_models.WorkspaceRow:
selected_workspace = await self._db_reader.get_workspace_by_name(workspace_name)
if not selected_workspace:
raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.")

custom_instructions = " ".join(custom_instr_lst)
workspace_update = WorkspaceRow(
workspace_update = db_models.WorkspaceRow(
id=selected_workspace.id,
name=selected_workspace.name,
custom_instructions=custom_instructions,
Expand Down Expand Up @@ -217,17 +212,13 @@ async def hard_delete_workspace(self, workspace_name: str):
raise WorkspaceCrudError(f"Error deleting workspace {workspace_name}")
return

async def get_workspace_by_name(self, workspace_name: str) -> WorkspaceRow:
async def get_workspace_by_name(self, workspace_name: str) -> db_models.WorkspaceRow:
workspace = await self._db_reader.get_workspace_by_name(workspace_name)
if not workspace:
raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.")
return workspace

# Can't use type hints since the models are not yet defined
# Note that I'm explicitly importing the models here to avoid circular imports.
async def get_muxes(self, workspace_name: str):
from codegate.api import v1_models

async def get_muxes(self, workspace_name: str) -> List[mux_models.MuxRule]:
# Verify if workspace exists
workspace = await self._db_reader.get_workspace_by_name(workspace_name)
if not workspace:
Expand All @@ -239,7 +230,7 @@ async def get_muxes(self, workspace_name: str):
# These are already sorted by priority
for dbmux in dbmuxes:
muxes.append(
v1_models.MuxRule(
mux_models.MuxRule(
provider_id=dbmux.provider_endpoint_id,
model=dbmux.provider_model_name,
matcher_type=dbmux.matcher_type,
Expand All @@ -249,10 +240,7 @@ async def get_muxes(self, workspace_name: str):

return muxes

# Can't use type hints since the models are not yet defined
async def set_muxes(self, workspace_name: str, muxes):
from codegate.api import v1_models

async def set_muxes(self, workspace_name: str, muxes: mux_models.MuxRule) -> None:
# Verify if workspace exists
workspace = await self._db_reader.get_workspace_by_name(workspace_name)
if not workspace:
Expand All @@ -265,7 +253,7 @@ async def set_muxes(self, workspace_name: str, muxes):
# Add the new muxes
priority = 0

muxes_with_routes: List[Tuple[v1_models.MuxRule, rulematcher.ModelRoute]] = []
muxes_with_routes: List[Tuple[mux_models.MuxRule, rulematcher.ModelRoute]] = []

# Verify all models are valid
for mux in muxes:
Expand All @@ -275,7 +263,7 @@ async def set_muxes(self, workspace_name: str, muxes):
matchers: List[rulematcher.MuxingRuleMatcher] = []

for mux, route in muxes_with_routes:
new_mux = MuxRule(
new_mux = db_models.MuxRule(
id=str(uuid()),
provider_endpoint_id=mux.provider_id,
provider_model_name=mux.model,
Expand All @@ -294,7 +282,7 @@ async def set_muxes(self, workspace_name: str, muxes):
mux_registry = await rulematcher.get_muxing_rules_registry()
await mux_registry.set_ws_rules(workspace_name, matchers)

async def get_routing_for_mux(self, mux) -> rulematcher.ModelRoute:
async def get_routing_for_mux(self, mux: mux_models.MuxRule) -> rulematcher.ModelRoute:
"""Get the routing for a mux
Note that this particular mux object is the API model, not the database model.
Expand Down Expand Up @@ -322,7 +310,7 @@ async def get_routing_for_mux(self, mux) -> rulematcher.ModelRoute:
auth_material=dbauth,
)

async def get_routing_for_db_mux(self, mux: MuxRule) -> rulematcher.ModelRoute:
async def get_routing_for_db_mux(self, mux: db_models.MuxRule) -> rulematcher.ModelRoute:
"""Get the routing for a mux
Note that this particular mux object is the database model, not the API model.
Expand Down

0 comments on commit 8da7955

Please sign in to comment.