2
2
from typing import List , Optional , Tuple
3
3
from uuid import uuid4 as uuid
4
4
5
+ from codegate .db import models as db_models
5
6
from codegate .db .connection import DbReader , DbRecorder
6
- from codegate .db .models import (
7
- ActiveWorkspace ,
8
- MuxRule ,
9
- Session ,
10
- WorkspaceRow ,
11
- WorkspaceWithSessionInfo ,
12
- )
7
+ from codegate .muxing import models as mux_models
13
8
from codegate .muxing import rulematcher
14
9
15
10
@@ -40,7 +35,7 @@ class WorkspaceCrud:
40
35
def __init__ (self ):
41
36
self ._db_reader = DbReader ()
42
37
43
- async def add_workspace (self , new_workspace_name : str ) -> WorkspaceRow :
38
+ async def add_workspace (self , new_workspace_name : str ) -> db_models . WorkspaceRow :
44
39
"""
45
40
Add a workspace
46
41
@@ -57,7 +52,7 @@ async def add_workspace(self, new_workspace_name: str) -> WorkspaceRow:
57
52
58
53
async def rename_workspace (
59
54
self , old_workspace_name : str , new_workspace_name : str
60
- ) -> WorkspaceRow :
55
+ ) -> db_models . WorkspaceRow :
61
56
"""
62
57
Rename a workspace
63
58
@@ -79,33 +74,33 @@ async def rename_workspace(
79
74
if not ws :
80
75
raise WorkspaceDoesNotExistError (f"Workspace { old_workspace_name } does not exist." )
81
76
db_recorder = DbRecorder ()
82
- new_ws = WorkspaceRow (
77
+ new_ws = db_models . WorkspaceRow (
83
78
id = ws .id , name = new_workspace_name , custom_instructions = ws .custom_instructions
84
79
)
85
80
workspace_renamed = await db_recorder .update_workspace (new_ws )
86
81
return workspace_renamed
87
82
88
- async def get_workspaces (self ) -> List [WorkspaceWithSessionInfo ]:
83
+ async def get_workspaces (self ) -> List [db_models . WorkspaceWithSessionInfo ]:
89
84
"""
90
85
Get all workspaces
91
86
"""
92
87
return await self ._db_reader .get_workspaces ()
93
88
94
- async def get_archived_workspaces (self ) -> List [WorkspaceRow ]:
89
+ async def get_archived_workspaces (self ) -> List [db_models . WorkspaceRow ]:
95
90
"""
96
91
Get all archived workspaces
97
92
"""
98
93
return await self ._db_reader .get_archived_workspaces ()
99
94
100
- async def get_active_workspace (self ) -> Optional [ActiveWorkspace ]:
95
+ async def get_active_workspace (self ) -> Optional [db_models . ActiveWorkspace ]:
101
96
"""
102
97
Get the active workspace
103
98
"""
104
99
return await self ._db_reader .get_active_workspace ()
105
100
106
101
async def _is_workspace_active (
107
102
self , workspace_name : str
108
- ) -> Tuple [bool , Optional [Session ], Optional [WorkspaceRow ]]:
103
+ ) -> Tuple [bool , Optional [db_models . Session ], Optional [db_models . WorkspaceRow ]]:
109
104
"""
110
105
Check if the workspace is active alongside the session and workspace objects
111
106
"""
@@ -155,13 +150,13 @@ async def recover_workspace(self, workspace_name: str):
155
150
156
151
async def update_workspace_custom_instructions (
157
152
self , workspace_name : str , custom_instr_lst : List [str ]
158
- ) -> WorkspaceRow :
153
+ ) -> db_models . WorkspaceRow :
159
154
selected_workspace = await self ._db_reader .get_workspace_by_name (workspace_name )
160
155
if not selected_workspace :
161
156
raise WorkspaceDoesNotExistError (f"Workspace { workspace_name } does not exist." )
162
157
163
158
custom_instructions = " " .join (custom_instr_lst )
164
- workspace_update = WorkspaceRow (
159
+ workspace_update = db_models . WorkspaceRow (
165
160
id = selected_workspace .id ,
166
161
name = selected_workspace .name ,
167
162
custom_instructions = custom_instructions ,
@@ -217,17 +212,13 @@ async def hard_delete_workspace(self, workspace_name: str):
217
212
raise WorkspaceCrudError (f"Error deleting workspace { workspace_name } " )
218
213
return
219
214
220
- async def get_workspace_by_name (self , workspace_name : str ) -> WorkspaceRow :
215
+ async def get_workspace_by_name (self , workspace_name : str ) -> db_models . WorkspaceRow :
221
216
workspace = await self ._db_reader .get_workspace_by_name (workspace_name )
222
217
if not workspace :
223
218
raise WorkspaceDoesNotExistError (f"Workspace { workspace_name } does not exist." )
224
219
return workspace
225
220
226
- # Can't use type hints since the models are not yet defined
227
- # Note that I'm explicitly importing the models here to avoid circular imports.
228
- async def get_muxes (self , workspace_name : str ):
229
- from codegate .api import v1_models
230
-
221
+ async def get_muxes (self , workspace_name : str ) -> List [mux_models .MuxRule ]:
231
222
# Verify if workspace exists
232
223
workspace = await self ._db_reader .get_workspace_by_name (workspace_name )
233
224
if not workspace :
@@ -239,7 +230,7 @@ async def get_muxes(self, workspace_name: str):
239
230
# These are already sorted by priority
240
231
for dbmux in dbmuxes :
241
232
muxes .append (
242
- v1_models .MuxRule (
233
+ mux_models .MuxRule (
243
234
provider_id = dbmux .provider_endpoint_id ,
244
235
model = dbmux .provider_model_name ,
245
236
matcher_type = dbmux .matcher_type ,
@@ -249,10 +240,7 @@ async def get_muxes(self, workspace_name: str):
249
240
250
241
return muxes
251
242
252
- # Can't use type hints since the models are not yet defined
253
- async def set_muxes (self , workspace_name : str , muxes ):
254
- from codegate .api import v1_models
255
-
243
+ async def set_muxes (self , workspace_name : str , muxes : mux_models .MuxRule ) -> None :
256
244
# Verify if workspace exists
257
245
workspace = await self ._db_reader .get_workspace_by_name (workspace_name )
258
246
if not workspace :
@@ -265,7 +253,7 @@ async def set_muxes(self, workspace_name: str, muxes):
265
253
# Add the new muxes
266
254
priority = 0
267
255
268
- muxes_with_routes : List [Tuple [v1_models .MuxRule , rulematcher .ModelRoute ]] = []
256
+ muxes_with_routes : List [Tuple [mux_models .MuxRule , rulematcher .ModelRoute ]] = []
269
257
270
258
# Verify all models are valid
271
259
for mux in muxes :
@@ -275,7 +263,7 @@ async def set_muxes(self, workspace_name: str, muxes):
275
263
matchers : List [rulematcher .MuxingRuleMatcher ] = []
276
264
277
265
for mux , route in muxes_with_routes :
278
- new_mux = MuxRule (
266
+ new_mux = db_models . MuxRule (
279
267
id = str (uuid ()),
280
268
provider_endpoint_id = mux .provider_id ,
281
269
provider_model_name = mux .model ,
@@ -294,7 +282,7 @@ async def set_muxes(self, workspace_name: str, muxes):
294
282
mux_registry = await rulematcher .get_muxing_rules_registry ()
295
283
await mux_registry .set_ws_rules (workspace_name , matchers )
296
284
297
- async def get_routing_for_mux (self , mux ) -> rulematcher .ModelRoute :
285
+ async def get_routing_for_mux (self , mux : mux_models . MuxRule ) -> rulematcher .ModelRoute :
298
286
"""Get the routing for a mux
299
287
300
288
Note that this particular mux object is the API model, not the database model.
@@ -322,7 +310,7 @@ async def get_routing_for_mux(self, mux) -> rulematcher.ModelRoute:
322
310
auth_material = dbauth ,
323
311
)
324
312
325
- async def get_routing_for_db_mux (self , mux : MuxRule ) -> rulematcher .ModelRoute :
313
+ async def get_routing_for_db_mux (self , mux : db_models . MuxRule ) -> rulematcher .ModelRoute :
326
314
"""Get the routing for a mux
327
315
328
316
Note that this particular mux object is the database model, not the API model.
0 commit comments