Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 8da7955

Browse files
Add muxing models (#948)
Small refactor to be able to have typing and make the code more readable
1 parent 0d9396c commit 8da7955

File tree

4 files changed

+49
-56
lines changed

4 files changed

+49
-56
lines changed

src/codegate/api/v1.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from fastapi.routing import APIRoute
99
from pydantic import BaseModel, ValidationError
1010

11+
import codegate.muxing.models as mux_models
1112
from codegate import __version__
1213
from codegate.api import v1_models, v1_processing
1314
from codegate.db.connection import AlreadyExistsError, DbReader
@@ -477,7 +478,7 @@ async def delete_workspace_custom_instructions(workspace_name: str):
477478
)
478479
async def get_workspace_muxes(
479480
workspace_name: str,
480-
) -> List[v1_models.MuxRule]:
481+
) -> List[mux_models.MuxRule]:
481482
"""Get the mux rules of a workspace.
482483
483484
The list is ordered in order of priority. That is, the first rule in the list
@@ -501,7 +502,7 @@ async def get_workspace_muxes(
501502
)
502503
async def set_workspace_muxes(
503504
workspace_name: str,
504-
request: List[v1_models.MuxRule],
505+
request: List[mux_models.MuxRule],
505506
):
506507
"""Set the mux rules of a workspace."""
507508
try:

src/codegate/api/v1_models.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -267,26 +267,3 @@ class ModelByProvider(pydantic.BaseModel):
267267

268268
def __str__(self):
269269
return f"{self.provider_name} / {self.name}"
270-
271-
272-
class MuxMatcherType(str, Enum):
273-
"""
274-
Represents the different types of matchers we support.
275-
"""
276-
277-
# Always match this prompt
278-
catch_all = "catch_all"
279-
280-
281-
class MuxRule(pydantic.BaseModel):
282-
"""
283-
Represents a mux rule for a provider.
284-
"""
285-
286-
provider_id: str
287-
model: str
288-
# The type of matcher to use
289-
matcher_type: MuxMatcherType
290-
# The actual matcher to use. Note that
291-
# this depends on the matcher type.
292-
matcher: Optional[str] = None

src/codegate/muxing/models.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from enum import Enum
2+
from typing import Optional
3+
4+
import pydantic
5+
6+
7+
class MuxMatcherType(str, Enum):
8+
"""
9+
Represents the different types of matchers we support.
10+
"""
11+
12+
# Always match this prompt
13+
catch_all = "catch_all"
14+
15+
16+
class MuxRule(pydantic.BaseModel):
17+
"""
18+
Represents a mux rule for a provider.
19+
"""
20+
21+
provider_id: str
22+
model: str
23+
# The type of matcher to use
24+
matcher_type: MuxMatcherType
25+
# The actual matcher to use. Note that
26+
# this depends on the matcher type.
27+
matcher: Optional[str] = None

src/codegate/workspaces/crud.py

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,9 @@
22
from typing import List, Optional, Tuple
33
from uuid import uuid4 as uuid
44

5+
from codegate.db import models as db_models
56
from codegate.db.connection import DbReader, DbRecorder
6-
from codegate.db.models import (
7-
ActiveWorkspace,
8-
MuxRule,
9-
Session,
10-
WorkspaceRow,
11-
WorkspaceWithSessionInfo,
12-
)
7+
from codegate.muxing import models as mux_models
138
from codegate.muxing import rulematcher
149

1510

@@ -40,7 +35,7 @@ class WorkspaceCrud:
4035
def __init__(self):
4136
self._db_reader = DbReader()
4237

43-
async def add_workspace(self, new_workspace_name: str) -> WorkspaceRow:
38+
async def add_workspace(self, new_workspace_name: str) -> db_models.WorkspaceRow:
4439
"""
4540
Add a workspace
4641
@@ -57,7 +52,7 @@ async def add_workspace(self, new_workspace_name: str) -> WorkspaceRow:
5752

5853
async def rename_workspace(
5954
self, old_workspace_name: str, new_workspace_name: str
60-
) -> WorkspaceRow:
55+
) -> db_models.WorkspaceRow:
6156
"""
6257
Rename a workspace
6358
@@ -79,33 +74,33 @@ async def rename_workspace(
7974
if not ws:
8075
raise WorkspaceDoesNotExistError(f"Workspace {old_workspace_name} does not exist.")
8176
db_recorder = DbRecorder()
82-
new_ws = WorkspaceRow(
77+
new_ws = db_models.WorkspaceRow(
8378
id=ws.id, name=new_workspace_name, custom_instructions=ws.custom_instructions
8479
)
8580
workspace_renamed = await db_recorder.update_workspace(new_ws)
8681
return workspace_renamed
8782

88-
async def get_workspaces(self) -> List[WorkspaceWithSessionInfo]:
83+
async def get_workspaces(self) -> List[db_models.WorkspaceWithSessionInfo]:
8984
"""
9085
Get all workspaces
9186
"""
9287
return await self._db_reader.get_workspaces()
9388

94-
async def get_archived_workspaces(self) -> List[WorkspaceRow]:
89+
async def get_archived_workspaces(self) -> List[db_models.WorkspaceRow]:
9590
"""
9691
Get all archived workspaces
9792
"""
9893
return await self._db_reader.get_archived_workspaces()
9994

100-
async def get_active_workspace(self) -> Optional[ActiveWorkspace]:
95+
async def get_active_workspace(self) -> Optional[db_models.ActiveWorkspace]:
10196
"""
10297
Get the active workspace
10398
"""
10499
return await self._db_reader.get_active_workspace()
105100

106101
async def _is_workspace_active(
107102
self, workspace_name: str
108-
) -> Tuple[bool, Optional[Session], Optional[WorkspaceRow]]:
103+
) -> Tuple[bool, Optional[db_models.Session], Optional[db_models.WorkspaceRow]]:
109104
"""
110105
Check if the workspace is active alongside the session and workspace objects
111106
"""
@@ -155,13 +150,13 @@ async def recover_workspace(self, workspace_name: str):
155150

156151
async def update_workspace_custom_instructions(
157152
self, workspace_name: str, custom_instr_lst: List[str]
158-
) -> WorkspaceRow:
153+
) -> db_models.WorkspaceRow:
159154
selected_workspace = await self._db_reader.get_workspace_by_name(workspace_name)
160155
if not selected_workspace:
161156
raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.")
162157

163158
custom_instructions = " ".join(custom_instr_lst)
164-
workspace_update = WorkspaceRow(
159+
workspace_update = db_models.WorkspaceRow(
165160
id=selected_workspace.id,
166161
name=selected_workspace.name,
167162
custom_instructions=custom_instructions,
@@ -217,17 +212,13 @@ async def hard_delete_workspace(self, workspace_name: str):
217212
raise WorkspaceCrudError(f"Error deleting workspace {workspace_name}")
218213
return
219214

220-
async def get_workspace_by_name(self, workspace_name: str) -> WorkspaceRow:
215+
async def get_workspace_by_name(self, workspace_name: str) -> db_models.WorkspaceRow:
221216
workspace = await self._db_reader.get_workspace_by_name(workspace_name)
222217
if not workspace:
223218
raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.")
224219
return workspace
225220

226-
# Can't use type hints since the models are not yet defined
227-
# Note that I'm explicitly importing the models here to avoid circular imports.
228-
async def get_muxes(self, workspace_name: str):
229-
from codegate.api import v1_models
230-
221+
async def get_muxes(self, workspace_name: str) -> List[mux_models.MuxRule]:
231222
# Verify if workspace exists
232223
workspace = await self._db_reader.get_workspace_by_name(workspace_name)
233224
if not workspace:
@@ -239,7 +230,7 @@ async def get_muxes(self, workspace_name: str):
239230
# These are already sorted by priority
240231
for dbmux in dbmuxes:
241232
muxes.append(
242-
v1_models.MuxRule(
233+
mux_models.MuxRule(
243234
provider_id=dbmux.provider_endpoint_id,
244235
model=dbmux.provider_model_name,
245236
matcher_type=dbmux.matcher_type,
@@ -249,10 +240,7 @@ async def get_muxes(self, workspace_name: str):
249240

250241
return muxes
251242

252-
# Can't use type hints since the models are not yet defined
253-
async def set_muxes(self, workspace_name: str, muxes):
254-
from codegate.api import v1_models
255-
243+
async def set_muxes(self, workspace_name: str, muxes: mux_models.MuxRule) -> None:
256244
# Verify if workspace exists
257245
workspace = await self._db_reader.get_workspace_by_name(workspace_name)
258246
if not workspace:
@@ -265,7 +253,7 @@ async def set_muxes(self, workspace_name: str, muxes):
265253
# Add the new muxes
266254
priority = 0
267255

268-
muxes_with_routes: List[Tuple[v1_models.MuxRule, rulematcher.ModelRoute]] = []
256+
muxes_with_routes: List[Tuple[mux_models.MuxRule, rulematcher.ModelRoute]] = []
269257

270258
# Verify all models are valid
271259
for mux in muxes:
@@ -275,7 +263,7 @@ async def set_muxes(self, workspace_name: str, muxes):
275263
matchers: List[rulematcher.MuxingRuleMatcher] = []
276264

277265
for mux, route in muxes_with_routes:
278-
new_mux = MuxRule(
266+
new_mux = db_models.MuxRule(
279267
id=str(uuid()),
280268
provider_endpoint_id=mux.provider_id,
281269
provider_model_name=mux.model,
@@ -294,7 +282,7 @@ async def set_muxes(self, workspace_name: str, muxes):
294282
mux_registry = await rulematcher.get_muxing_rules_registry()
295283
await mux_registry.set_ws_rules(workspace_name, matchers)
296284

297-
async def get_routing_for_mux(self, mux) -> rulematcher.ModelRoute:
285+
async def get_routing_for_mux(self, mux: mux_models.MuxRule) -> rulematcher.ModelRoute:
298286
"""Get the routing for a mux
299287
300288
Note that this particular mux object is the API model, not the database model.
@@ -322,7 +310,7 @@ async def get_routing_for_mux(self, mux) -> rulematcher.ModelRoute:
322310
auth_material=dbauth,
323311
)
324312

325-
async def get_routing_for_db_mux(self, mux: MuxRule) -> rulematcher.ModelRoute:
313+
async def get_routing_for_db_mux(self, mux: db_models.MuxRule) -> rulematcher.ModelRoute:
326314
"""Get the routing for a mux
327315
328316
Note that this particular mux object is the database model, not the API model.

0 commit comments

Comments
 (0)